Python tensorflow.python.framework.ops 模块,Graph() 实例源码
我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.Graph()。
def __init__(self, target='', graph=None, config=None):
"""Creates a new TensorFlow session.
If no `graph` argument is specified when constructing the session,
the default graph will be launched in the session. If you are
using more than one graph (created with `tf.Graph()` in the same
process, you will have to use different sessions for each graph,
but each graph can be used in multiple sessions. In this case, it
is often clearer to pass the graph to be launched explicitly to
the session constructor.
Args:
target: (Optional.) The execution engine to connect to.
Defaults to using an in-process engine. See
@{$distributed$Distributed TensorFlow}
for more examples.
graph: (Optional.) The `Graph` to be launched (described above).
config: (Optional.) A [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
protocol buffer with configuration options for the session.
"""
super(Session, self).__init__(target, graph, config=config)
# NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.
self._default_graph_context_manager = None
self._default_session_context_manager = None
def __copy__(self):
"""Create a copy of this subgraph.
Note that this class is a "view", copying it only create another view and
does not copy the underlying part of the tf.Graph.
Returns:
A new identical instance of the original subgraph view.
"""
cls = self.__class__
result = cls.__new__(cls)
for k, v in iteritems(self.__dict__):
if k == "_graph":
setattr(result, k, v)
else:
setattr(result, k, list(v)) # copy the list
return result
def remap_inputs(self, new_input_indices):
"""Remap the inputs of the subgraph.
If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0]
will create a new instance whose inputs is [t2, t0].
Note that this is only modifying the view: the underlying tf.Graph is not
affected.
Args:
new_input_indices: an iterable of integers representing a mapping between
the old inputs and the new ones. This mapping can be under-complete and
must be without repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
inputs.
"""
res = self.copy()
res._remap_inputs(new_input_indices) # pylint: disable=protected-access
return res
def remap_outputs(self, new_output_indices):
"""Remap the output of the subgraph.
If the output of the original subgraph are [t0, t1, t2], remapping to
[1,1,0] will create a new instance whose outputs is [t1, t1, t0].
Note that this is only modifying the view: the underlying tf.Graph is not
affected.
Args:
new_output_indices: an iterable of integers representing a mapping between
the old outputs and the new ones. This mapping can be under-complete and
can have repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
outputs.
"""
res = copy.copy(self)
res._remap_outputs(new_output_indices) # pylint: disable=protected-access
return res
def remap(self, new_input_indices=None, new_output_indices=None):
"""Remap the inputs and outputs of the subgraph.
Note that this is only modifying the view: the underlying tf.Graph is not
affected.
Args:
new_input_indices: an iterable of integers representing a mapping between
the old inputs and the new ones. This mapping can be under-complete and
must be without repetitions.
new_output_indices: an iterable of integers representing a mapping between
the old outputs and the new ones. This mapping can be under-complete and
can have repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
inputs and outputs.
"""
res = copy.copy(self)
if new_input_indices is not None:
res._remap_inputs(new_input_indices) # pylint: disable=protected-access
if new_output_indices is not None:
res._remap_outputs(new_output_indices) # pylint: disable=protected-access
return res
def _check_graph(sgv, graph):
"""Check if sgv belongs to the given graph.
Args:
sgv: a SubGraphView.
graph: a graph or None.
Returns:
The SubGraphView sgv.
Raises:
TypeError: if sgv is not a SubGraphView or if graph is not None and not
a tf.Graph.
ValueError: if the graph of sgv and the given graph are not None and
different.
"""
if not isinstance(sgv, SubGraphView):
raise TypeError("Expected a SubGraphView, got: {}".format(type(graph)))
if graph is None or not sgv.graph:
return sgv
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
if sgv.graph is not graph:
raise ValueError("Graph mismatch.")
return sgv
def select_ops_and_ts(*args, **kwargs):
"""Helper to select operations and tensors.
Args:
*args: list of 1) regular expressions (compiled or not) or 2) (array of)
tf.Operation 3) (array of) tf.Tensor. Regular expressions matching tensors
must start with the comment "(?#ts)", for instance: "(?#ts)^foo/.*".
**kwargs: 'graph': tf.Graph in which to perform the regex query.This is
required when using regex.
'positive_filter': an elem if selected only if positive_filter(elem) is
True. This is optional.
Returns:
A tuple `(ops, ts)` where:
`ops` is a list of tf.Operation
`ts` is a list of tf.Tensor
Raises:
TypeError: if the optional keyword argument graph is not a tf.Graph
or if an argument in args is not an (array of) tf.Tensor
or an (array of) tf.Operation or a string or a regular expression.
ValueError: if one of the keyword arguments is unexpected or if a regular
expression is used without passing a graph as a keyword argument.
"""
ops = select_ops(*args, restrict_ops_regex=False, **kwargs)
ts = select_ts(*args, restrict_ts_regex=True, **kwargs)
return ops, ts
def get_tensors(graph):
"""get all the tensors which are input or output of an op in the graph.
Args:
graph: a tf.Graph.
Returns:
A list of tf.Tensor.
Raises:
TypeError: if graph is not a tf.Graph.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a graph, got: {}".format(type(graph)))
ts = []
for op in graph.get_operations():
ts += op.outputs
return ts
def make_placeholder_from_tensor(t, scope=None):
"""Create a tf.placeholder for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
Args:
t: a tf.Tensor whose name will be used to create the placeholder
(see function placeholder_name).
scope: absolute scope within which to create the placeholder. None
means that the scope of t is preserved. "" means the root scope.
Returns:
A newly created tf.placeholder.
Raises:
TypeError: if t is not None or a tf.Tensor.
"""
return tf_array_ops.placeholder(dtype=t.dtype, shape=t.get_shape(),
name=placeholder_name(t, scope=scope))
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
"""Create a tf.placeholder for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
The placeholder is named using the function placeholder_name (with no
tensor argument).
Args:
dtype: the tensor type.
shape: the tensor shape (optional).
scope: absolute scope within which to create the placeholder. None
means that the scope of t is preserved. "" means the root scope.
Returns:
A newly created tf.placeholder.
"""
return tf_array_ops.placeholder(dtype=dtype, shape=shape,
name=placeholder_name(scope=scope))
def infer_real_valued_columns_from_input_fn(input_fn):
"""Creates `FeatureColumn` objects for inputs defined by `input_fn`.
This interprets all inputs as dense, fixed-length float values. This creates
a local graph in which it calls `input_fn` to build the tensors, then discards
it.
Args:
input_fn: Input function returning a tuple of:
features - Dictionary of string feature name to `Tensor` or `Tensor`.
target - `Tensor` of target objects.
Returns:
List of `FeatureColumn` objects.
"""
with ops.Graph().as_default():
features, _ = input_fn()
return layers.infer_real_valued_columns(features)
def __copy__(self):
"""Create a copy of this subgraph.
Note that this class is a "view", copying it only create another view and
does not copy the underlying part of the `tf.Graph`.
Returns:
A new identical instance of the original subgraph view.
"""
cls = self.__class__
result = cls.__new__(cls)
for k, v in iteritems(self.__dict__):
if k == "_graph":
setattr(result, k, v)
else:
setattr(result, k, list(v)) # copy the list
return result
def remap_inputs(self, new_input_indices):
"""Remap the inputs of the subgraph.
If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0]
will create a new instance whose inputs is [t2, t0].
Note that this is only modifying the view: the underlying `tf.Graph` is not
affected.
Args:
new_input_indices: an iterable of integers representing a mapping between
the old inputs and the new ones. This mapping can be under-complete and
must be without repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
inputs.
"""
res = self.copy()
res._remap_inputs(new_input_indices) # pylint: disable=protected-access
return res
def remap_outputs(self, new_output_indices):
"""Remap the output of the subgraph.
If the output of the original subgraph are [t0, t1, t2], remapping to
[1,1,0] will create a new instance whose outputs is [t1, t1, t0].
Note that this is only modifying the view: the underlying tf.Graph is not
affected.
Args:
new_output_indices: an iterable of integers representing a mapping between
the old outputs and the new ones. This mapping can be under-complete and
can have repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
outputs.
"""
res = copy.copy(self)
res._remap_outputs(new_output_indices) # pylint: disable=protected-access
return res
def _check_graph(sgv, graph):
"""Check if sgv belongs to the given graph.
Args:
sgv: a SubGraphView.
graph: a graph or None.
Returns:
The SubGraphView sgv.
Raises:
TypeError: if sgv is not a SubGraphView or if graph is not None and not
a tf.Graph.
ValueError: if the graph of sgv and the given graph are not None and
different.
"""
if not isinstance(sgv, SubGraphView):
raise TypeError("Expected a SubGraphView, got: {}".format(type(graph)))
if graph is None or not sgv.graph:
return sgv
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
if sgv.graph is not graph:
raise ValueError("Graph mismatch.")
return sgv
def get_tensors(graph):
"""get all the tensors which are input or output of an op in the graph.
Args:
graph: a `tf.Graph`.
Returns:
A list of `tf.Tensor`.
Raises:
TypeError: if graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a graph, got: {}".format(type(graph)))
ts = []
for op in graph.get_operations():
ts += op.outputs
return ts
def __init__(self, graph):
"""Create a dictionary of control-output dependencies.
Args:
graph: a `tf.Graph`.
Returns:
A dictionary where a key is a `tf.Operation` instance and the
corresponding value is a list of all the ops which have the key
as one of their control-input dependencies.
Raises:
TypeError: graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
self._control_outputs = {}
self._graph = graph
self._version = None
self._build()
def make_placeholder_from_tensor(t, scope=None):
"""Create a `tf.placeholder` for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
Args:
t: a `tf.Tensor` whose name will be used to create the placeholder
(see function placeholder_name).
scope: absolute scope within which to create the placeholder. None
means that the scope of `t` is preserved. `""` means the root scope.
Returns:
A newly created `tf.placeholder`.
Raises:
TypeError: if `t` is not `None` or a `tf.Tensor`.
"""
return tf_array_ops.placeholder(
dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(
t, scope=scope))
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
"""Create a tf.placeholder for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
The placeholder is named using the function placeholder_name (with no
tensor argument).
Args:
dtype: the tensor type.
shape: the tensor shape (optional).
scope: absolute scope within which to create the placeholder. None
means that the scope of t is preserved. "" means the root scope.
Returns:
A newly created tf.placeholder.
"""
return tf_array_ops.placeholder(
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
def infer_real_valued_columns_from_input_fn(input_fn):
"""Creates `FeatureColumn` objects for inputs defined by `input_fn`.
This interprets all inputs as dense, fixed-length float values. This creates
a local graph in which it calls `input_fn` to build the tensors, then discards
it.
Args:
input_fn: Input function returning a tuple of:
features - Dictionary of string feature name to `Tensor` or `Tensor`.
labels - `Tensor` of label values.
Returns:
List of `FeatureColumn` objects.
"""
with ops.Graph().as_default():
features, _ = input_fn()
return layers.infer_real_valued_columns(features)
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as session:
with variable_scope.variable_scope("some_scope"):
my1 = variable_scope.get_variable("var1", [1, 10])
my2 = variable_scope.get_variable("var2", [10, 10])
my3 = variable_scope.get_variable("var3", [100, 100])
with variable_scope.variable_scope("useful_scope"):
my4 = variable_scope.get_variable("var4", [9, 9])
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"/": "some_scope/",})
session.run(variables.global_variables_initializer())
self.assertAllEqual(my1.eval(session), v1)
self.assertAllEqual(my2.eval(session), v2)
self.assertAllEqual(my3.eval(session), v3)
self.assertAllEqual(my4.eval(session), v4)
def testDeviceFn(self):
class DevFn(object):
def __init__(self):
self.counter = -1
def __call__(self, op):
self.counter += 1
return '/cpu:%d' % self.counter
with ops.Graph().as_default():
with arg_scope([variables_lib2.model_variable], device=DevFn()):
a = variables_lib2.model_variable('a', [5])
b = variables_lib2.model_variable('b', [20])
self.assertDeviceEqual(a.device, '/cpu:0')
self.assertEqual(a.initial_value.op.colocation_groups(),
a.op.colocation_groups())
self.assertDeviceEqual(b.device, '/cpu:1')
self.assertEqual(b.initial_value.op.colocation_groups(),
b.op.colocation_groups())
def create_checkpoint_from_values(self,
var_names_to_values,
checkpoint_dir,
global_step=None):
"""Creates a checkpoint from a mapping of name to values in model_dir.
Args:
var_names_to_values: a map from variable names to values.
checkpoint_dir: the directory where the checkpoint will be saved.
global_step: the global step used to save the checkpoint.
Returns:
the model_path to the checkpoint.
"""
var_list = []
with session.Session('', graph=ops.Graph()) as sess:
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
var_list.append(variables_lib.Variable(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
# Save the initialized values in the file at 'checkpoint_dir'
return saver.save(sess, checkpoint_dir, global_step=global_step)
def create_checkpoint_from_values(self,
var_names_to_values,
checkpoint_dir,
global_step=None):
"""Creates a checkpoint from a mapping of name to values in model_dir.
Args:
var_names_to_values: a map from variable names to values.
checkpoint_dir: the directory where the checkpoint will be saved.
global_step: the global step used to save the checkpoint.
Returns:
the model_path to the checkpoint.
"""
var_list = []
with session.Session('', graph=ops.Graph()) as sess:
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
var_list.append(variables_lib.Variable(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
# Save the initialized values in the file at 'checkpoint_dir'
return saver.save(sess, checkpoint_dir, global_step=global_step)
def __copy__(self):
"""Create a copy of this subgraph.
Note that this class is a "view", copying it only create another view and
does not copy the underlying part of the `tf.Graph`.
Returns:
A new identical instance of the original subgraph view.
"""
cls = self.__class__
result = cls.__new__(cls)
for k, v in iteritems(self.__dict__):
if k == "_graph":
setattr(result, k, v)
else:
setattr(result, k, list(v)) # copy the list
return result
def remap_inputs(self, new_input_indices):
"""Remap the inputs of the subgraph.
If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0]
will create a new instance whose inputs is [t2, t0].
Note that this is only modifying the view: the underlying `tf.Graph` is not
affected.
Args:
new_input_indices: an iterable of integers representing a mapping between
the old inputs and the new ones. This mapping can be under-complete and
must be without repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
inputs.
"""
res = self.copy()
res._remap_inputs(new_input_indices) # pylint: disable=protected-access
return res
def remap(self, new_input_indices=None, new_output_indices=None):
"""Remap the inputs and outputs of the subgraph.
Note that this is only modifying the view: the underlying tf.Graph is not
affected.
Args:
new_input_indices: an iterable of integers representing a mapping between
the old inputs and the new ones. This mapping can be under-complete and
must be without repetitions.
new_output_indices: an iterable of integers representing a mapping between
the old outputs and the new ones. This mapping can be under-complete and
can have repetitions.
Returns:
A new modified instance of the original subgraph view with remapped
inputs and outputs.
"""
res = copy.copy(self)
if new_input_indices is not None:
res._remap_inputs(new_input_indices) # pylint: disable=protected-access
if new_output_indices is not None:
res._remap_outputs(new_output_indices) # pylint: disable=protected-access
return res
def _check_graph(sgv, graph):
"""Check if sgv belongs to the given graph.
Args:
sgv: a SubGraphView.
graph: a graph or None.
Returns:
The SubGraphView sgv.
Raises:
TypeError: if sgv is not a SubGraphView or if graph is not None and not
a tf.Graph.
ValueError: if the graph of sgv and the given graph are not None and
different.
"""
if not isinstance(sgv, SubGraphView):
raise TypeError("Expected a SubGraphView, got: {}".format(type(graph)))
if graph is None or not sgv.graph:
return sgv
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
if sgv.graph is not graph:
raise ValueError("Graph mismatch.")
return sgv
def test_copy(self):
graph = ops.Graph()
_, info = ge.copy(self.graph, graph)
self.assertEqual(
set(op.name for op in self.graph.get_operations()),
set(op.name for op in graph.get_operations()))
src_ops = self.graph.get_operations()
dst_ops = graph.get_operations()
for op in src_ops:
op_ = info.transformed(op)
self.assertTrue(op_ in dst_ops)
self.assertEqual(op.name, op_.name)
self.assertEqual(info.original(op_), op)
src_ts = ge.util.get_tensors(self.graph)
dst_ts = ge.util.get_tensors(graph)
for t in src_ts:
t_ = info.transformed(t)
self.assertTrue(t_ in dst_ts)
self.assertEqual(t.name, t_.name)
self.assertEqual(info.original(t_), t)
def test_placeholder(self):
"""Test placeholder functionalities."""
g0 = ops.Graph()
with g0.as_default():
a0 = constant_op.constant(1, name="foo")
# Test placeholder name.
self.assertEqual(ge.util.placeholder_name(a0), "geph__foo_0")
self.assertEqual(ge.util.placeholder_name(None), "geph")
self.assertEqual(
ge.util.placeholder_name(
a0, scope="foo/"), "foo/geph__foo_0")
self.assertEqual(
ge.util.placeholder_name(
a0, scope="foo"), "foo/geph__foo_0")
self.assertEqual(ge.util.placeholder_name(None, scope="foo/"), "foo/geph")
self.assertEqual(ge.util.placeholder_name(None, scope="foo"), "foo/geph")
# Test placeholder creation.
g0 = ops.Graph()
with g0.as_default():
a0 = constant_op.constant(1, dtype=dtypes.float32, name="a0")
c0 = math_ops.add(
ge.util.make_placeholder_from_tensor(a0),
ge.util.make_placeholder_from_dtype_and_shape(dtype=dtypes.float32))
self.assertEqual(c0.op.inputs[0].op.name, "geph__a0_0")
self.assertEqual(c0.op.inputs[1].op.name, "geph")
def test_reroute_can_modify(self):
graph = ops.Graph()
# create a special graph where "a" is an ambiguous tensor. That is
# it is both an input and an output of the ops in sgv0.
with graph.as_default():
a = constant_op.constant(1.0, shape=[2], name="a")
b = constant_op.constant(2.0, shape=[2], name="b")
c = math_ops.add(a, b, name="c")
d = math_ops.add(a, c, name="d")
e = constant_op.constant(1.0, shape=[2], name="e")
f = constant_op.constant(2.0, shape=[2], name="f")
g = math_ops.add(e, f, name="g")
sgv0 = ge.sgv(a.op, b.op, c.op)
sgv1 = ge.sgv(e.op, f.op)
ge.swap_outputs(sgv0, sgv1)
self.assertTrue(
ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))(
g.op))
self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op))
def get_tensors(graph):
"""get all the tensors which are input or output of an op in the graph.
Args:
graph: a `tf.Graph`.
Returns:
A list of `tf.Tensor`.
Raises:
TypeError: if graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a graph, got: {}".format(type(graph)))
ts = []
for op in graph.get_operations():
ts += op.outputs
return ts
def __init__(self, graph):
"""Create a dictionary of control-output dependencies.
Args:
graph: a `tf.Graph`.
Returns:
A dictionary where a key is a `tf.Operation` instance and the
corresponding value is a list of all the ops which have the key
as one of their control-input dependencies.
Raises:
TypeError: graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
self._control_outputs = {}
self._graph = graph
self._version = None
self._build()
def make_placeholder_from_tensor(t, scope=None):
"""Create a `tf.placeholder` for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
Args:
t: a `tf.Tensor` whose name will be used to create the placeholder
(see function placeholder_name).
scope: absolute scope within which to create the placeholder. None
means that the scope of `t` is preserved. `""` means the root scope.
Returns:
A newly created `tf.placeholder`.
Raises:
TypeError: if `t` is not `None` or a `tf.Tensor`.
"""
return tf_array_ops.placeholder(
dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(
t, scope=scope))
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
"""Create a tf.placeholder for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
The placeholder is named using the function placeholder_name (with no
tensor argument).
Args:
dtype: the tensor type.
shape: the tensor shape (optional).
scope: absolute scope within which to create the placeholder. None
means that the scope of t is preserved. "" means the root scope.
Returns:
A newly created tf.placeholder.
"""
return tf_array_ops.placeholder(
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
def testGradientWithZeroWeight(self):
with ops.Graph().as_default():
random_seed.set_random_seed(0)
inputs = array_ops.ones((2, 3))
weights = variable_scope.get_variable(
'weights',
shape=[3, 4],
initializer=init_ops.truncated_normal_initializer())
predictions = math_ops.matmul(inputs, weights)
optimizer = momentum_lib.MomentumOptimizer(
learning_rate=0.001, momentum=0.9)
loss = loss_ops.mean_pairwise_squared_error(predictions, predictions, 0)
gradients_to_variables = optimizer.compute_gradients(loss)
init_op = variables.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
self.assertFalse(np.isnan(np_grad).any())
def testNumpySource(self):
batch_size = 3
iterations = 1000
array = np.arange(32).reshape([16, 2])
numpy_source = in_memory_source.NumpySource(array, batch_size=batch_size)
index_column = numpy_source().index
value_column = numpy_source().value
cache = {}
with ops.Graph().as_default():
value_tensor = value_column.build(cache)
index_tensor = index_column.build(cache)
with session.Session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
for i in range(iterations):
expected_index = [
j % array.shape[0]
for j in range(batch_size * i, batch_size * (i + 1))
]
expected_value = get_rows(array, expected_index)
actual_index, actual_value = sess.run([index_tensor, value_tensor])
np.testing.assert_array_equal(expected_index, actual_index)
np.testing.assert_array_equal(expected_value, actual_value)
coord.request_stop()
coord.join(threads)
def testArrayFeedingMultiThread(self):
with ops.Graph().as_default():
array = np.arange(256).reshape([128, 2])
q = ff.enqueue_data(array, capacity=128, num_threads=8, shuffle=True)
batch_size = 3
dq_op = q.dequeue_many(batch_size)
with session.Session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
for _ in range(100):
dq = sess.run(dq_op)
indices = dq[0]
expected_dq = get_rows(array, indices)
np.testing.assert_array_equal(expected_dq, dq[1])
coord.request_stop()
coord.join(threads)
def testPandasFeedingMultiThread(self):
if not HAS_PANDAS:
return
with ops.Graph().as_default():
array1 = np.arange(128, 256)
array2 = 2 * array1
df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(128))
q = ff.enqueue_data(df, capacity=128, num_threads=8, shuffle=True)
batch_size = 5
dq_op = q.dequeue_many(batch_size)
with session.Session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
for _ in range(100):
dq = sess.run(dq_op)
indices = dq[0]
expected_rows = df.iloc[indices]
for col_num, col in enumerate(df.columns):
np.testing.assert_array_equal(expected_rows[col].values,
dq[col_num + 1])
coord.request_stop()
coord.join(threads)
def test_evaluate_invalid_args(self):
with ops.Graph().as_default() as g, self.test_session(g):
self._assert_ckpt(self._output_dir, False)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions.evaluate(
g,
output_dir=None,
checkpoint_path=None,
eval_dict={'a': constant_op.constant(1.0)})
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions.evaluate(
g,
output_dir='',
checkpoint_path=None,
eval_dict={'a': constant_op.constant(1.0)})
self._assert_ckpt(self._output_dir, False)
def test_evaluate_ready_for_local_init(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
w = variables.Variable(
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
ready_for_local_init_op)
_ = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': v},
max_steps=1)
def test_evaluate_feed_fn(self):
with ops.Graph().as_default() as g, self.test_session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
self._assert_ckpt(self._output_dir, False)
feeder = _Feeder(in0, 3)
results = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': out},
feed_fn=feeder.feed_fn,
max_steps=3)
self.assertEqual(3, feeder.step)
self.assertEqual(({'a': 25.0}, 0), results)
self._assert_summaries(
self._output_dir,
writer,
expected_summaries={0: {
'a': 25.0
}},
expected_session_logs=[])
self._assert_ckpt(self._output_dir, False)
def test_evaluate_feed_fn_with_exhaustion(self):
with ops.Graph().as_default() as g, self.test_session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
feeder = _Feeder(in0, 2)
results = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': out},
feed_fn=feeder.feed_fn,
max_steps=3)
self.assertEqual(2, feeder.step)
self.assertEqual(({'a': 15.0}, 0), results)
self._assert_summaries(
self._output_dir,
writer,
expected_summaries={0: {
'a': 15.0
}},
expected_session_logs=[])
def test_evaluate_with_saver(self):
with ops.Graph().as_default() as g, self.test_session(g):
_, _, out = self._build_inference_graph()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
results = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': out},
max_steps=1)
self.assertEqual(({'a': 6.0}, 0), results)
self._assert_summaries(
self._output_dir,
writer,
expected_summaries={0: {
'a': 6.0
}},
expected_session_logs=[])
def test_train_loss(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
state_ops.assign_add(variables_lib.get_global_step(), 1),
state_ops.assign_add(loss_var, -1.0))
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_var.value(),
steps=6)
self.assertEqual(4.0, loss)
self._assert_summaries(
self._output_dir,
writer,
expected_graphs=[g],
expected_meta_graphs=None)
self._assert_ckpt(self._output_dir, True)
def test_train(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
self._assert_summaries(self._output_dir)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
steps=1)
# TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
# SaverDef, so we can't add it to the summary assertion test below.
# meta_graph_def = meta_graph.create_meta_graph_def()
self.assertEqual(2.0, loss)
self._assert_summaries(self._output_dir, expected_graphs=[g])
self._assert_ckpt(self._output_dir, True)
def test_train_loss(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
state_ops.assign_add(variables_lib.get_global_step(), 1),
state_ops.assign_add(loss_var, -1.0))
self._assert_summaries(self._output_dir)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_var.value(),
steps=6)
# TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
# SaverDef, so we can't add it to the summary assertion test below.
# meta_graph_def = meta_graph.create_meta_graph_def()
self.assertEqual(4.0, loss)
self._assert_summaries(self._output_dir, expected_graphs=[g])
self._assert_ckpt(self._output_dir, True)
def test_train_chief_monitor(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
summary.scalar('loss', loss_op)
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_op,
supervisor_is_chief=True,
steps=1,
monitors=[chief_exclusive_monitor, all_workers_monitor])
self.assertEqual(2.0, loss)
self.assertTrue(chief_exclusive_monitor.is_active and
all_workers_monitor.is_active,
'All monitors must have been active.')
self.assertTrue(chief_exclusive_monitor.has_step and
all_workers_monitor.has_step,
'All monitors must have a step.')
def testRegressionWithLogitsInput(self):
head = head_lib._regression_head()
with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops(
{},
labels=((0.,), (1.,), (1.,)),
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=_noop_train_op,
logits_input=((0., 0.), (0., 0.), (0., 0.)))
self._assert_output_alternatives(model_fn_ops)
w = ("regression_head/logits/weights:0",
"regression_head/logits/biases:0")
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
def testRegressionWithCenteredBias(self):
head = head_lib._regression_head(enable_centered_bias=True)
with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops(
{},
labels=((0.,), (1.,), (1.,)),
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=_noop_train_op,
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_variables(
self,
expected_global=(
"regression_head/centered_bias_weight:0",
"regression_head/regression_head/centered_bias_weight/Adagrad:0",
),
expected_trainable=("regression_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["loss", "regression_head/centered_bias/bias_0"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)