Python sklearn.model_selection 模块,StratifiedShuffleSplit() 实例源码

我们从Python开源项目中,提取了以下25个代码示例,用于说明如何使用sklearn.model_selection.StratifiedShuffleSplit()

项目:dac-training    作者:jlonij    | 项目源码 | 文件源码
def validate(data, labels):
    '''
    Ten-fold cross-validation with stratified sampling.
    '''
    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    sss = StratifiedShuffleSplit(n_splits=10)
    for train_index, test_index in sss.split(data, labels):
        x_train, x_test = data[train_index], data[test_index]
        y_train, y_test = labels[train_index], labels[test_index]
        clf.fit(x_train, y_train)
        y_pred = clf.predict(x_test)
        accuracy_scores.append(accuracy_score(y_test, y_pred))
        precision_scores.append(precision_score(y_test, y_pred))
        recall_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred))

    print('Accuracy', np.mean(accuracy_scores))
    print('Precision', np.mean(precision_scores))
    print('Recall', np.mean(recall_scores))
    print('F1-measure', np.mean(f1_scores))
项目:tencent_social_algo    作者:Folieshell    | 项目源码 | 文件源码
def check_log_loss(max_depth, n_splits, test_size):
    model = RandomForestClassifier(max_depth=max_depth, n_jobs=-1, random_state=777)
    trn_scores = []
    vld_scores = []
    sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=777)
    for i, (t_ind, v_ind) in enumerate(sss.split(feature_train, trainY)):
        print('# Iter {} / {}'.format(i + 1, n_splits))
        x_trn = feature_train.values[t_ind]
        y_trn = trainY[t_ind]
        x_vld = feature_train.values[v_ind]
        y_vld = trainY[v_ind]

        model.fit(x_trn, y_trn)

        score = log_loss(y_trn, model.predict_proba(x_trn))
        trn_scores.append(score)

        score = log_loss(y_vld, model.predict_proba(x_vld))
        vld_scores.append(score)

    print("max_depth: %d   n_splits: %d    test_size: %f" % (max_depth, n_splits, test_size))
    print('# TRN logloss: {}'.format(np.mean(trn_scores)))
    print('# VLD logloss: {}'.format(np.mean(vld_scores)))
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_stratified_shuffle_split_iter():
    ys = [np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
          np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
          np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2]),
          np.array([1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4]),
          np.array([-1] * 800 + [1] * 50)
          ]

    for y in ys:
        sss = StratifiedShuffleSplit(6, test_size=0.33,
                                     random_state=0).split(np.ones(len(y)), y)
        for train, test in sss:
            assert_array_equal(np.unique(y[train]), np.unique(y[test]))
            # Checks if folds keep classes proportions
            p_train = (np.bincount(np.unique(y[train],
                                   return_inverse=True)[1]) /
                       float(len(y[train])))
            p_test = (np.bincount(np.unique(y[test],
                                  return_inverse=True)[1]) /
                      float(len(y[test])))
            assert_array_almost_equal(p_train, p_test, 1)
            assert_equal(y[train].size + y[test].size, y.size)
            assert_array_equal(np.lib.arraysetops.intersect1d(train, test), [])
