我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.Graph()。
def __init__(self, target='', graph=None, config=None): """Creates a new TensorFlow session. If no `graph` argument is specified when constructing the session, the default graph will be launched in the session. If you are using more than one graph (created with `tf.Graph()` in the same process, you will have to use different sessions for each graph, but each graph can be used in multiple sessions. In this case, it is often clearer to pass the graph to be launched explicitly to the session constructor. Args: target: (Optional.) The execution engine to connect to. Defaults to using an in-process engine. See @{$distributed$Distributed TensorFlow} for more examples. graph: (Optional.) The `Graph` to be launched (described above). config: (Optional.) A [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) protocol buffer with configuration options for the session. """ super(Session, self).__init__(target, graph, config=config) # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle. self._default_graph_context_manager = None self._default_session_context_manager = None
def __copy__(self): """Create a copy of this subgraph. Note that this class is a "view", copying it only create another view and does not copy the underlying part of the tf.Graph. Returns: A new identical instance of the original subgraph view. """ cls = self.__class__ result = cls.__new__(cls) for k, v in iteritems(self.__dict__): if k == "_graph": setattr(result, k, v) else: setattr(result, k, list(v)) # copy the list return result
def remap_inputs(self, new_input_indices): """Remap the inputs of the subgraph. If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0] will create a new instance whose inputs is [t2, t0]. Note that this is only modifying the view: the underlying tf.Graph is not affected. Args: new_input_indices: an iterable of integers representing a mapping between the old inputs and the new ones. This mapping can be under-complete and must be without repetitions. Returns: A new modified instance of the original subgraph view with remapped inputs. """ res = self.copy() res._remap_inputs(new_input_indices) # pylint: disable=protected-access return res
def remap_outputs(self, new_output_indices): """Remap the output of the subgraph. If the output of the original subgraph are [t0, t1, t2], remapping to [1,1,0] will create a new instance whose outputs is [t1, t1, t0]. Note that this is only modifying the view: the underlying tf.Graph is not affected. Args: new_output_indices: an iterable of integers representing a mapping between the old outputs and the new ones. This mapping can be under-complete and can have repetitions. Returns: A new modified instance of the original subgraph view with remapped outputs. """ res = copy.copy(self) res._remap_outputs(new_output_indices) # pylint: disable=protected-access return res
def remap(self, new_input_indices=None, new_output_indices=None): """Remap the inputs and outputs of the subgraph. Note that this is only modifying the view: the underlying tf.Graph is not affected. Args: new_input_indices: an iterable of integers representing a mapping between the old inputs and the new ones. This mapping can be under-complete and must be without repetitions. new_output_indices: an iterable of integers representing a mapping between the old outputs and the new ones. This mapping can be under-complete and can have repetitions. Returns: A new modified instance of the original subgraph view with remapped inputs and outputs. """ res = copy.copy(self) if new_input_indices is not None: res._remap_inputs(new_input_indices) # pylint: disable=protected-access if new_output_indices is not None: res._remap_outputs(new_output_indices) # pylint: disable=protected-access return res
def _check_graph(sgv, graph): """Check if sgv belongs to the given graph. Args: sgv: a SubGraphView. graph: a graph or None. Returns: The SubGraphView sgv. Raises: TypeError: if sgv is not a SubGraphView or if graph is not None and not a tf.Graph. ValueError: if the graph of sgv and the given graph are not None and different. """ if not isinstance(sgv, SubGraphView): raise TypeError("Expected a SubGraphView, got: {}".format(type(graph))) if graph is None or not sgv.graph: return sgv if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) if sgv.graph is not graph: raise ValueError("Graph mismatch.") return sgv
def select_ops_and_ts(*args, **kwargs): """Helper to select operations and tensors. Args: *args: list of 1) regular expressions (compiled or not) or 2) (array of) tf.Operation 3) (array of) tf.Tensor. Regular expressions matching tensors must start with the comment "(?#ts)", for instance: "(?#ts)^foo/.*". **kwargs: 'graph': tf.Graph in which to perform the regex query.This is required when using regex. 'positive_filter': an elem if selected only if positive_filter(elem) is True. This is optional. Returns: A tuple `(ops, ts)` where: `ops` is a list of tf.Operation `ts` is a list of tf.Tensor Raises: TypeError: if the optional keyword argument graph is not a tf.Graph or if an argument in args is not an (array of) tf.Tensor or an (array of) tf.Operation or a string or a regular expression. ValueError: if one of the keyword arguments is unexpected or if a regular expression is used without passing a graph as a keyword argument. """ ops = select_ops(*args, restrict_ops_regex=False, **kwargs) ts = select_ts(*args, restrict_ts_regex=True, **kwargs) return ops, ts
def get_tensors(graph): """get all the tensors which are input or output of an op in the graph. Args: graph: a tf.Graph. Returns: A list of tf.Tensor. Raises: TypeError: if graph is not a tf.Graph. """ if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a graph, got: {}".format(type(graph))) ts = [] for op in graph.get_operations(): ts += op.outputs return ts
def make_placeholder_from_tensor(t, scope=None): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. Args: t: a tf.Tensor whose name will be used to create the placeholder (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. Returns: A newly created tf.placeholder. Raises: TypeError: if t is not None or a tf.Tensor. """ return tf_array_ops.placeholder(dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(t, scope=scope))
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. The placeholder is named using the function placeholder_name (with no tensor argument). Args: dtype: the tensor type. shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder(dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
def infer_real_valued_columns_from_input_fn(input_fn): """Creates `FeatureColumn` objects for inputs defined by `input_fn`. This interprets all inputs as dense, fixed-length float values. This creates a local graph in which it calls `input_fn` to build the tensors, then discards it. Args: input_fn: Input function returning a tuple of: features - Dictionary of string feature name to `Tensor` or `Tensor`. target - `Tensor` of target objects. Returns: List of `FeatureColumn` objects. """ with ops.Graph().as_default(): features, _ = input_fn() return layers.infer_real_valued_columns(features)
def __copy__(self): """Create a copy of this subgraph. Note that this class is a "view", copying it only create another view and does not copy the underlying part of the `tf.Graph`. Returns: A new identical instance of the original subgraph view. """ cls = self.__class__ result = cls.__new__(cls) for k, v in iteritems(self.__dict__): if k == "_graph": setattr(result, k, v) else: setattr(result, k, list(v)) # copy the list return result
def remap_inputs(self, new_input_indices): """Remap the inputs of the subgraph. If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0] will create a new instance whose inputs is [t2, t0]. Note that this is only modifying the view: the underlying `tf.Graph` is not affected. Args: new_input_indices: an iterable of integers representing a mapping between the old inputs and the new ones. This mapping can be under-complete and must be without repetitions. Returns: A new modified instance of the original subgraph view with remapped inputs. """ res = self.copy() res._remap_inputs(new_input_indices) # pylint: disable=protected-access return res
def get_tensors(graph): """get all the tensors which are input or output of an op in the graph. Args: graph: a `tf.Graph`. Returns: A list of `tf.Tensor`. Raises: TypeError: if graph is not a `tf.Graph`. """ if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a graph, got: {}".format(type(graph))) ts = [] for op in graph.get_operations(): ts += op.outputs return ts
def __init__(self, graph): """Create a dictionary of control-output dependencies. Args: graph: a `tf.Graph`. Returns: A dictionary where a key is a `tf.Operation` instance and the corresponding value is a list of all the ops which have the key as one of their control-input dependencies. Raises: TypeError: graph is not a `tf.Graph`. """ if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) self._control_outputs = {} self._graph = graph self._version = None self._build()
def make_placeholder_from_tensor(t, scope=None): """Create a `tf.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. Args: t: a `tf.Tensor` whose name will be used to create the placeholder (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of `t` is preserved. `""` means the root scope. Returns: A newly created `tf.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ return tf_array_ops.placeholder( dtype=t.dtype, shape=t.get_shape(), name=placeholder_name( t, scope=scope))
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. The placeholder is named using the function placeholder_name (with no tensor argument). Args: dtype: the tensor type. shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder( dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
def infer_real_valued_columns_from_input_fn(input_fn): """Creates `FeatureColumn` objects for inputs defined by `input_fn`. This interprets all inputs as dense, fixed-length float values. This creates a local graph in which it calls `input_fn` to build the tensors, then discards it. Args: input_fn: Input function returning a tuple of: features - Dictionary of string feature name to `Tensor` or `Tensor`. labels - `Tensor` of label values. Returns: List of `FeatureColumn` objects. """ with ops.Graph().as_default(): features, _ = input_fn() return layers.infer_real_valued_columns(features)
def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) # New graph and session. with ops.Graph().as_default() as g: with self.test_session(graph=g) as session: with variable_scope.variable_scope("some_scope"): my1 = variable_scope.get_variable("var1", [1, 10]) my2 = variable_scope.get_variable("var2", [10, 10]) my3 = variable_scope.get_variable("var3", [100, 100]) with variable_scope.variable_scope("useful_scope"): my4 = variable_scope.get_variable("var4", [9, 9]) checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"/": "some_scope/",}) session.run(variables.global_variables_initializer()) self.assertAllEqual(my1.eval(session), v1) self.assertAllEqual(my2.eval(session), v2) self.assertAllEqual(my3.eval(session), v3) self.assertAllEqual(my4.eval(session), v4)
def testDeviceFn(self): class DevFn(object): def __init__(self): self.counter = -1 def __call__(self, op): self.counter += 1 return '/cpu:%d' % self.counter with ops.Graph().as_default(): with arg_scope([variables_lib2.model_variable], device=DevFn()): a = variables_lib2.model_variable('a', [5]) b = variables_lib2.model_variable('b', [20]) self.assertDeviceEqual(a.device, '/cpu:0') self.assertEqual(a.initial_value.op.colocation_groups(), a.op.colocation_groups()) self.assertDeviceEqual(b.device, '/cpu:1') self.assertEqual(b.initial_value.op.colocation_groups(), b.op.colocation_groups())
def create_checkpoint_from_values(self, var_names_to_values, checkpoint_dir, global_step=None): """Creates a checkpoint from a mapping of name to values in model_dir. Args: var_names_to_values: a map from variable names to values. checkpoint_dir: the directory where the checkpoint will be saved. global_step: the global step used to save the checkpoint. Returns: the model_path to the checkpoint. """ var_list = [] with session.Session('', graph=ops.Graph()) as sess: # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] var_list.append(variables_lib.Variable(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) # Save the initialized values in the file at 'checkpoint_dir' return saver.save(sess, checkpoint_dir, global_step=global_step)
def test_copy(self): graph = ops.Graph() _, info = ge.copy(self.graph, graph) self.assertEqual( set(op.name for op in self.graph.get_operations()), set(op.name for op in graph.get_operations())) src_ops = self.graph.get_operations() dst_ops = graph.get_operations() for op in src_ops: op_ = info.transformed(op) self.assertTrue(op_ in dst_ops) self.assertEqual(op.name, op_.name) self.assertEqual(info.original(op_), op) src_ts = ge.util.get_tensors(self.graph) dst_ts = ge.util.get_tensors(graph) for t in src_ts: t_ = info.transformed(t) self.assertTrue(t_ in dst_ts) self.assertEqual(t.name, t_.name) self.assertEqual(info.original(t_), t)
def test_placeholder(self): """Test placeholder functionalities.""" g0 = ops.Graph() with g0.as_default(): a0 = constant_op.constant(1, name="foo") # Test placeholder name. self.assertEqual(ge.util.placeholder_name(a0), "geph__foo_0") self.assertEqual(ge.util.placeholder_name(None), "geph") self.assertEqual( ge.util.placeholder_name( a0, scope="foo/"), "foo/geph__foo_0") self.assertEqual( ge.util.placeholder_name( a0, scope="foo"), "foo/geph__foo_0") self.assertEqual(ge.util.placeholder_name(None, scope="foo/"), "foo/geph") self.assertEqual(ge.util.placeholder_name(None, scope="foo"), "foo/geph") # Test placeholder creation. g0 = ops.Graph() with g0.as_default(): a0 = constant_op.constant(1, dtype=dtypes.float32, name="a0") c0 = math_ops.add( ge.util.make_placeholder_from_tensor(a0), ge.util.make_placeholder_from_dtype_and_shape(dtype=dtypes.float32)) self.assertEqual(c0.op.inputs[0].op.name, "geph__a0_0") self.assertEqual(c0.op.inputs[1].op.name, "geph")
def test_reroute_can_modify(self): graph = ops.Graph() # create a special graph where "a" is an ambiguous tensor. That is # it is both an input and an output of the ops in sgv0. with graph.as_default(): a = constant_op.constant(1.0, shape=[2], name="a") b = constant_op.constant(2.0, shape=[2], name="b") c = math_ops.add(a, b, name="c") d = math_ops.add(a, c, name="d") e = constant_op.constant(1.0, shape=[2], name="e") f = constant_op.constant(2.0, shape=[2], name="f") g = math_ops.add(e, f, name="g") sgv0 = ge.sgv(a.op, b.op, c.op) sgv1 = ge.sgv(e.op, f.op) ge.swap_outputs(sgv0, sgv1) self.assertTrue( ge.OpMatcher("g").input_ops("a", ge.OpMatcher("c").input_ops("a", "b"))( g.op)) self.assertTrue(ge.OpMatcher("d").input_ops("e", "f")(d.op))
def testGradientWithZeroWeight(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) inputs = array_ops.ones((2, 3)) weights = variable_scope.get_variable( 'weights', shape=[3, 4], initializer=init_ops.truncated_normal_initializer()) predictions = math_ops.matmul(inputs, weights) optimizer = momentum_lib.MomentumOptimizer( learning_rate=0.001, momentum=0.9) loss = loss_ops.mean_pairwise_squared_error(predictions, predictions, 0) gradients_to_variables = optimizer.compute_gradients(loss) init_op = variables.global_variables_initializer() with self.test_session() as sess: sess.run(init_op) for grad, _ in gradients_to_variables: np_grad = sess.run(grad) self.assertFalse(np.isnan(np_grad).any())
def testNumpySource(self): batch_size = 3 iterations = 1000 array = np.arange(32).reshape([16, 2]) numpy_source = in_memory_source.NumpySource(array, batch_size=batch_size) index_column = numpy_source().index value_column = numpy_source().value cache = {} with ops.Graph().as_default(): value_tensor = value_column.build(cache) index_tensor = index_column.build(cache) with session.Session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) for i in range(iterations): expected_index = [ j % array.shape[0] for j in range(batch_size * i, batch_size * (i + 1)) ] expected_value = get_rows(array, expected_index) actual_index, actual_value = sess.run([index_tensor, value_tensor]) np.testing.assert_array_equal(expected_index, actual_index) np.testing.assert_array_equal(expected_value, actual_value) coord.request_stop() coord.join(threads)
def testArrayFeedingMultiThread(self): with ops.Graph().as_default(): array = np.arange(256).reshape([128, 2]) q = ff.enqueue_data(array, capacity=128, num_threads=8, shuffle=True) batch_size = 3 dq_op = q.dequeue_many(batch_size) with session.Session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) for _ in range(100): dq = sess.run(dq_op) indices = dq[0] expected_dq = get_rows(array, indices) np.testing.assert_array_equal(expected_dq, dq[1]) coord.request_stop() coord.join(threads)
def testPandasFeedingMultiThread(self): if not HAS_PANDAS: return with ops.Graph().as_default(): array1 = np.arange(128, 256) array2 = 2 * array1 df = pd.DataFrame({"a": array1, "b": array2}, index=np.arange(128)) q = ff.enqueue_data(df, capacity=128, num_threads=8, shuffle=True) batch_size = 5 dq_op = q.dequeue_many(batch_size) with session.Session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) for _ in range(100): dq = sess.run(dq_op) indices = dq[0] expected_rows = df.iloc[indices] for col_num, col in enumerate(df.columns): np.testing.assert_array_equal(expected_rows[col].values, dq[col_num + 1]) coord.request_stop() coord.join(threads)
def test_evaluate_invalid_args(self): with ops.Graph().as_default() as g, self.test_session(g): self._assert_ckpt(self._output_dir, False) with self.assertRaisesRegexp(ValueError, 'utput directory'): learn.graph_actions.evaluate( g, output_dir=None, checkpoint_path=None, eval_dict={'a': constant_op.constant(1.0)}) with self.assertRaisesRegexp(ValueError, 'utput directory'): learn.graph_actions.evaluate( g, output_dir='', checkpoint_path=None, eval_dict={'a': constant_op.constant(1.0)}) self._assert_ckpt(self._output_dir, False)
def test_evaluate_ready_for_local_init(self): with ops.Graph().as_default() as g, self.test_session(g): variables_lib.create_global_step() v = variables.Variable(1.0) w = variables.Variable( v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False) ready_for_local_init_op = variables.report_uninitialized_variables( variables.global_variables()) ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, ready_for_local_init_op) _ = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, eval_dict={'a': v}, max_steps=1)
def test_evaluate_feed_fn(self): with ops.Graph().as_default() as g, self.test_session(g): in0, _, out = self._build_inference_graph() writer = learn.graph_actions.get_summary_writer(self._output_dir) self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) self._assert_ckpt(self._output_dir, False) feeder = _Feeder(in0, 3) results = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, eval_dict={'a': out}, feed_fn=feeder.feed_fn, max_steps=3) self.assertEqual(3, feeder.step) self.assertEqual(({'a': 25.0}, 0), results) self._assert_summaries( self._output_dir, writer, expected_summaries={0: { 'a': 25.0 }}, expected_session_logs=[]) self._assert_ckpt(self._output_dir, False)
def test_evaluate_feed_fn_with_exhaustion(self): with ops.Graph().as_default() as g, self.test_session(g): in0, _, out = self._build_inference_graph() writer = learn.graph_actions.get_summary_writer(self._output_dir) self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) feeder = _Feeder(in0, 2) results = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, eval_dict={'a': out}, feed_fn=feeder.feed_fn, max_steps=3) self.assertEqual(2, feeder.step) self.assertEqual(({'a': 15.0}, 0), results) self._assert_summaries( self._output_dir, writer, expected_summaries={0: { 'a': 15.0 }}, expected_session_logs=[])
def test_evaluate_with_saver(self): with ops.Graph().as_default() as g, self.test_session(g): _, _, out = self._build_inference_graph() ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver()) writer = learn.graph_actions.get_summary_writer(self._output_dir) self._assert_summaries(self._output_dir, writer, expected_session_logs=[]) results = learn.graph_actions.evaluate( g, output_dir=self._output_dir, checkpoint_path=None, eval_dict={'a': out}, max_steps=1) self.assertEqual(({'a': 6.0}, 0), results) self._assert_summaries( self._output_dir, writer, expected_summaries={0: { 'a': 6.0 }}, expected_session_logs=[])
def test_train_loss(self): with ops.Graph().as_default() as g, self.test_session(g): variables_lib.create_global_step() loss_var = variables_lib.local_variable(10.0) train_op = control_flow_ops.group( state_ops.assign_add(variables_lib.get_global_step(), 1), state_ops.assign_add(loss_var, -1.0)) writer = learn.graph_actions.get_summary_writer(self._output_dir) self._assert_summaries(self._output_dir, writer) self._assert_ckpt(self._output_dir, False) loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access g, output_dir=self._output_dir, train_op=train_op, loss_op=loss_var.value(), steps=6) self.assertEqual(4.0, loss) self._assert_summaries( self._output_dir, writer, expected_graphs=[g], expected_meta_graphs=None) self._assert_ckpt(self._output_dir, True)
def test_train(self): with ops.Graph().as_default() as g, self.test_session(g): with ops.control_dependencies(self._build_inference_graph()): train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) self._assert_summaries(self._output_dir) self._assert_ckpt(self._output_dir, False) loss = learn.graph_actions.train( g, output_dir=self._output_dir, train_op=train_op, loss_op=constant_op.constant(2.0), steps=1) # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the # SaverDef, so we can't add it to the summary assertion test below. # meta_graph_def = meta_graph.create_meta_graph_def() self.assertEqual(2.0, loss) self._assert_summaries(self._output_dir, expected_graphs=[g]) self._assert_ckpt(self._output_dir, True)
def test_train_loss(self): with ops.Graph().as_default() as g, self.test_session(g): variables_lib.create_global_step() loss_var = variables_lib.local_variable(10.0) train_op = control_flow_ops.group( state_ops.assign_add(variables_lib.get_global_step(), 1), state_ops.assign_add(loss_var, -1.0)) self._assert_summaries(self._output_dir) self._assert_ckpt(self._output_dir, False) loss = learn.graph_actions.train( g, output_dir=self._output_dir, train_op=train_op, loss_op=loss_var.value(), steps=6) # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the # SaverDef, so we can't add it to the summary assertion test below. # meta_graph_def = meta_graph.create_meta_graph_def() self.assertEqual(4.0, loss) self._assert_summaries(self._output_dir, expected_graphs=[g]) self._assert_ckpt(self._output_dir, True)
def test_train_chief_monitor(self): with ops.Graph().as_default() as g, self.test_session(g): with ops.control_dependencies(self._build_inference_graph()): train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) loss_op = constant_op.constant(2.0) summary.scalar('loss', loss_op) chief_exclusive_monitor = _BaseMonitorWrapper(False) all_workers_monitor = _BaseMonitorWrapper(True) loss = learn.graph_actions.train( g, output_dir=self._output_dir, train_op=train_op, loss_op=loss_op, supervisor_is_chief=True, steps=1, monitors=[chief_exclusive_monitor, all_workers_monitor]) self.assertEqual(2.0, loss) self.assertTrue(chief_exclusive_monitor.is_active and all_workers_monitor.is_active, 'All monitors must have been active.') self.assertTrue(chief_exclusive_monitor.has_step and all_workers_monitor.has_step, 'All monitors must have a step.')
def testRegressionWithLogitsInput(self): head = head_lib._regression_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, labels=((0.,), (1.,), (1.,)), mode=model_fn.ModeKeys.TRAIN, train_op_fn=_noop_train_op, logits_input=((0., 0.), (0., 0.), (0., 0.))) self._assert_output_alternatives(model_fn_ops) w = ("regression_head/logits/weights:0", "regression_head/logits/biases:0") _assert_variables( self, expected_global=w, expected_model=w, expected_trainable=w) variables.global_variables_initializer().run() _assert_summary_tags(self, ["loss"]) _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
def testRegressionWithCenteredBias(self): head = head_lib._regression_head(enable_centered_bias=True) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, labels=((0.,), (1.,), (1.,)), mode=model_fn.ModeKeys.TRAIN, train_op_fn=_noop_train_op, logits=((1.,), (1.,), (3.,))) self._assert_output_alternatives(model_fn_ops) _assert_variables( self, expected_global=( "regression_head/centered_bias_weight:0", "regression_head/regression_head/centered_bias_weight/Adagrad:0", ), expected_trainable=("regression_head/centered_bias_weight:0",)) variables.global_variables_initializer().run() _assert_summary_tags( self, ["loss", "regression_head/centered_bias/bias_0"]) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)