假设我有一个DATA形状为 的张量(M, N, 2)。我还有另一个IND由零和一组成的形状为 (N) 的张量。
DATA
(M, N, 2)
IND
如果IND(i)==1是,DATA(:,i,0)那么DATA(:,i,1)就必须交换。如果IND(i)==0他们不愿意交换。
IND(i)==1
DATA(:,i,0)
DATA(:,i,1)
IND(i)==0
我该怎么做?我知道可以通过 来完成tf.gather_nd,但我不知道怎么做。
tf.gather_nd
以下是使用tf.equal、tf.where、和 的一个可能解决方案tf.scater_nd_update:tf.gather_ndtf.reverse_v2
tf.equal
tf.where
tf.scater_nd_update
tf.reverse_v2
data = tf.Variable([[[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]]) # shape=(1,5,2) # reverse elements where ind is 1 ind = tf.constant([1, 0, 1, 0, 1]) # shape(5,) cond = tf.where(tf.equal([ind], 1)) match_data = tf.gather_nd(data, cond) rev_match_data = tf.reverse_v2(match_data, axis=[-1]) data = tf.scatter_nd_update(data, cond, rev_match_data) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(data)) #[[[2 1] # [2 3] # [4 3] # [4 5] # [6 5]]]