项目:website-fingerprinting    作者:AxelGoetz    | 项目源码 | 文件源码
def main(_):
    paths, labels = None, None
    dirname, _ = ospath.split(ospath.abspath(__file__))

    try:
        data_dir = dirname + '/../../data/cells'
        paths, labels = import_data(data_dir=data_dir, in_memory=False, extension=args.extension)

        monitored_data, monitored_label, unmonitored_data = split_mon_unmon(paths, labels)
        monitored_data, monitored_label, unmonitored_data = np.array(monitored_data), np.array(monitored_label), np.array(unmonitored_data)

        helpers.shuffle_data(unmonitored_data)
        unmon_train, unmon_test = unmonitored_data[:int((1 - TEST_SIZE) * len(unmonitored_data))], unmonitored_data[int((1 - TEST_SIZE) * len(unmonitored_data)):]

        sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=123)
        sss.get_n_splits(monitored_data, monitored_label)

        for train_index, test_index in sss.split(monitored_data, monitored_label):
            X_train, X_test = monitored_data[train_index], monitored_data[test_index]
            y_train, y_test = monitored_label[train_index], monitored_label[test_index]

            X_train = np.append(X_train, unmon_train)
            X_test = np.append(X_test, unmon_test)

            y_train = np.append(y_train, [-1] * len(unmon_train))
            y_test = np.append(y_test, [-1] * len(unmon_train))

            store_data(X_test, 'X_test')
            store_data(y_test, 'y_test')

            stdout.write("Training on data...\n")
            run_model(X_train, in_memory=False)
            stdout.write("Finished running model.")
            break

    except KeyboardInterrupt:
        stdout.write("Interrupted, this might take a while...\n")
        exit(0)
项目:website-fingerprinting    作者:AxelGoetz    | 项目源码 | 文件源码
def main(_):
    paths, labels = None, None
    dirname, _ = ospath.split(ospath.abspath(__file__))

    try:
        data_dir = dirname + '/../../data/cells'
        paths, labels = import_data(data_dir=data_dir, in_memory=False, extension=args.extension)

        monitored_data, monitored_label, unmonitored_data = split_mon_unmon(paths, labels)
        monitored_data, monitored_label, unmonitored_data = np.array(monitored_data), np.array(monitored_label), np.array(unmonitored_data)

        helpers.shuffle_data(unmonitored_data)
        unmon_train, unmon_test = unmonitored_data[:int((1 - TEST_SIZE) * len(unmonitored_data))], unmonitored_data[int((1 - TEST_SIZE) * len(unmonitored_data)):]

        sss = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=123)
        sss.get_n_splits(monitored_data, monitored_label)

        for train_index, test_index in sss.split(monitored_data, monitored_label):
            X_train, X_test = monitored_data[train_index], monitored_data[test_index]
            y_train, y_test = monitored_label[train_index], monitored_label[test_index]

            X_train = np.append(X_train, unmon_train)
            X_test = np.append(X_test, unmon_test)

            y_train = np.append(y_train, [-1] * len(unmon_train))
            y_test = np.append(y_test, [-1] * len(unmon_train))

            store_data(X_test, 'X_test')
            store_data(y_test, 'y_test')

            stdout.write("Training on data...\n")
            run_model(X_train, in_memory=False)
            stdout.write("Finished running model.")
            break

    except KeyboardInterrupt:
        stdout.write("Interrupted, this might take a while...\n")
        exit(0)
项目:keras-text    作者:raghakot    | 项目源码 | 文件源码
def update_test_indices(self, test_size=0.1):
        """Updates `test_indices` property with indices of `test_size` proportion.

        Args:
            test_size: The test proportion in [0, 1] (Default value: 0.1)
        """
        if self.is_multi_label:
            self._train_indices, self._test_indices = sampling.multi_label_train_test_split(self.y, test_size)
        else:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
            self._train_indices, self._test_indices = next(sss.split(self.X, self.y))
项目:keras-text    作者:raghakot    | 项目源码 | 文件源码
def train_val_split(self, split_ratio=0.1):
        """Generates train and validation sets from the training indices.

        Args:
            split_ratio: The split proportion in [0, 1] (Default value: 0.1)

        Returns:
            The stratified train and val subsets. Multi-label outputs are handled as well.
        """
        if self.is_multi_label:
            train_indices, val_indices = sampling.multi_label_train_test_split(self.y, split_ratio)
        else:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=split_ratio)
            train_indices, val_indices = next(sss.split(self.X, self.y))
        return self.X[train_indices], self.X[val_indices], self.y[train_indices], self.y[val_indices]
项目:MLClass    作者:bm2-lab    | 项目源码 | 文件源码
def split_testing_data_c(y):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2)
    tri = None
    tei = None
    for itr, ite in sss.split(np.zeros(len(y)), y):
        tri = itr
        tei = ite
    return tri, tei
