Python tensorflow.python.framework.ops 模块,add_to_collection() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用tensorflow.python.framework.ops.add_to_collection()

项目:lsdc    作者:febert    | 项目源码 | 文件源码
def register_prior(variational, prior):
  """Associate a variational `DistributionTensor` with a `Distribution` prior.

  This is a helper function used in conjunction with `elbo` that allows users
  to specify the mapping between variational distributions and their priors
  without having to pass in `variational_with_prior` explicitly.

  Args:
    variational: `DistributionTensor` q(Z). Approximating distribution.
    prior: `Distribution` p(Z). Prior distribution.

  Returns:
    None

  Raises:
    ValueError: if variational is not a `DistributionTensor` or `prior` is not
      a `Distribution`.
  """
  if not isinstance(variational, st.StochasticTensor):
    raise TypeError("variational must be a DistributionTensor")
  if not isinstance(prior, distributions.BaseDistribution):
    raise TypeError("prior must be a BaseDistribution")
  ops.add_to_collection(VI_PRIORS, (variational, prior))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def initialize(self, table):
    """Initializes the given `table` with `keys` and `values` tensors.

    Args:
      table: The table to initialize.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    # pylint: disable=protected-access
    table._check_table_dtypes(self._keys.dtype, self._values.dtype)
    with ops.name_scope(self._name, values=[table]) as scope:
      init_op = gen_data_flow_ops._initialize_table(table.table_ref,
                                                    self._keys,
                                                    self._values,
                                                    name=scope)
    # pylint: enable=protected-access
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    return init_op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def register_prior(variational, prior):
  """Associate a variational `StochasticTensor` with a `Distribution` prior.

  This is a helper function used in conjunction with `elbo` that allows users
  to specify the mapping between variational distributions and their priors
  without having to pass in `variational_with_prior` explicitly.

  Args:
    variational: `StochasticTensor` q(Z). Approximating distribution.
    prior: `Distribution` p(Z). Prior distribution.

  Returns:
    None

  Raises:
    ValueError: if variational is not a `StochasticTensor` or `prior` is not
      a `Distribution`.
  """
  if not isinstance(variational, st.StochasticTensor):
    raise TypeError("variational must be a StochasticTensor")
  if not isinstance(prior, distributions.Distribution):
    raise TypeError("prior must be a Distribution")
  ops.add_to_collection(VI_PRIORS, (variational, prior))
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def initialize(self, table):
    """Initializes the given `table` with `keys` and `values` tensors.

    Args:
      table: The table to initialize.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    # pylint: disable=protected-access
    table._check_table_dtypes(self._keys.dtype, self._values.dtype)
    with ops.name_scope(self._name, values=[table]) as scope:
      init_op = gen_data_flow_ops._initialize_table(table.table_ref,
                                                    self._keys,
                                                    self._values,
                                                    name=scope)
    # pylint: enable=protected-access
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    return init_op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_train_override_saver(self):
    with tf.Graph().as_default() as g, self.test_session(g):
      saver = tf.test.mock.Mock()
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      with tf.control_dependencies(self._build_inference_graph()):
        train_op = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
      self._assert_ckpt(self._output_dir, False)
      loss = learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=tf.constant(2.0),
          steps=1)
      self.assertEqual(2.0, loss)
      self._assert_ckpt(self._output_dir, False)
      self.assertTrue(saver.build.called)
      self.assertEqual(1, saver.save.call_count)


