Python torch.nn 模块,DataParallel() 实例源码

我们从Python开源项目中,提取了以下38个代码示例,用于说明如何使用torch.nn.DataParallel()

项目:MIL.pytorch    作者:gujiuxiang    | 项目源码 | 文件源码
def build_mil(opt):
    opt.n_gpus = getattr(opt, 'n_gpus', 1)

    if 'resnet101' in opt.model:
        mil_model = resnet_mil(opt)
    else:
        mil_model = vgg_mil(opt)

    if opt.n_gpus>1:
        print('Construct multi-gpu model ...')
        model = nn.DataParallel(mil_model, device_ids=opt.gpus, dim=0)
    else:
        model = mil_model
    # check compatibility if training is continued from previously saved model
    if len(opt.start_from) != 0:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from), " %s must be a a path" % opt.start_from
        lm_info_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.infos-best.pkl')
        lm_pth_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.model-best.pth')
        assert os.path.isfile(lm_info_path), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(torch.load(lm_pth_path))
    model.cuda()
    model.train()  # Assure in training mode
    return model
项目:ExperimentPackage_PyTorch    作者:ICEORY    | 项目源码 | 文件源码
def dataparallel(model, ngpus, gpu0=0):
    if ngpus == 0:
        assert False, "only support gpu mode"
    gpu_list = list(range(gpu0, gpu0+ngpus))
    assert torch.cuda.device_count() >= gpu0+ngpus, "Invalid Number of GPUs"
    if isinstance(model, list):
        for i in range(len(model)):
            if ngpus >= 2:
                if not isinstance(model[i], nn.DataParallel):
                    model[i] = torch.nn.DataParallel(model[i], gpu_list).cuda()
            else:
                model[i] = model[i].cuda()
    else:
        if ngpus >= 2:
            if not isinstance(model, nn.DataParallel):
                model = torch.nn.DataParallel(model, gpu_list).cuda()
        else:
            model = model.cuda()
    return model
项目:NeuralMT    作者:hlt-mt    | 项目源码 | 文件源码
def new_instance(src_dict, trg_dict, model_params=None, random_seed=None, gpu_ids=None, init_value=0.1):
        if model_params is None:
            from nmmt import NMTEngine
            model_params = NMTEngine.Parameters()

        if gpu_ids is not None and len(gpu_ids) > 0:
            torch.cuda.set_device(gpu_ids[0])

        encoder = Models.Encoder(model_params, src_dict)
        decoder = Models.Decoder(model_params, trg_dict)
        generator = nn.Sequential(nn.Linear(model_params.rnn_size, trg_dict.size()), nn.LogSoftmax())

        model = Models.NMTModel(encoder, decoder)

        if gpu_ids is not None and len(gpu_ids) > 0:
            model.cuda()
            generator.cuda()

            if len(gpu_ids) > 1:
                model = nn.DataParallel(model, device_ids=gpu_ids, dim=1)
                generator = nn.DataParallel(generator, device_ids=gpu_ids, dim=0)
        else:
            model.cpu()
            generator.cpu()

        model.generator = generator

        for p in model.parameters():
            p.data.uniform_(-init_value, init_value)

        optim = Optim(model_params.optim, model_params.learning_rate, model_params.max_grad_norm,
                      lr_decay=model_params.learning_rate_decay, start_decay_at=model_params.start_decay_at)
        optim.set_parameters(model.parameters())

        return NMTEngineTrainer(model, optim, src_dict, trg_dict,
                                model_params=model_params, gpu_ids=gpu_ids, random_seed=random_seed)