项目:pydl    作者:rafaeltg    | 项目源码 | 文件源码
def get_cv_method(method, **kwargs):

    if method == 'kfold':
        return KFold(**kwargs)
    elif method == 'skfold':
        return StratifiedKFold(**kwargs)
    elif method == 'loo':
        return LeaveOneOut()
    elif method == 'shuffle_split':
        return ShuffleSplit(**kwargs)
    elif method == 'split':
        return TrainTestSplit(**kwargs)
    elif method == 's_shuffle_split':
        return StratifiedShuffleSplit(**kwargs)
    elif method == 'time_series':
        return TimeSeriesSplit(**kwargs)
    else:
        raise AttributeError('Invalid CV method - %s!' % method)
项目:thesis    作者:jonvet    | 项目源码 | 文件源码
def balance_data(data):

    lengths = [len(s.split(' ')) for s in data]
    data = data[np.array(lengths)<=70]
    lengths = [len(s.split(' ')) for s in data]
    bins = np.array([0, 5, 8, 12, 17, 21, 26, 70])
    share_dev = 0.05
    labels = np.digitize(lengths, bins) - np.ones_like(lengths)
    sss = StratifiedShuffleSplit(n_splits=2, test_size=0.05, random_state=0)
    for train_index, test_index in sss.split(lengths, labels):
        X_train, X_test = data[train_index], data[test_index]
    return X_train, X_test
项目:Spam-Message-Classifier-sklearn    作者:ZPdesu    | 项目源码 | 文件源码
def learn_best_param(self):
        C_range = np.logspace(-2, 10, 13)
        param_grid = dict(C=C_range)
        cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
        grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
        grid.fit(self.training_data, self.training_target)
        self.clf.set_params(C=grid.best_params_['C'])
        print("The best parameters are %s with a score of %0.2f"
              % (grid.best_params_, grid.best_score_))
项目:Spam-Message-Classifier-sklearn    作者:ZPdesu    | 项目源码 | 文件源码
def learn_best_param(self):
        C_range = np.logspace(-2, 10, 13)
        gamma_range = np.logspace(-9, 3, 13)
        param_grid = dict(gamma=gamma_range, C=C_range)
        cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
        grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
        grid.fit(self.training_data, self.training_target)
        self.clf.set_params(C=grid.best_params_['C'], gamma=grid.best_params_['gamma'])
        print("The best parameters are %s with a score of %0.2f"
              % (grid.best_params_, grid.best_score_))
        self.draw_visualization_param_effect(grid, C_range, gamma_range)
项目:dac-training    作者:jlonij    | 项目源码 | 文件源码
def validate(data, labels):
    '''
    Ten-fold cross-validation with stratified sampling.
    '''
    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    sss = StratifiedShuffleSplit(n_splits=10)
    for train_index, test_index in sss.split(data, labels):
        x_train, x_test = data[train_index], data[test_index]
        y_train, y_test = labels[train_index], labels[test_index]

        model = load_model(data)
        model.fit(x_train, y_train, epochs=100, batch_size=128,
            class_weight=class_weight)
        y_pred = model.predict_classes(x_test, batch_size=128)

        accuracy_scores.append(accuracy_score(y_test, y_pred))
        precision_scores.append(precision_score(y_test, y_pred))
        recall_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred))

    print('')
    print('Accuracy', np.mean(accuracy_scores))
    print('Precision', np.mean(precision_scores))
    print('Recall', np.mean(recall_scores))
    print('F1-measure', np.mean(f1_scores))
