Python tensorflow.python.framework.ops 模块,add_to_collection() 实例源码
我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.add_to_collection()。
def register_prior(variational, prior):
"""Associate a variational `DistributionTensor` with a `Distribution` prior.
This is a helper function used in conjunction with `elbo` that allows users
to specify the mapping between variational distributions and their priors
without having to pass in `variational_with_prior` explicitly.
Args:
variational: `DistributionTensor` q(Z). Approximating distribution.
prior: `Distribution` p(Z). Prior distribution.
Returns:
None
Raises:
ValueError: if variational is not a `DistributionTensor` or `prior` is not
a `Distribution`.
"""
if not isinstance(variational, st.StochasticTensor):
raise TypeError("variational must be a DistributionTensor")
if not isinstance(prior, distributions.BaseDistribution):
raise TypeError("prior must be a BaseDistribution")
ops.add_to_collection(VI_PRIORS, (variational, prior))
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
# pylint: disable=protected-access
table._check_table_dtypes(self._keys.dtype, self._values.dtype)
with ops.name_scope(self._name, values=[table]) as scope:
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
self._values,
name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def register_prior(variational, prior):
"""Associate a variational `StochasticTensor` with a `Distribution` prior.
This is a helper function used in conjunction with `elbo` that allows users
to specify the mapping between variational distributions and their priors
without having to pass in `variational_with_prior` explicitly.
Args:
variational: `StochasticTensor` q(Z). Approximating distribution.
prior: `Distribution` p(Z). Prior distribution.
Returns:
None
Raises:
ValueError: if variational is not a `StochasticTensor` or `prior` is not
a `Distribution`.
"""
if not isinstance(variational, st.StochasticTensor):
raise TypeError("variational must be a StochasticTensor")
if not isinstance(prior, distributions.Distribution):
raise TypeError("prior must be a Distribution")
ops.add_to_collection(VI_PRIORS, (variational, prior))
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
# pylint: disable=protected-access
table._check_table_dtypes(self._keys.dtype, self._values.dtype)
with ops.name_scope(self._name, values=[table]) as scope:
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
self._values,
name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
def test_train_override_saver(self):
with tf.Graph().as_default() as g, self.test_session(g):
saver = tf.test.mock.Mock()
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
with tf.control_dependencies(self._build_inference_graph()):
train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=tf.constant(2.0),
steps=1)
self.assertEqual(2.0, loss)
self._assert_ckpt(self._output_dir, False)
self.assertTrue(saver.build.called)
self.assertEqual(1, saver.save.call_count)
# TODO(ispir): remove following tests after deprecated train.
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def fc_layers(net,
scope,
end_points_collection,
num_classes=1000,
is_training=True,
dropout_keep_prob=0.5,
spatial_squeeze=True,
name_prefix=None):
full_scope_name = lambda scope_name: scope_name if name_prefix is None else '%s_%s' % (name_prefix, scope_name)
# Use conv2d instead of fully_connected layers.
with slim.arg_scope([slim.conv2d],
weights_initializer=trunc_normal(0.005),
biases_initializer=tf.constant_initializer(0.1),
outputs_collections=[end_points_collection]):
net = slim.conv2d(net, num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
biases_initializer=tf.zeros_initializer(),
scope=full_scope_name('fc8'))
if spatial_squeeze:
net = tf.squeeze(net, [1, 2], name=full_scope_name('fc8/squeezed'))
ops.add_to_collection(end_points_collection, net)
return net, end_points_collection
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
if len(sharded_variable) == 1:
return sharded_variable[0]
concat_name = name + "/concat"
concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
if value.name == concat_full_name:
return value
concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
concat_variable)
return concat_variable
def _create_slots(self, grads_and_vars):
""""""
for g_t, x_tm1 in grads_and_vars:
if self._save_step:
self._ones_slot(x_tm1, 's', self._name)
if self._save_grad:
self._ones_slot(x_tm1, 'g', self._name)
if self._chi > 0:
ops.add_to_collection(self._zeros_slot(x_tm1, 'x', self._name),
ops.GraphKeys.MOVING_AVERAGE_VARIABLES)
if isinstance(g_t, ops.Tensor):
self._zero_slot(x_tm1, 'x/tm1', self._name)
else:
self._zeros_idx_slot(x_tm1, 'x/tm1', self._name)
#=============================================================
def _create_slots(self, grads_and_vars):
""""""
for g_t, x_tm1 in grads_and_vars:
if self._save_step:
self._ones_slot(x_tm1, 's', self._name)
if self._save_grad:
self._ones_slot(x_tm1, 'g', self._name)
if self._chi > 0:
ops.add_to_collection(self._zeros_slot(x_tm1, 'x', self._name),
ops.GraphKeys.MOVING_AVERAGE_VARIABLES)
if isinstance(g_t, ops.Tensor):
self._zero_slot(x_tm1, 'x/tm1', self._name)
else:
self._zeros_idx_slot(x_tm1, 'x/tm1', self._name)
#=============================================================
def register_prior(variational, prior):
"""Associate a variational `StochasticTensor` with a `Distribution` prior.
This is a helper function used in conjunction with `elbo` that allows users
to specify the mapping between variational distributions and their priors
without having to pass in `variational_with_prior` explicitly.
Args:
variational: `StochasticTensor` q(Z). Approximating distribution.
prior: `Distribution` p(Z). Prior distribution.
Returns:
None
Raises:
ValueError: if variational is not a `StochasticTensor` or `prior` is not
a `Distribution`.
"""
if not isinstance(variational, st.StochasticTensor):
raise TypeError("variational must be a StochasticTensor")
if not isinstance(prior, distribution.Distribution):
raise TypeError("prior must be a Distribution")
ops.add_to_collection(VI_PRIORS, (variational, prior))
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
table.check_table_dtypes(self._keys.dtype, self._values.dtype)
with ops.name_scope(self._name, values=[table]) as scope:
# pylint: disable=protected-access
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
self._values,
name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
def test_evaluate_ready_for_local_init(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
w = variables.Variable(
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
ready_for_local_init_op)
_ = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': v},
max_steps=1)
def test_evaluate_with_saver(self):
with ops.Graph().as_default() as g, self.test_session(g):
_, _, out = self._build_inference_graph()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
results = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': out},
max_steps=1)
self.assertEqual(({'a': 6.0}, 0), results)
self._assert_summaries(
self._output_dir,
writer,
expected_summaries={0: {
'a': 6.0
}},
expected_session_logs=[])
def test_train_override_saver(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
self._assert_ckpt(self._output_dir, False)
real_saver = saver_lib.Saver()
saver = test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
steps=1)
self.assertEqual(2.0, loss)
self._assert_ckpt(self._output_dir, True)
self.assertTrue(saver.build.called)
self.assertEqual(1, saver.save.call_count)
# TODO(ispir): remove following tests after deprecated train.
def testUpdateOpFromCollection(self):
optimizers = [
"SGD", gradient_descent.GradientDescentOptimizer,
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
update_op = state_ops.assign(update_var, 20)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
train = optimizers_lib.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer=optimizer)
variables.global_variables_initializer().run()
session.run(train, feed_dict={x: 5})
var_value, update_var_value, global_step_value = session.run(
[var, update_var, global_step])
self.assertEqual(var_value, 9.5)
self.assertEqual(update_var_value, 20)
self.assertEqual(global_step_value, 1)
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def build_image_summary(self):
"""
A simple graph for write image summary
:return:
"""
log_image_data = tf.placeholder(tf.uint8, [None, None, 3])
log_image_name = tf.placeholder(tf.string)
# import tensorflow.python.ops.gen_logging_ops as logging_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.framework import ops as _ops
log_image = gen_logging_ops._image_summary(log_image_name, tf.expand_dims(log_image_data, 0), max_images=1)
_ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image)
# log_image = tf.summary.image(log_image_name, tf.expand_dims(log_image_data, 0), max_outputs=1)
return log_image, log_image_data, log_image_name
def add_model_variable(var):
"""Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.
Args:
var: a variable.
"""
if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
def __init__(self):
# Add self to this graph's Stochsatic Tensor collection for
# purposes of later performing correct surrogate loss calculation.
ops.add_to_collection(STOCHASTIC_TENSOR_COLLECTION, self)
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
"""Adds a externally defined loss to the collection of losses.
Args:
loss: A loss `Tensor`.
loss_collection: Optional collection to add the loss to.
"""
if loss_collection:
ops.add_to_collection(loss_collection, loss)
def test_evaluate_with_saver(self):
with tf.Graph().as_default() as g, self.test_session(g):
_, _, out = self._build_inference_graph()
tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver())
self._assert_summaries(self._output_dir, expected_session_logs=[])
results = learn.graph_actions.evaluate(
g, output_dir=self._output_dir, checkpoint_path=None,
eval_dict={'a': out}, max_steps=1)
self.assertEqual(({'a': 6.0}, 0), results)
self._assert_summaries(
self._output_dir, expected_summaries={0: {'a': 6.0}},
expected_session_logs=[])
def test_train_worker_monitor(self):
# We need to explicitly set device due to check on non-chief workers
# requiring all variables to have a device assigned.
with tf.Graph().as_default() as g, g.device('/cpu:0'):
global_step = tf.contrib.framework.create_global_step(g)
train_op = tf.assign_add(global_step, 1)
loss_op = tf.constant(2.0)
tf.scalar_summary('loss', loss_op)
# Add explicit "local" init op to initialize all variables
# as there's no chief to init here.
init_op = variables.initialize_all_variables()
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
# Create worker monitors where one should be active on the worker
# and the other chief exclusive.
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
with self.test_session(g):
loss = learn.graph_actions.train(
g, output_dir=self._output_dir,
global_step_tensor=global_step,
train_op=train_op, loss_op=loss_op,
supervisor_is_chief=False, steps=1,
monitors=[chief_exclusive_monitor, all_workers_monitor])
self.assertEqual(2.0, loss)
self.assertTrue(not chief_exclusive_monitor.is_active and
all_workers_monitor.is_active,
'Only non-chief runnable monitor must have been active.')
self.assertTrue(not chief_exclusive_monitor.has_step and
all_workers_monitor.has_step,
'Only non-chief runnable monitor must have a step.')
def _get_saver():
"""Lazy init and return saver."""
saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
if saver is not None:
if saver:
saver = saver[0]
else:
saver = None
if saver is None and variables.all_variables():
saver = tf_saver.Saver()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
return saver
def _get_saver():
"""Lazy init and return saver."""
saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
if saver is None and variables.all_variables():
saver = tf_saver.Saver()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
return saver
def _get_local_init_op():
local_init_op = _get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
op_list = [variables.initialize_local_variables(),
data_flow_ops.initialize_all_tables()]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
return local_init_op
def _get_or_default(arg_name, collection_key, default_constructor):
"""Get from cache or create a default operation."""
elements = ops.get_collection(collection_key)
if elements:
if len(elements) > 1:
raise RuntimeError('More than one item in the collection "%s". '
'Please indicate which one to use by passing it to '
'the tf.Scaffold constructor as: '
'tf.Scaffold(%s=item to use)', collection_key,
arg_name)
return elements[0]
op = default_constructor()
if op is not None:
ops.add_to_collection(collection_key, op)
return op
def apply_regularization(regularizer, weights_list=None):
"""Returns the summed penalty by applying `regularizer` to the `weights_list`.
Adding a regularization penalty over the layer weights and embedding weights
can help prevent overfitting the training data. Regularization over layer
biases is less common/useful, but assuming proper data preprocessing/mean
subtraction, it usually shouldn't hurt much either.
Args:
regularizer: A function that takes a single `Tensor` argument and returns
a scalar `Tensor` output.
weights_list: List of weights `Tensors` or `Variables` to apply
`regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
`None`.
Returns:
A scalar representing the overall regularization penalty.
Raises:
ValueError: If `regularizer` does not return a scalar output, or if we find
no weights.
"""
if not weights_list:
weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
if not weights_list:
raise ValueError('No weights to regularize.')
with ops.name_scope('get_regularization_penalty',
values=weights_list) as scope:
penalties = [regularizer(w) for w in weights_list]
for p in penalties:
if p.get_shape().ndims != 0:
raise ValueError('regularizer must return a scalar Tensor instead of a '
'Tensor with rank %d.' % p.get_shape().ndims)
summed_penalty = math_ops.add_n(penalties, name=scope)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
return summed_penalty
def add_model_variable(var):
"""Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.
Args:
var: a variable.
"""
if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
"""Adds a externally defined loss to the collection of losses.
Args:
loss: A loss `Tensor`.
loss_collection: Optional collection to add the loss to.
"""
if loss_collection:
ops.add_to_collection(loss_collection, loss)
def initialize(self, table):
"""Initializes the table from a text file.
Args:
table: The table to be initialized.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
# pylint: disable=protected-access
table._check_table_dtypes(self.key_dtype, self.value_dtype)
with ops.name_scope(self._name, "text_file_init", [table]) as scope:
filename = ops.convert_to_tensor(self._filename,
dtypes.string,
name="asset_filepath")
init_op = gen_data_flow_ops._initialize_table_from_text_file(
table.table_ref,
filename,
self._key_index,
self._value_index,
-1 if self._vocab_size is None else self._vocab_size,
self._delimiter,
name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
return init_op
def test_evaluate_with_saver(self):
with tf.Graph().as_default() as g, self.test_session(g):
_, _, out = self._build_inference_graph()
tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver())
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
results = learn.graph_actions.evaluate(
g, output_dir=self._output_dir, checkpoint_path=None,
eval_dict={'a': out}, max_steps=1)
self.assertEqual(({'a': 6.0}, 0), results)
self._assert_summaries(
self._output_dir, writer, expected_summaries={0: {'a': 6.0}},
expected_session_logs=[])
def test_train_worker_monitor(self):
# We need to explicitly set device due to check on non-chief workers
# requiring all variables to have a device assigned.
with tf.Graph().as_default() as g, g.device('/cpu:0'):
global_step = tf.contrib.framework.create_global_step(g)
train_op = tf.assign_add(global_step, 1)
loss_op = tf.constant(2.0)
tf.summary.scalar('loss', loss_op)
# Add explicit "local" init op to initialize all variables
# as there's no chief to init here.
init_op = variables.global_variables_initializer()
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
# Create worker monitors where one should be active on the worker
# and the other chief exclusive.
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
with self.test_session(g):
loss = learn.graph_actions.train(
g, output_dir=self._output_dir,
global_step_tensor=global_step,
train_op=train_op, loss_op=loss_op,
supervisor_is_chief=False, steps=1,
monitors=[chief_exclusive_monitor, all_workers_monitor])
self.assertEqual(2.0, loss)
self.assertTrue(not chief_exclusive_monitor.is_active and
all_workers_monitor.is_active,
'Only non-chief runnable monitor must have been active.')
self.assertTrue(not chief_exclusive_monitor.has_step and
all_workers_monitor.has_step,
'Only non-chief runnable monitor must have a step.')
def _get_saver():
"""Lazy init and return saver."""
saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
if saver is not None:
if saver:
saver = saver[0]
else:
saver = None
if saver is None and variables.global_variables():
saver = tf_saver.Saver(write_version=saver_pb2.SaverDef.V1)
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
return saver
def _get_saver():
"""Lazy init and return saver."""
saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
if saver is None and variables.global_variables():
saver = tf_saver.Saver()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
return saver
def _get_local_init_op():
local_init_op = _get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
op_list = [variables.local_variables_initializer(),
data_flow_ops.initialize_all_tables()]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
return local_init_op
def apply_regularization(regularizer, weights_list=None):
"""Returns the summed penalty by applying `regularizer` to the `weights_list`.
Adding a regularization penalty over the layer weights and embedding weights
can help prevent overfitting the training data. Regularization over layer
biases is less common/useful, but assuming proper data preprocessing/mean
subtraction, it usually shouldn't hurt much either.
Args:
regularizer: A function that takes a single `Tensor` argument and returns
a scalar `Tensor` output.
weights_list: List of weights `Tensors` or `Variables` to apply
`regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
`None`.
Returns:
A scalar representing the overall regularization penalty.
Raises:
ValueError: If `regularizer` does not return a scalar output, or if we find
no weights.
"""
if not weights_list:
weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
if not weights_list:
raise ValueError('No weights to regularize.')
with ops.name_scope('get_regularization_penalty',
values=weights_list) as scope:
penalties = [regularizer(w) for w in weights_list]
penalties = [
p if p is not None else constant_op.constant(0.0) for p in penalties
]
for p in penalties:
if p.get_shape().ndims != 0:
raise ValueError('regularizer must return a scalar Tensor instead of a '
'Tensor with rank %d.' % p.get_shape().ndims)
summed_penalty = math_ops.add_n(penalties, name=scope)
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
return summed_penalty
def build_image_summary(self):
"""
A simple graph for write image summary
:return:
"""
log_image_data = tf.placeholder(tf.uint8, [None, None, 3])
log_image_name = tf.placeholder(tf.string)
# import tensorflow.python.ops.gen_logging_ops as logging_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.framework import ops as _ops
log_image = gen_logging_ops._image_summary(log_image_name, tf.expand_dims(log_image_data, 0), max_images=1)
_ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image)
# log_image = tf.summary.image(log_image_name, tf.expand_dims(log_image_data, 0), max_outputs=1)
return log_image, log_image_data, log_image_name
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack
def _get_arg_stack():
stack = ops.get_collection(_ARGSTACK_KEY)
if stack:
return stack[0]
else:
stack = [{}]
ops.add_to_collection(_ARGSTACK_KEY, stack)
return stack