项目:vqa.pytorch    作者:Cadene    | 项目源码 | 文件源码
def factory(opt, vocab_words, vocab_answers, cuda=True, data_parallel=True):
    opt = copy.copy(opt)

    if opt['arch'] in model_names:
        model = getattr(sys.modules[__name__], opt['arch'])(opt, vocab_words, vocab_answers)
    else:
        raise ValueError

    if data_parallel:
        model = nn.DataParallel(model).cuda()
        if not cuda:
            raise ValueError

    if cuda:
        model.cuda()

    return model
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_data_parallel_module_kwargs_only(self):
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.l = l

            def forward(self, input):
                return self.l(input)

        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        n = nn.DataParallel(Net())
        out = n(input=i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def save(epoch, model, criterion, optimState, bestModel, loss, opt):
    if isinstance(model, nn.DataParallel):
        model = model.get(0)

    if bestModel or (epoch % opt.saveEpoch == 0):
        if opt.saveOne:
            subprocess.call('rm ' + opt.resume + '/*.pth', shell=True)

        modelFile = 'model_' + str(epoch) + '.pth'
        criterionFile = 'criterion_' + str(epoch) + '.pth'
        optimFile = 'optimState_' + str(epoch) +'.pth'
        torch.save(model, os.path.join(opt.resume, modelFile))
        torch.save(criterion, os.path.join(opt.resume, criterionFile))
        torch.save(optimState, os.path.join(opt.resume, optimFile))
        info = {'epoch':epoch, 'modelFile':modelFile, 'criterionFile':criterionFile, 'optimFile':optimFile, 'loss':loss}
        torch.save(info, os.path.join(opt.resume, 'latest.pth'))

    if bestModel:
        info = {'epoch':epoch, 'modelFile':modelFile, 'criterionFile':criterionFile, 'optimFile':optimFile, 'loss':loss}
        torch.save(info, os.path.join(opt.resume, 'best.pth'))
        torch.save(model, os.path.join(opt.resume, 'model_best.pth'))
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_data_parallel_module_kwargs_only(self):
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.l = l

            def forward(self, input):
                return self.l(input)

        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        n = nn.DataParallel(Net())
        out = n(input=i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_data_parallel_module_kwargs_only_empty_list(self):
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.l = l

            def forward(self, input):
                return self.l(input['data'])

        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        n = nn.DataParallel(Net())
        out = n(input={'data': i, 'unused': []})
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_data_parallel_module_kwargs_only_empty_dict(self):
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.l = l

            def forward(self, input):
                return self.l(input['data'])

        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        n = nn.DataParallel(Net())
        out = n(input={'data': i, 'unused': {}})
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_data_parallel_module_kwargs_only_empty_tuple(self):
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.l = l

            def forward(self, input):
                return self.l(input['data'])

        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        n = nn.DataParallel(Net())
        out = n(input={'data': i, 'unused': ()})
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:tutorials    作者:pytorch    | 项目源码 | 文件源码
def forward(self, input):
        output = self.fc(input)
        print("  In Model: input size", input.size(), 
              "output size", output.size())

        return output


######################################################################
# Create Model and DataParallel
# -----------------------------
# 
# This is the core part of the tutorial. First, we need to make a model instance
# and check if we have multiple GPUs. If we have multiple GPUs, we can wrap 
# our model using ``nn.DataParallel``. Then we can put our model on GPUs by
# ``model.gpu()`` 
#
项目:vsepp    作者:fartashf    | 项目源码 | 文件源码
def get_cnn(self, arch, pretrained):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if pretrained:
            print("=> using pre-trained model '{}'".format(arch))
            model = models.__dict__[arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(arch))
            model = models.__dict__[arch]()

        if arch.startswith('alexnet') or arch.startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        return model
项目:RetinaNet    作者:c0nn3r    | 项目源码 | 文件源码
def train(model, cuda=False):

    average_loss = 0

    if cuda:
        model.cuda()
        model = nn.DataParallel(model)

    for current_batch, (images, box_targets, class_targets) in enumerate(
            tqdm(train_loader, desc='Training on COCO', unit='epoch')):

        scheduler.step()

        optimizer.zero_grad()

        if cuda:
            images.cuda()
            box_targets.cuda()
            class_targets.cuda()

        images = Variable(images)
        # box_predictions = Variable(box_targets)
        # class_predictions = Variable(class_targets)
        box_predictions, classes_predictions = model(images)

        loss = criterion(box_predictions, box_targets, class_predictions, class_targets)
        # loss.backwards()
        loss.backward()

        average_loss += loss[0]

        # boxes, classes = model(images)

        optimizer.step()

        print(f'Batch: {current_batch}, Loss: {loss[0]}, Average Loss: {average_loss / current_batch + 1}')
项目:ExperimentPackage_PyTorch    作者:ICEORY    | 项目源码 | 文件源码
def model2list(model):
    """ 
    convert model to list type
    :param model: should be type of list or nn.DataParallel or nn.Sequential
    :return: no return params
    """
    if isinstance(model, nn.DataParallel):
        model = list(model.module)
    elif isinstance(model, nn.Sequential):
        model = list(model)
    elif not isinstance(model, list):
        assert False, "model should be type of <nn.DataParallel> or <nn.Sequential> or <list>"
    return model
项目:pytorch    作者:tylergenter    | 项目源码 | 文件源码
def test_data_parallel_module(self):
        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        net = nn.DataParallel(l)
        out = net(i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20, 20)
项目:ml-utils    作者:LinxiFan    | 项目源码 | 文件源码
def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

########################################################################
# The code does not need to be changed in CPU-mode.
#
# The documentation for DataParallel is
# `here <http://pytorch.org/docs/nn.html#torch.nn.DataParallel>`_.
#
# **Primitives on which DataParallel is implemented upon:**
#
#
# In general, pytorch’s `nn.parallel` primitives can be used independently.
# We have implemented simple MPI-like primitives:
#
# - replicate: replicate a Module on multiple devices
# - scatter: distribute the input in the first-dimension
# - gather: gather and concatenate the input in the first-dimension
# - parallel\_apply: apply a set of already-distributed inputs to a set of
#   already-distributed models.
#
# To give a better clarity, here function ``data_parallel`` composed using
# these collectives
项目:LIE    作者:EmbraceLife    | 项目源码 | 文件源码
def parallelize(self):
        """ Applies any parallelism requested.
        """
        if not self.gpu:
            return self.model
        if isinstance(self.gpu, bool):
            devices = None
        else:
            devices = self.gpu
        return nn.DataParallel(self.model, devices).cuda()

    ###########################################################################
项目:pytorch-coriander    作者:hughperkins    | 项目源码 | 文件源码
def test_data_parallel_module(self):
        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        net = nn.DataParallel(l)
        out = net(i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:pytorch    作者:ezyang    | 项目源码 | 文件源码
def test_data_parallel_module(self):
        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        net = nn.DataParallel(l)
        out = net(i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def setup(opt, checkpoint):
    model = None
    if checkpoint != None:
        modelPath = os.path.join(opt.resume, checkpoint['modelFile'])
        assert os.path.exists(modelPath), 'Saved model not found: '+modelPath
        print('=> Resuming model from ' + modelPath)
        model = torch.load(modelPath)
    else:
        print('=> Creating new model')
        models = importlib.import_module('models.' + opt.netType)
        model = models.createModel(opt)

    if isinstance(model, nn.DataParallel):
        model = model.get(0)

    if opt.resetClassifier and not checkpoint:
        pass
        #TODO

    if opt.cudnn == 'fastest':
        cudnn.fastest = True
        cudnn.benchmark = True
    elif opt.cudnn == 'deterministic':
        cudnn.fastest = False
        cudnn.benchmark = False
        #TODO

    if opt.nGPUs > 1:
        gpus = opt.GPUs
        fastest, benchmark = cudnn.fastest, cudnn.benchmark
        # TODO  make a dataparallel to split data on different GPUs

    optimState = None
    if checkpoint != None:
        optimPath = os.path.join(opt.resume, checkpoint['optimFile'])
        assert os.path.exists(optimPath), 'Saved optimState not found: ' + optimPath
        print('=> Resuming optimState from ' + optimPath)
        optimState = torch.load(optimPath)

    return model, optimState
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        self.model = myModel(opt)
        self.model = nn.DataParallel(self.model, opt.GPUs)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        self.model = myModel(opt)
        self.model = nn.DataParallel(self.model, opt.GPUs)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        self.model = myModel(opt)
        self.model = nn.DataParallel(self.model, opt.GPUs)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        self.model = ENet(opt)
        self.model = nn.DataParallel(self.model, opt.GPUs)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        self.model = myModel(opt)
        self.model = nn.DataParallel(self.model, opt.GPUs)
项目:PytorchDL    作者:FredHuangBia    | 项目源码 | 文件源码
def __init__(self, opt):
        super().__init__()
        self.opt = opt

        self.model = myModel(opt)
        self.model = nn.DataParallel(self.model, opt.GPUs)
项目:iffse    作者:kendricktan    | 项目源码 | 文件源码
def load_openface_net(checkpoint_pth, cuda=True, gpu_id=0, multi_gpu=False):
    """
    Creates an OpenFace Network and loads the
    checkpoint file (openface.pth)
    """
    model = netOpenFace(cuda, gpu_id)
    model.load_state_dict(torch.load(checkpoint_pth))

    if multi_gpu:
        model = nn.DataParallel(model)

    return model
项目:pytorch    作者:pytorch    | 项目源码 | 文件源码
def test_data_parallel_module(self):
        l = nn.Linear(10, 5).float().cuda()
        i = Variable(torch.randn(20, 10).float().cuda())
        expected_out = l(i).data
        net = nn.DataParallel(l)
        out = net(i)
        self.assertEqual(out.get_device(), 0)
        self.assertEqual(out.data, expected_out)
项目:pytorch-vqa    作者:Cyanogenoid    | 项目源码 | 文件源码
def main():
    if len(sys.argv) > 1:
        name = ' '.join(sys.argv[1:])
    else:
        from datetime import datetime
        name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
    target_name = os.path.join('logs', '{}.pth'.format(name))
    print('will save to {}'.format(target_name))

    cudnn.benchmark = True

    train_loader = data.get_loader(train=True)
    val_loader = data.get_loader(val=True)

    net = nn.DataParallel(model.Net(train_loader.dataset.num_tokens)).cuda()
    optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad])

    tracker = utils.Tracker()
    config_as_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}

    for i in range(config.epochs):
        _ = run(net, train_loader, optimizer, tracker, train=True, prefix='train', epoch=i)
        r = run(net, val_loader, optimizer, tracker, train=False, prefix='val', epoch=i)

        results = {
            'name': name,
            'tracker': tracker.to_dict(),
            'config': config_as_dict,
            'weights': net.state_dict(),
            'eval': {
                'answers': r[0],
                'accuracies': r[1],
                'idx': r[2],
            },
            'vocab': train_loader.dataset.vocab,
        }
        torch.save(results, target_name)
项目:kur    作者:deepgram    | 项目源码 | 文件源码
def parallelize(self):
        """ Applies any parallelism requested.
        """
        if not self.gpu:
            return self.model
        if isinstance(self.gpu, bool):
            devices = None
        else:
            devices = self.gpu
        return nn.DataParallel(self.model, devices).cuda()

    ###########################################################################
项目:tutorials    作者:pytorch    | 项目源码 | 文件源码
def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20, 20)
项目:tutorials    作者:pytorch    | 项目源码 | 文件源码
def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