项目:dac-training    作者:jlonij    | 项目源码 | 文件源码
def validate(data, labels):
    '''
    Ten-fold cross-validation with stratified sampling.
    '''
    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    sss = StratifiedShuffleSplit(n_splits=10)

    for train_index, test_index in sss.split(data[0], labels):
        x_train_0, x_test_0 = data[0][train_index], data[0][test_index]
        x_train_1, x_test_1 = data[1][train_index], data[1][test_index]
        x_train_2, x_test_2 = data[2][train_index], data[2][test_index]

        y_train, y_test = labels[train_index], labels[test_index]

        model = create_model(data)
        model.fit([x_train_0, x_train_1, x_train_2], y_train,
            epochs=100, batch_size=128, class_weight=class_weight)
        #y_pred = model.predict_classes(x_test, batch_size=128)

        y_pred = model.predict([x_test_0, x_test_1, x_test_2], batch_size=128)
        y_pred = [1 if y[0] > 0.5 else 0 for y in y_pred]

        accuracy_scores.append(accuracy_score(y_test, y_pred))
        precision_scores.append(precision_score(y_test, y_pred))
        recall_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred))

    print('')
    print('Accuracy', np.mean(accuracy_scores))
    print('Precision', np.mean(precision_scores))
    print('Recall', np.mean(recall_scores))
    print('F1-measure', np.mean(f1_scores))
项目:planet-amazon-deforestation    作者:EKami    | 项目源码 | 文件源码
def _get_validation_split(self):
        train = pd.read_csv(self.train_csv_file)
        # mapping labels to integer classes
        flatten = lambda l: [item for sublist in l for item in sublist]
        labels = list(set(flatten([l.split(' ') for l in train['tags'].values])))
        label_map = {l: i for i, l in enumerate(labels)}

        y_train = []
        for f,tags in (train.values):
            targets = np.zeros(len(label_map))
            for t in tags.split(' '):
                targets[label_map[t]] = 1
            y_train.append(targets)

        y_train = np.array(y_train, np.uint8)
        trn_index = []
        val_index = []
        index = np.arange(len(train))
        for i in (range(len(label_map))):
            sss = StratifiedShuffleSplit(n_splits=2, test_size=self.validation_split, random_state=i)
            for train_index, test_index in sss.split(index,y_train[:,i]):
                X_train, X_test = index[train_index], index[test_index]
            # to ensure there is no repetetion within each split and between the splits
            trn_index = trn_index + list(set(X_train) - set(trn_index) - set(val_index))
            val_index = val_index + list(set(X_test) - set(val_index) - set(trn_index))
        return np.array(trn_index), np.array(val_index)
项目:torchsample    作者:ncullen93    | 项目源码 | 文件源码
def gen_sample_array(self):
        try:
            from sklearn.model_selection import StratifiedShuffleSplit
        except:
            print('Need scikit-learn for this functionality')
        import numpy as np

        s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5)
        X = th.randn(self.class_vector.size(0),2).numpy()
        y = self.class_vector.numpy()
        s.get_n_splits(X, y)

        train_index, test_index = next(s.split(X, y))
        return np.hstack([train_index, test_index])