# TODO(ispir): remove following tests after deprecated train.
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:ChineseNER    作者:zjy-ucas    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:LSTM-CRF-For-Named-Entity-Recognition    作者:zpppy    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:DL-Benchmarks    作者:DL-Benchmarks    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:num-seq-recognizer    作者:gmlove    | 项目源码 | 文件源码
def fc_layers(net,
              scope,
              end_points_collection,
              num_classes=1000,
              is_training=True,
              dropout_keep_prob=0.5,
              spatial_squeeze=True,
              name_prefix=None):
  full_scope_name = lambda scope_name: scope_name if name_prefix is None else '%s_%s' % (name_prefix, scope_name)
  # Use conv2d instead of fully_connected layers.
  with slim.arg_scope([slim.conv2d],
                      weights_initializer=trunc_normal(0.005),
                      biases_initializer=tf.constant_initializer(0.1),
                      outputs_collections=[end_points_collection]):
    net = slim.conv2d(net, num_classes, [1, 1],
                      activation_fn=None,
                      normalizer_fn=None,
                      biases_initializer=tf.zeros_initializer(),
                      scope=full_scope_name('fc8'))

  if spatial_squeeze:
    net = tf.squeeze(net, [1, 2], name=full_scope_name('fc8/squeezed'))
    ops.add_to_collection(end_points_collection, net)
  return net, end_points_collection
项目:PLSTM    作者:Enny1991    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
    """Get a sharded variable concatenated into one tensor."""
    sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
    if len(sharded_variable) == 1:
        return sharded_variable[0]

    concat_name = name + "/concat"
    concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
    for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
        if value.name == concat_full_name:
            return value

    concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
    ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                          concat_variable)
    return concat_variable
项目:diversity_based_attention    作者:PrekshaNema25    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:ROLO    作者:Guanghan    | 项目源码 | 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
  """Get a sharded variable concatenated into one tensor."""
  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
  if len(sharded_variable) == 1:
    return sharded_variable[0]

  concat_name = name + "/concat"
  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
    if value.name == concat_full_name:
      return value

  concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                        concat_variable)
  return concat_variable
项目:Optimization    作者:tdozat    | 项目源码 | 文件源码
def _create_slots(self, grads_and_vars):
    """"""

    for g_t, x_tm1 in grads_and_vars:
      if self._save_step:
        self._ones_slot(x_tm1, 's', self._name)
      if self._save_grad:
        self._ones_slot(x_tm1, 'g', self._name)
      if self._chi > 0:
        ops.add_to_collection(self._zeros_slot(x_tm1, 'x', self._name),
                              ops.GraphKeys.MOVING_AVERAGE_VARIABLES)
        if isinstance(g_t, ops.Tensor):
          self._zero_slot(x_tm1, 'x/tm1', self._name)
        else:
          self._zeros_idx_slot(x_tm1, 'x/tm1', self._name)

  #=============================================================
