Python chainer 模块,dataset() 实例源码

我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用chainer.dataset()

项目:convolutional_seq2seq    作者:soskek    | 项目源码 | 文件源码
def __call__(self, trainer):
        print('## Calculate BLEU')
        with chainer.no_backprop_mode():
            with chainer.using_config('train', False):
                references = []
                hypotheses = []
                for i in range(0, len(self.test_data), self.batch):
                    sources, targets = zip(*self.test_data[i:i + self.batch])
                    references.extend([[t.tolist()] for t in targets])

                    sources = [
                        chainer.dataset.to_device(self.device, x) for x in sources]
                    ys = [y.tolist()
                          for y in self.model.translate(sources, self.max_length)]
                    hypotheses.extend(ys)

        bleu = bleu_score.corpus_bleu(
            references, hypotheses,
            smoothing_function=bleu_score.SmoothingFunction().method1) * 100
        print('BLEU:', bleu)
        reporter.report({self.key: bleu})
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def is_dataset(obj):
    """Check if obj is Chainer dataset instance or not"""
    return isinstance(obj, (DictDataset, ImageDataset, LabeledImageDataset,
                            TupleDataset, DatasetMixin))
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def _check_X_y(self, X, y=None):
        """Check type of X and y. 

        It updates the format of X and y (such as dtype, convert sparse matrix 
        to matrix format etc) if necessary. 

        `X` and `y` might be array (numpy.ndarray or sparse matrix) for sklearn
        interface, but `X` might be chainer dataset.

        :param X: chainer dataset type or array
        :param y: None or array
        :return: 
        """
        return X, y
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def fit(self, X, y=None, **kwargs):
        """If hyper parameters are set to None, then instance's variable is used,
        this functionality is used Grid search with `set_params` method.
        Also if instance's variable is not set, _default_hyperparam is used. 

        Usage: model.fit(train_dataset) or model.fit(X, y)

        Args:
            train: training dataset, assumes chainer's dataset class 
            test: test dataset for evaluation, assumes chainer's dataset class
            batchsize: batchsize for both training and evaluation
            iterator_class: iterator class used for this training, 
                            currently assumes SerialIterator or MultiProcessIterator
            optimizer: optimizer instance to update parameter
            epoch: training epoch
            out: directory path to save the result
            snapshot_frequency (int): snapshot frequency in epoch. 
                                Negative value indicates not to take snapshot.
            dump_graph: Save computational graph info or not, default is False.
            log_report: Enable LogReport or not
            plot_report: Enable PlotReport or not
            print_report: Enable PrintReport or not
            progress_report: Enable ProgressReport or not
            resume: specify trainer saved path to resume training.

        """
        kwargs = self.filter_sk_params(self.fit_core, kwargs)
        return self.fit_core(X, y, **kwargs)
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def _check_X_y(self, X, y=None):
        #print('check_X_y', type(X), type(y))
        if not is_dataset(X) and not isinstance(X, list):
            if isinstance(X, numpy.ndarray):
                X = check_array(X, dtype=self._data_x_dtype)
            else:
                print('[WARNING] skip check type for dataset X with type {}'
                      .format(type(X)))
        if y is not None:
            y = check_array(y, dtype=self._data_y_dtype, ensure_2d=False)
        return X, y
项目:chainer_sklearn    作者:corochann    | 项目源码 | 文件源码
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=50,
                        help='Number of units')
    parser.add_argument('--example', '-ex', type=int, default=3,
                        help='Example mode')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    model = SklearnWrapperClassifier(MLP(args.unit, 10), device=args.gpu)

    if args.example == 1:
        print("Example 1. fit with x, y numpy array (same with sklearn's fit)")
        x, y = concat_examples(train)
        model.fit(x, y)
    elif args.example == 2:
        print("Example 2. Train with Chainer's dataset")
        # `train` is TupleDataset in this example
        # Even this one line work! (but no validation)
        model.fit(train)
    else:
        print("Example 3. Train with configuration")
        model.fit(
            train,
            test=test,
            batchsize=args.batchsize,
            #iterator_class=chainer.iterators.SerialIterator,
            optimizer=chainer.optimizers.Adam(),
            epoch=args.epoch,
            out=args.out,
            snapshot_frequency=1,
            #dump_graph=False
            #log_report=True,
            plot_report=False,
            #print_report=True,
            progress_report=False,
            resume=args.resume
        )

    # Save trained model
    serializers.save_npz('{}/mlp.model'.format(args.out), model)