项目:genre-classifier    作者:jlonij    | 项目源码 | 文件源码
def validate():
    # Load an existing training set
    X_train, y_train = dataset.load_training('data/training.txt')

    # Ten-fold cross-validation with stratified sampling
    cv = StratifiedShuffleSplit(n_splits=10)
    scores = cross_val_score(clf, X_train, y_train, cv=cv)
    print("Accuracy: %0.4f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
项目:skorch    作者:dnouri    | 项目源码 | 文件源码
def _is_stratified(self, cv):
        return isinstance(cv, (StratifiedKFold, StratifiedShuffleSplit))
项目:skorch    作者:dnouri    | 项目源码 | 文件源码
def _check_cv_float(self):
        cv_cls = StratifiedShuffleSplit if self.stratified else ShuffleSplit
        return cv_cls(test_size=self.cv, random_state=self.random_state)
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_stratified_shuffle_split_init():
    X = np.arange(7)
    y = np.asarray([0, 1, 1, 1, 2, 2, 2])
    # Check that error is raised if there is a class with only one sample
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(3, 0.2).split(X, y))

    # Check that error is raised if the test set size is smaller than n_classes
    assert_raises(ValueError, next, StratifiedShuffleSplit(3, 2).split(X, y))
    # Check that error is raised if the train set size is smaller than
    # n_classes
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(3, 3, 2).split(X, y))

    X = np.arange(9)
    y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2])
    # Check that errors are raised if there is not enough samples
    assert_raises(ValueError, StratifiedShuffleSplit, 3, 0.5, 0.6)
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(3, 8, 0.6).split(X, y))
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(3, 0.6, 8).split(X, y))

    # Train size or test size too small
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(train_size=2).split(X, y))
    assert_raises(ValueError, next,
                  StratifiedShuffleSplit(test_size=2).split(X, y))
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_stratified_shuffle_split_overlap_train_test_bug():
    # See https://github.com/scikit-learn/scikit-learn/issues/6121 for
    # the original bug report
    y = [0, 1, 2, 3] * 3 + [4, 5] * 5
    X = np.ones_like(y)

    splits = StratifiedShuffleSplit(n_iter=1,
                                    test_size=0.5, random_state=0)

    train, test = next(iter(splits.split(X=X, y=y)))

    assert_array_equal(np.intersect1d(train, test), [])
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_nested_cv():
    # Test if nested cross validation works with different combinations of cv
    rng = np.random.RandomState(0)

    X, y = make_classification(n_samples=15, n_classes=2, random_state=0)
    labels = rng.randint(0, 5, 15)

    cvs = [LeaveOneLabelOut(), LeaveOneOut(), LabelKFold(), StratifiedKFold(),
           StratifiedShuffleSplit(n_iter=3, random_state=0)]

    for inner_cv, outer_cv in combinations_with_replacement(cvs, 2):
        gs = GridSearchCV(Ridge(), param_grid={'alpha': [1, .1]},
                          cv=inner_cv)
        cross_val_score(gs, X=X, y=y, labels=labels, cv=outer_cv,
                        fit_params={'labels': labels})
项目:palladio    作者:slipguru    | 项目源码 | 文件源码
def _check_cv(cv=3, y=None, classifier=False, **kwargs):
    """Input checker utility for building a cross-validator.

    Parameters
    ----------
    cv : int, cross-validation generator or an iterable, optional
        Determines the cross-validation splitting strategy.
        Possible inputs for cv are:
          - None, to use the default 3-fold cross-validation,
          - integer, to specify the number of folds.
          - An object to be used as a cross-validation generator.
          - An iterable yielding train/test splits.

        For integer/None inputs, if classifier is True and ``y`` is either
        binary or multiclass, :class:`StratifiedKFold` is used. In all other
        cases, :class:`KFold` is used.

        Refer :ref:`User Guide <cross_validation>` for the various
        cross-validation strategies that can be used here.

    y : array-like, optional
        The target variable for supervised learning problems.

    classifier : boolean, optional, default False
        Whether the task is a classification task, in which case
        stratified KFold will be used.

    kwargs : dict
        Other parameters for StratifiedShuffleSplit or ShuffleSplit.

    Returns
    -------
    checked_cv : a cross-validator instance.
        The return value is a cross-validator which generates the train/test
        splits via the ``split`` method.
    """
    if cv is None:
        cv = kwargs.pop('n_splits', 0) or 10

    if isinstance(cv, numbers.Integral):
        if (classifier and (y is not None) and
                (type_of_target(y) in ('binary', 'multiclass'))):
            return StratifiedShuffleSplit(cv, **kwargs)
        else:
            return ShuffleSplit(cv, **kwargs)

    if not hasattr(cv, 'split') or isinstance(cv, str):
        if not isinstance(cv, Iterable) or isinstance(cv, str):
            raise ValueError("Expected cv as an integer, cross-validation "
                             "object (from sklearn.model_selection) "
                             "or an iterable. Got %s." % cv)
        return _CVIterableWrapper(cv)

    return cv  # New style cv objects are passed without any modification
