我们从Python开源项目中,提取了以下43个代码示例,用于说明如何使用tensorflow.python.framework.ops.get_default_graph()。
def get_uid(prefix=''): """Associates a string prefix with an integer counter in a TensorFlow graph. Arguments: prefix: String prefix to index. Returns: Unique integer ID. Example:
>>> get_uid('dense') 1 >>> get_uid('dense') 2 ``` """ graph = ops.get_default_graph() layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph] layer_name_uids[prefix] += 1 return layer_name_uids[prefix]
```
def create_global_step(graph=None): """Create global step tensor in graph. Args: graph: The graph in which to create the global step. If missing, use default graph. Returns: Global step tensor. Raises: ValueError: if global step key is already defined. """ graph = ops.get_default_graph() if graph is None else graph if get_global_step(graph) is not None: raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): collections = [ops.GraphKeys.VARIABLES, ops.GraphKeys.GLOBAL_STEP] return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64, initializer=init_ops.zeros_initializer, trainable=False, collections=collections)
def _as_graph_element(obj): """Retrieves Graph element.""" graph = ops.get_default_graph() if not isinstance(obj, six.string_types): if not hasattr(obj, "graph") or obj.graph != graph: raise ValueError("Passed %s should have graph attribute that is equal " "to current graph %s." % (obj, graph)) return obj if ":" in obj: element = graph.as_graph_element(obj) else: element = graph.as_graph_element(obj + ":0") # Check that there is no :1 (e.g. it's single output). try: graph.as_graph_element(obj + ":1") except (KeyError, ValueError): pass else: raise ValueError("Name %s is ambiguous, " "as this `Operation` has multiple outputs " "(at least 2)." % obj) return element
def get(logdir): """Returns the SummaryWriter for the specified directory. Args: logdir: str, name of the directory. Returns: A `SummaryWriter`. """ with SummaryWriterCache._lock: if logdir not in SummaryWriterCache._cache: SummaryWriterCache._cache[logdir] = summary_io.SummaryWriter( logdir, graph=ops.get_default_graph()) return SummaryWriterCache._cache[logdir] # Backward compatible interface. Remove?
def create_global_step(graph=None): """Create global step tensor in graph. Args: graph: The graph in which to create the global step. If missing, use default graph. Returns: Global step tensor. Raises: ValueError: if global step key is already defined. """ graph = ops.get_default_graph() if graph is None else graph if get_global_step(graph) is not None: raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): collections = [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP] return variable(ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64, initializer=init_ops.zeros_initializer, trainable=False, collections=collections)
def get_or_create_eval_step(): """Gets or creates the eval step `Tensor`. Returns: A `Tensor` representing a counter for the evaluation step. Raises: ValueError: If multiple `Tensors` have been added to the `tf.GraphKeys.EVAL_STEP` collection. """ graph = ops.get_default_graph() eval_steps = graph.get_collection(ops.GraphKeys.EVAL_STEP) if len(eval_steps) == 1: return eval_steps[0] elif len(eval_steps) > 1: raise ValueError( 'Multiple tensors added to tf.GraphKeys.EVAL_STEP') else: counter = variables.local_variable(0.0, name='eval_step') graph.add_to_collection(ops.GraphKeys.EVAL_STEP, counter) return counter
def _get_train_ops(self, features, labels): """See base class.""" features = self._get_feature_dict(features) features, labels = self._feature_engineering_fn(features, labels) logits = self._logits(features, is_training=True) def _make_training_op(training_loss): global_step = contrib_variables.get_global_step() assert global_step linear_train_step = self._linear_model.get_train_step(training_loss) dnn_train_step = (self._dnn_model.get_train_step(training_loss) if self._dnn_model else []) with ops.control_dependencies(linear_train_step + dnn_train_step): with ops.get_default_graph().colocate_with(global_step): return state_ops.assign_add(global_step, 1).op return self._head.head_ops(features, labels, model_fn.ModeKeys.TRAIN, _make_training_op, logits=logits)
def export_meta_graph(self, filename=None, collection_list=None, as_text=False): """Writes `MetaGraphDef` to save_path/filename. Args: filename: Optional meta_graph filename including the path. collection_list: List of string keys to collect. as_text: If `True`, writes the meta_graph as an ASCII proto. Returns: A `MetaGraphDef` proto. """ return export_meta_graph(filename=filename, graph_def=ops.get_default_graph().as_graph_def(), saver_def=self.saver_def, collection_list=collection_list, as_text=as_text)
def _assert_summaries(self, output_dir, writer, expected_summaries=None, expected_graphs=None, expected_meta_graphs=None, expected_session_logs=None): self.assertTrue(isinstance(writer, testing.FakeSummaryWriter)) writer.assert_summaries( self, expected_logdir=output_dir, expected_graph=ops.get_default_graph(), expected_summaries=expected_summaries, expected_added_graphs=expected_graphs, expected_added_meta_graphs=expected_meta_graphs, expected_session_logs=expected_session_logs) # TODO(ptucker): Test number and contents of checkpoint files.
def testCustomConfig(self): test_random_seed = 5783452 class TestInput(object): def __init__(self): self.random_seed = 0 def config_test_input_fn(self): self.random_seed = ops.get_default_graph().seed return constant_op.constant([[1.]]), constant_op.constant([1.]) config = run_config.RunConfig(tf_random_seed=test_random_seed) test_input = TestInput() est = estimator.Estimator(model_fn=linear_model_fn, config=config) est.fit(input_fn=test_input.config_test_input_fn, steps=1) # If input_fn ran, it will have given us the random seed set on the graph. self.assertEquals(test_random_seed, test_input.random_seed)
def __init__(self, fetches, contraction_fn): """Creates an _ElementFetchMapper. This is the fetch mapper used for leaves in the fetch struct. Because of the expansions mechanism, a leaf can actually fetch more than one tensor. Also note that the fetches here can be just strings (tensor or op names) or any other object that the graph knows how to convert to a tensor, such as a Variable. So we have to run each fetch through `as_graph_element()` to get the corresponding tensor or op. Args: fetches: List of objects, as returned by a fetch_fn defined in _REGISTERED_EXPANSIONS. contraction_fn: Callable as returned by a fetch_fn. """ self._unique_fetches = [] for fetch in fetches: try: self._unique_fetches.append(ops.get_default_graph().as_graph_element( fetch, allow_tensor=True, allow_operation=True)) except TypeError as e: raise TypeError('Fetch argument %r has invalid type %r, ' 'must be a string or Tensor. (%s)' % (fetch, type(fetch), str(e))) except ValueError as e: raise ValueError('Fetch argument %r cannot be interpreted as a ' 'Tensor. (%s)' % (fetch, str(e))) except KeyError as e: raise ValueError('Fetch argument %r cannot be interpreted as a ' 'Tensor. (%s)' % (fetch, str(e))) self._contraction_fn = contraction_fn
def clear_session(): """Destroys the current TF graph and creates a new one. Useful to avoid clutter from old models / layers. """ global _SESSION global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned ops.reset_default_graph() reset_uids() _SESSION = None phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase') _GRAPH_LEARNING_PHASES = {} _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
def learning_phase(): """Returns the learning phase flag. The learning phase flag is a bool tensor (0 = test, 1 = train) to be passed as input to any Keras function that uses a different behavior at train time and test time. Returns: Learning phase (scalar integer tensor or Python integer). """ graph = ops.get_default_graph() if graph not in _GRAPH_LEARNING_PHASES: phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase') _GRAPH_LEARNING_PHASES[graph] = phase return _GRAPH_LEARNING_PHASES[graph]
def __init__(self, layers=None, name=None): self.layers = [] # Stack of layers. self.model = None # Internal Model instance. self.inputs = [] # List of input tensors self.outputs = [] # List of length 1: the output tensor (unique). self._trainable = True self._initial_weights = None # Model attributes. self.inbound_nodes = [] self.outbound_nodes = [] self.built = False # Set model name. if not name: prefix = 'sequential_' name = prefix + str(K.get_uid(prefix)) self.name = name # The following properties are not actually used by Keras; # they exist for compatibility with TF's variable scoping mechanism. self._updates = [] self._scope = None self._reuse = None self._base_name = name self._graph = ops.get_default_graph() # Add to the model any layers passed to the constructor. if layers: for layer in layers: self.add(layer)
def get_global_step(graph=None): """Get the global step tensor. The global step tensor must be an integer variable. We first try to find it in the collection `GLOBAL_STEP`, or by name `global_step:0`. Args: graph: The graph to find the global step in. If missing, use default graph. Returns: The global step variable, or `None` if none was found. Raises: TypeError: If the global step tensor has a non-integer type, or if it is not a `Variable`. """ graph = ops.get_default_graph() if graph is None else graph global_step_tensor = None global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) if len(global_step_tensors) == 1: global_step_tensor = global_step_tensors[0] elif not global_step_tensors: try: global_step_tensor = graph.get_tensor_by_name('global_step:0') except KeyError: return None else: logging.error('Multiple tensors in global_step collection.') return None assert_global_step(global_step_tensor) return global_step_tensor
def get_or_create_global_step(graph=None): """Returns and create (if necessary) the global step variable. Args: graph: The graph in which to create the global step. If missing, use default graph. Returns: the tensor representing the global step variable. """ graph = ops.get_default_graph() if graph is None else graph globalstep = get_global_step(graph) if globalstep is None: globalstep = create_global_step(graph) return globalstep
def begin(self, max_steps=None): super(GraphDump, self).begin(max_steps=max_steps) self._tensors = [] graph = ops.get_default_graph() graph_def = graph.as_graph_def() for node in graph_def.node: if node.op in self._ignore_ops: continue logging.info("op=%s name=%s.", node.op, node.name) try: self._tensors.append(graph.get_tensor_by_name(node.name + ":0")) except KeyError: pass
def _get_train_ops(self, features, targets): """See base class.""" global_step = contrib_variables.get_global_step() assert global_step features = self._get_feature_dict(features) logits = self._logits(features, is_training=True) if self._enable_centered_bias: centered_bias_step = [self._centered_bias_step(targets, features)] else: centered_bias_step = [] with ops.control_dependencies(centered_bias_step): training_loss = self._target_column.training_loss(logits, targets, features) weighted_average_loss = self._target_column.loss(logits, targets, features) logging_ops.scalar_summary("loss", weighted_average_loss) linear_train_step = self._linear_model.get_train_step(training_loss) dnn_train_step = (self._dnn_model.get_train_step(training_loss) if self._dnn_model else []) with ops.control_dependencies(linear_train_step + dnn_train_step): with ops.get_default_graph().colocate_with(global_step): return state_ops.assign_add(global_step, 1).op, weighted_average_loss
def _get_model_fn(self, model_fn): """Backward compatibility way of adding class weight and IS_TRAINING. TODO(ipolosukhin): Remove this function after new layers are available. Specifically: * dropout and batch norm should work via update ops. * class weights should be retrieved from weights column or hparams. Args: model_fn: Core model function. Returns: Model function. """ def _model_fn(features, targets, mode): """Model function.""" ops.get_default_graph().add_to_collection('IS_TRAINING', mode == 'train') if self.class_weight is not None: constant_op.constant(self.class_weight, name='class_weight') predictions, loss = model_fn(features, targets) if isinstance(self.learning_rate, types.FunctionType): learning_rate = self.learning_rate(contrib_framework.get_global_step()) else: learning_rate = self.learning_rate if isinstance(self.optimizer, types.FunctionType): optimizer = self.optimizer(learning_rate) else: optimizer = self.optimizer train_op = layers.optimize_loss( loss, contrib_framework.get_global_step(), learning_rate=learning_rate, optimizer=optimizer, clip_gradients=self.clip_gradients) return predictions, loss, train_op return _model_fn
def _loss_to_train_op(self, loss): """Map `loss` to a training op.""" with ops.name_scope('loss_to_train_op'): trainable_variables = ops.get_default_graph().get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) global_step = contrib_framework.get_global_step() gradients = self._optimizer.compute_gradients( loss=loss, var_list=trainable_variables) processed_gradients = self._process_gradients(gradients) return self._optimizer.apply_gradients( processed_gradients, global_step=global_step)
def finalize(self): """Creates operations if needed and finalizes the graph.""" if self._init_op is None: self._init_op = Scaffold._get_or_default( 'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables) if self._ready_op is None: self._ready_op = Scaffold._get_or_default( 'ready_op', ops.GraphKeys.READY_OP, variables.report_uninitialized_variables) if self._local_init_op is None: self._local_init_op = Scaffold._get_or_default( 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op) if self._summary_op is None: self._summary_op = Scaffold._get_or_default( 'summary_op', ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries) # pylint: disable=g-long-lambda if self._saver is None: self._saver = Scaffold._get_or_default( 'saver', ops.GraphKeys.SAVERS, lambda: training_saver.Saver(sharded=True, allow_empty=True)) # pylint: enable=g-long-lambda self._saver.build() ops.get_default_graph().finalize() return self
def _get_session_manager(self): if self._session_manager: return self._session_manager self._session_manager = sm.SessionManager( local_init_op=self._scaffold.local_init_op, ready_op=self._scaffold.ready_op, graph=ops.get_default_graph()) return self._session_manager
def before_run(self, run_context): # pylint: disable=unused-argument if self._last_saved_time is None: # Write graph in the first call training_util.write_graph( ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir, "graph.pbtxt") self._summary_writer.add_graph(ops.get_default_graph()) return SessionRunArgs(self._global_step_tensor)
def _get_train_ops(self, features, labels): global_step = contrib_variables.get_global_step() assert global_step logits = self._model.build_model( features, self._feature_columns, is_training=True) model_fn_ops = self._head.head_ops(features, labels, tf.contrib.learn.ModeKeys.TRAIN, _noop_training_fn, logits=logits) train_step = self._model.get_train_step(model_fn_ops.loss) with ops.control_dependencies(train_step): with ops.get_default_graph().colocate_with(global_step): return state_ops.assign_add(global_step, 1).op, model_fn_ops.loss
def average_name(self, var): """Returns the name of the `Variable` holding the average for `var`. The typical scenario for `ExponentialMovingAverage` is to compute moving averages of variables during training, and restore the variables from the computed moving averages during evaluations. To restore variables, you have to know the name of the shadow variables. That name and the original variable can then be passed to a `Saver()` object to restore the variable from the moving average value with: `saver = tf.train.Saver({ema.average_name(var): var})` `average_name()` can be called whether or not `apply()` has been called. Args: var: A `Variable` object. Returns: A string: The name of the variable that will be used or was used by the `ExponentialMovingAverage class` to hold the moving average of `var`. """ if var in self._averages: return self._averages[var].op.name return ops.get_default_graph().unique_name( var.op.name + "/" + self._name, mark_as_used=False)
def __init__(self, session_init_fn, graph=None): self._session_init_fn = session_init_fn if graph is None: graph = ops.get_default_graph() self._graph = graph
def restore_graph(s): log.info('restore_graph') g = ops.get_default_graph() graph_def = graph_pb2.GraphDef() graph_def.ParseFromString(s) # print_nodes(graph_def) # print ('before', len(g.as_graph_def().node)) importer.import_graph_def(graph_def, name='restore') # print ('after', len(g.as_graph_def().node)) # print_nodes(g.as_graph_def()) # t = g.get_tensor_by_name('restore/y1:0') return graph_def
def create_global_step(graph=None): """Create global step tensor in graph. Args: graph: The graph in which to create the global step. If missing, use default graph. Returns: Global step tensor. Raises: ValueError: if global step key is already defined. """ graph = ops.get_default_graph() if graph is None else graph if get_global_step(graph) is not None: raise ValueError('"global_step" already exists.') # Create in proper graph and base name_scope. with graph.as_default() as g, g.name_scope(None): collections = [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP] return variable( ops.GraphKeys.GLOBAL_STEP, shape=[], dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, collections=collections)
def _base_model_fn(features, labels, mode, params): model = params['model'] feature_columns = params['feature_columns'] head = params['head'] if mode == model_fn_lib.ModeKeys.TRAIN: logits = model.build_model(features, feature_columns, is_training=True) elif mode == model_fn_lib.ModeKeys.EVAL: logits = model.build_model(features, feature_columns, is_training=False) else: raise NotImplementedError def _train_op_fn(loss): global_step = contrib_variables.get_global_step() assert global_step train_step = model.get_train_step(loss) with ops.control_dependencies(train_step): with ops.get_default_graph().colocate_with(global_step): return state_ops.assign_add(global_step, 1).op return head.create_model_fn_ops( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits)
def begin(self): self._loss_tensor = ops.get_default_graph().get_tensor_by_name( KMeansClustering.LOSS_OP_NAME + ':0') assert self._loss_tensor is not None
def as_default(self): """Returns a context manager that makes this object the default session. Use with the `with` keyword to specify that calls to @{tf.Operation.run} or @{tf.Tensor.eval} should be executed in this session. ```python c = tf.constant(..) sess = tf.Session() with sess.as_default(): assert tf.get_default_session() is sess print(c.eval())
To get the current default session, use @{tf.get_default_session}. *N.B.* The `as_default` context manager *does not* close the session when you exit the context, and you must close the session explicitly. ```python c = tf.constant(...) sess = tf.Session() with sess.as_default(): print(c.eval()) # ... with sess.as_default(): print(c.eval()) sess.close() ``` Alternatively, you can use `with tf.Session():` to create a session that is automatically closed on exiting the context, including when an uncaught exception is raised. *N.B.* The default session is a property of the current thread. If you create a new thread, and wish to use the default session in that thread, you must explicitly add a `with sess.as_default():` in that thread's function. *N.B.* Entering a `with sess.as_default():` block does not affect the current default graph. If you are using multiple graphs, and `sess.graph` is different from the value of @{tf.get_default_graph}, you must explicitly enter a `with sess.graph.as_default():` block to make `sess.graph` the default graph. Returns: A context manager using this session as the default session. """ return ops.default_session(self)
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. The resulting tensor is populated with values of type `dtype`, as specified by arguments `value` and (optionally) `shape` (see examples below). The argument `value` can be a constant value, or a list of values of type `dtype`. If `value` is a list, then the length of the list must be less than or equal to the number of elements implied by the `shape` argument (if specified). In the case where the list length is less than the number of elements specified by `shape`, the last element in the list will be used to fill the remaining entries. The argument `shape` is optional. If present, it specifies the dimensions of the resulting tensor. If not present, the shape of `value` is used. If the argument `dtype` is not specified, then the type is inferred from the type of `value`. For example: ```python # Constant 1-D Tensor populated with value list. tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7] # Constant 2-D tensor populated with scalar value -1. tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.] [-1. -1. -1.]]
Args: value: A constant value (or list) of output type dtype.
dtype
dtype: The type of the elements of the resulting tensor. shape: Optional dimensions of resulting tensor. name: Optional name for the tensor. verify_shape: Boolean that enables verification of a shape of values.
Returns: A Constant Tensor. """ g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) const_tensor = g.create_op( "Const", [], [dtype_value.type], attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0] return const_tensor ```
def run(self, num_batches=None, graph=None, session=None, start_queues=True, initialize_variables=True, **kwargs): """Builds and runs the columns of the `DataFrame` and yields batches. This is a generator that yields a dictionary mapping column names to evaluated columns. Args: num_batches: the maximum number of batches to produce. If none specified, the returned value will iterate through infinite batches. graph: the `Graph` in which the `DataFrame` should be built. session: the `Session` in which to run the columns of the `DataFrame`. start_queues: if true, queues will be started before running and halted after producting `n` batches. initialize_variables: if true, variables will be initialized. **kwargs: Additional keyword arguments e.g. `num_epochs`. Yields: A dictionary, mapping column names to the values resulting from running each column for a single batch. """ if graph is None: graph = ops.get_default_graph() with graph.as_default(): if session is None: session = sess.Session() self_built = self.build(**kwargs) keys = list(self_built.keys()) cols = list(self_built.values()) if initialize_variables: if variables.local_variables(): session.run(variables.initialize_local_variables()) if variables.all_variables(): session.run(variables.initialize_all_variables()) if start_queues: coord = coordinator.Coordinator() threads = qr.start_queue_runners(sess=session, coord=coord) i = 0 while num_batches is None or i < num_batches: i += 1 try: values = session.run(cols) yield collections.OrderedDict(zip(keys, values)) except errors.OutOfRangeError: break if start_queues: coord.request_stop() coord.join(threads)
def _as_meta_graph_def(meta_info_def=None, graph_def=None, saver_def=None, collection_list=None): """Construct and returns a `MetaGraphDef` protocol buffer. Args: meta_info_def: `MetaInfoDef` protocol buffer. graph_def: `GraphDef` protocol buffer. saver_def: `SaverDef` protocol buffer. collection_list: List of string keys to collect. Returns: MetaGraphDef protocol buffer. Raises: TypeError: If the arguments are not of the correct proto buffer type. """ # Type check. if meta_info_def and not isinstance(meta_info_def, meta_graph_pb2.MetaGraphDef.MetaInfoDef): raise TypeError("meta_info_def must be of type MetaInfoDef, not %s", type(meta_info_def)) if graph_def and not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError("graph_def must be of type GraphDef, not %s", type(graph_def)) if saver_def and not isinstance(saver_def, saver_pb2.SaverDef): raise TypeError("saver_def must be of type SaverDef, not %s", type(saver_def)) # Creates a MetaGraphDef proto. meta_graph_def = meta_graph_pb2.MetaGraphDef() # Adds meta_info_def. if meta_info_def: meta_graph_def.meta_info_def.MergeFrom(meta_info_def) # Adds graph_def or the default. if not graph_def: meta_graph_def.graph_def.MergeFrom(ops.get_default_graph().as_graph_def()) else: meta_graph_def.graph_def.MergeFrom(graph_def) # Fills in meta_info_def.stripped_op_list using the ops from graph_def. # pylint: disable=g-explicit-length-test if len(meta_graph_def.meta_info_def.stripped_op_list.op) == 0: meta_graph_def.meta_info_def.stripped_op_list.MergeFrom( stripped_op_list_for_graph(meta_graph_def.graph_def)) # pylint: enable=g-explicit-length-test # Adds saver_def. if saver_def: meta_graph_def.saver_def.MergeFrom(saver_def) # Adds collection_list. if collection_list: clist = collection_list else: clist = ops.get_all_collection_keys() for ctype in clist: _add_collection_def(meta_graph_def, ctype) return meta_graph_def
def _import_meta_graph_def(meta_graph_def): """Recreates a Graph saved in a `MetaGraphDef` proto. This function adds all the nodes from the meta graph def proto to the current graph, recreates all the collections, and returns a saver from saver_def. Args: meta_graph_def: `MetaGraphDef` protocol buffer. Returns: A saver constructed from `saver_def` in `meta_graph_def` or None. A None value is returned if no variables exist in the `meta_graph_def` (i.e., no variables to restore). """ # Gathers the list of nodes we are interested in. importer.import_graph_def(meta_graph_def.graph_def, name="") # Restores all the other collections. for key, col_def in meta_graph_def.collection_def.items(): kind = col_def.WhichOneof("kind") if kind is None: logging.error("Cannot identify data type for collection %s. Skipping." % key) continue from_proto = ops.get_from_proto_function(key) if from_proto: assert kind == "bytes_list" proto_type = ops.get_collection_proto_type(key) for value in col_def.bytes_list.value: proto = proto_type() proto.ParseFromString(value) ops.add_to_collection(key, from_proto(proto)) else: field = getattr(col_def, kind) if kind == "node_list": for value in field.value: col_op = ops.get_default_graph().as_graph_element(value) ops.add_to_collection(key, col_op) elif kind == "int64_list": # NOTE(opensource): This force conversion is to work around the fact # that Python2 distinguishes between int and long, while Python3 has # only int. for value in field.value: ops.add_to_collection(key, int(value)) else: for value in field.value: ops.add_to_collection(key, value) if meta_graph_def.HasField("saver_def"): return Saver(saver_def=meta_graph_def.saver_def) else: if variables.all_variables(): # Return the default saver instance for all graph variables. return Saver() else: # If not graph variables exist, then a Saver cannot be constructed. logging.info("Saver not created because there are no variables in the" " graph to restore") return None
def run(self, num_batches=None, graph=None, session=None, start_queues=True, initialize_variables=True, **kwargs): """Builds and runs the columns of the `DataFrame` and yields batches. This is a generator that yields a dictionary mapping column names to evaluated columns. Args: num_batches: the maximum number of batches to produce. If none specified, the returned value will iterate through infinite batches. graph: the `Graph` in which the `DataFrame` should be built. session: the `Session` in which to run the columns of the `DataFrame`. start_queues: if true, queues will be started before running and halted after producting `n` batches. initialize_variables: if true, variables will be initialized. **kwargs: Additional keyword arguments e.g. `num_epochs`. Yields: A dictionary, mapping column names to the values resulting from running each column for a single batch. """ if graph is None: graph = ops.get_default_graph() with graph.as_default(): if session is None: session = sess.Session() self_built = self.build(**kwargs) keys = list(self_built.keys()) cols = list(self_built.values()) if initialize_variables: if variables.local_variables(): session.run(variables.local_variables_initializer()) if variables.global_variables(): session.run(variables.global_variables_initializer()) if start_queues: coord = coordinator.Coordinator() threads = qr.start_queue_runners(sess=session, coord=coord) i = 0 while num_batches is None or i < num_batches: i += 1 try: values = session.run(cols) yield collections.OrderedDict(zip(keys, values)) except errors.OutOfRangeError: break if start_queues: coord.request_stop() coord.join(threads)
def experimental_jit_scope(compile_ops=True): """Enable or disable JIT compilation of operators within the scope. NOTE: This is an experimental feature. The compilation is a hint and only supported on a best-effort basis. Example usage: with tf.contrib.compiler.experimental_jit_scope(): c = tf.matmul(a, b) # compiled with tf.contrib.compiler.experimental_jit_scope(compile_ops=False): d = tf.matmul(a, c) # not compiled with tf.contrib.compiler.experimental_jit_scope( compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. Args: compile_ops: Whether to enable or disable compilation in the scope. Either a Python bool, or a callable that accepts the parameter `node_def` and returns a python bool. Yields: The current scope, enabling or disabling compilation. """ if callable(compile_ops): def xla_compile(node_def): return attr_value_pb2.AttrValue(b=compile_ops(node_def)) else: xla_compile = attr_value_pb2.AttrValue(b=compile_ops) attrs = {"_XlaCompile": xla_compile} # TODO(ebrevdo): Keep a global XlaScope counter and here create a # special scope that checks if already within a xla scope or creates # a new one with a new scope string. Add a new attr _XlaScope # taking this string. Modify the xla fusion to respect scope # boundaries. Modify gradients_impl to either create a new gradient # scope with a suffix from the fw scope or to try to fuse with # the fw scope of the given op. Should be backwards compatible to # avoid having to modify Defun compilation attributes. # pylint: disable=protected-access with ops.get_default_graph()._attr_scope(attrs): yield # pylint: enable=protected-access