Python tensorflow.python.framework.ops 模块,get_collection() 实例源码
我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.get_collection()。
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.VARIABLES):
"""Gets the list of variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return.
suffix: an optional suffix for filtering the variables to return.
collection: in which collection search for. Defaults to GraphKeys.VARIABLES.
Returns:
a list of variables in collection with scope and suffix.
"""
if suffix is not None:
if ':' not in suffix:
suffix += ':'
scope = (scope or '') + '.*' + suffix
return ops.get_collection(collection, scope)
def _export_graph(graph, saver, checkpoint_path, export_dir,
default_graph_signature, named_graph_signatures,
exports_to_keep):
"""Exports graph via session_bundle, by creating a Session."""
with graph.as_default():
with tf_session.Session('') as session:
variables.initialize_local_variables()
data_flow_ops.initialize_all_tables()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
export.init(init_op=control_flow_ops.group(
variables.initialize_local_variables(),
data_flow_ops.initialize_all_tables()),
default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def get_variables(scope=None, suffix=None,
collection=ops.GraphKeys.GLOBAL_VARIABLES):
"""Gets the list of variables, filtered by scope and/or suffix.
Args:
scope: an optional scope for filtering the variables to return. Can be a
variable scope or a string.
suffix: an optional suffix for filtering the variables to return.
collection: in which collection search for. Defaults to
`GraphKeys.GLOBAL_VARIABLES`.
Returns:
a list of variables in collection with scope and suffix.
"""
if isinstance(scope, variable_scope.VariableScope):
scope = scope.name
if suffix is not None:
if ':' not in suffix:
suffix += ':'
scope = (scope or '') + '.*' + suffix
return ops.get_collection(collection, scope)
def _export_graph(graph, saver, checkpoint_path, export_dir,
default_graph_signature, named_graph_signatures,
exports_to_keep):
"""Exports graph via session_bundle, by creating a Session."""
with graph.as_default():
with tf_session.Session('') as session:
variables.local_variables_initializer()
data_flow_ops.initialize_all_tables()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
export.init(init_op=control_flow_ops.group(
variables.local_variables_initializer(),
data_flow_ops.initialize_all_tables()),
default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def testNoUpdatesWhenIsTrainingFalse(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
output = _layers.batch_norm(images, decay=0.1, is_training=False)
update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
# updates_ops are not added to UPDATE_OPS collection.
self.assertEqual(len(update_ops), 0)
# Initialize all variables
sess.run(variables_lib.global_variables_initializer())
moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
# When is_training is False batch_norm doesn't update moving_vars.
for _ in range(10):
sess.run([output])
self.assertAllClose(moving_mean.eval(), [0] * 3)
self.assertAllClose(moving_variance.eval(), [1] * 3)
def testNoneUpdatesCollectionNoTraining(self):
height, width = 3, 3
with self.test_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
output = _layers.batch_norm(
images, decay=0.1, updates_collections=None, is_training=False)
# updates_ops are not added to UPDATE_OPS collection.
self.assertEqual(ops.get_collection(ops.GraphKeys.UPDATE_OPS), [])
# Initialize all variables
sess.run(variables_lib.global_variables_initializer())
moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
mean, variance = sess.run([moving_mean, moving_variance])
# After initialization moving_mean == 0 and moving_variance == 1.
self.assertAllClose(mean, [0] * 3)
self.assertAllClose(variance, [1] * 3)
# When is_training is False batch_norm doesn't update moving_vars.
for _ in range(10):
sess.run([output])
self.assertAllClose(moving_mean.eval(), [0] * 3)
self.assertAllClose(moving_variance.eval(), [1] * 3)
def testCreateConvWithWeightDecay(self):
random_seed.set_random_seed(0)
height, width = 3, 3
with self.test_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
images, 32, [3, 3], 2, weights_regularizer=regularizer)
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
self.assertEqual(
weight_decay.op.name,
'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer')
sess.run(variables_lib.global_variables_initializer())
self.assertLessEqual(sess.run(weight_decay), 0.05)
weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1]
self.assertEqual(
weight_decay.op.name,
'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer')
self.assertLessEqual(sess.run(weight_decay), 0.05)
def testReuseConvWithWeightDecay(self):
height, width = 3, 3
with self.test_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
images, 32, [3, 3], 2, weights_regularizer=regularizer, scope='conv1')
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
layers_lib.separable_conv2d(
images,
32, [3, 3],
2,
weights_regularizer=regularizer,
scope='conv1',
reuse=True)
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
def test_relu_layer_basic_use(self):
output = layers_lib.legacy_relu(self.input, 8)
with session.Session() as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(output)
variables_lib.global_variables_initializer().run()
out_value = sess.run(output)
self.assertEqual(output.get_shape().as_list(), [2, 8])
self.assertTrue(np.all(out_value >= 0), 'Relu should have all values >= 0.')
self.assertEqual(2,
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)))
self.assertEqual(
0, len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)))
def test_regularizer_with_variable_reuse(self):
cnt = [0]
tensor = constant_op.constant(5.0)
def test_fn(_):
cnt[0] += 1
return tensor
with variable_scope.variable_scope('test') as vs:
_layers.legacy_fully_connected(self.input, 2, weight_regularizer=test_fn)
with variable_scope.variable_scope(vs, reuse=True):
_layers.legacy_fully_connected(self.input, 2, weight_regularizer=test_fn)
self.assertEqual([tensor],
ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES))
self.assertEqual(1, cnt[0])
def testScatteredEmbeddingColumnSucceedsForDNN(self):
wire_tensor = sparse_tensor.SparseTensor(
values=["omar", "stringer", "marlo", "omar"],
indices=[[0, 0], [1, 0], [1, 1], [2, 0]],
dense_shape=[3, 2])
features = {"wire": wire_tensor}
# Big enough hash space so that hopefully there is no collision
embedded_sparse = feature_column.scattered_embedding_column(
"wire", 1000, 3, layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
output = feature_column_ops.input_from_feature_columns(
features, [embedded_sparse], weight_collections=["my_collection"])
weights = ops.get_collection("my_collection")
grad = gradients_impl.gradients(output, weights)
with self.test_session():
variables_lib.global_variables_initializer().run()
gradient_values = []
# Collect the gradient from the different partitions (one in this test)
for p in range(len(grad)):
gradient_values.extend(grad[p].values.eval())
gradient_values.sort()
self.assertAllEqual(gradient_values, [0.5] * 6 + [2] * 3)
def testInputLayerWithCollectionsForDNN(self):
real_valued = feature_column.real_valued_column("price")
bucket = feature_column.bucketized_column(
real_valued, boundaries=[0., 10., 100.])
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
features = {
"price":
constant_op.constant([[20.], [110], [-3]]),
"wire":
sparse_tensor.SparseTensor(
values=["omar", "stringer", "marlo"],
indices=[[0, 0], [1, 0], [2, 0]],
dense_shape=[3, 1])
}
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
feature_column_ops.input_from_feature_columns(
features, [real_valued, bucket, embeded_sparse],
weight_collections=["my_collection"])
weights = ops.get_collection("my_collection")
# one variable for embeded sparse
self.assertEqual(1, len(weights))
def testVariablesAddedToCollection(self):
price_bucket = feature_column.bucketized_column(
feature_column.real_valued_column("price"), boundaries=[0., 10., 100.])
country = feature_column.sparse_column_with_hash_bucket(
"country", hash_bucket_size=5)
country_price = feature_column.crossed_column(
[country, price_bucket], hash_bucket_size=10)
with ops.Graph().as_default():
features = {
"price":
constant_op.constant([[20.]]),
"country":
sparse_tensor.SparseTensor(
values=["US", "SV"],
indices=[[0, 0], [0, 1]],
dense_shape=[1, 2])
}
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_price, price_bucket],
num_outputs=1,
weight_collections=["my_collection"])
weights = ops.get_collection("my_collection")
# 3 = bias + price_bucket + country_price
self.assertEqual(3, len(weights))
def benchmarkTfRNNLSTMBlockCellTraining(self):
test_configs = self._GetTestConfig()
for config_name, config in test_configs.items():
num_layers = config["num_layers"]
num_units = config["num_units"]
batch_size = config["batch_size"]
seq_length = config["seq_length"]
with ops.Graph().as_default(), ops.device("/gpu:0"):
inputs = seq_length * [
array_ops.zeros([batch_size, num_units], dtypes.float32)
]
cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop
multi_cell = core_rnn_cell_impl.MultiRNNCell(
[cell() for _ in range(num_layers)])
outputs, final_state = core_rnn.static_rnn(
multi_cell, inputs, dtype=dtypes.float32)
trainable_variables = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES)
gradients = gradients_impl.gradients([outputs, final_state],
trainable_variables)
training_op = control_flow_ops.group(*gradients)
self._BenchmarkOp(training_op, "tf_rnn_lstm_block_cell %s %s" %
(config_name, self._GetConfigDesc(config)))
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def get_global_step(graph=None):
"""Get the global step tensor.
The global step tensor must be an integer variable. We first try to find it
in the collection `GLOBAL_STEP`, or by name `global_step:0`.
Args:
graph: The graph to find the global step in. If missing, use default graph.
Returns:
The global step variable, or `None` if none was found.
Raises:
TypeError: If the global step tensor has a non-integer type, or if it is not
a `Variable`.
"""
graph = ops.get_default_graph() if graph is None else graph
global_step_tensor = None
global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
if len(global_step_tensors) == 1:
global_step_tensor = global_step_tensors[0]
elif not global_step_tensors:
try:
global_step_tensor = graph.get_tensor_by_name('global_step:0')
except KeyError:
return None
else:
logging.error('Multiple tensors in global_step collection.')
return None
assert_global_step(global_step_tensor)
return global_step_tensor
def add_model_variable(var):
"""Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.
Args:
var: a variable.
"""
if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None):
"""Map stochastic tensors to the fixed losses that depend on them.
Args:
fixed_losses: a list of `Tensor`s.
stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses.
If `None`, all `StochasticTensor`s in the graph will be used.
Returns:
A dict `dependencies` that maps `StochasticTensor` objects to subsets of
`fixed_losses`.
If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
is a direct path from `st.value()` to `loss` in the graph.
"""
stoch_value_collection = stochastic_tensors or ops.get_collection(
stochastic_tensor.STOCHASTIC_TENSOR_COLLECTION)
if not stoch_value_collection:
return {}
stoch_value_map = dict(
(node.value(), node) for node in stoch_value_collection)
# Step backwards through the graph to see which surrogate losses correspond
# to which fixed_losses.
#
# TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the
# same frame.
stoch_dependencies_map = collections.defaultdict(set)
for loss in fixed_losses:
boundary = set([loss])
while boundary:
edge = boundary.pop()
edge_stoch_node = stoch_value_map.get(edge, None)
if edge_stoch_node:
stoch_dependencies_map[edge_stoch_node].add(loss)
boundary.update(edge.op.inputs)
return stoch_dependencies_map
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
"""Gets the list of losses from the loss_collection.
Args:
scope: an optional scope for filtering the losses to return.
loss_collection: Optional losses collection.
Returns:
a list of loss tensors.
"""
return ops.get_collection(loss_collection, scope)
def get_regularization_losses(scope=None):
"""Gets the regularization losses.
Args:
scope: an optional scope for filtering the losses to return.
Returns:
A list of loss variables.
"""
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
def QueueRunners(session):
"""Creates a context manager that handles starting and stopping queue runners.
Args:
session: the currently running session.
Yields:
a context in which queues are run.
Raises:
NestedQueueRunnerError: if a QueueRunners context is nested within another.
"""
if not _queue_runner_lock.acquire(False):
raise NestedQueueRunnerError('QueueRunners cannot be nested')
coord = coordinator.Coordinator()
threads = []
for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(session,
coord=coord,
daemon=True,
start=True))
try:
yield
finally:
coord.request_stop()
coord.join(threads, stop_grace_period_secs=120)
_queue_runner_lock.release()
def dnn(tensor_in, hidden_units, activation=nn.relu, dropout=None):
"""Creates fully connected deep neural network subgraph.
This is deprecated. Please use contrib.layers.dnn instead.
Args:
tensor_in: tensor or placeholder for input features.
hidden_units: list of counts of hidden units in each layer.
activation: activation function between layers. Can be None.
dropout: if not None, will add a dropout layer with given probability.
Returns:
A tensor which would be a deep neural network.
"""
logging.warning("learn.ops.dnn is deprecated, \
please use contrib.layers.dnn.")
with vs.variable_scope('dnn'):
for i, n_units in enumerate(hidden_units):
with vs.variable_scope('layer%d' % i):
# Weight initializer was set to None to replicate the behavior of
# rnn_cell.linear. Using fully_connected's default initializer gets
# slightly worse quality results on unit tests.
tensor_in = layers.legacy_fully_connected(
tensor_in,
n_units,
weight_init=None,
weight_collections=['dnn_weights'],
bias_collections=['dnn_biases'])
if activation is not None:
tensor_in = activation(tensor_in)
if dropout is not None:
is_training = array_ops_.squeeze(ops.get_collection('IS_TRAINING'))
tensor_in = control_flow_ops.cond(
is_training,
lambda: dropout_ops.dropout(tensor_in, prob=(1.0 - dropout)),
lambda: tensor_in)
return tensor_in
def every_n_step_begin(self, step):
super(LoggingTrainable, self).every_n_step_begin(step)
# Get a list of trainable variables at the begining of every N steps.
# We cannot get this in __init__ because train_op has not been generated.
trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
scope=self._scope)
self._names = {}
for var in trainables:
self._names[var.name] = var.value().name
return list(self._names.values())
def _centered_bias_step(self, targets, features):
centered_bias = ops.get_collection(self._centered_bias_weight_collection)
batch_size = array_ops.shape(targets)[0]
logits = array_ops.reshape(
array_ops.tile(centered_bias[0], [batch_size]),
[batch_size, self._target_column.num_label_columns])
with ops.name_scope(None, "centered_bias", (targets, features)):
training_loss = self._target_column.training_loss(
logits, targets, features)
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
# single variable.
return training.AdagradOptimizer(0.1).minimize(
training_loss, var_list=centered_bias)
def _centered_bias_step(targets, loss_fn, num_label_columns):
centered_bias = ops.get_collection("centered_bias")
batch_size = array_ops.shape(targets)[0]
logits = array_ops.reshape(
array_ops.tile(centered_bias[0], [batch_size]),
[batch_size, num_label_columns])
loss = loss_fn(logits, targets)
return train.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias)
def _get_vars(self):
if self._get_feature_columns():
return ops.get_collection(self._scope)
return []
def _get_first_op_from_collection(collection_name):
"""Get first element from the collection."""
elements = ops.get_collection(collection_name)
if elements is not None:
if elements:
return elements[0]
return None
def _get_first_op_from_collection(collection_name):
elements = ops.get_collection(collection_name)
if elements:
return elements[0]
return None
def _get_or_default(arg_name, collection_key, default_constructor):
"""Get from cache or create a default operation."""
elements = ops.get_collection(collection_key)
if elements:
if len(elements) > 1:
raise RuntimeError('More than one item in the collection "%s". '
'Please indicate which one to use by passing it to '
'the tf.Scaffold constructor as: '
'tf.Scaffold(%s=item to use)', collection_key,
arg_name)
return elements[0]
op = default_constructor()
if op is not None:
ops.add_to_collection(collection_key, op)
return op
def apply_regularization(regularizer, weights_list=None):
"""Returns the summed penalty by applying `regularizer` to the `weights_list`.
Adding a regularization penalty over the layer weights and embedding weights
can help prevent overfitting the training data. Regularization over layer
biases is less common/useful, but assuming proper data preprocessing/mean
subtraction, it usually shouldn't hurt much either.
Args:
regularizer: A function that takes a single `Tensor` argument and returns
a scalar `Tensor` output.
weights_list: List of weights `Tensors` or `Variables` to apply
`regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
`None`.
Returns:
A scalar representing the overall regularization penalty.
Raises:
ValueError: If `regularizer` does not return a scalar output, or if we find
no weights.
"""
if not weights_list:
weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
if not weights_list:
raise ValueError('No weights to regularize.')
with ops.name_scope('get_regularization_penalty',
values=weights_list) as scope:
penalties = [regularizer(w) for w in weights_list]
for p in penalties:
if p.get_shape().ndims != 0:
raise ValueError('regularizer must return a scalar Tensor instead of a '
'Tensor with rank %d.' % p.get_shape().ndims)
summed_penalty = math_ops.add_n(penalties, name=scope)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
return summed_penalty
def is_summary_tag_unique(tag):
"""Checks if a summary tag is unique.
Args:
tag: The tag to use
Returns:
True if the summary tag is unique.
"""
existing_tags = [tensor_util.constant_value(summary.op.inputs[0])
for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)]
existing_tags = [name.tolist() if isinstance(name, np.ndarray) else name
for name in existing_tags]
return tag.encode() not in existing_tags
def summarize_collection(collection, name_filter=None,
summarizer=summarize_tensor):
"""Summarize a graph collection of tensors, possibly filtered by name."""
tensors = []
for op in ops.get_collection(collection):
if name_filter is None or re.match(name_filter, op.op.name):
tensors.append(op)
return summarize_tensors(tensors, summarizer)
# Utility functions for commonly used collections
def add_model_variable(var):
"""Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.
Args:
var: a variable.
"""
if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None):
"""Map stochastic tensors to the fixed losses that depend on them.
Args:
fixed_losses: a list of `Tensor`s.
stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses.
If `None`, all `StochasticTensor`s in the graph will be used.
Returns:
A dict `dependencies` that maps `StochasticTensor` objects to subsets of
`fixed_losses`.
If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there
is a direct path from `st.value()` to `loss` in the graph.
"""
stoch_value_collection = stochastic_tensors or ops.get_collection(
stochastic_tensor.STOCHASTIC_TENSOR_COLLECTION)
if not stoch_value_collection:
return {}
stoch_value_map = dict(
(node.value(), node) for node in stoch_value_collection)
# Step backwards through the graph to see which surrogate losses correspond
# to which fixed_losses.
#
# TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the
# same frame.
stoch_dependencies_map = collections.defaultdict(set)
for loss in fixed_losses:
boundary = set([loss])
while boundary:
edge = boundary.pop()
edge_stoch_node = stoch_value_map.get(edge, None)
if edge_stoch_node:
stoch_dependencies_map[edge_stoch_node].add(loss)
boundary.update(edge.op.inputs)
return stoch_dependencies_map
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
"""Gets the list of losses from the loss_collection.
Args:
scope: an optional scope for filtering the losses to return.
loss_collection: Optional losses collection.
Returns:
a list of loss tensors.
"""
return ops.get_collection(loss_collection, scope)
def get_regularization_losses(scope=None):
"""Gets the regularization losses.
Args:
scope: an optional scope for filtering the losses to return.
Returns:
A list of loss variables.
"""
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
def QueueRunners(session):
"""Creates a context manager that handles starting and stopping queue runners.
Args:
session: the currently running session.
Yields:
a context in which queues are run.
Raises:
NestedQueueRunnerError: if a QueueRunners context is nested within another.
"""
if not _queue_runner_lock.acquire(False):
raise NestedQueueRunnerError('QueueRunners cannot be nested')
coord = coordinator.Coordinator()
threads = []
for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(session,
coord=coord,
daemon=True,
start=True))
try:
yield
finally:
coord.request_stop()
coord.join(threads, stop_grace_period_secs=120)
_queue_runner_lock.release()
def every_n_step_begin(self, step):
super(LoggingTrainable, self).every_n_step_begin(step)
# Get a list of trainable variables at the begining of every N steps.
# We cannot get this in __init__ because train_op has not been generated.
trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
scope=self._scope)
self._names = {}
for var in trainables:
self._names[var.name] = var.value().name
return list(self._names.values())
def _get_qr(self, name):
for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
if qr.name == name:
return qr