Python caffe.proto.caffe_pb2 模块,TEST 实例源码

我们从Python开源项目中,提取了以下12个代码示例,用于说明如何使用caffe.proto.caffe_pb2.TEST

项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def resnet(n=3, num_output = 16):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""    
    net_name = "resnet-"    
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    if n > 18:
        # warm up
        solver = Solver(solver_name="solver_warm.prototxt", folder=pt_folder, lr_policy=Solver.policy.fixed)
        solver.p.base_lr = 0.01
        solver.set_max_iter(500)
        solver.write()
        del solver

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.resnet_cifar(n, num_output=num_output)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def resnet_orth(n=3):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""    
    net_name = "resnet-orth-"    
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    if n > 18:
        # warm up
        solver = Solver(solver_name="solver_warm.prototxt", folder=pt_folder, lr_policy=Solver.policy.fixed)
        solver.p.base_lr = 0.01
        solver.set_max_iter(500)
        solver.write()
        del solver

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.resnet_cifar(n, orth=True)
    builder.write(folder=pt_folder)
项目:pre-resnet-gen-caffe    作者:Cysu    | 项目源码 | 文件源码
def _get_include(phase):
    inc = caffe_pb2.NetStateRule()
    if phase == 'train':
        inc.phase = caffe_pb2.TRAIN
    elif phase == 'test':
        inc.phase = caffe_pb2.TEST
    else:
        raise ValueError("Unknown phase {}".format(phase))
    return inc
项目:resnet-cifar10-caffe    作者:yihui-he    | 项目源码 | 文件源码
def transform_param(self, mean_value=128, batch_size=128, scale=.0078125, mirror=1, crop_size=None, mean_file_size=None, phase=None):

        new_transform_param = self.this.transform_param
        new_transform_param.scale = scale
        new_transform_param.mean_value.extend([mean_value])
        if phase is not None and phase == 'TEST':
            return

        new_transform_param.mirror = mirror
        if crop_size is not None:
            new_transform_param.crop_size = crop_size
项目:resnet-cifar10-caffe    作者:yihui-he    | 项目源码 | 文件源码
def include(self, phase='TRAIN'):
        if phase is not None:
            includes = self.this.include.add()
            if phase == 'TRAIN':
                includes.phase = caffe_pb2.TRAIN
            elif phase == 'TEST':
                includes.phase = caffe_pb2.TEST
        else:
            NotImplementedError


    #************************** inplace **************************
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def transform_param(self, mean_value=128, batch_size=128, scale=.0078125, mirror=1, crop_size=None, mean_file_size=None, phase=None):

        new_transform_param = self.this.transform_param
        if scale != 1:
            new_transform_param.scale = scale
        if isinstance(mean_value, list):
            new_transform_param.mean_value.extend(mean_value)
        else:
            new_transform_param.mean_value.extend([mean_value])
        if phase is not None and phase == 'TEST':
            return

        new_transform_param.mirror = mirror
        if crop_size is not None:
            new_transform_param.crop_size = crop_size
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def include(self, phase='TRAIN'):
        if phase is not None:
            includes = self.this.include.add()
            if phase == 'TRAIN':
                includes.phase = caffe_pb2.TRAIN
            elif phase == 'TEST':
                includes.phase = caffe_pb2.TEST
        else:
            NotImplementedError


    #************************** inplace **************************
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def BatchNorm(self, name=None, inplace=True,eps=1e-5):
        moving_average_fraction = 0
        if not inplace:
            bottom = self.this.name
        # train
        bn_name = self.suffix('bn', name)
        self.setup(bn_name, 'BatchNorm', inplace=inplace)
        # self.include()

        self.param(lr_mult=0, decay_mult=0)
        self.param(lr_mult=0, decay_mult=0)
        self.param(lr_mult=0, decay_mult=0)
        batch_norm_param = self.this.batch_norm_param
        if eps != 1e-5:
            batch_norm_param.eps = eps

        return bn_name
        # batch_norm_param.use_global_stats = False
        #batch_norm_param.moving_average_fraction = moving_average_fraction

        # test 
        # if not inplace:
        #     self.setup(bn_name, 'BatchNorm', inplace=inplace, bottom=[bottom])
        # else:
        #     self.setup(bn_name, 'BatchNorm', inplace=inplace)

        # self.include(phase='TEST')

        # self.param(lr_mult=0, decay_mult=0)
        # self.param(lr_mult=0, decay_mult=0)
        # self.param(lr_mult=0, decay_mult=0)
        # batch_norm_param = self.this.batch_norm_param
        # batch_norm_param.use_global_stats = True
        # batch_norm_param.moving_average_fraction = moving_average_fraction
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain(n=3):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""
    net_name = "plain"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, num_output = 16)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain_orth(n=3):
    """6n+2, n=3 5 7 9 18 coresponds to 20 56 110 layers"""
    net_name = "plain-orth"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, orth=True)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def plain_orth_v1(n=3):
    """6n+2, n=3 5 7 9 18 coresponds to 20 32 44 56 110 layers"""
    net_name = "plain-orth-v1-"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, orth=True, inplace=False, num_output = 16)
    builder.write(folder=pt_folder)
项目:channel-pruning    作者:yihui-he    | 项目源码 | 文件源码
def acc(n=3):
    """6n+2, n=3 9 18 coresponds to 20 56 110 layers"""
    net_name = "plain"
    pt_folder = osp.join(osp.abspath(osp.curdir), net_name +str(6*n+2))
    name = net_name+str(6*n+2)+'-cifar10'

    solver = Solver(folder=pt_folder)
    solver.write()
    del solver

    builder = Net(name)
    builder.Data('cifar-10-batches-py/train', phase='TRAIN', crop_size=32)
    builder.Data('cifar-10-batches-py/test', phase='TEST')
    builder.plain_cifar(n, num_output = 16, inplace=False)
    builder.write(folder=pt_folder)