我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.get_collection()。
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES): """Gets the list of variables, filtered by scope and/or suffix. Args: scope: an optional scope for filtering the variables to return. suffix: an optional suffix for filtering the variables to return. collection: in which collection search for. Defaults to GraphKeys.VARIABLES. Returns: a list of variables in collection with scope and suffix. """ if suffix is not None: if ':' not in suffix: suffix += ':' scope = (scope or '') + '.*' + suffix return ops.get_collection(collection, scope)
def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature, named_graph_signatures, exports_to_keep): """Exports graph via session_bundle, by creating a Session.""" with graph.as_default(): with tf_session.Session('') as session: variables.initialize_local_variables() data_flow_ops.initialize_all_tables() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) export.init(init_op=control_flow_ops.group( variables.initialize_local_variables(), data_flow_ops.initialize_all_tables()), default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS)) return export.export(export_dir, contrib_variables.get_global_step(), session, exports_to_keep=exports_to_keep)
def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) if len(sharded_variable) == 1: return sharded_variable[0] concat_name = name + "/concat" concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): if value.name == concat_full_name: return value concat_variable = array_ops.concat(0, sharded_variable, name=concat_name) ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) return concat_variable
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.GLOBAL_VARIABLES): """Gets the list of variables, filtered by scope and/or suffix. Args: scope: an optional scope for filtering the variables to return. Can be a variable scope or a string. suffix: an optional suffix for filtering the variables to return. collection: in which collection search for. Defaults to `GraphKeys.GLOBAL_VARIABLES`. Returns: a list of variables in collection with scope and suffix. """ if isinstance(scope, variable_scope.VariableScope): scope = scope.name if suffix is not None: if ':' not in suffix: suffix += ':' scope = (scope or '') + '.*' + suffix return ops.get_collection(collection, scope)
def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature, named_graph_signatures, exports_to_keep): """Exports graph via session_bundle, by creating a Session.""" with graph.as_default(): with tf_session.Session('') as session: variables.local_variables_initializer() data_flow_ops.initialize_all_tables() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) export.init(init_op=control_flow_ops.group( variables.local_variables_initializer(), data_flow_ops.initialize_all_tables()), default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS)) return export.export(export_dir, contrib_variables.get_global_step(), session, exports_to_keep=exports_to_keep)
def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) if len(sharded_variable) == 1: return sharded_variable[0] concat_name = name + "/concat" concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): if value.name == concat_full_name: return value concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) return concat_variable
def testNoUpdatesWhenIsTrainingFalse(self): height, width = 3, 3 with self.test_session() as sess: image_shape = (10, height, width, 3) image_values = np.random.rand(*image_shape) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) output = _layers.batch_norm(images, decay=0.1, is_training=False) update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) # updates_ops are not added to UPDATE_OPS collection. self.assertEqual(len(update_ops), 0) # Initialize all variables sess.run(variables_lib.global_variables_initializer()) moving_mean = variables.get_variables('BatchNorm/moving_mean')[0] moving_variance = variables.get_variables('BatchNorm/moving_variance')[0] mean, variance = sess.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 3) self.assertAllClose(variance, [1] * 3) # When is_training is False batch_norm doesn't update moving_vars. for _ in range(10): sess.run([output]) self.assertAllClose(moving_mean.eval(), [0] * 3) self.assertAllClose(moving_variance.eval(), [1] * 3)
def testNoneUpdatesCollectionNoTraining(self): height, width = 3, 3 with self.test_session() as sess: image_shape = (10, height, width, 3) image_values = np.random.rand(*image_shape) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) output = _layers.batch_norm( images, decay=0.1, updates_collections=None, is_training=False) # updates_ops are not added to UPDATE_OPS collection. self.assertEqual(ops.get_collection(ops.GraphKeys.UPDATE_OPS), []) # Initialize all variables sess.run(variables_lib.global_variables_initializer()) moving_mean = variables.get_variables('BatchNorm/moving_mean')[0] moving_variance = variables.get_variables('BatchNorm/moving_variance')[0] mean, variance = sess.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 3) self.assertAllClose(variance, [1] * 3) # When is_training is False batch_norm doesn't update moving_vars. for _ in range(10): sess.run([output]) self.assertAllClose(moving_mean.eval(), [0] * 3) self.assertAllClose(moving_variance.eval(), [1] * 3)
def testCreateConvWithWeightDecay(self): random_seed.set_random_seed(0) height, width = 3, 3 with self.test_session() as sess: images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(0.01) layers_lib.separable_conv2d( images, 32, [3, 3], 2, weights_regularizer=regularizer) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2) weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual( weight_decay.op.name, 'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer') sess.run(variables_lib.global_variables_initializer()) self.assertLessEqual(sess.run(weight_decay), 0.05) weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1] self.assertEqual( weight_decay.op.name, 'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer') self.assertLessEqual(sess.run(weight_decay), 0.05)
def testReuseConvWithWeightDecay(self): height, width = 3, 3 with self.test_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(0.01) layers_lib.separable_conv2d( images, 32, [3, 3], 2, weights_regularizer=regularizer, scope='conv1') self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2) layers_lib.separable_conv2d( images, 32, [3, 3], 2, weights_regularizer=regularizer, scope='conv1', reuse=True) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
def test_relu_layer_basic_use(self): output = layers_lib.legacy_relu(self.input, 8) with session.Session() as sess: with self.assertRaises(errors_impl.FailedPreconditionError): sess.run(output) variables_lib.global_variables_initializer().run() out_value = sess.run(output) self.assertEqual(output.get_shape().as_list(), [2, 8]) self.assertTrue(np.all(out_value >= 0), 'Relu should have all values >= 0.') self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))) self.assertEqual( 0, len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)))
def test_regularizer_with_variable_reuse(self): cnt = [0] tensor = constant_op.constant(5.0) def test_fn(_): cnt[0] += 1 return tensor with variable_scope.variable_scope('test') as vs: _layers.legacy_fully_connected(self.input, 2, weight_regularizer=test_fn) with variable_scope.variable_scope(vs, reuse=True): _layers.legacy_fully_connected(self.input, 2, weight_regularizer=test_fn) self.assertEqual([tensor], ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)) self.assertEqual(1, cnt[0])
def testScatteredEmbeddingColumnSucceedsForDNN(self): wire_tensor = sparse_tensor.SparseTensor( values=["omar", "stringer", "marlo", "omar"], indices=[[0, 0], [1, 0], [1, 1], [2, 0]], dense_shape=[3, 2]) features = {"wire": wire_tensor} # Big enough hash space so that hopefully there is no collision embedded_sparse = feature_column.scattered_embedding_column( "wire", 1000, 3, layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY) output = feature_column_ops.input_from_feature_columns( features, [embedded_sparse], weight_collections=["my_collection"]) weights = ops.get_collection("my_collection") grad = gradients_impl.gradients(output, weights) with self.test_session(): variables_lib.global_variables_initializer().run() gradient_values = [] # Collect the gradient from the different partitions (one in this test) for p in range(len(grad)): gradient_values.extend(grad[p].values.eval()) gradient_values.sort() self.assertAllEqual(gradient_values, [0.5] * 6 + [2] * 3)
def testInputLayerWithCollectionsForDNN(self): real_valued = feature_column.real_valued_column("price") bucket = feature_column.bucketized_column( real_valued, boundaries=[0., 10., 100.]) hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10) features = { "price": constant_op.constant([[20.], [110], [-3]]), "wire": sparse_tensor.SparseTensor( values=["omar", "stringer", "marlo"], indices=[[0, 0], [1, 0], [2, 0]], dense_shape=[3, 1]) } embeded_sparse = feature_column.embedding_column(hashed_sparse, 10) feature_column_ops.input_from_feature_columns( features, [real_valued, bucket, embeded_sparse], weight_collections=["my_collection"]) weights = ops.get_collection("my_collection") # one variable for embeded sparse self.assertEqual(1, len(weights))
def testVariablesAddedToCollection(self): price_bucket = feature_column.bucketized_column( feature_column.real_valued_column("price"), boundaries=[0., 10., 100.]) country = feature_column.sparse_column_with_hash_bucket( "country", hash_bucket_size=5) country_price = feature_column.crossed_column( [country, price_bucket], hash_bucket_size=10) with ops.Graph().as_default(): features = { "price": constant_op.constant([[20.]]), "country": sparse_tensor.SparseTensor( values=["US", "SV"], indices=[[0, 0], [0, 1]], dense_shape=[1, 2]) } feature_column_ops.weighted_sum_from_feature_columns( features, [country_price, price_bucket], num_outputs=1, weight_collections=["my_collection"]) weights = ops.get_collection("my_collection") # 3 = bias + price_bucket + country_price self.assertEqual(3, len(weights))
def benchmarkTfRNNLSTMBlockCellTraining(self): test_configs = self._GetTestConfig() for config_name, config in test_configs.items(): num_layers = config["num_layers"] num_units = config["num_units"] batch_size = config["batch_size"] seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/gpu:0"): inputs = seq_length * [ array_ops.zeros([batch_size, num_units], dtypes.float32) ] cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop multi_cell = core_rnn_cell_impl.MultiRNNCell( [cell() for _ in range(num_layers)]) outputs, final_state = core_rnn.static_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) gradients = gradients_impl.gradients([outputs, final_state], trainable_variables) training_op = control_flow_ops.group(*gradients) self._BenchmarkOp(training_op, "tf_rnn_lstm_block_cell %s %s" % (config_name, self._GetConfigDesc(config)))
def _get_arg_stack(): stack = ops.get_collection(_ARGSTACK_KEY) if stack: return stack[0] else: stack = [{}] ops.add_to_collection(_ARGSTACK_KEY, stack) return stack
def get_global_step(graph=None): """Get the global step tensor. The global step tensor must be an integer variable. We first try to find it in the collection `GLOBAL_STEP`, or by name `global_step:0`. Args: graph: The graph to find the global step in. If missing, use default graph. Returns: The global step variable, or `None` if none was found. Raises: TypeError: If the global step tensor has a non-integer type, or if it is not a `Variable`. """ graph = ops.get_default_graph() if graph is None else graph global_step_tensor = None global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) if len(global_step_tensors) == 1: global_step_tensor = global_step_tensors[0] elif not global_step_tensors: try: global_step_tensor = graph.get_tensor_by_name('global_step:0') except KeyError: return None else: logging.error('Multiple tensors in global_step collection.') return None assert_global_step(global_step_tensor) return global_step_tensor
def add_model_variable(var): """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection. Args: var: a variable. """ if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES): ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None): """Map stochastic tensors to the fixed losses that depend on them. Args: fixed_losses: a list of `Tensor`s. stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses. If `None`, all `StochasticTensor`s in the graph will be used. Returns: A dict `dependencies` that maps `StochasticTensor` objects to subsets of `fixed_losses`. If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there is a direct path from `st.value()` to `loss` in the graph. """ stoch_value_collection = stochastic_tensors or ops.get_collection( stochastic_tensor.STOCHASTIC_TENSOR_COLLECTION) if not stoch_value_collection: return {} stoch_value_map = dict( (node.value(), node) for node in stoch_value_collection) # Step backwards through the graph to see which surrogate losses correspond # to which fixed_losses. # # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the # same frame. stoch_dependencies_map = collections.defaultdict(set) for loss in fixed_losses: boundary = set([loss]) while boundary: edge = boundary.pop() edge_stoch_node = stoch_value_map.get(edge, None) if edge_stoch_node: stoch_dependencies_map[edge_stoch_node].add(loss) boundary.update(edge.op.inputs) return stoch_dependencies_map
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): """Gets the list of losses from the loss_collection. Args: scope: an optional scope for filtering the losses to return. loss_collection: Optional losses collection. Returns: a list of loss tensors. """ return ops.get_collection(loss_collection, scope)
def get_regularization_losses(scope=None): """Gets the regularization losses. Args: scope: an optional scope for filtering the losses to return. Returns: A list of loss variables. """ return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
def QueueRunners(session): """Creates a context manager that handles starting and stopping queue runners. Args: session: the currently running session. Yields: a context in which queues are run. Raises: NestedQueueRunnerError: if a QueueRunners context is nested within another. """ if not _queue_runner_lock.acquire(False): raise NestedQueueRunnerError('QueueRunners cannot be nested') coord = coordinator.Coordinator() threads = [] for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): threads.extend(qr.create_threads(session, coord=coord, daemon=True, start=True)) try: yield finally: coord.request_stop() coord.join(threads, stop_grace_period_secs=120) _queue_runner_lock.release()
def dnn(tensor_in, hidden_units, activation=nn.relu, dropout=None): """Creates fully connected deep neural network subgraph. This is deprecated. Please use contrib.layers.dnn instead. Args: tensor_in: tensor or placeholder for input features. hidden_units: list of counts of hidden units in each layer. activation: activation function between layers. Can be None. dropout: if not None, will add a dropout layer with given probability. Returns: A tensor which would be a deep neural network. """ logging.warning("learn.ops.dnn is deprecated, \ please use contrib.layers.dnn.") with vs.variable_scope('dnn'): for i, n_units in enumerate(hidden_units): with vs.variable_scope('layer%d' % i): # Weight initializer was set to None to replicate the behavior of # rnn_cell.linear. Using fully_connected's default initializer gets # slightly worse quality results on unit tests. tensor_in = layers.legacy_fully_connected( tensor_in, n_units, weight_init=None, weight_collections=['dnn_weights'], bias_collections=['dnn_biases']) if activation is not None: tensor_in = activation(tensor_in) if dropout is not None: is_training = array_ops_.squeeze(ops.get_collection('IS_TRAINING')) tensor_in = control_flow_ops.cond( is_training, lambda: dropout_ops.dropout(tensor_in, prob=(1.0 - dropout)), lambda: tensor_in) return tensor_in
def every_n_step_begin(self, step): super(LoggingTrainable, self).every_n_step_begin(step) # Get a list of trainable variables at the begining of every N steps. # We cannot get this in __init__ because train_op has not been generated. trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope) self._names = {} for var in trainables: self._names[var.name] = var.value().name return list(self._names.values())
def _centered_bias_step(self, targets, features): centered_bias = ops.get_collection(self._centered_bias_weight_collection) batch_size = array_ops.shape(targets)[0] logits = array_ops.reshape( array_ops.tile(centered_bias[0], [batch_size]), [batch_size, self._target_column.num_label_columns]) with ops.name_scope(None, "centered_bias", (targets, features)): training_loss = self._target_column.training_loss( logits, targets, features) # Learn central bias by an optimizer. 0.1 is a convervative lr for a # single variable. return training.AdagradOptimizer(0.1).minimize( training_loss, var_list=centered_bias)
def _centered_bias_step(targets, loss_fn, num_label_columns): centered_bias = ops.get_collection("centered_bias") batch_size = array_ops.shape(targets)[0] logits = array_ops.reshape( array_ops.tile(centered_bias[0], [batch_size]), [batch_size, num_label_columns]) loss = loss_fn(logits, targets) return train.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias)
def _get_vars(self): if self._get_feature_columns(): return ops.get_collection(self._scope) return []
def _get_first_op_from_collection(collection_name): """Get first element from the collection.""" elements = ops.get_collection(collection_name) if elements is not None: if elements: return elements[0] return None
def _get_first_op_from_collection(collection_name): elements = ops.get_collection(collection_name) if elements: return elements[0] return None
def _get_or_default(arg_name, collection_key, default_constructor): """Get from cache or create a default operation.""" elements = ops.get_collection(collection_key) if elements: if len(elements) > 1: raise RuntimeError('More than one item in the collection "%s". ' 'Please indicate which one to use by passing it to ' 'the tf.Scaffold constructor as: ' 'tf.Scaffold(%s=item to use)', collection_key, arg_name) return elements[0] op = default_constructor() if op is not None: ops.add_to_collection(collection_key, op) return op
def apply_regularization(regularizer, weights_list=None): """Returns the summed penalty by applying `regularizer` to the `weights_list`. Adding a regularization penalty over the layer weights and embedding weights can help prevent overfitting the training data. Regularization over layer biases is less common/useful, but assuming proper data preprocessing/mean subtraction, it usually shouldn't hurt much either. Args: regularizer: A function that takes a single `Tensor` argument and returns a scalar `Tensor` output. weights_list: List of weights `Tensors` or `Variables` to apply `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if `None`. Returns: A scalar representing the overall regularization penalty. Raises: ValueError: If `regularizer` does not return a scalar output, or if we find no weights. """ if not weights_list: weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS) if not weights_list: raise ValueError('No weights to regularize.') with ops.name_scope('get_regularization_penalty', values=weights_list) as scope: penalties = [regularizer(w) for w in weights_list] for p in penalties: if p.get_shape().ndims != 0: raise ValueError('regularizer must return a scalar Tensor instead of a ' 'Tensor with rank %d.' % p.get_shape().ndims) summed_penalty = math_ops.add_n(penalties, name=scope) ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty) return summed_penalty
def is_summary_tag_unique(tag): """Checks if a summary tag is unique. Args: tag: The tag to use Returns: True if the summary tag is unique. """ existing_tags = [tensor_util.constant_value(summary.op.inputs[0]) for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)] existing_tags = [name.tolist() if isinstance(name, np.ndarray) else name for name in existing_tags] return tag.encode() not in existing_tags
def summarize_collection(collection, name_filter=None, summarizer=summarize_tensor): """Summarize a graph collection of tensors, possibly filtered by name.""" tensors = [] for op in ops.get_collection(collection): if name_filter is None or re.match(name_filter, op.op.name): tensors.append(op) return summarize_tensors(tensors, summarizer) # Utility functions for commonly used collections
def _get_qr(self, name): for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): if qr.name == name: return qr