########################################################################
# The code does not need to be changed in CPU-mode.
#
# The documentation for DataParallel is
# `here <http://pytorch.org/docs/nn.html#torch.nn.DataParallel>`_.
#
# **Primitives on which DataParallel is implemented upon:**
#
#
# In general, pytorch’s `nn.parallel` primitives can be used independently.
# We have implemented simple MPI-like primitives:
#
# - replicate: replicate a Module on multiple devices
# - scatter: distribute the input in the first-dimension
# - gather: gather and concatenate the input in the first-dimension
# - parallel\_apply: apply a set of already-distributed inputs to a set of
#   already-distributed models.
#
# To give a better clarity, here function ``data_parallel`` composed using
# these collectives
项目:age    作者:ly015    | 项目源码 | 文件源码
def extract_feat(model_id, model = None):

    if model is None:
        model = GANModel(fn = 'models/%s/best.pth' % model_id)

    if torch.cuda.device_count() > 1:
        model.cnn = nn.DataParallel(model.cnn)
    model.cuda()
    model.eval()

    output_dir = os.path.join('output', 'feature', model_id)
    io.mkdir_if_missing(output_dir)

    for subset in ['train', 'test']:
        dset = dataset.load_video_age_dataset(version = '2.0', subset = subset,
            crop_size = 128, age_rng = [model.opts.min_age, model.opts.max_age])

        loader = torch.utils.data.DataLoader(dset, batch_size = torch.cuda.device_count() * 32, shuffle = False, 
            num_workers = 4, pin_memory = True)

        feats = []
        for batch_idx, data in enumerate(loader):
            img_seq, seq_len, _, _ = data
            img_seq = Variable(img_seq, volatile = True).cuda()
            seq_len = Variable(seq_len, volatile = True).cuda()

            age_out, _, feat = model.forward_video(img_seq, seq_len)
            feats.append(feat.data.cpu())
            print('\r[extract CNN feature] %s: %.2f%%' % (subset, 100.*batch_idx/len(loader)), end = '')
            sys.stdout.flush()
        print('\n')

        feats = torch.cat(feats, dim = 0).numpy()
        id_lst = dset.id_lst
        out = {'feat': feats, 'id_lst': id_lst}
        io.save_data(out, os.path.join(output_dir, subset + '.pkl'))
