我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用chainer.dataset()。
def __call__(self, trainer): print('## Calculate BLEU') with chainer.no_backprop_mode(): with chainer.using_config('train', False): references = [] hypotheses = [] for i in range(0, len(self.test_data), self.batch): sources, targets = zip(*self.test_data[i:i + self.batch]) references.extend([[t.tolist()] for t in targets]) sources = [ chainer.dataset.to_device(self.device, x) for x in sources] ys = [y.tolist() for y in self.model.translate(sources, self.max_length)] hypotheses.extend(ys) bleu = bleu_score.corpus_bleu( references, hypotheses, smoothing_function=bleu_score.SmoothingFunction().method1) * 100 print('BLEU:', bleu) reporter.report({self.key: bleu})
def is_dataset(obj): """Check if obj is Chainer dataset instance or not""" return isinstance(obj, (DictDataset, ImageDataset, LabeledImageDataset, TupleDataset, DatasetMixin))
def _check_X_y(self, X, y=None): """Check type of X and y. It updates the format of X and y (such as dtype, convert sparse matrix to matrix format etc) if necessary. `X` and `y` might be array (numpy.ndarray or sparse matrix) for sklearn interface, but `X` might be chainer dataset. :param X: chainer dataset type or array :param y: None or array :return: """ return X, y
def fit(self, X, y=None, **kwargs): """If hyper parameters are set to None, then instance's variable is used, this functionality is used Grid search with `set_params` method. Also if instance's variable is not set, _default_hyperparam is used. Usage: model.fit(train_dataset) or model.fit(X, y) Args: train: training dataset, assumes chainer's dataset class test: test dataset for evaluation, assumes chainer's dataset class batchsize: batchsize for both training and evaluation iterator_class: iterator class used for this training, currently assumes SerialIterator or MultiProcessIterator optimizer: optimizer instance to update parameter epoch: training epoch out: directory path to save the result snapshot_frequency (int): snapshot frequency in epoch. Negative value indicates not to take snapshot. dump_graph: Save computational graph info or not, default is False. log_report: Enable LogReport or not plot_report: Enable PlotReport or not print_report: Enable PrintReport or not progress_report: Enable ProgressReport or not resume: specify trainer saved path to resume training. """ kwargs = self.filter_sk_params(self.fit_core, kwargs) return self.fit_core(X, y, **kwargs)
def _check_X_y(self, X, y=None): #print('check_X_y', type(X), type(y)) if not is_dataset(X) and not isinstance(X, list): if isinstance(X, numpy.ndarray): X = check_array(X, dtype=self._data_x_dtype) else: print('[WARNING] skip check type for dataset X with type {}' .format(type(X))) if y is not None: y = check_array(y, dtype=self._data_y_dtype, ensure_2d=False) return X, y
def main(): parser = argparse.ArgumentParser(description='Chainer example: MNIST') parser.add_argument('--batchsize', '-b', type=int, default=100, help='Number of images in each mini-batch') parser.add_argument('--epoch', '-e', type=int, default=20, help='Number of sweeps over the dataset to train') parser.add_argument('--frequency', '-f', type=int, default=-1, help='Frequency of taking a snapshot') parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') parser.add_argument('--out', '-o', default='result', help='Directory to output the result') parser.add_argument('--resume', '-r', default='', help='Resume the training from snapshot') parser.add_argument('--unit', '-u', type=int, default=50, help='Number of units') parser.add_argument('--example', '-ex', type=int, default=3, help='Example mode') args = parser.parse_args() print('GPU: {}'.format(args.gpu)) print('# unit: {}'.format(args.unit)) print('# Minibatch-size: {}'.format(args.batchsize)) print('# epoch: {}'.format(args.epoch)) print('') # Load the MNIST dataset train, test = chainer.datasets.get_mnist() model = SklearnWrapperClassifier(MLP(args.unit, 10), device=args.gpu) if args.example == 1: print("Example 1. fit with x, y numpy array (same with sklearn's fit)") x, y = concat_examples(train) model.fit(x, y) elif args.example == 2: print("Example 2. Train with Chainer's dataset") # `train` is TupleDataset in this example # Even this one line work! (but no validation) model.fit(train) else: print("Example 3. Train with configuration") model.fit( train, test=test, batchsize=args.batchsize, #iterator_class=chainer.iterators.SerialIterator, optimizer=chainer.optimizers.Adam(), epoch=args.epoch, out=args.out, snapshot_frequency=1, #dump_graph=False #log_report=True, plot_report=False, #print_report=True, progress_report=False, resume=args.resume ) # Save trained model serializers.save_npz('{}/mlp.model'.format(args.out), model)