项目:qtim_ROP    作者:QTIM-Lab    | 项目源码 | 文件源码
def random_forest(df_features, df_ground_truth, out_dir, n_splits=5):

    X = df_features.as_matrix()
    y_true = np.argmax(df_ground_truth.as_matrix(), axis=1)

    print "~~ Class distribution ~~"
    for k, v in sorted(CLASS_LABELS.items(), key=lambda x: x[1]):
        print "{}: {:.2f}%".format(k.capitalize(), (len(y_true[y_true == v]) / float(len(y_true))) * 100)

    # Use stratified k-fold cross-validation
    skf = StratifiedShuffleSplit(n_splits=5, test_size=.2)

    auc_results = []
    for i, (train, test) in enumerate(skf.split(X, y_true)):

        X_train, y_train = X[train, :], y_true[train]
        rf = RandomForestClassifier(class_weight='balanced')
        rf.fit(X_train, y_train)

        X_test, y_test = X[test, :], y_true[test]
        y_pred_prob = rf.predict_proba(X_test)
        auc = roc_auc(y_pred_prob, to_categorical(y_test), dict_reverse(CLASS_LABELS),
                      join(out_dir, 'roc_auc_split_{}.svg'.format(i)))

        auc_results.append(auc['micro'])

    print "\n~~ Average AUC over {} splits ~~\n{}".format(n_splits, np.mean(auc_results))

    #X_train, X_test, y_train, y_test = train_test_split(X, y_true, train_size=.7, test_size=.3, random_state=4)
    #
    # rf.fit(X_train, y_train)
    # # joblib.dump(rf, join(out_dir, 'classifier.pkl'))
    #
    # y_pred = rf.predict(X_test)
    #
    # print classification_report(y_true, y_pred)
    # print confusion_matrix(y_true, y_pred)
    # print accuracy_score(y_true, y_pred)
    #
    # return
    #
    # y_pred_prob = rf.predict_proba(X_test)
    # roc_auc(y_pred_prob, to_categorical(y_test), CLASS_LABELS, join(out_dir, 'roc_auc.svg'))
项目:Parallel-SGD    作者:angadgill    | 项目源码 | 文件源码
def test_stratified_shuffle_split_even():
    # Test the StratifiedShuffleSplit, indices are drawn with a
    # equal chance
    n_folds = 5
    n_iter = 1000

    def assert_counts_are_ok(idx_counts, p):
        # Here we test that the distribution of the counts
        # per index is close enough to a binomial
        threshold = 0.05 / n_splits
        bf = stats.binom(n_splits, p)
        for count in idx_counts:
            p = bf.pmf(count)
            assert_true(p > threshold,
                        "An index is not drawn with chance corresponding "
                        "to even draws")

    for n_samples in (6, 22):
        labels = np.array((n_samples // 2) * [0, 1])
        splits = StratifiedShuffleSplit(n_iter=n_iter,
                                        test_size=1. / n_folds,
                                        random_state=0)

        train_counts = [0] * n_samples
        test_counts = [0] * n_samples
        n_splits = 0
        for train, test in splits.split(X=np.ones(n_samples), y=labels):
            n_splits += 1
            for counter, ids in [(train_counts, train), (test_counts, test)]:
                for id in ids:
                    counter[id] += 1
        assert_equal(n_splits, n_iter)

        n_train, n_test = _validate_shuffle_split(n_samples,
                                                  test_size=1./n_folds,
                                                  train_size=1.-(1./n_folds))

        assert_equal(len(train), n_train)
        assert_equal(len(test), n_test)
        assert_equal(len(set(train).intersection(test)), 0)

        label_counts = np.unique(labels)
        assert_equal(splits.test_size, 1.0 / n_folds)
        assert_equal(n_train + n_test, len(labels))
        assert_equal(len(label_counts), 2)
        ex_test_p = float(n_test) / n_samples
        ex_train_p = float(n_train) / n_samples

        assert_counts_are_ok(train_counts, ex_train_p)
        assert_counts_are_ok(test_counts, ex_test_p)