小能豆

如何在图构建时获取张量(在 TensorFlow 中)的维度?

python

我正在尝试一个行为不符合预期的操作。

graph = tf.Graph()
with graph.as_default():
  train_dataset = tf.placeholder(tf.int32, shape=[128, 2])
  embeddings = tf.Variable(
    tf.random_uniform([50000, 64], -1.0, 1.0))
  embed = tf.nn.embedding_lookup(embeddings, train_dataset)
  embed = tf.reduce_sum(embed, reduction_indices=0)

所以我需要知道 Tensor 的尺寸embed。我知道它可以在运行时完成,但对于这样一个简单的操作来说工作量太大了。有什么更简单的方法可以做到这一点?


阅读 72

收藏
2024-05-17

共1个答案

小能豆

在TensorFlow中,您可以使用get_shape()方法获取张量的形状信息,这样可以避免在运行时执行操作。在您的代码中,您可以像这样使用get_shape()方法获取张量embed的形状:

print(embed.get_shape())

这将输出embed张量的形状信息,而不需要在运行时执行操作。

2024-05-17