我们从Python开源项目中,提取了以下3个代码示例,用于说明如何使用chainer.training.extensions.observe_lr()。
def __init__(self, **kwargs): required_keys = [] optional_keys = [ 'dump_graph', 'Evaluator', 'ExponentialShift', 'LinearShift', 'LogReport', 'observe_lr', 'observe_value', 'snapshot', 'PlotReport', 'PrintReport', ] super().__init__( required_keys, optional_keys, kwargs, self.__class__.__name__)
def start(self): """ Train pose net. """ # set random seed. if self.seed is not None: random.seed(self.seed) np.random.seed(self.seed) if self.gpu >= 0: chainer.cuda.cupy.random.seed(self.seed) # initialize model to train. model = AlexNet(self.Nj, self.use_visibility) if self.resume_model: serializers.load_npz(self.resume_model, model) # prepare gpu. if self.gpu >= 0: chainer.cuda.get_device(self.gpu).use() model.to_gpu() # load the datasets. train = PoseDataset(self.train, data_augmentation=self.data_augmentation) val = PoseDataset(self.val, data_augmentation=False) # training/validation iterators. train_iter = chainer.iterators.MultiprocessIterator( train, self.batchsize) val_iter = chainer.iterators.MultiprocessIterator( val, self.batchsize, repeat=False, shuffle=False) # set up an optimizer. optimizer = self._get_optimizer() optimizer.setup(model) if self.resume_opt: chainer.serializers.load_npz(self.resume_opt, optimizer) # set up a trainer. updater = training.StandardUpdater(train_iter, optimizer, device=self.gpu) trainer = training.Trainer( updater, (self.epoch, 'epoch'), os.path.join(self.out, 'chainer')) # standard trainer settings trainer.extend(extensions.dump_graph('main/loss')) val_interval = (10, 'epoch') trainer.extend(TestModeEvaluator(val_iter, model, device=self.gpu), trigger=val_interval) # save parameters and optimization state per validation step resume_interval = (self.epoch/10, 'epoch') trainer.extend(extensions.snapshot_object( model, "epoch-{.updater.epoch}.model"), trigger=resume_interval) trainer.extend(extensions.snapshot_object( optimizer, "epoch-{.updater.epoch}.state"), trigger=resume_interval) trainer.extend(extensions.snapshot( filename="epoch-{.updater.epoch}.iter"), trigger=resume_interval) # show log log_interval = (10, "iteration") trainer.extend(extensions.LogReport(trigger=log_interval)) trainer.extend(extensions.observe_lr(), trigger=log_interval) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'lr']), trigger=log_interval) trainer.extend(extensions.ProgressBar(update_interval=10)) # start training if self.resume: chainer.serializers.load_npz(self.resume, trainer) trainer.run()