1、当indices=[0,2],axis=0
input =[ [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]],
[[[7, 7, 7], [8, 8, 8]],
[[9, 9, 9], [10, 10, 10]],
[[11, 11, 11], [12, 12, 12]]],
[[[13, 13, 13], [14, 14, 14]],
[[15, 15, 15], [16, 16, 16]],
[[17, 17, 17], [18, 18, 18]]]
]
print(tf.shape(input))
with tf.Session() as sess:
output=tf.gather(input, [0,2],axis=0)#其实默认axis=0
print(sess.run(output))
输出结果
[[[[ 1 1 1]
[ 2 2 2]]
[[ 3 3 3]
[ 4 4 4]]
[[ 5 5 5]
[ 6 6 6]]]
[[[13 13 13]
[14 14 14]]
[[15 15 15]
[16 16 16]]
[[17 17 17]
[18 18 18]]]]
解释:
右中括号就暂时不理会他先了。
第一个[ 是列表语法需要的括号,剩下的最里面的三个[[[是axis=0需要搜寻的中括号。这里一共有3个[[[。
indices的[0,2]即取第0个[[[和第2个[[[,也就是第0个和第2个三维立体。
2、当indices=[0,2],axis=1
input =[ [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]],
[[[7, 7, 7], [8, 8, 8]],
[[9, 9, 9], [10, 10, 10]],
[[11, 11, 11], [12, 12, 12]]],
[[[13, 13, 13], [14, 14, 14]],
[[15, 15, 15], [16, 16, 16]],
[[17, 17, 17], [18, 18, 18]]]
]
print(tf.shape(input))
with tf.Session() as sess:
output=tf.gather(input, [0,2],axis=1)#默认axis=0
print(sess.run(output))
输出结果
[[[[ 1 1 1]
[ 2 2 2]]
[[ 5 5 5]
[ 6 6 6]]]
[[[ 7 7 7]
[ 8 8 8]]
[[11 11 11]
[12 12 12]]]
[[[13 13 13]
[14 14 14]]
[[17 17 17]
[18 18 18]]]]
解释:
第一个[ 是列表语法需要的括号,先把这个干扰去掉,剩下的所有内侧的 [[ 是axis=1搜索的中括号。
然后[0,2]即再取每个[[[体内的第0个[[和第2个[[,也就是去每个三维体的第0个面和第2个面。