项目:Optimization    作者:tdozat    | 项目源码 | 文件源码
def _create_slots(self, grads_and_vars):
    """"""

    for g_t, x_tm1 in grads_and_vars:
      if self._save_step:
        self._ones_slot(x_tm1, 's', self._name)
      if self._save_grad:
        self._ones_slot(x_tm1, 'g', self._name)
      if self._chi > 0:
        ops.add_to_collection(self._zeros_slot(x_tm1, 'x', self._name),
                              ops.GraphKeys.MOVING_AVERAGE_VARIABLES)
        if isinstance(g_t, ops.Tensor):
          self._zero_slot(x_tm1, 'x/tm1', self._name)
        else:
          self._zeros_idx_slot(x_tm1, 'x/tm1', self._name)

  #=============================================================
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def register_prior(variational, prior):
  """Associate a variational `StochasticTensor` with a `Distribution` prior.

  This is a helper function used in conjunction with `elbo` that allows users
  to specify the mapping between variational distributions and their priors
  without having to pass in `variational_with_prior` explicitly.

  Args:
    variational: `StochasticTensor` q(Z). Approximating distribution.
    prior: `Distribution` p(Z). Prior distribution.

  Returns:
    None

  Raises:
    ValueError: if variational is not a `StochasticTensor` or `prior` is not
      a `Distribution`.
  """
  if not isinstance(variational, st.StochasticTensor):
    raise TypeError("variational must be a StochasticTensor")
  if not isinstance(prior, distribution.Distribution):
    raise TypeError("prior must be a Distribution")
  ops.add_to_collection(VI_PRIORS, (variational, prior))
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def initialize(self, table):
    """Initializes the given `table` with `keys` and `values` tensors.

    Args:
      table: The table to initialize.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    table.check_table_dtypes(self._keys.dtype, self._values.dtype)
    with ops.name_scope(self._name, values=[table]) as scope:
      # pylint: disable=protected-access
      init_op = gen_data_flow_ops._initialize_table(table.table_ref,
                                                    self._keys,
                                                    self._values,
                                                    name=scope)
      # pylint: enable=protected-access
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    return init_op
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_evaluate_ready_for_local_init(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      variables_lib.create_global_step()
      v = variables.Variable(1.0)
      w = variables.Variable(
          v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
      ready_for_local_init_op = variables.report_uninitialized_variables(
          variables.global_variables())
      ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
                            ready_for_local_init_op)
      _ = learn.graph_actions.evaluate(
          g,
          output_dir=self._output_dir,
          checkpoint_path=None,
          eval_dict={'a': v},
          max_steps=1)
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_evaluate_with_saver(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      _, _, out = self._build_inference_graph()
      ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
      writer = learn.graph_actions.get_summary_writer(self._output_dir)
      self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
      results = learn.graph_actions.evaluate(
          g,
          output_dir=self._output_dir,
          checkpoint_path=None,
          eval_dict={'a': out},
          max_steps=1)
      self.assertEqual(({'a': 6.0}, 0), results)
      self._assert_summaries(
          self._output_dir,
          writer,
          expected_summaries={0: {
              'a': 6.0
          }},
          expected_session_logs=[])
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def test_train_override_saver(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      self._assert_ckpt(self._output_dir, False)
      real_saver = saver_lib.Saver()
      saver = test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)
      ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
      loss = learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          steps=1)
      self.assertEqual(2.0, loss)
      self._assert_ckpt(self._output_dir, True)
      self.assertTrue(saver.build.called)
      self.assertEqual(1, saver.save.call_count)


# TODO(ispir): remove following tests after deprecated train.
项目:DeepLearning_VirtualReality_BigData_Project    作者:rashmitripathi    | 项目源码 | 文件源码
def testUpdateOpFromCollection(self):
    optimizers = [
        "SGD", gradient_descent.GradientDescentOptimizer,
        gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
    ]
    for optimizer in optimizers:
      with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
        x, var, loss, global_step = _setup_model()
        update_var = variable_scope.get_variable(
            "update", [], initializer=init_ops.constant_initializer(10))
        update_op = state_ops.assign(update_var, 20)
        ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
        train = optimizers_lib.optimize_loss(
            loss, global_step, learning_rate=0.1, optimizer=optimizer)
        variables.global_variables_initializer().run()
        session.run(train, feed_dict={x: 5})
        var_value, update_var_value, global_step_value = session.run(
            [var, update_var, global_step])
        self.assertEqual(var_value, 9.5)
        self.assertEqual(update_var_value, 20)
        self.assertEqual(global_step_value, 1)
项目:Tensormodels    作者:asheshjain399    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:piecewisecrf    作者:Vaan5    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:terngrad    作者:wenwei202    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:FPN    作者:xmyqsh    | 项目源码 | 文件源码
def build_image_summary(self):
        """
        A simple graph for write image summary
        :return:
        """
        log_image_data = tf.placeholder(tf.uint8, [None, None, 3])
        log_image_name = tf.placeholder(tf.string)
        # import tensorflow.python.ops.gen_logging_ops as logging_ops
        from tensorflow.python.ops import gen_logging_ops
        from tensorflow.python.framework import ops as _ops
        log_image = gen_logging_ops._image_summary(log_image_name, tf.expand_dims(log_image_data, 0), max_images=1)
        _ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image)
        # log_image = tf.summary.image(log_image_name, tf.expand_dims(log_image_data, 0), max_outputs=1)
        return log_image, log_image_data, log_image_name
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_model_variable(var):
  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.

  Args:
    var: a variable.
  """
  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def __init__(self):
    # Add self to this graph's Stochsatic Tensor collection for
    # purposes of later performing correct surrogate loss calculation.
    ops.add_to_collection(STOCHASTIC_TENSOR_COLLECTION, self)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
  """Adds a externally defined loss to the collection of losses.

  Args:
    loss: A loss `Tensor`.
    loss_collection: Optional collection to add the loss to.
  """
  if loss_collection:
    ops.add_to_collection(loss_collection, loss)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_evaluate_with_saver(self):
    with tf.Graph().as_default() as g, self.test_session(g):
      _, _, out = self._build_inference_graph()
      tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver())

      self._assert_summaries(self._output_dir, expected_session_logs=[])
      results = learn.graph_actions.evaluate(
          g, output_dir=self._output_dir, checkpoint_path=None,
          eval_dict={'a': out}, max_steps=1)
      self.assertEqual(({'a': 6.0}, 0), results)
      self._assert_summaries(
          self._output_dir, expected_summaries={0: {'a': 6.0}},
          expected_session_logs=[])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_train_worker_monitor(self):
    # We need to explicitly set device due to check on non-chief workers
    # requiring all variables to have a device assigned.
    with tf.Graph().as_default() as g, g.device('/cpu:0'):
      global_step = tf.contrib.framework.create_global_step(g)
      train_op = tf.assign_add(global_step, 1)
      loss_op = tf.constant(2.0)
      tf.scalar_summary('loss', loss_op)
      # Add explicit "local" init op to initialize all variables
      # as there's no chief to init here.
      init_op = variables.initialize_all_variables()
      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
      # Create worker monitors where one should be active on the worker
      # and the other chief exclusive.
      chief_exclusive_monitor = _BaseMonitorWrapper(False)
      all_workers_monitor = _BaseMonitorWrapper(True)
      with self.test_session(g):
        loss = learn.graph_actions.train(
            g, output_dir=self._output_dir,
            global_step_tensor=global_step,
            train_op=train_op, loss_op=loss_op,
            supervisor_is_chief=False, steps=1,
            monitors=[chief_exclusive_monitor, all_workers_monitor])
      self.assertEqual(2.0, loss)
      self.assertTrue(not chief_exclusive_monitor.is_active and
                      all_workers_monitor.is_active,
                      'Only non-chief runnable monitor must have been active.')
      self.assertTrue(not chief_exclusive_monitor.has_step and
                      all_workers_monitor.has_step,
                      'Only non-chief runnable monitor must have a step.')
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_saver():
  """Lazy init and return saver."""
  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
  if saver is not None:
    if saver:
      saver = saver[0]
    else:
      saver = None
  if saver is None and variables.all_variables():
    saver = tf_saver.Saver()
    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
  return saver
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_saver():
  """Lazy init and return saver."""
  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
  if saver is None and variables.all_variables():
    saver = tf_saver.Saver()
    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
  return saver
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_local_init_op():
  local_init_op = _get_first_op_from_collection(
      ops.GraphKeys.LOCAL_INIT_OP)
  if local_init_op is None:
    op_list = [variables.initialize_local_variables(),
               data_flow_ops.initialize_all_tables()]
    if op_list:
      local_init_op = control_flow_ops.group(*op_list)
      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
  return local_init_op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_or_default(arg_name, collection_key, default_constructor):
    """Get from cache or create a default operation."""
    elements = ops.get_collection(collection_key)
    if elements:
      if len(elements) > 1:
        raise RuntimeError('More than one item in the collection "%s". '
                           'Please indicate which one to use by passing it to '
                           'the tf.Scaffold constructor as:  '
                           'tf.Scaffold(%s=item to use)', collection_key,
                           arg_name)
      return elements[0]
    op = default_constructor()
    if op is not None:
      ops.add_to_collection(collection_key, op)
    return op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def apply_regularization(regularizer, weights_list=None):
  """Returns the summed penalty by applying `regularizer` to the `weights_list`.

  Adding a regularization penalty over the layer weights and embedding weights
  can help prevent overfitting the training data. Regularization over layer
  biases is less common/useful, but assuming proper data preprocessing/mean
  subtraction, it usually shouldn't hurt much either.

  Args:
    regularizer: A function that takes a single `Tensor` argument and returns
      a scalar `Tensor` output.
    weights_list: List of weights `Tensors` or `Variables` to apply
      `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
      `None`.

  Returns:
    A scalar representing the overall regularization penalty.

  Raises:
    ValueError: If `regularizer` does not return a scalar output, or if we find
        no weights.
  """
  if not weights_list:
    weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
  if not weights_list:
    raise ValueError('No weights to regularize.')
  with ops.name_scope('get_regularization_penalty',
                      values=weights_list) as scope:
    penalties = [regularizer(w) for w in weights_list]
    for p in penalties:
      if p.get_shape().ndims != 0:
        raise ValueError('regularizer must return a scalar Tensor instead of a '
                         'Tensor with rank %d.' % p.get_shape().ndims)

    summed_penalty = math_ops.add_n(penalties, name=scope)
    ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
    return summed_penalty
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_model_variable(var):
  """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection.

  Args:
    var: a variable.
  """
  if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES):
    ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
  """Adds a externally defined loss to the collection of losses.

  Args:
    loss: A loss `Tensor`.
    loss_collection: Optional collection to add the loss to.
  """
  if loss_collection:
    ops.add_to_collection(loss_collection, loss)
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def initialize(self, table):
    """Initializes the table from a text file.

    Args:
      table: The table to be initialized.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    # pylint: disable=protected-access
    table._check_table_dtypes(self.key_dtype, self.value_dtype)
    with ops.name_scope(self._name, "text_file_init", [table]) as scope:
      filename = ops.convert_to_tensor(self._filename,
                                       dtypes.string,
                                       name="asset_filepath")
      init_op = gen_data_flow_ops._initialize_table_from_text_file(
          table.table_ref,
          filename,
          self._key_index,
          self._value_index,
          -1 if self._vocab_size is None else self._vocab_size,
          self._delimiter,
          name=scope)
    # pylint: enable=protected-access
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
    return init_op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_evaluate_with_saver(self):
    with tf.Graph().as_default() as g, self.test_session(g):
      _, _, out = self._build_inference_graph()
      tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver())
      writer = learn.graph_actions.get_summary_writer(self._output_dir)
      self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
      results = learn.graph_actions.evaluate(
          g, output_dir=self._output_dir, checkpoint_path=None,
          eval_dict={'a': out}, max_steps=1)
      self.assertEqual(({'a': 6.0}, 0), results)
      self._assert_summaries(
          self._output_dir, writer, expected_summaries={0: {'a': 6.0}},
          expected_session_logs=[])
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def test_train_worker_monitor(self):
    # We need to explicitly set device due to check on non-chief workers
    # requiring all variables to have a device assigned.
    with tf.Graph().as_default() as g, g.device('/cpu:0'):
      global_step = tf.contrib.framework.create_global_step(g)
      train_op = tf.assign_add(global_step, 1)
      loss_op = tf.constant(2.0)
      tf.summary.scalar('loss', loss_op)
      # Add explicit "local" init op to initialize all variables
      # as there's no chief to init here.
      init_op = variables.global_variables_initializer()
      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
      # Create worker monitors where one should be active on the worker
      # and the other chief exclusive.
      chief_exclusive_monitor = _BaseMonitorWrapper(False)
      all_workers_monitor = _BaseMonitorWrapper(True)
      with self.test_session(g):
        loss = learn.graph_actions.train(
            g, output_dir=self._output_dir,
            global_step_tensor=global_step,
            train_op=train_op, loss_op=loss_op,
            supervisor_is_chief=False, steps=1,
            monitors=[chief_exclusive_monitor, all_workers_monitor])
      self.assertEqual(2.0, loss)
      self.assertTrue(not chief_exclusive_monitor.is_active and
                      all_workers_monitor.is_active,
                      'Only non-chief runnable monitor must have been active.')
      self.assertTrue(not chief_exclusive_monitor.has_step and
                      all_workers_monitor.has_step,
                      'Only non-chief runnable monitor must have a step.')
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_saver():
  """Lazy init and return saver."""
  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
  if saver is not None:
    if saver:
      saver = saver[0]
    else:
      saver = None
  if saver is None and variables.global_variables():
    saver = tf_saver.Saver(write_version=saver_pb2.SaverDef.V1)
    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
  return saver
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_saver():
  """Lazy init and return saver."""
  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
  if saver is None and variables.global_variables():
    saver = tf_saver.Saver()
    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
  return saver
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def _get_local_init_op():
  local_init_op = _get_first_op_from_collection(
      ops.GraphKeys.LOCAL_INIT_OP)
  if local_init_op is None:
    op_list = [variables.local_variables_initializer(),
               data_flow_ops.initialize_all_tables()]
    if op_list:
      local_init_op = control_flow_ops.group(*op_list)
      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
  return local_init_op
项目:lsdc    作者:febert    | 项目源码 | 文件源码
def apply_regularization(regularizer, weights_list=None):
  """Returns the summed penalty by applying `regularizer` to the `weights_list`.

  Adding a regularization penalty over the layer weights and embedding weights
  can help prevent overfitting the training data. Regularization over layer
  biases is less common/useful, but assuming proper data preprocessing/mean
  subtraction, it usually shouldn't hurt much either.

  Args:
    regularizer: A function that takes a single `Tensor` argument and returns
      a scalar `Tensor` output.
    weights_list: List of weights `Tensors` or `Variables` to apply
      `regularizer` over. Defaults to the `GraphKeys.WEIGHTS` collection if
      `None`.

  Returns:
    A scalar representing the overall regularization penalty.

  Raises:
    ValueError: If `regularizer` does not return a scalar output, or if we find
        no weights.
  """
  if not weights_list:
    weights_list = ops.get_collection(ops.GraphKeys.WEIGHTS)
  if not weights_list:
    raise ValueError('No weights to regularize.')
  with ops.name_scope('get_regularization_penalty',
                      values=weights_list) as scope:
    penalties = [regularizer(w) for w in weights_list]
    penalties = [
        p if p is not None else constant_op.constant(0.0) for p in penalties
    ]
    for p in penalties:
      if p.get_shape().ndims != 0:
        raise ValueError('regularizer must return a scalar Tensor instead of a '
                         'Tensor with rank %d.' % p.get_shape().ndims)

    summed_penalty = math_ops.add_n(penalties, name=scope)
    ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, summed_penalty)
    return summed_penalty
项目:TF_Deformable_Net    作者:Zardinality    | 项目源码 | 文件源码
def build_image_summary(self):
        """
        A simple graph for write image summary
        :return:
        """
        log_image_data = tf.placeholder(tf.uint8, [None, None, 3])
        log_image_name = tf.placeholder(tf.string)
        # import tensorflow.python.ops.gen_logging_ops as logging_ops
        from tensorflow.python.ops import gen_logging_ops
        from tensorflow.python.framework import ops as _ops
        log_image = gen_logging_ops._image_summary(log_image_name, tf.expand_dims(log_image_data, 0), max_images=1)
        _ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image)
        # log_image = tf.summary.image(log_image_name, tf.expand_dims(log_image_data, 0), max_outputs=1)
        return log_image, log_image_data, log_image_name
项目:the-neural-perspective    作者:GokuMohandas    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:InceptionV3_TensorFlow    作者:MasazI    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:darkskies-challenge    作者:LiberiFatali    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack
项目:dcn.tf    作者:beopst    | 项目源码 | 文件源码
def _get_arg_stack():
  stack = ops.get_collection(_ARGSTACK_KEY)
  if stack:
    return stack[0]
  else:
    stack = [{}]
    ops.add_to_collection(_ARGSTACK_KEY, stack)
    return stack