Python tensorflow.python.framework.ops 模块,Operation() 实例源码
我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.Operation()。
def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
"""For each tensor's pair, swap the end of (t0,t1).
B0 B1 B0 B1
| | => X
A0 A1 A0 A1
Args:
ts0: an object convertible to a list of tf.Tensor.
ts1: an object convertible to a list of tf.Tensor.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified.
Any operation within cannot_modify will be left untouched by this
function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None):
"""For each tensor's pair, replace the end of t1 by the end of t0.
B0 B1 B0 B1
| | => |/
A0 A1 A0 A1
The end of the tensors in ts1 are left dangling.
Args:
ts0: an object convertible to a list of tf.Tensor.
ts1: an object convertible to a list of tf.Tensor.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified. Any
operation within cannot_modify will be left untouched by this function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None):
r"""For each tensor's pair, replace the end of t0 by the end of t1.
B0 B1 B0 B1
| | => \|
A0 A1 A0 A1
The end of the tensors in ts0 are left dangling.
Args:
ts0: an object convertible to a list of tf.Tensor.
ts1: an object convertible to a list of tf.Tensor.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified.
Any operation within cannot_modify will be left untouched by this
function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
def remove_control_inputs(op, cops):
"""Remove the control inputs cops from co.
Warning: this function is directly manipulating the internals of the tf.Graph.
Args:
op: a tf.Operation from which to remove the control inputs.
cops: an object convertible to a list of tf.Operation.
Raises:
TypeError: if op is not a tf.Operation
ValueError: if any cop in cops is not a control input of op.
"""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
cops = util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop not in op.control_inputs:
raise ValueError("{} is not a control_input of {}".format(op.name,
cop.name))
# pylint: disable=protected-access
op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
op._recompute_node_def()
# pylint: enable=protected-access
def get_consuming_ops(ts):
"""Return all the consuming ops of the tensors in ts.
Args:
ts: a list of tf.Tensor
Returns:
A list of all the consuming tf.Operation of the tensors in ts.
Raises:
TypeError: if ts cannot be converted to a list of tf.Tensor.
"""
ts = make_list_of_t(ts, allow_graph=False)
ops = []
for t in ts:
for op in t.consumers():
if op not in ops:
ops.append(op)
return ops
def get_copied_op(org_instance, graph, scope=""):
"""Given an `Operation` instance from some `Graph`, returns
its namesake from `graph`, under the specified scope
(default `""`).
If a copy of `org_instance` is present in `graph` under the given
`scope`, it will be returned.
Args:
org_instance: An `Operation` from some `Graph`.
graph: The `Graph` to be searched for a copr of `org_instance`.
scope: The scope `org_instance` is present in.
Returns:
The `Operation` copy from `graph`.
"""
#The name of the copied instance
if scope != '':
new_name = scope + '/' + org_instance.name
else:
new_name = org_instance.name
return graph.as_graph_element(new_name, allow_tensor=True,
allow_operation=True)
def testKFeatureTrainingConstruction(self):
# pylint: disable=W0612
data = constant_op.constant(
[[random.uniform(-1, 1) for i in range(self.params.num_features)]
for _ in range(100)])
labels = [1 for _ in range(100)]
with variable_scope.variable_scope(
"KFeatureDecisionsToDataThenNNTest.testKFeatureTrainingContruction"):
graph_builder = (
k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN(
self.params))
graph = graph_builder.training_graph(data, labels, None)
self.assertTrue(isinstance(graph, Operation))
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None):
"""For each tensor's pair, replace the end of t1 by the end of t0.
B0 B1 B0 B1
| | => |/
A0 A1 A0 A1
The end of the tensors in ts1 are left dangling.
Args:
ts0: an object convertible to a list of `tf.Tensor`.
ts1: an object convertible to a list of `tf.Tensor`.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified. Any
operation within cannot_modify will be left untouched by this function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None):
r"""For each tensor's pair, replace the end of t0 by the end of t1.
B0 B1 B0 B1
| | => \|
A0 A1 A0 A1
The end of the tensors in ts0 are left dangling.
Args:
ts0: an object convertible to a list of `tf.Tensor`.
ts1: an object convertible to a list of `tf.Tensor`.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified.
Any operation within cannot_modify will be left untouched by this
function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
def remove_control_inputs(op, cops):
"""Remove the control inputs cops from co.
Warning: this function is directly manipulating the internals of the
`tf.Graph`.
Args:
op: a `tf.Operation` from which to remove the control inputs.
cops: an object convertible to a list of `tf.Operation`.
Raises:
TypeError: if op is not a `tf.Operation`.
ValueError: if any cop in cops is not a control input of op.
"""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
cops = util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop not in op.control_inputs:
raise ValueError("{} is not a control_input of {}".format(op.name,
cop.name))
# pylint: disable=protected-access
op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
op._recompute_node_def()
# pylint: enable=protected-access
def add_control_inputs(op, cops):
"""Add the control inputs cops to co.
Warning: this function is directly manipulating the internals of the tf.Graph.
Args:
op: a tf.Operation to which the control inputs are added.
cops: an object convertible to a list of `tf.Operation`.
Raises:
TypeError: if op is not a tf.Operation
ValueError: if any cop in cops is already a control input of op.
"""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
cops = util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop in op.control_inputs:
raise ValueError("{} is already a control_input of {}".format(op.name,
cop.name))
# pylint: disable=protected-access
op._control_inputs += cops
op._recompute_node_def()
# pylint: enable=protected-access
def get_consuming_ops(ts):
"""Return all the consuming ops of the tensors in ts.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the consuming `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if ts cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
ops = []
for t in ts:
for op in t.consumers():
if op not in ops:
ops.append(op)
return ops
def __init__(self, graph):
"""Create a dictionary of control-output dependencies.
Args:
graph: a `tf.Graph`.
Returns:
A dictionary where a key is a `tf.Operation` instance and the
corresponding value is a list of all the ops which have the key
as one of their control-input dependencies.
Raises:
TypeError: graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
self._control_outputs = {}
self._graph = graph
self._version = None
self._build()
def get_copied_op(org_instance, graph, scope=""):
"""Given an `Operation` instance from some `Graph`, returns
its namesake from `graph`, under the specified scope
(default `""`).
If a copy of `org_instance` is present in `graph` under the given
`scope`, it will be returned.
Args:
org_instance: An `Operation` from some `Graph`.
graph: The `Graph` to be searched for a copr of `org_instance`.
scope: The scope `org_instance` is present in.
Returns:
The `Operation` copy from `graph`.
"""
#The name of the copied instance
if scope != '':
new_name = scope + '/' + org_instance.name
else:
new_name = org_instance.name
return graph.as_graph_element(new_name, allow_tensor=True,
allow_operation=True)
def assign_renamed_collections_handler(info, elem, elem_):
"""Add the transformed elem to the (renamed) collections of elem.
A collection is renamed only if is not a known key, as described in
`tf.GraphKeys`.
Args:
info: Transform._TmpInfo instance.
elem: the original element (`tf.Tensor` or `tf.Operation`)
elem_: the transformed element
"""
known_collection_names = util.get_predefined_collection_names()
for name, collection in iteritems(info.collections):
if elem not in collection:
continue
if name in known_collection_names:
transformed_name = name
else:
transformed_name = info.new_name(name)
info.graph_.add_to_collection(transformed_name, elem_)
def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
"""For each tensor's pair, swap the end of (t0,t1).
B0 B1 B0 B1
| | => X
A0 A1 A0 A1
Args:
ts0: an object convertible to a list of `tf.Tensor`.
ts1: an object convertible to a list of `tf.Tensor`.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified.
Any operation within cannot_modify will be left untouched by this
function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None):
"""For each tensor's pair, replace the end of t1 by the end of t0.
B0 B1 B0 B1
| | => |/
A0 A1 A0 A1
The end of the tensors in ts1 are left dangling.
Args:
ts0: an object convertible to a list of `tf.Tensor`.
ts1: an object convertible to a list of `tf.Tensor`.
can_modify: iterable of operations which can be modified. Any operation
outside within_ops will be left untouched by this function.
cannot_modify: iterable of operations which cannot be modified. Any
operation within cannot_modify will be left untouched by this function.
Returns:
The number of individual modifications made by the function.
Raises:
TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
TypeError: if can_modify or cannot_modify is not None and cannot be
converted to a list of tf.Operation.
"""
return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
def remove_control_inputs(op, cops):
"""Remove the control inputs cops from co.
Warning: this function is directly manipulating the internals of the
`tf.Graph`.
Args:
op: a `tf.Operation` from which to remove the control inputs.
cops: an object convertible to a list of `tf.Operation`.
Raises:
TypeError: if op is not a `tf.Operation`.
ValueError: if any cop in cops is not a control input of op.
"""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
cops = util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop not in op.control_inputs:
raise ValueError("{} is not a control_input of {}".format(op.name,
cop.name))
# pylint: disable=protected-access
op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
op._recompute_node_def()
# pylint: enable=protected-access
def add_control_inputs(op, cops):
"""Add the control inputs cops to co.
Warning: this function is directly manipulating the internals of the tf.Graph.
Args:
op: a tf.Operation to which the control inputs are added.
cops: an object convertible to a list of `tf.Operation`.
Raises:
TypeError: if op is not a tf.Operation
ValueError: if any cop in cops is already a control input of op.
"""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expected a tf.Operation, got: {}", type(op))
cops = util.make_list_of_op(cops, allow_graph=False)
for cop in cops:
if cop in op.control_inputs:
raise ValueError("{} is already a control_input of {}".format(op.name,
cop.name))
# pylint: disable=protected-access
op._control_inputs += cops
op._recompute_node_def()
# pylint: enable=protected-access
def get_consuming_ops(ts):
"""Return all the consuming ops of the tensors in ts.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the consuming `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if ts cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
ops = []
for t in ts:
for op in t.consumers():
if op not in ops:
ops.append(op)
return ops
def __init__(self, graph):
"""Create a dictionary of control-output dependencies.
Args:
graph: a `tf.Graph`.
Returns:
A dictionary where a key is a `tf.Operation` instance and the
corresponding value is a list of all the ops which have the key
as one of their control-input dependencies.
Raises:
TypeError: graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
self._control_outputs = {}
self._graph = graph
self._version = None
self._build()
def get_copied_op(org_instance, graph, scope=""):
"""Given an `Operation` instance from some `Graph`, returns
its namesake from `graph`, under the specified scope
(default `""`).
If a copy of `org_instance` is present in `graph` under the given
`scope`, it will be returned.
Args:
org_instance: An `Operation` from some `Graph`.
graph: The `Graph` to be searched for a copr of `org_instance`.
scope: The scope `org_instance` is present in.
Returns:
The `Operation` copy from `graph`.
"""
#The name of the copied instance
if scope != '':
new_name = scope + '/' + org_instance.name
else:
new_name = org_instance.name
return graph.as_graph_element(new_name, allow_tensor=True,
allow_operation=True)
def build_results(self, values):
if not values:
# 'Operation' case
return None
else:
return self._contraction_fn(values)
def __init__(self, graph, fetches, feeds, feed_handles=None):
"""Creates a fetch handler.
Args:
graph: Graph of the fetches. Used to check for fetchability
and to convert all fetches to tensors or ops as needed.
fetches: An arbitrary fetch structure: singleton, list, tuple,
namedtuple, or dict.
feeds: A feed dict where keys are Tensors.
feed_handles: A dict from feed Tensors to TensorHandle objects used as
direct feeds.
"""
with graph.as_default():
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
self._fetches = []
self._targets = []
self._feeds = feeds
self._feed_handles = feed_handles or {}
self._ops = []
self._fetch_handles = {}
for fetch in self._fetch_mapper.unique_fetches():
if isinstance(fetch, ops.Operation):
self._assert_fetchable(graph, fetch)
self._targets.append(fetch)
self._ops.append(True)
else:
self._assert_fetchable(graph, fetch.op)
self._fetches.append(fetch)
self._ops.append(False)
# Remember the fetch if it is for a tensor handle.
if (isinstance(fetch, ops.Tensor) and
(fetch.op.type == 'GetSessionHandle' or
fetch.op.type == 'GetSessionHandleV2')):
self._fetch_handles[fetch] = fetch.op.inputs[0].dtype
self._final_fetches = [x for x in self._fetches if x not in feeds]
def _assert_fetchable(self, graph, op):
if not graph.is_fetchable(op):
raise ValueError(
'Operation %r has been marked as not fetchable.' % op.name)
def _finalize_positive_filter(self, elem):
"""Convert to a filter function."""
if select.can_be_regex(elem):
regex_ = select.make_regex(elem)
return lambda op, regex=regex_: regex.search(op.name) is not None
elif isinstance(elem, tf_ops.Operation):
return lambda op, match_op=elem: op is match_op
elif callable(elem):
return elem
elif elem is True:
return lambda op: True
else:
raise ValueError("Cannot finalize the positive filter: {}".format(elem))
def __call__(self, op):
"""Evaluate if the op matches or not."""
if not isinstance(op, tf_ops.Operation):
raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
for positive_filter in self.positive_filters:
if not positive_filter(op):
return False
if self.input_op_matches is not None:
if len(op.inputs) != len(self.input_op_matches):
return False
for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
if input_op_match is None:
continue
if not input_op_match(input_t.op):
return False
if self.control_input_op_matches is not None:
if len(op.control_inputs) != len(self.control_input_op_matches):
return False
for cinput_op, cinput_op_match in zip(op.control_inputs,
self.control_input_op_matches):
if cinput_op_match is None:
continue
if not cinput_op_match(cinput_op):
return False
if self.output_op_matches is not None:
if len(op.outputs) != len(self.output_op_matches):
return False
for output_t, output_op_matches in zip(op.outputs,
self.output_op_matches):
if output_op_matches is None:
continue
if len(output_t.consumers()) != len(output_op_matches):
return False
for consumer_op, consumer_op_match in zip(output_t.consumers(),
output_op_matches):
if consumer_op_match is None:
continue
if not consumer_op_match(consumer_op):
return False
return True
def assign_renamed_collections_handler(info, elem, elem_):
"""Add the transformed elem to the (renamed) collections of elem.
Args:
info: Transform._Info instance.
elem: the original element (tf.Tensor or tf.Operation)
elem_: the transformed element
"""
# TODO(fkp): handle known special cases
for name, collection in iteritems(
elem.graph._collections): # pylint: disable=protected-access
if elem not in collection:
continue
collection_name_ = info.transformer.new_name(name)
info.graph_.add_to_collection(collection_name_, elem_)
def _get_transformed_map(self, top):
"""Return the correct container depending on the type of `top`."""
if isinstance(top, tf_ops.Operation):
return self._transformed_ops
elif isinstance(top, tf_ops.Tensor):
return self._transformed_ts
else:
raise TypeError(
"Expected a tf.Tensor or a tf.Operation, got a {}".format(
type(top)))
def _transform_op(self, op):
"""Transform a tf.Operation.
Args:
op: the operation to be transformed.
Returns:
The transformed operation.
"""
if op in self._info.transformed_ops:
return self._info.transformed_ops[op]
op_ = self.transform_op_handler(self._info, op)
# Add to all the active control dependencies
# pylint: disable=protected-access
self._info.graph_._record_op_seen_by_control_dependencies(op_)
# All to all the active devices
for device_function in reversed(self._info.graph_._device_function_stack):
if device_function is None:
break
op_._set_device(device_function(op_))
# pylint: enable=protected-access
# TODO(fkp): Establish clear policy about what context managers are allowed.
# assign to collection
if op is not op_:
self.assign_collections_handler(self._info, op, op_)
self._info.transformed_ops[op] = op_
return op_
def get_unique_graph(tops, check_types=None, none_if_empty=False):
"""Return the unique graph used by the all the elements in tops.
Args:
tops: list of elements to check (usually a list of tf.Operation and/or
tf.Tensor). Or a tf.Graph.
check_types: check that the element in tops are of given type(s). If None,
the types (tf.Operation, tf.Tensor) are used.
none_if_empty: don't raise an error if tops is an empty list, just return
None.
Returns:
The unique graph used by all the tops.
Raises:
TypeError: if tops is not a iterable of tf.Operation.
ValueError: if the graph is not unique.
"""
if isinstance(tops, tf_ops.Graph):
return tops
if not is_iterable(tops):
raise TypeError("{} is not iterable".format(type(tops)))
if check_types is None:
check_types = (tf_ops.Operation, tf_ops.Tensor)
elif not is_iterable(check_types):
check_types = (check_types,)
g = None
for op in tops:
if not isinstance(op, check_types):
raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
t) for t in check_types]), type(op)))
if g is None:
g = op.graph
elif g is not op.graph:
raise ValueError("Operation {} does not belong to given graph".format(op))
if g is None and not none_if_empty:
raise ValueError("Can't find the unique graph of an empty list")
return g
def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False):
"""Convert ops to a list of tf.Operation.
Args:
ops: can be an iterable of tf.Operation, a tf.Graph or a single operation.
check_graph: if True check if all the operations belong to the same graph.
allow_graph: if False a tf.Graph cannot be converted.
ignore_ts: if True, silently ignore tf.Tensor.
Returns:
A newly created list of tf.Operation.
Raises:
TypeError: if ops cannot be converted to a list of tf.Operation or,
if check_graph is True, if all the ops do not belong to the same graph.
"""
if isinstance(ops, tf_ops.Graph):
if allow_graph:
return ops.get_operations()
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(ops):
ops = [ops]
if not ops:
return []
if check_graph:
check_types = None if ignore_ts else tf_ops.Operation
get_unique_graph(ops, check_types=check_types)
return [op for op in ops if isinstance(op, tf_ops.Operation)]
# TODO(fkp): move this function in tf.Graph?
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
"""Convert ts to a list of tf.Tensor.
Args:
ts: can be an iterable of tf.Tensor, a tf.Graph or a single tensor.
check_graph: if True check if all the tensors belong to the same graph.
allow_graph: if False a tf.Graph cannot be converted.
ignore_ops: if True, silently ignore tf.Operation.
Returns:
A newly created list of tf.Tensor.
Raises:
TypeError: if ts cannot be converted to a list of tf.Tensor or,
if check_graph is True, if all the ops do not belong to the same graph.
"""
if isinstance(ts, tf_ops.Graph):
if allow_graph:
return get_tensors(ts)
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(ts):
ts = [ts]
if not ts:
return []
if check_graph:
check_types = None if ignore_ops else tf_ops.Tensor
get_unique_graph(ts, check_types=check_types)
return [t for t in ts if isinstance(t, tf_ops.Tensor)]
def get_generating_ops(ts):
"""Return all the generating ops of the tensors in ts.
Args:
ts: a list of tf.Tensor
Returns:
A list of all the generating tf.Operation of the tensors in ts.
Raises:
TypeError: if ts cannot be converted to a list of tf.Tensor.
"""
ts = make_list_of_t(ts, allow_graph=False)
return [t.op for t in ts]
def testTrainingConstruction(self):
# pylint: disable=W0612
data = constant_op.constant(
[[random.uniform(-1, 1) for i in range(self.params.num_features)]
for _ in range(100)])
labels = [1 for _ in range(100)]
with variable_scope.variable_scope(
"DecisionsToDataThenNNTest_testTrainingContruction"):
graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
self.params)
graph = graph_builder.training_graph(data, labels, None)
self.assertTrue(isinstance(graph, Operation))
def testTrainingConstruction(self):
# pylint: disable=W0612
data = constant_op.constant(
[[random.uniform(-1, 1) for i in range(self.params.num_features)]
for _ in range(100)])
labels = [1 for _ in range(100)]
with variable_scope.variable_scope(
"ForestToDataThenNNTest.testTrainingContruction"):
graph_builder = forest_to_data_then_nn.ForestToDataThenNN(self.params)
graph = graph_builder.training_graph(data, labels, None)
self.assertTrue(isinstance(graph, Operation))
def _finalize_positive_filter(self, elem):
"""Convert to a filter function."""
if select.can_be_regex(elem):
regex_ = select.make_regex(elem)
return lambda op, regex=regex_: regex.search(op.name) is not None
elif isinstance(elem, tf_ops.Operation):
return lambda op, match_op=elem: op is match_op
elif callable(elem):
return elem
elif elem is True:
return lambda op: True
else:
raise ValueError("Cannot finalize the positive filter: {}".format(elem))
def assign_renamed_collections_handler(info, elem, elem_):
"""Add the transformed elem to the (renamed) collections of elem.
Args:
info: Transform._Info instance.
elem: the original element (`tf.Tensor` or `tf.Operation`)
elem_: the transformed element
"""
# TODO(fkp): handle known special cases
for name, collection in iteritems(
elem.graph._collections): # pylint: disable=protected-access
if elem not in collection:
continue
collection_name_ = info.transformer.new_name(name)
info.graph_.add_to_collection(collection_name_, elem_)
def _get_transformed_map(self, top):
"""Return the correct container depending on the type of `top`."""
if isinstance(top, tf_ops.Operation):
return self._transformed_ops
elif isinstance(top, tf_ops.Tensor):
return self._transformed_ts
else:
raise TypeError(
"Expected a tf.Tensor or a tf.Operation, got a {}".format(
type(top)))
def __init__(self):
"""Transformer constructor.
The following members can be modified:
transform_op_handler: handle the transformation of a `tf.Operation`.
This handler defaults to a simple copy.
assign_collections_handler: handle the assignment of collections.
This handler defaults to assigning new collections created under the
given name-scope.
transform_external_input_handler: handle the transform of the inputs to
the given subgraph. This handler defaults to creating placeholders
instead of the ops just before the input tensors of the subgraph.
transform_external_hidden_input_handler: handle the transform of the
hidden inputs of the subgraph, that is, the inputs which are not listed
in sgv.inputs. This handler defaults to a transform which keep the same
input if the source and destination graphs are the same, otherwise
use placeholders.
transform_original_op_handler: handle the transform of original_op. This
handler defaults to transforming original_op only if they are in the
subgraph, otherwise they are ignored.
"""
# handlers
self.transform_op_handler = copy_op_handler
self.transform_control_input_handler = transform_op_if_inside_handler
self.assign_collections_handler = assign_renamed_collections_handler
self.transform_external_input_handler = replace_t_with_placeholder_handler
self.transform_external_hidden_input_handler = keep_t_if_possible_handler
self.transform_original_op_handler = transform_op_if_inside_handler
# temporary per-call variable
self._info = None
def _transform_op(self, op):
"""Transform a tf.Operation.
Args:
op: the operation to be transformed.
Returns:
The transformed operation.
"""
if op in self._info.transformed_ops:
return self._info.transformed_ops[op]
op_ = self.transform_op_handler(self._info, op)
# Add to all the active control dependencies
# pylint: disable=protected-access
self._info.graph_._record_op_seen_by_control_dependencies(op_)
# All to all the active devices
for device_function in reversed(self._info.graph_._device_function_stack):
if device_function is None:
break
op_._set_device(device_function(op_))
# pylint: enable=protected-access
# TODO(fkp): Establish clear policy about what context managers are allowed.
# assign to collection
if op is not op_:
self.assign_collections_handler(self._info, op, op_)
self._info.transformed_ops[op] = op_
return op_
def get_unique_graph(tops, check_types=None, none_if_empty=False):
"""Return the unique graph used by the all the elements in tops.
Args:
tops: list of elements to check (usually a list of tf.Operation and/or
tf.Tensor). Or a tf.Graph.
check_types: check that the element in tops are of given type(s). If None,
the types (tf.Operation, tf.Tensor) are used.
none_if_empty: don't raise an error if tops is an empty list, just return
None.
Returns:
The unique graph used by all the tops.
Raises:
TypeError: if tops is not a iterable of tf.Operation.
ValueError: if the graph is not unique.
"""
if isinstance(tops, tf_ops.Graph):
return tops
if not is_iterable(tops):
raise TypeError("{} is not iterable".format(type(tops)))
if check_types is None:
check_types = (tf_ops.Operation, tf_ops.Tensor)
elif not is_iterable(check_types):
check_types = (check_types,)
g = None
for op in tops:
if not isinstance(op, check_types):
raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
t) for t in check_types]), type(op)))
if g is None:
g = op.graph
elif g is not op.graph:
raise ValueError("Operation {} does not belong to given graph".format(op))
if g is None and not none_if_empty:
raise ValueError("Can't find the unique graph of an empty list")
return g
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
"""Convert ts to a list of `tf.Tensor`.
Args:
ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
check_graph: if `True` check if all the tensors belong to the same graph.
allow_graph: if `False` a `tf.Graph` cannot be converted.
ignore_ops: if `True`, silently ignore `tf.Operation`.
Returns:
A newly created list of `tf.Tensor`.
Raises:
TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
if `check_graph` is `True`, if all the ops do not belong to the same graph.
"""
if isinstance(ts, tf_ops.Graph):
if allow_graph:
return get_tensors(ts)
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(ts):
ts = [ts]
if not ts:
return []
if check_graph:
check_types = None if ignore_ops else tf_ops.Tensor
get_unique_graph(ts, check_types=check_types)
return [t for t in ts if isinstance(t, tf_ops.Tensor)]
def get_generating_ops(ts):
"""Return all the generating ops of the tensors in `ts`.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the generating `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
return [t.op for t in ts]
def testTrainingConstruction(self):
# pylint: disable=W0612
data = constant_op.constant(
[[random.uniform(-1, 1) for i in range(self.params.num_features)]
for _ in range(100)])
labels = [1 for _ in range(100)]
with variable_scope.variable_scope(
"DecisionsToDataThenNNTest_testTrainingContruction"):
graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
self.params)
graph = graph_builder.training_graph(data, labels, None)
self.assertTrue(isinstance(graph, Operation))
def testTrainingConstruction(self):
# pylint: disable=W0612
data = constant_op.constant(
[[random.uniform(-1, 1) for i in range(self.params.num_features)]
for _ in range(100)])
labels = [1 for _ in range(100)]
with variable_scope.variable_scope(
"ForestToDataThenNNTest.testTrainingContruction"):
graph_builder = forest_to_data_then_nn.ForestToDataThenNN(self.params)
graph = graph_builder.training_graph(data, labels, None)
self.assertTrue(isinstance(graph, Operation))
def op(self):
"""Method for compatibility with Tensor."""
node_def = graph_pb2.NodeDef()
node_def.name = "imperative-dummy-node"
node_def.input.extend(["dummy1", "dummy2", "dummy3"])
dummy_input1 = array_ops.placeholder(self.dtype)
dummy_input2 = array_ops.placeholder(self.dtype)
dummy_input3 = array_ops.placeholder(self.dtype)
dummy_op = tf_ops.Operation(node_def, tf_ops.Graph(), inputs=[dummy_input1,
dummy_input2,
dummy_input3])
return dummy_op
def op(self):
"""Method for compatibility with Tensor."""
node_def = graph_pb2.NodeDef()
node_def.name = "imperative-dummy-node"
node_def.input.extend(["dummy1", "dummy2", "dummy3"])
dummy_input1 = array_ops.placeholder(self.dtype)
dummy_input2 = array_ops.placeholder(self.dtype)
dummy_input3 = array_ops.placeholder(self.dtype)
dummy_op = tf_ops.Operation(node_def, tf_ops.Graph(), inputs=[dummy_input1,
dummy_input2,
dummy_input3])
return dummy_op
def _check_is_tensor_or_operation(x, name):
if not isinstance(x, (ops.Operation, ops.Tensor)):
raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
def _finalize_positive_filter(self, elem):
"""Convert to a filter function."""
if select.can_be_regex(elem):
regex_ = select.make_regex(elem)
return lambda op, regex=regex_: regex.search(op.name) is not None
elif isinstance(elem, tf_ops.Operation):
return lambda op, match_op=elem: op is match_op
elif callable(elem):
return elem
elif elem is True:
return lambda op: True
else:
raise ValueError("Cannot finalize the positive filter: {}".format(elem))