一尘不染

Tensorflow:如何替换计算图中的节点?

python

如果您有两个不相交的图,并且想要链接它们,请执行以下操作:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

到这个:

x = tf.placeholder('float')
y = f(x)
z = g(y)

有没有办法做到这一点?在某些情况下,这似乎可以使施工更容易。

例如,如果您有一个图,其输入图像为tf.placeholder,并且想要优化输入图像(深梦风格),是否有一种方法可以仅用tf.variable节点替换占位符?还是在构建图形之前必须考虑一下?


阅读 244

收藏
2020-12-20

共1个答案

一尘不染

TL;
DR:如果可以将这两个计算定义为Python函数,则应该这样做。如果不能,那么TensorFlow中有更多高级功能可用于序列化和导入图形,这使您可以从不同来源组成图形。

在TensorFlow中执行此操作的一种方法是将不相交的计算构建为单独的tf.Graph对象,然后使用以下命令将它们转换为序列化的协议缓冲区Graph.as_graph_def()

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

然后,你可以撰写gdef_1gdef_2成第三曲线,使用tf.import_graph_def()

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]
2020-12-20