import tensorflow as tf
import numpy as np
X = np.random.random([1,3,3,4])
X
array([[[[ 0.82959287, 0.97123702, 0.28140139, 0.27116128],
[ 0.17657325, 0.95732474, 0.69869441, 0.68558369],
[ 0.27456733, 0.75242884, 0.00578983, 0.36427501]],
[[ 0.55055599, 0.27293508, 0.58177528, 0.60010759],
[ 0.49096017, 0.03448037, 0.77094952, 0.72902519],
[ 0.72496438, 0.57176329, 0.9313365 , 0.81825572]],
[[ 0.35645042, 0.79323193, 0.08155452, 0.75811829],
[ 0.24662546, 0.20411053, 0.19005582, 0.72657277],
[ 0.84135906, 0.77598372, 0.26645642, 0.69704092]]]])
splits = tf.split(axis=3, num_or_size_splits=2, value=X)
splits
[<tf.Tensor 'split_8:0' shape=(1, 3, 3, 2) dtype=float64>,
<tf.Tensor 'split_8:1' shape=(1, 3, 3, 2) dtype=float64>]
sess = tf.Session()
splits_res = sess.run(splits)
for i in splits_res:
print(i)
print()
print()
[[[[ 0.82959287 0.97123702]
[ 0.17657325 0.95732474]
[ 0.27456733 0.75242884]]
[[ 0.55055599 0.27293508]
[ 0.49096017 0.03448037]
[ 0.72496438 0.57176329]]
[[ 0.35645042 0.79323193]
[ 0.24662546 0.20411053]
[ 0.84135906 0.77598372]]]]
[[[[ 0.28140139 0.27116128]
[ 0.69869441 0.68558369]
[ 0.00578983 0.36427501]]
[[ 0.58177528 0.60010759]
[ 0.77094952 0.72902519]
[ 0.9313365 0.81825572]]
[[ 0.08155452 0.75811829]
[ 0.19005582 0.72657277]
[ 0.26645642 0.69704092]]]]
splits_concat = tf.concat(axis=3, values=splits_res)
splits_concat_res = sess.run(splits_concat)
splits_concat_res
array([[[[ 0.82959287, 0.97123702, 0.28140139, 0.27116128],
[ 0.17657325, 0.95732474, 0.69869441, 0.68558369],
[ 0.27456733, 0.75242884, 0.00578983, 0.36427501]],
[[ 0.55055599, 0.27293508, 0.58177528, 0.60010759],
[ 0.49096017, 0.03448037, 0.77094952, 0.72902519],
[ 0.72496438, 0.57176329, 0.9313365 , 0.81825572]],
[[ 0.35645042, 0.79323193, 0.08155452, 0.75811829],
[ 0.24662546, 0.20411053, 0.19005582, 0.72657277],
[ 0.84135906, 0.77598372, 0.26645642, 0.69704092]]]])
splits_concat_res.shape
(1, 3, 3, 4)