项目:OpenNMT-py    作者:OpenNMT    | 项目源码 | 文件源码
def drop_checkpoint(self, opt, epoch, fields, valid_stats):
        """ Save a resumable checkpoint.

        Args:
            opt (dict): option object
            epoch (int): epoch number
            fields (dict): fields and vocabulary
            valid_stats : statistics of last validation run
        """
        real_model = (self.model.module
                      if isinstance(self.model, nn.DataParallel)
                      else self.model)
        real_generator = (real_model.generator.module
                          if isinstance(real_model.generator, nn.DataParallel)
                          else real_model.generator)

        model_state_dict = real_model.state_dict()
        model_state_dict = {k: v for k, v in model_state_dict.items()
                            if 'generator' not in k}
        generator_state_dict = real_generator.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'generator': generator_state_dict,
            'vocab': onmt.io.save_fields_to_vocab(fields),
            'opt': opt,
            'epoch': epoch,
            'optim': self.optim,
        }
        torch.save(checkpoint,
                   '%s_acc_%.2f_ppl_%.2f_e%d.pt'
                   % (opt.save_model, valid_stats.accuracy(),
                      valid_stats.ppl(), epoch))
项目:OpenNMT-py    作者:OpenNMT    | 项目源码 | 文件源码
def build_model(model_opt, opt, fields, checkpoint):
    print('Building model...')
    model = onmt.ModelConstructor.make_base_model(model_opt, fields,
                                                  use_gpu(opt), checkpoint)
    if len(opt.gpuid) > 1:
        print('Multi gpu training: ', opt.gpuid)
        model = nn.DataParallel(model, device_ids=opt.gpuid, dim=1)
    print(model)

    return model
