我们从Python开源项目中,提取了以下42个代码示例,用于说明如何使用six.moves.cPickle.loads()。
def load(fname): """Load an embedding dump generated by `save`""" content = _open(fname).read() if PY2: state = pickle.loads(content) else: state = pickle.loads(content, encoding='latin1') voc, vec = state if len(voc) == 2: words, counts = voc word_count = dict(zip(words, counts)) vocab = CountedVocabulary(word_count=word_count) else: vocab = OrderedVocabulary(voc) return Embedding(vocabulary=vocab, vectors=vec)
def check_pickling(self, x_data): x = chainer.Variable(x_data) y = self.link(x) y_data1 = y.data del x, y pickled = pickle.dumps(self.link, -1) del self.link self.link = pickle.loads(pickled) x = chainer.Variable(x_data) y = self.link(x) y_data2 = y.data gradient_check.assert_allclose(y_data1, y_data2, atol=0, rtol=0)
def test_transform_then_prediction(self): with TemporaryDirectory() as temp: from sklearn.pipeline import Pipeline path = os.path.join(temp, 'audio.sph') urlretrieve(filename=path, url='https://s3.amazonaws.com/ai-datasets/sw02001.sph') f = Pipeline([ ('mspec', model.SpeechTransform('mspec', fs=8000, vad=False)), ('slice', model.Transform(lambda x: x[:, :40])), ('pred', model.SequentialModel(N.Dropout(0.3), N.Dense(20, activation=K.relu), N.Dense(10, activation=K.softmax)) ) ]) x1 = f.predict(path) x2 = f.predict_proba(path) f = cPickle.loads(cPickle.dumps(f)) y1 = f.predict(path) y2 = f.predict_proba(path) self.assertEqual(np.array_equal(x1, y1), True) self.assertEqual(np.array_equal(x2, y2), True)
def test_complex_transform(self): with TemporaryDirectory() as temp: from sklearn.pipeline import Pipeline path = os.path.join(temp, 'audio.sph') urlretrieve(filename=path, url='https://s3.amazonaws.com/ai-datasets/sw02001.sph') f = Pipeline([ ('step1', model.SpeechTransform('mspec', fs=8000, vad=True)), ('step2', model.Transform(lambda x: (x[0][:, :40], x[1].astype(str)))), ('step3', model.Transform(lambda x: (np.sum(x[0]), ''.join(x[1].tolist())))) ]) x = f.transform(path) f = cPickle.loads(cPickle.dumps(f)) y = f.transform(path) self.assertEqual(x[0], y[0]) self.assertEqual(y[0], -3444229.0) self.assertEqual(x[1], y[1])
def test_load_save1(self): K.set_training(True) X = K.placeholder((None, 1, 28, 28)) f = N.Dense(128, activation=K.relu) y = f(X) W, b = [K.get_value(p).sum() for p in K.ComputationGraph(y).parameters] num_units = f.num_units W_init = f.W_init b_init = f.b_init activation = f.activation f = cPickle.loads(cPickle.dumps(f)) W1, b1 = [K.get_value(p).sum() for p in f.parameters] num_units1 = f.num_units W_init1 = f.W_init b_init1 = f.b_init activation1 = f.activation self.assertEqual(W1, W) self.assertEqual(b1, b) self.assertEqual(num_units1, num_units) self.assertEqual(W_init1.__name__, W_init.__name__) self.assertEqual(b_init.__name__, b_init1.__name__) self.assertEqual(activation1, activation)
def test_load_save2(self): K.set_training(True) X = K.placeholder((None, 1, 28, 28)) f = N.Dense(128, activation=K.relu) y = f(X) yT = f.T(y) f1 = K.function(X, y) f2 = K.function(X, yT) f = cPickle.loads(cPickle.dumps(f)) y = f(X) yT = f.T(y) f3 = K.function(X, y) f4 = K.function(X, yT) x = np.random.rand(12, 1, 28, 28) self.assertEqual(f1(x).sum(), f3(x).sum()) self.assertEqual(f2(x).sum(), f4(x).sum())
def str_to_func(s, sandbox=None): if isinstance(s, (tuple, list)): code, closure, defaults = s elif isinstance(s, string_types): # path to file if os.path.isfile(s): with open(s, 'rb') as f: code, closure, defaults = cPickle.load(f) else: # pickled string code, closure, defaults = cPickle.loads(s) else: raise ValueError("Unsupport str_to_func for type:%s" % type(s)) code = marshal.loads(cPickle.loads(code).tostring()) func = types.FunctionType(code=code, name=code.co_name, globals=sandbox if isinstance(sandbox, Mapping) else globals(), closure=closure, argdefs=defaults) return func
def test_pickling(self): so = ex.SomeObj(minlen=5) assert so._sav.entity.minlen == 5 pstr = pickle.dumps(so) del so so2 = pickle.loads(pstr) assert so2._sav.entity.minlen == 5 # make sure it's a weakref vh = so2._sav del so2 gc.collect() try: vh.entity assert False, 'expected exception' except EntityRefMissing: pass
def _deserialise_args(shared_objects, local_objects, serialised_args): # pragma: no cover args = [] for arg in serialised_args: if isinstance(arg, _SharedRef): key = arg.key if key in local_objects: x = local_objects[key] else: x = loads(shared_objects[arg.key]) local_objects[arg.key] = x else: x = arg args.append(x) return tuple(args) # pragma: no cover
def check_pickling(self, x_data): x = chainer.Variable(x_data) y = self.link(x) y_data1 = y.data del x, y pickled = pickle.dumps(self.link, -1) del self.link self.link = pickle.loads(pickled) x = chainer.Variable(x_data) y = self.link(x) y_data2 = y.data testing.assert_allclose(y_data1, y_data2, atol=0, rtol=0)
def test_map(self): def plus_one(x): return x + 1 N = 10 x = np.arange(N) futures_original = self.wrenexec.map(plus_one, x) futures_str = pickle.dumps(futures_original) futures = pickle.loads(futures_str) result_count = 0 while result_count < N: fs_dones, fs_notdones = pywren.wait(futures) result_count = len(fs_dones) res = np.array([f.result() for f in futures]) np.testing.assert_array_equal(res, x + 1)
def from_bytes(bytes_graph, check_version=True): """Reads a graph from bytes (the result of pickling the graph). :param bytes bytes_graph: File or filename to write :param bool check_version: Checks if the graph was produced by this version of PyBEL :return: A BEL graph :rtype: BELGraph """ graph = loads(bytes_graph) raise_for_not_bel(graph) if check_version: raise_for_old_graph(graph) return graph
def get_remote_messages(config, queue, fill=True, block=False): """ Get all messages from queue without removing from it :return: yield raw deserialized messages :rtype: json """ to_inject = [] try: while 1: message = queue.get(block=False, timeout=1) # -------------------------------------------------------------------------- # Try to deserialize # -------------------------------------------------------------------------- # Is Pickle info? try: deserialized = loads(message.body) except SerializationError: pass yield deserialized to_inject.append(deserialized) except Empty: # When Queue is Empty -> reinject all removed messages if fill is True: for x in to_inject: queue.put(x, serializer="pickle") # ----------------------------------------------------------------------
def _loads(s): return cPickle.loads(s)
def test_pickle_cpu(self): fs2_serialized = pickle.dumps(self.fs2) fs2_loaded = pickle.loads(fs2_serialized) self.assertTrue((self.fs2.b.p.data == fs2_loaded.b.p.data).all()) self.assertTrue( (self.fs2.fs1.a.p.data == fs2_loaded.fs1.a.p.data).all())
def test_pickle_gpu(self): self.fs2.to_gpu() fs2_serialized = pickle.dumps(self.fs2) fs2_loaded = pickle.loads(fs2_serialized) fs2_loaded.to_cpu() self.fs2.to_cpu() self.assertTrue((self.fs2.b.p.data == fs2_loaded.b.p.data).all()) self.assertTrue( (self.fs2.fs1.a.p.data == fs2_loaded.fs1.a.p.data).all())
def test_pickle_cpu(self): s = pickle.dumps(self.fs) fs2 = pickle.loads(s) self.check_equal_fs(self.fs, fs2)
def test_pickle_gpu(self): self.fs.to_gpu() s = pickle.dumps(self.fs) fs2 = pickle.loads(s) self.fs.to_cpu() fs2.to_cpu() self.check_equal_fs(self.fs, fs2)
def test_get(self): #mock the pick and set it to the data variable test_pickle = pickle.dumps( {pickle.dumps(self.test_key): self.test_value}, protocol=2) self.test_cache.data = pickle.loads(test_pickle) #assert self.assertEquals(self.test_cache.get(self.test_key), self.test_value) self.assertEquals(self.test_cache.get(self.bad_key), None)
def _restore_dict(self, path, read_only, cache_size): # ====== already exist ====== # if os.path.exists(path): if os.path.getsize(path) == 0: if read_only: raise Exception('File at path:"%s" has zero size, no data ' 'found in (read-only mode).' % path) file = open(str(path), mode='rb+') if file.read(len(MmapDict.HEADER)) != MmapDict.HEADER: raise Exception('Given file is not in the right format ' 'for MmapDict.') # 48 bytes for the file size max_position = int(file.read(MmapDict.SIZE_BYTES)) # length of pickled indices dictionary dict_size = int(file.read(MmapDict.SIZE_BYTES)) # read dictionary file.seek(max_position) pickled_indices = file.read(dict_size) self._indices_dict = async(lambda: cPickle.loads(pickled_indices))() # ====== create new file from scratch ====== # else: file = open(str(path), mode='wb+') file.write(MmapDict.HEADER) # just write the header header = ('%' + str(MmapDict.SIZE_BYTES) + 'd') % \ (len(MmapDict.HEADER) + MmapDict.SIZE_BYTES * 2) file.write(header.encode()) # write the length of Pickled indices dictionary data_size = ('%' + str(MmapDict.SIZE_BYTES) + 'd') % 0 file.write(data_size.encode()) file.flush() # init indices dict self._indices_dict = {} # ====== create Mmap from offset file ====== # self._file = file self._mmap = mmap.mmap(file.fileno(), length=0, offset=0, flags=mmap.MAP_SHARED) self._increased_indices_size = 0. # in MB # store all the (key, value) recently added self._cache_dict = {}
def __getitem__(self, key): if key in self._cache_dict: return self._cache_dict[key] # ====== load from mmap ====== # start, size = self.indices[key] self._mmap.seek(start) return marshal.loads(self._mmap.read(size))
def values(self): for name, (start, size) in self.indices.items(): self._mmap.seek(start) yield marshal.loads(self._mmap.read(size)) for val in self._cache_dict.values(): yield val
def items(self): for name, (start, size) in self.indices.items(): self._mmap.seek(start) yield name, marshal.loads(self._mmap.read(size)) for key, val in self._cache_dict.values(): yield key, val # =========================================================================== # SQLiteDict # ===========================================================================
def __getitem__(self, key): # ====== multiple keys select ====== # if isinstance(key, (tuple, list, np.ndarray)): query = """SELECT value FROM {tb} WHERE key IN {keyval};""" keyval = '(' + ', '.join(['"%s"' % str(k) for k in key]) + ')' self.cursor.execute( query.format(tb=self._current_table, keyval=keyval)) results = self.cursor.fetchall() # check if any not found keys if len(results) != len(key): raise KeyError("Cannot find all `key`='%s' in the dictionary." % keyval) # load binary data results = [marshal.loads(r[0]) for r in results] # ====== single key select ====== # else: key = str(key) if key in self.current_cache: return self.current_cache[key] query = """SELECT value FROM {tb} WHERE key="{keyval}" LIMIT 1;""" results = self.connection.execute( query.format(tb=self._current_table, keyval=key)).fetchone() # results = self.cursor.fetchone() if results is None: raise KeyError("Cannot find `key`='%s' in the dictionary." % key) results = marshal.loads(results[0]) return results
def items(self): for item in self.cursor.execute( """SELECT key, value from {tb};""".format(tb=self._current_table)): yield (item[0], marshal.loads(item[1])) for k, v in self.current_cache.items(): yield k, v
def test_seq(self): X = K.placeholder((None, 28, 28, 1)) f = N.Sequence([ N.Conv(8, (3, 3), strides=1, pad='same'), N.Dimshuffle(pattern=(0, 3, 1, 2)), N.Flatten(outdim=2), N.Noise(level=0.3, noise_dims=None, noise_type='gaussian'), N.Dense(128, activation=tf.nn.relu), N.Dropout(level=0.3, noise_dims=None), N.Dense(10, activation=tf.nn.softmax) ]) y = f(X) yT = f.T(y) f1 = K.function(X, y, defaults={K.is_training(): True}) f2 = K.function(X, yT, defaults={K.is_training(): False}) f = cPickle.loads(cPickle.dumps(f)) y = f(X) yT = f.T(y) f3 = K.function(X, y, defaults={K.is_training(): True}) f4 = K.function(X, yT, defaults={K.is_training(): False}) x = np.random.rand(12, 28, 28, 1) self.assertEquals(f1(x).shape, (2688, 10)) self.assertEquals(f3(x).shape, (2688, 10)) self.assertEqual(np.round(f1(x).sum(), 4), np.round(f3(x).sum(), 4)) self.assertEquals(y.get_shape().as_list(), (None, 10)) self.assertEquals(f2(x).shape, (12, 28, 28, 1)) self.assertEquals(f4(x).shape, (12, 28, 28, 1)) self.assertEqual(str(f2(x).sum())[:4], str(f4(x).sum())[:4]) self.assertEquals(yT.get_shape().as_list(), (None, 28, 28, 1))
def test_simple_rnn(self): np.random.seed(12082518) x = np.random.rand(128, 8, 32) # X = K.placeholder(shape=(None, 8, 32)) X1 = K.placeholder(shape=(None, 8, 32)) X2 = K.placeholder(shape=(None, 8, 32)) X3 = K.placeholder(shape=(None, 8, 33)) f = N.RNN(32, activation=K.relu, input_mode='skip') # y = f(X, mask=K.ones(shape=(128, 8))) graph = K.ComputationGraph(y) self.assertEqual(len(graph.inputs), 1) f1 = K.function([X], y) x1 = f1(x) # ====== different placeholder ====== # y = f(X1) f2 = K.function([X1], y) x2 = f1(x) self.assertEqual(np.sum(x1[0] == x2[0]), np.prod(x1[0].shape)) # ====== pickle load ====== # f = cPickle.loads(cPickle.dumps(f)) y = f(X2) f2 = K.function([X2], y) x3 = f2(x) self.assertEqual(np.sum(x2[0] == x3[0]), np.prod(x2[0].shape)) # ====== other input shape ====== # error_happen = False try: y = f(X3) f3 = K.function([X3], y) x3 = f3(np.random.rand(128, 8, 33)) except (ValueError, Exception): error_happen = True self.assertTrue(error_happen)
def __setstate__(self, states): (self._sandbox, self._source, self._argsmap) = states # ====== deserialize the function ====== # if isinstance(self._sandbox, string_types): self._function = cPickle.loads(self._sandbox) else: self._function, sandbox = _deserialize_function_sandbox(self._sandbox) if self._function is None: raise RuntimeError('[funtionable] Cannot find function in sandbox.') # ==================== properties ==================== #
def _read_meta_data(data_dir_path, metadata_file_path, max_number_length, rand_bbox_count): if metadata_file_path.endswith('.mat'): return parse_data(metadata_file_path, max_number_length, data_dir_path, rand_bbox_count) elif metadata_file_path.endswith('.pickle'): metadata = pickle.loads(open(metadata_file_path, 'rb').read()) return metadata['filenames'], metadata['labels'], metadata['bboxes'], metadata['sep_bboxes']
def test_combined_infer(self): from nsrec.nets import iclr_mnr, lenet_v2 from six.moves import cPickle as pickle metadata = pickle.loads(open(test_helper.train_data_dir_path + '/metadata.pickle', 'rb').read()) def test_img_data_generator(new_size, crop_bbox=False): for i in range(10): filename = '%s.png' % (i + 1) img_idx = metadata['filenames'].index(filename) bbox, label = metadata['bboxes'][img_idx], metadata['labels'][img_idx] input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename)) width, height = input_data.shape[1], input_data.shape[0] if crop_bbox: input_data = inputs.read_img(os.path.join(test_helper.train_data_dir_path, filename), bbox) input_data = inputs.normalize_img(input_data, [new_size[0], new_size[1]]) yield (input_data, (width, height), bbox, label) bbox_model = Inferrable(test_helper.output_bbox_graph_file, 'initializer-bbox', 'input-bbox', 'output-bbox') for input_data, (width, height), bbox, _ in test_img_data_generator([lenet_v2.image_width, lenet_v2.image_height]): bbox_in_rate = bbox_model.infer(np.array([input_data])) print(width, height) print('label bbox: %s, bbox: %s' % (bbox, [bbox_in_rate[0] * width, bbox_in_rate[1] * height, bbox_in_rate[2] * width, bbox_in_rate[3] * height])) nsr_model = Inferrable(test_helper.output_graph_file, 'initializer', 'input', 'output') for input_data, _, _, label in test_img_data_generator([iclr_mnr.image_width, iclr_mnr.image_height], True): pb = nsr_model.infer(np.array([input_data])) print('actual: %s, length pb: %s, numbers: %s' % ( label, np.argmax(pb[:5]), np.argmax(pb[5:].reshape([5, 11]), axis=1)))
def pickle_loads(s): if six.PY3: return pickle.loads(s, encoding='iso-8859-1') else: return pickle.loads(s)
def test_pickle_unpickle_with_reoptimization(): mode = theano.config.mode if mode in ["DEBUG_MODE", "DebugMode"]: mode = "FAST_RUN" x1 = T.fmatrix('x1') x2 = T.fmatrix('x2') x3 = theano.shared(numpy.ones((10, 10), dtype=floatX)) x4 = theano.shared(numpy.ones((10, 10), dtype=floatX)) y = T.sum(T.sum(T.sum(x1 ** 2 + x2) + x3) + x4) updates = OrderedDict() updates[x3] = x3 + 1 updates[x4] = x4 + 1 f = theano.function([x1, x2], y, updates=updates, mode=mode) # now pickle the compiled theano fn string_pkl = pickle.dumps(f, -1) in1 = numpy.ones((10, 10), dtype=floatX) in2 = numpy.ones((10, 10), dtype=floatX) # test unpickle with optimization default = theano.config.reoptimize_unpickled_function try: # the default is True theano.config.reoptimize_unpickled_function = True f_ = pickle.loads(string_pkl) assert f(in1, in2) == f_(in1, in2) finally: theano.config.reoptimize_unpickled_function = default
def test_pickle_unpickle_without_reoptimization(): mode = theano.config.mode if mode in ["DEBUG_MODE", "DebugMode"]: mode = "FAST_RUN" x1 = T.fmatrix('x1') x2 = T.fmatrix('x2') x3 = theano.shared(numpy.ones((10, 10), dtype=floatX)) x4 = theano.shared(numpy.ones((10, 10), dtype=floatX)) y = T.sum(T.sum(T.sum(x1**2 + x2) + x3) + x4) updates = OrderedDict() updates[x3] = x3 + 1 updates[x4] = x4 + 1 f = theano.function([x1, x2], y, updates=updates, mode=mode) # now pickle the compiled theano fn string_pkl = pickle.dumps(f, -1) # compute f value in1 = numpy.ones((10, 10), dtype=floatX) in2 = numpy.ones((10, 10), dtype=floatX) # test unpickle without optimization default = theano.config.reoptimize_unpickled_function try: # the default is True theano.config.reoptimize_unpickled_function = False f_ = pickle.loads(string_pkl) assert f(in1, in2) == f_(in1, in2) finally: theano.config.reoptimize_unpickled_function = default
def test_pickle_bug(self): # Regression test for bug fixed in 24d4fd291054. o = Prod() s = pickle.dumps(o, protocol=-1) o = pickle.loads(s) pickle.dumps(o)
def test_none_Constant(): """ Tests equals We had an error in the past with unpickling """ o1 = Constant(NoneTypeT(), None, name='NoneConst') o2 = Constant(NoneTypeT(), None, name='NoneConst') assert o1.equals(o2) assert NoneConst.equals(o1) assert o1.equals(NoneConst) assert NoneConst.equals(o2) assert o2.equals(NoneConst) # This trigger equals that returned the wrong answer in the past. import six.moves.cPickle as pickle import theano from theano import tensor x = tensor.vector('x') y = tensor.argmax(x) kwargs = {} # We can't pickle DebugMode if theano.config.mode in ["DebugMode", "DEBUG_MODE"]: kwargs = {'mode': 'FAST_RUN'} f = theano.function([x], [y], **kwargs) pickle.loads(pickle.dumps(f))
def test_pickle(self): a = T.scalar() # the a is for 'anonymous' (un-named). x, s = T.scalars('xs') f = function([x, In(a, value=1.0, name='a'), In(s, value=0.0, update=s + a * x, mutable=True)], s + a * x) try: # Note that here we also test protocol 0 on purpose, since it # should work (even though one should not use it). g = pickle.loads(pickle.dumps(f, protocol=0)) g = pickle.loads(pickle.dumps(f, protocol=-1)) except NotImplementedError as e: if e[0].startswith('DebugMode is not picklable'): return else: raise # if they both return, assume that they return equivalent things. # print [(k,id(k)) for k in f.finder.keys()] # print [(k,id(k)) for k in g.finder.keys()] self.assertFalse(g.container[0].storage is f.container[0].storage) self.assertFalse(g.container[1].storage is f.container[1].storage) self.assertFalse(g.container[2].storage is f.container[2].storage) self.assertFalse(x in g.container) self.assertFalse(x in g.value) self.assertFalse(g.value[1] is f.value[1]) # should not have been copied self.assertFalse(g.value[2] is f.value[2]) # should have been copied because it is mutable. self.assertFalse((g.value[2] != f.value[2]).any()) # its contents should be identical self.assertTrue(f(2, 1) == g(2)) # they should be in sync, default value should be copied. self.assertTrue(f(2, 1) == g(2)) # they should be in sync, default value should be copied. f(1, 2) # put them out of sync self.assertFalse(f(1, 2) == g(1, 2)) # they should not be equal anymore.
def test_consistent_inner_fct(self): # Test that scan does not falsely detect inconsistencies in a valid # inner graph rs = theano.sandbox.rng_mrg.MRG_RandomStreams(use_cuda=True) output, _ = theano.scan(lambda : rs.uniform((3,), dtype="float32"), n_steps=3) pickle.loads(pickle.dumps(output)) # Also ensure that, after compilation, the Scan has been moved # on the gpu fct = theano.function([], output, mode=self.mode_with_gpu) scan_nodes = scan_nodes_from_fct(fct) assert len(scan_nodes) == 1 assert self.is_scan_on_gpu(scan_nodes[0])
def test_pickle(self): self.test_file_name_property() name = "file" file1 = os.path.join(self.tmp_dir, name) wrap = FileWrapper(file1) pickled_data = pickle.dumps(wrap) wrap2 = pickle.loads(pickled_data) print(wrap2.file_path)
def test_pickle(self): rpm_version = [int(v) for v in getattr(rpm, '__version__', '0.0').split('.')] if rpm_version[0:2] < [4, 10]: warnings.warn('RPM header pickling unsupported in rpm %s' % rpm_version) return wrap = RpmWrapper(self.file_path) pickled_data = pickle.dumps(wrap) wrap2 = pickle.loads(pickled_data) self.assertEqual(wrap.name, wrap2.name)
def test_pickle(self): wrap = SimpleRpmWrapper(self.file_path) pickled_data = pickle.dumps(wrap) wrap2 = pickle.loads(pickled_data) self.assertEqual(wrap.name, wrap2.name)
def _deserialize_function_sandbox(sandbox): ''' environment : dictionary create by `serialize_sandbox` ''' import marshal import importlib environment = {} defined_function = [] main_func = None # first pass we deserialize all type except function type for name, (typ, val) in sandbox.items(): if isinstance(typ, string_types): if typ == 'None': val = None elif typ == 'edward_distribution': try: import edward val = getattr(edward.models, val) except ImportError: raise ImportError("Cannot import 'edward' library to deserialize " "the function.") # exec("from edward.models import %s as %s" % (val, name)) elif typ == 'function_type': val = types.FunctionType elif typ == 'Mapping': val = cPickle.loads(val) elif typ == 'ndarray': val = np.fromstring(val[0], dtype=val[1]) elif typ == 'module': val = importlib.import_module(val) elif 'imported_function' == typ: val = getattr(importlib.import_module(val[1]), val[0]) if '_main' in typ: main_func = val elif 'defined_function' in typ: val = str_to_func(val, globals()) if '_main' in typ: main_func = val defined_function.append(name) elif builtins.any(isinstance(typ, i) for i in _primitives): pass else: raise ValueError('Unsupport deserializing type: {}, ' 'value: {}'.format(typ, val)) environment[name] = val # ====== create all defined function ====== # # second pass, function all funciton and set it globales to new environment for name in defined_function: func = environment[name] func.__globals__.update(environment) return main_func, environment
def inference(label_fn, bboxes=False, flags=None): """ Used to infer against a nsr bbox model or a nsr model. Args: label_fn: Accept parameters (label, metadata_bbox), and return the real label bboxes: Use metadata bbox if False. Do not crop image if None. Will crop image if not False and None. flags: Returns: """ flags = flags or _FLAGS # Build the inference graph. g = tf.Graph() with g.as_default(), tf.device('/cpu:0'): model = create_model(flags, 'inference') model.build() saver = tf.train.Saver() g.finalize() model_path = tf.train.latest_checkpoint(flags.checkpoint_dir) if not model_path: tf.logging.info("Skipping inference. No checkpoint found in: %s", flags.checkpoint_dir) return with tf.Session(graph=g) as sess: # Load the model from checkpoint. tf.logging.info("Loading model from checkpoint: %s", flags.checkpoint_dir) saver.restore(sess, model_path) files = [s.strip() for s in flags.input_files.split(',')] metadata = pickle.loads(open(flags.metadata_file_path, 'rb').read()) real_labels = [] sep_bboxes = [] file_paths = [os.path.join(flags.data_dir_path, f) for f in files] data = [] for i, f in enumerate(files): metadata_idx = metadata['filenames'].index(f) label, metadata_bbox = metadata['labels'][metadata_idx], metadata['bboxes'][metadata_idx] sep_bboxes.append(metadata['sep_bboxes'][metadata_idx]) real_labels.append(label_fn(label, metadata_bbox)) bbox = (bboxes[i] if bboxes is not None else None) if bboxes is not False else metadata_bbox data.append(inputs.read_img(file_paths[i], bbox, flags.bbox_expand)) labels = model.infer(sess, data) for i in range(len(files)): tf.logging.info('inferred image %s(%s, %s): %s', files[i], real_labels[i], sep_bboxes[i], labels[i]) correct_inferences = filter(lambda i: real_labels[i] == labels[i][0], range(len(files))) correct_count = len(list(correct_inferences)) tf.logging.info('correct count: %s, rate: %.4f', correct_count, correct_count / len(files)) return labels