我们从Python开源项目中,提取了以下6个代码示例,用于说明如何使用chainer.training.extensions.PlotReport()。
def set_trainer(self, out_dir, gpu, n_epoch, g_clip, opt_name, lr=None): if opt_name == "Adam": opt = getattr(optimizers, opt_name)() else: opt = getattr(optimizers, opt_name)(lr) opt.setup(self.model) opt.add_hook(optimizer.GradientClipping(g_clip)) updater = training.StandardUpdater(self.train_iter, opt, device=gpu) self.trainer = training.Trainer(updater, (n_epoch, 'epoch'), out=out_dir) self.trainer.extend(extensions.Evaluator(self.test_iter, self.model, device=gpu)) self.trainer.extend(extensions.dump_graph('main/loss')) self.trainer.extend(extensions.snapshot(), trigger=(n_epoch, 'epoch')) self.trainer.extend(extensions.LogReport()) self.trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png')) self.trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png')) self.trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) self.trainer.extend(extensions.ProgressBar())
def __init__(self, **kwargs): required_keys = [] optional_keys = [ 'dump_graph', 'Evaluator', 'ExponentialShift', 'LinearShift', 'LogReport', 'observe_lr', 'observe_value', 'snapshot', 'PlotReport', 'PrintReport', ] super().__init__( required_keys, optional_keys, kwargs, self.__class__.__name__)
def fit(self, X, y=None, **kwargs): """If hyper parameters are set to None, then instance's variable is used, this functionality is used Grid search with `set_params` method. Also if instance's variable is not set, _default_hyperparam is used. Usage: model.fit(train_dataset) or model.fit(X, y) Args: train: training dataset, assumes chainer's dataset class test: test dataset for evaluation, assumes chainer's dataset class batchsize: batchsize for both training and evaluation iterator_class: iterator class used for this training, currently assumes SerialIterator or MultiProcessIterator optimizer: optimizer instance to update parameter epoch: training epoch out: directory path to save the result snapshot_frequency (int): snapshot frequency in epoch. Negative value indicates not to take snapshot. dump_graph: Save computational graph info or not, default is False. log_report: Enable LogReport or not plot_report: Enable PlotReport or not print_report: Enable PrintReport or not progress_report: Enable ProgressReport or not resume: specify trainer saved path to resume training. """ kwargs = self.filter_sk_params(self.fit_core, kwargs) return self.fit_core(X, y, **kwargs)
def train_task(args, train_name, model, epoch_num, train_dataset, test_dataset_dict, batch_size): optimizer = optimizers.SGD() optimizer.setup(model) train_iter = iterators.SerialIterator(train_dataset, batch_size) test_iter_dict = {name: iterators.SerialIterator( test_dataset, batch_size, repeat=False, shuffle=False) for name, test_dataset in test_dataset_dict.items()} updater = training.StandardUpdater(train_iter, optimizer) trainer = training.Trainer(updater, (epoch_num, 'epoch'), out=args.out) for name, test_iter in test_iter_dict.items(): trainer.extend(extensions.Evaluator(test_iter, model), name) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss'] + [test+'/main/loss' for test in test_dataset_dict.keys()] + ['main/accuracy'] + [test+'/main/accuracy' for test in test_dataset_dict.keys()])) trainer.extend(extensions.ProgressBar()) trainer.extend(extensions.PlotReport( [test+"/main/accuracy" for test in test_dataset_dict.keys()], file_name=train_name+".png")) trainer.run()
def train_main(args): """ trains model specfied in args. main method for train subcommand. """ # load text with open(args.text_path) as f: text = f.read() logger.info("corpus length: %s.", len(text)) # data iterator data_iter = DataIterator(text, args.batch_size, args.seq_len) # load or build model if args.restore: logger.info("restoring model.") load_path = args.checkpoint_path if args.restore is True else args.restore model = load_model(load_path) else: net = Network(vocab_size=VOCAB_SIZE, embedding_size=args.embedding_size, rnn_size=args.rnn_size, num_layers=args.num_layers, drop_rate=args.drop_rate) model = L.Classifier(net) # make checkpoint directory log_dir = make_dirs(args.checkpoint_path) with open("{}.json".format(args.checkpoint_path), "w") as f: json.dump(model.predictor.args, f, indent=2) chainer.serializers.save_npz(args.checkpoint_path, model) logger.info("model saved: %s.", args.checkpoint_path) # optimizer optimizer = chainer.optimizers.Adam(alpha=args.learning_rate) optimizer.setup(model) # clip gradient norm optimizer.add_hook(chainer.optimizer.GradientClipping(args.clip_norm)) # trainer updater = BpttUpdater(data_iter, optimizer) trainer = chainer.training.Trainer(updater, (args.num_epochs, 'epoch'), out=log_dir) trainer.extend(extensions.snapshot_object(model, filename=os.path.basename(args.checkpoint_path))) trainer.extend(extensions.ProgressBar(update_interval=1)) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PlotReport(y_keys=["main/loss"])) trainer.extend(LoggerExtension(text)) # training start model.predictor.reset_state() logger.info("start of training.") time_train = time.time() trainer.run() # training end duration_train = time.time() - time_train logger.info("end of training, duration: %ds.", duration_train) # generate text seed = generate_seed(text) generate_text(model, seed, 1024, 3) return model