项目:age    作者:ly015    | 项目源码 | 文件源码
def test_model_video(model, test_opts):
    print('[PoseModel.test_video] test options: %s' % test_opts)

    # move model to GPU and set to eval mode.
    if torch.cuda.device_count() > 1:
        model.cnn = nn.DataParallel(model.cnn)
    model.cuda()
    model.eval()

    # create dataloader
    test_dset = dataset.load_video_age_dataset(version = test_opts.dataset_version, subset = test_opts.subset, 
        crop_size = test_opts.crop_size, age_rng = [0, 70])
    test_loader = torch.utils.data.DataLoader(test_dset, batch_size = test_opts.batch_size, num_workers = 4)

    pose_pred = []

    for batch_idx, data in enumerate(test_loader):

        img_seq, seq_len, _, _  = data
        img_seq = Variable(img_seq, volatile = True).cuda()
        seq_len = Variable(seq_len, volatile = True).cuda()

        pose = model.forward_video(img_seq, seq_len)

        for i, l in enumerate(seq_len):
            l = int(l.data[0])
            pose_pred.append(pose.data.cpu()[i, 0:l, :].numpy().tolist())
        print('\rTesting %d/%d (%.2f%%)' % (batch_idx, len(test_loader), 100.*batch_idx/len(test_loader)), end = '')
        sys.stdout.flush()
    print('\n')


    # result
    id_lst = test_dset.id_lst
    rst = {s_id:{'pose': p} for s_id, p in zip(id_lst, pose_pred)}

    # output result
    if test_opts.id.endswith('.pth'):
        # test_opts.id is a file name
        output_dir = os.path.dirname(test_opts.id)
    else:
        # test_opts.id is a model id
        output_dir = os.path.join('models', test_opts.id)

    assert os.path.isdir(output_dir)

    fn_rst = os.path.join(output_dir, 'video_age_v%s_test_rst.pkl' % test_opts.dataset_version)
    io.save_data(rst, fn_rst)
项目:age    作者:ly015    | 项目源码 | 文件源码
def test_model_video(model, test_opts):
    print('[AttributeModel.test] test options: %s' % test_opts)

    # move model to GPU and set to eval mode.
    if torch.cuda.device_count() > 1:
        model.cnn = nn.DataParallel(model.cnn)
    model.cuda()
    model.eval()

    # create dataloader
    test_dset = dataset.load_video_age_dataset(version = test_opts.dataset_version, subset = test_opts.subset, 
        crop_size = test_opts.crop_size, age_rng = [0, 70])
    test_loader = torch.utils.data.DataLoader(test_dset, batch_size = test_opts.batch_size, num_workers = 4)

    attr_pred = []
    for batch_idx, data in enumerate(test_loader):

        img_seq, seq_len, _, _  = data
        img_seq = Variable(img_seq, volatile = True).cuda()
        seq_len = Variable(seq_len, volatile = True).cuda()

        attr = model.forward_video(img_seq, seq_len)

        for i, l in enumerate(seq_len):
            l = int(l.data[0])
            attr_pred.append(attr.data.cpu()[i, 0:l, :].numpy().tolist())
        print('\rTesting %d/%d (%.2f%%)' % (batch_idx, len(test_loader), 100.*batch_idx/len(test_loader)), end = '')
        sys.stdout.flush()
    print('\n')

    id_lst = test_dset.id_lst
    rst = {s_id:{'attr': p} for s_id, p in zip(id_lst, attr_pred)}

    # output result
    if test_opts.id.endswith('.pth'):
        # test_opts.id is a file name
        output_dir = os.path.dirname(test_opts.id)
    else:
        # test_opts.id is a model id
        output_dir = os.path.join('models', test_opts.id)

    assert os.path.isdir(output_dir)

    fn_rst = os.path.join(output_dir, 'video_age_v%s_test_rst.pkl' % test_opts.dataset_version)
    io.save_data(rst, fn_rst)