def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, swap the end of (t0,t1). B0 B1 B0 B1 | | => X A0 A1 A0 A1 Args: ts0: an object convertible to a list of tf.Tensor. ts1: an object convertible to a list of tf.Tensor. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, replace the end of t1 by the end of t0. B0 B1 B0 B1 | | => |/ A0 A1 A0 A1 The end of the tensors in ts1 are left dangling. Args: ts0: an object convertible to a list of tf.Tensor. ts1: an object convertible to a list of tf.Tensor. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None): r"""For each tensor's pair, replace the end of t0 by the end of t1. B0 B1 B0 B1 | | => \| A0 A1 A0 A1 The end of the tensors in ts0 are left dangling. Args: ts0: an object convertible to a list of tf.Tensor. ts1: an object convertible to a list of tf.Tensor. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
def remove_control_inputs(op, cops): """Remove the control inputs cops from co. Warning: this function is directly manipulating the internals of the tf.Graph. Args: op: a tf.Operation from which to remove the control inputs. cops: an object convertible to a list of tf.Operation. Raises: TypeError: if op is not a tf.Operation ValueError: if any cop in cops is not a control input of op. """ if not isinstance(op, tf_ops.Operation): raise TypeError("Expected a tf.Operation, got: {}", type(op)) cops = util.make_list_of_op(cops, allow_graph=False) for cop in cops: if cop not in op.control_inputs: raise ValueError("{} is not a control_input of {}".format(op.name, cop.name)) # pylint: disable=protected-access op._control_inputs = [cop for cop in op._control_inputs if cop not in cops] op._recompute_node_def() # pylint: enable=protected-access
def get_consuming_ops(ts): """Return all the consuming ops of the tensors in ts. Args: ts: a list of tf.Tensor Returns: A list of all the consuming tf.Operation of the tensors in ts. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor. """ ts = make_list_of_t(ts, allow_graph=False) ops = [] for t in ts: for op in t.consumers(): if op not in ops: ops.append(op) return ops
def get_copied_op(org_instance, graph, scope=""): """Given an `Operation` instance from some `Graph`, returns its namesake from `graph`, under the specified scope (default `""`). If a copy of `org_instance` is present in `graph` under the given `scope`, it will be returned. Args: org_instance: An `Operation` from some `Graph`. graph: The `Graph` to be searched for a copr of `org_instance`. scope: The scope `org_instance` is present in. Returns: The `Operation` copy from `graph`. """ #The name of the copied instance if scope != '': new_name = scope + '/' + org_instance.name else: new_name = org_instance.name return graph.as_graph_element(new_name, allow_tensor=True, allow_operation=True)
def testKFeatureTrainingConstruction(self): # pylint: disable=W0612 data = constant_op.constant( [[random.uniform(-1, 1) for i in range(self.params.num_features)] for _ in range(100)]) labels = [1 for _ in range(100)] with variable_scope.variable_scope( "KFeatureDecisionsToDataThenNNTest.testKFeatureTrainingContruction"): graph_builder = ( k_feature_decisions_to_data_then_nn.KFeatureDecisionsToDataThenNN( self.params)) graph = graph_builder.training_graph(data, labels, None) self.assertTrue(isinstance(graph, Operation))
def reroute_a2b_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, replace the end of t1 by the end of t0. B0 B1 B0 B1 | | => |/ A0 A1 A0 A1 The end of the tensors in ts1 are left dangling. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
def reroute_b2a_ts(ts0, ts1, can_modify=None, cannot_modify=None): r"""For each tensor's pair, replace the end of t0 by the end of t1. B0 B1 B0 B1 | | => \| A0 A1 A0 A1 The end of the tensors in ts0 are left dangling. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.b2a, can_modify, cannot_modify)
def remove_control_inputs(op, cops): """Remove the control inputs cops from co. Warning: this function is directly manipulating the internals of the `tf.Graph`. Args: op: a `tf.Operation` from which to remove the control inputs. cops: an object convertible to a list of `tf.Operation`. Raises: TypeError: if op is not a `tf.Operation`. ValueError: if any cop in cops is not a control input of op. """ if not isinstance(op, tf_ops.Operation): raise TypeError("Expected a tf.Operation, got: {}", type(op)) cops = util.make_list_of_op(cops, allow_graph=False) for cop in cops: if cop not in op.control_inputs: raise ValueError("{} is not a control_input of {}".format(op.name, cop.name)) # pylint: disable=protected-access op._control_inputs = [cop for cop in op._control_inputs if cop not in cops] op._recompute_node_def() # pylint: enable=protected-access
def add_control_inputs(op, cops): """Add the control inputs cops to co. Warning: this function is directly manipulating the internals of the tf.Graph. Args: op: a tf.Operation to which the control inputs are added. cops: an object convertible to a list of `tf.Operation`. Raises: TypeError: if op is not a tf.Operation ValueError: if any cop in cops is already a control input of op. """ if not isinstance(op, tf_ops.Operation): raise TypeError("Expected a tf.Operation, got: {}", type(op)) cops = util.make_list_of_op(cops, allow_graph=False) for cop in cops: if cop in op.control_inputs: raise ValueError("{} is already a control_input of {}".format(op.name, cop.name)) # pylint: disable=protected-access op._control_inputs += cops op._recompute_node_def() # pylint: enable=protected-access
def get_consuming_ops(ts): """Return all the consuming ops of the tensors in ts. Args: ts: a list of `tf.Tensor` Returns: A list of all the consuming `tf.Operation` of the tensors in `ts`. Raises: TypeError: if ts cannot be converted to a list of `tf.Tensor`. """ ts = make_list_of_t(ts, allow_graph=False) ops = [] for t in ts: for op in t.consumers(): if op not in ops: ops.append(op) return ops
def __init__(self, graph): """Create a dictionary of control-output dependencies. Args: graph: a `tf.Graph`. Returns: A dictionary where a key is a `tf.Operation` instance and the corresponding value is a list of all the ops which have the key as one of their control-input dependencies. Raises: TypeError: graph is not a `tf.Graph`. """ if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) self._control_outputs = {} self._graph = graph self._version = None self._build()
def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. A collection is renamed only if is not a known key, as described in `tf.GraphKeys`. Args: info: Transform._TmpInfo instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ known_collection_names = util.get_predefined_collection_names() for name, collection in iteritems(info.collections): if elem not in collection: continue if name in known_collection_names: transformed_name = name else: transformed_name = info.new_name(name) info.graph_.add_to_collection(transformed_name, elem_)
def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, swap the end of (t0,t1). B0 B1 B0 B1 | | => X A0 A1 A0 A1 Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): """For each tensor's pair, replace the end of t1 by the end of t0. B0 B1 B0 B1 | | => |/ A0 A1 A0 A1 The end of the tensors in ts1 are left dangling. Args: ts0: an object convertible to a list of `tf.Tensor`. ts1: an object convertible to a list of `tf.Tensor`. can_modify: iterable of operations which can be modified. Any operation outside within_ops will be left untouched by this function. cannot_modify: iterable of operations which cannot be modified. Any operation within cannot_modify will be left untouched by this function. Returns: The number of individual modifications made by the function. Raises: TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. TypeError: if can_modify or cannot_modify is not None and cannot be converted to a list of tf.Operation. """ return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
def build_results(self, values): if not values: # 'Operation' case return None else: return self._contraction_fn(values)
def __init__(self, graph, fetches, feeds, feed_handles=None): """Creates a fetch handler. Args: graph: Graph of the fetches. Used to check for fetchability and to convert all fetches to tensors or ops as needed. fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple, or dict. feeds: A feed dict where keys are Tensors. feed_handles: A dict from feed Tensors to TensorHandle objects used as direct feeds. """ with graph.as_default(): self._fetch_mapper = _FetchMapper.for_fetch(fetches) self._fetches = [] self._targets = [] self._feeds = feeds self._feed_handles = feed_handles or {} self._ops = [] self._fetch_handles = {} for fetch in self._fetch_mapper.unique_fetches(): if isinstance(fetch, ops.Operation): self._assert_fetchable(graph, fetch) self._targets.append(fetch) self._ops.append(True) else: self._assert_fetchable(graph, fetch.op) self._fetches.append(fetch) self._ops.append(False) # Remember the fetch if it is for a tensor handle. if (isinstance(fetch, ops.Tensor) and (fetch.op.type == 'GetSessionHandle' or fetch.op.type == 'GetSessionHandleV2')): self._fetch_handles[fetch] = fetch.op.inputs[0].dtype self._final_fetches = [x for x in self._fetches if x not in feeds]
def _assert_fetchable(self, graph, op): if not graph.is_fetchable(op): raise ValueError( 'Operation %r has been marked as not fetchable.' % op.name)
def _finalize_positive_filter(self, elem): """Convert to a filter function.""" if select.can_be_regex(elem): regex_ = select.make_regex(elem) return lambda op, regex=regex_: regex.search(op.name) is not None elif isinstance(elem, tf_ops.Operation): return lambda op, match_op=elem: op is match_op elif callable(elem): return elem elif elem is True: return lambda op: True else: raise ValueError("Cannot finalize the positive filter: {}".format(elem))
def __call__(self, op): """Evaluate if the op matches or not.""" if not isinstance(op, tf_ops.Operation): raise TypeError("Expect tf.Operation, got: {}".format(type(op))) for positive_filter in self.positive_filters: if not positive_filter(op): return False if self.input_op_matches is not None: if len(op.inputs) != len(self.input_op_matches): return False for input_t, input_op_match in zip(op.inputs, self.input_op_matches): if input_op_match is None: continue if not input_op_match(input_t.op): return False if self.control_input_op_matches is not None: if len(op.control_inputs) != len(self.control_input_op_matches): return False for cinput_op, cinput_op_match in zip(op.control_inputs, self.control_input_op_matches): if cinput_op_match is None: continue if not cinput_op_match(cinput_op): return False if self.output_op_matches is not None: if len(op.outputs) != len(self.output_op_matches): return False for output_t, output_op_matches in zip(op.outputs, self.output_op_matches): if output_op_matches is None: continue if len(output_t.consumers()) != len(output_op_matches): return False for consumer_op, consumer_op_match in zip(output_t.consumers(), output_op_matches): if consumer_op_match is None: continue if not consumer_op_match(consumer_op): return False return True
def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. Args: info: Transform._Info instance. elem: the original element (tf.Tensor or tf.Operation) elem_: the transformed element """ # TODO(fkp): handle known special cases for name, collection in iteritems( elem.graph._collections): # pylint: disable=protected-access if elem not in collection: continue collection_name_ = info.transformer.new_name(name) info.graph_.add_to_collection(collection_name_, elem_)
def _get_transformed_map(self, top): """Return the correct container depending on the type of `top`.""" if isinstance(top, tf_ops.Operation): return self._transformed_ops elif isinstance(top, tf_ops.Tensor): return self._transformed_ts else: raise TypeError( "Expected a tf.Tensor or a tf.Operation, got a {}".format( type(top)))
def _transform_op(self, op): """Transform a tf.Operation. Args: op: the operation to be transformed. Returns: The transformed operation. """ if op in self._info.transformed_ops: return self._info.transformed_ops[op] op_ = self.transform_op_handler(self._info, op) # Add to all the active control dependencies # pylint: disable=protected-access self._info.graph_._record_op_seen_by_control_dependencies(op_) # All to all the active devices for device_function in reversed(self._info.graph_._device_function_stack): if device_function is None: break op_._set_device(device_function(op_)) # pylint: enable=protected-access # TODO(fkp): Establish clear policy about what context managers are allowed. # assign to collection if op is not op_: self.assign_collections_handler(self._info, op, op_) self._info.transformed_ops[op] = op_ return op_
def get_unique_graph(tops, check_types=None, none_if_empty=False): """Return the unique graph used by the all the elements in tops. Args: tops: list of elements to check (usually a list of tf.Operation and/or tf.Tensor). Or a tf.Graph. check_types: check that the element in tops are of given type(s). If None, the types (tf.Operation, tf.Tensor) are used. none_if_empty: don't raise an error if tops is an empty list, just return None. Returns: The unique graph used by all the tops. Raises: TypeError: if tops is not a iterable of tf.Operation. ValueError: if the graph is not unique. """ if isinstance(tops, tf_ops.Graph): return tops if not is_iterable(tops): raise TypeError("{} is not iterable".format(type(tops))) if check_types is None: check_types = (tf_ops.Operation, tf_ops.Tensor) elif not is_iterable(check_types): check_types = (check_types,) g = None for op in tops: if not isinstance(op, check_types): raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( t) for t in check_types]), type(op))) if g is None: g = op.graph elif g is not op.graph: raise ValueError("Operation {} does not belong to given graph".format(op)) if g is None and not none_if_empty: raise ValueError("Can't find the unique graph of an empty list") return g
def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): """Convert ops to a list of tf.Operation. Args: ops: can be an iterable of tf.Operation, a tf.Graph or a single operation. check_graph: if True check if all the operations belong to the same graph. allow_graph: if False a tf.Graph cannot be converted. ignore_ts: if True, silently ignore tf.Tensor. Returns: A newly created list of tf.Operation. Raises: TypeError: if ops cannot be converted to a list of tf.Operation or, if check_graph is True, if all the ops do not belong to the same graph. """ if isinstance(ops, tf_ops.Graph): if allow_graph: return ops.get_operations() else: raise TypeError("allow_graph is False: cannot convert a tf.Graph.") else: if not is_iterable(ops): ops = [ops] if not ops: return [] if check_graph: check_types = None if ignore_ts else tf_ops.Operation get_unique_graph(ops, check_types=check_types) return [op for op in ops if isinstance(op, tf_ops.Operation)] # TODO(fkp): move this function in tf.Graph?
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): """Convert ts to a list of tf.Tensor. Args: ts: can be an iterable of tf.Tensor, a tf.Graph or a single tensor. check_graph: if True check if all the tensors belong to the same graph. allow_graph: if False a tf.Graph cannot be converted. ignore_ops: if True, silently ignore tf.Operation. Returns: A newly created list of tf.Tensor. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor or, if check_graph is True, if all the ops do not belong to the same graph. """ if isinstance(ts, tf_ops.Graph): if allow_graph: return get_tensors(ts) else: raise TypeError("allow_graph is False: cannot convert a tf.Graph.") else: if not is_iterable(ts): ts = [ts] if not ts: return [] if check_graph: check_types = None if ignore_ops else tf_ops.Tensor get_unique_graph(ts, check_types=check_types) return [t for t in ts if isinstance(t, tf_ops.Tensor)]
def get_generating_ops(ts): """Return all the generating ops of the tensors in ts. Args: ts: a list of tf.Tensor Returns: A list of all the generating tf.Operation of the tensors in ts. Raises: TypeError: if ts cannot be converted to a list of tf.Tensor. """ ts = make_list_of_t(ts, allow_graph=False) return [t.op for t in ts]
def testTrainingConstruction(self): # pylint: disable=W0612 data = constant_op.constant( [[random.uniform(-1, 1) for i in range(self.params.num_features)] for _ in range(100)]) labels = [1 for _ in range(100)] with variable_scope.variable_scope( "DecisionsToDataThenNNTest_testTrainingContruction"): graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN( self.params) graph = graph_builder.training_graph(data, labels, None) self.assertTrue(isinstance(graph, Operation))
def testTrainingConstruction(self): # pylint: disable=W0612 data = constant_op.constant( [[random.uniform(-1, 1) for i in range(self.params.num_features)] for _ in range(100)]) labels = [1 for _ in range(100)] with variable_scope.variable_scope( "ForestToDataThenNNTest.testTrainingContruction"): graph_builder = forest_to_data_then_nn.ForestToDataThenNN(self.params) graph = graph_builder.training_graph(data, labels, None) self.assertTrue(isinstance(graph, Operation))
def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. Args: info: Transform._Info instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ # TODO(fkp): handle known special cases for name, collection in iteritems( elem.graph._collections): # pylint: disable=protected-access if elem not in collection: continue collection_name_ = info.transformer.new_name(name) info.graph_.add_to_collection(collection_name_, elem_)
def __init__(self): """Transformer constructor. The following members can be modified: transform_op_handler: handle the transformation of a `tf.Operation`. This handler defaults to a simple copy. assign_collections_handler: handle the assignment of collections. This handler defaults to assigning new collections created under the given name-scope. transform_external_input_handler: handle the transform of the inputs to the given subgraph. This handler defaults to creating placeholders instead of the ops just before the input tensors of the subgraph. transform_external_hidden_input_handler: handle the transform of the hidden inputs of the subgraph, that is, the inputs which are not listed in sgv.inputs. This handler defaults to a transform which keep the same input if the source and destination graphs are the same, otherwise use placeholders. transform_original_op_handler: handle the transform of original_op. This handler defaults to transforming original_op only if they are in the subgraph, otherwise they are ignored. """ # handlers self.transform_op_handler = copy_op_handler self.transform_control_input_handler = transform_op_if_inside_handler self.assign_collections_handler = assign_renamed_collections_handler self.transform_external_input_handler = replace_t_with_placeholder_handler self.transform_external_hidden_input_handler = keep_t_if_possible_handler self.transform_original_op_handler = transform_op_if_inside_handler # temporary per-call variable self._info = None
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): """Convert ts to a list of `tf.Tensor`. Args: ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. check_graph: if `True` check if all the tensors belong to the same graph. allow_graph: if `False` a `tf.Graph` cannot be converted. ignore_ops: if `True`, silently ignore `tf.Operation`. Returns: A newly created list of `tf.Tensor`. Raises: TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, if `check_graph` is `True`, if all the ops do not belong to the same graph. """ if isinstance(ts, tf_ops.Graph): if allow_graph: return get_tensors(ts) else: raise TypeError("allow_graph is False: cannot convert a tf.Graph.") else: if not is_iterable(ts): ts = [ts] if not ts: return [] if check_graph: check_types = None if ignore_ops else tf_ops.Tensor get_unique_graph(ts, check_types=check_types) return [t for t in ts if isinstance(t, tf_ops.Tensor)]
def get_generating_ops(ts): """Return all the generating ops of the tensors in `ts`. Args: ts: a list of `tf.Tensor` Returns: A list of all the generating `tf.Operation` of the tensors in `ts`. Raises: TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. """ ts = make_list_of_t(ts, allow_graph=False) return [t.op for t in ts]
def op(self): """Method for compatibility with Tensor.""" node_def = graph_pb2.NodeDef() node_def.name = "imperative-dummy-node" node_def.input.extend(["dummy1", "dummy2", "dummy3"]) dummy_input1 = array_ops.placeholder(self.dtype) dummy_input2 = array_ops.placeholder(self.dtype) dummy_input3 = array_ops.placeholder(self.dtype) dummy_op = tf_ops.Operation(node_def, tf_ops.Graph(), inputs=[dummy_input1, dummy_input2, dummy_input3]) return dummy_op
def _check_is_tensor_or_operation(x, name): if not isinstance(x, (ops.Operation, ops.Tensor)): raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))