训练网络时,经常要对矩阵进行拼接
、拆分
、减少纬度
、扩充纬度
、改变shape
、转置
、乱序等操作
,这里把常用到的方法总结归纳出来。
tf.concat(values, axis, name="concat")
tf.stack(values, axis=0, name="stack”)
tf.unstack(value, num=None, axis=0, name="unstack”)
tf.tile(input, multiples, name=None)
tf.split(value, num_or_size_splits, axis=0, num=None, name="split”)
tf.slice(input_, begin, size, name=None)
tf.expand_dims(input, axis=None, name=None, dim=None)
tf.squeeze(input, axis=None, name=None, squeeze_dims=None)
tf.reshape(tensor, shape, name=None)
tf.transpose(a, perm=None, name="transpose", conjugate=False)
tf.random_shuffle(value, seed=None, name=None)
- tf.concat(values, axis, name=”concat”)
把values
中的矩阵沿着纬度axis
拼接起来,纬度不变
a = tf.constant([[4, 2, 1], [1, 2, 3]], dtype=tf.float32)
b = tf.constant([[3, 2, 1], [0, 2, 1]], dtype=tf.float32)
c = tf.concat([a, b], axis=0)
c =
[[4. 2. 1.]
[1. 2. 3.]
[3. 2. 1.]
[0. 2. 1.]]
- tf.stack(values, axis=0, name=”stack”)
把values
中的矩阵沿着纬度axis
拼接起来,纬度+1
,相比而言,tf.concat
在网络中用的更多
a = tf.constant([[4, 2, 1], [1, 2, 3]], dtype=tf.float32)
b = tf.constant([[3, 2, 1], [0, 2, 1]], dtype=tf.float32)
c = tf.stack([a, b], axis=0)
c =
[[[4. 2. 1.]
[1. 2. 3.]]
[[3. 2. 1.]
[0. 2. 1.]]]
- tf.unstack(value, num=None, axis=0, name=”unstack”)
把value
中的矩阵沿着纬度axis
拆分出来,纬度-1
a = tf.constant([[4, 2, 1], [1, 2, 3]], dtype=tf.float32)
b = tf.constant([[3, 2, 1], [0, 2, 1]], dtype=tf.float32)
c = tf.stack([a, b], axis=0)
d1, d2 = tf.unstack(c, axis=0)
d1 =
[[4. 2. 1.]
[1. 2. 3.]]
d2 =
[[3. 2. 1.]
[0. 2. 1.]]
- tf.tile(input, multiples, name=None)
复制矩阵input
,multiples中指定每个纬度上复制次数
a = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
b = tf.tile(a, [2, 2])
b =
[[1. 2. 1. 2.]
[3. 4. 3. 4.]
[1. 2. 1. 2.]
[3. 4. 3. 4.]]
- tf.split(value, num_or_size_splits, axis=0, num=None, name=”split”)
把矩阵value
,沿着axis
纬度拆分成num_or_size_splits
个小矩阵
a = tf.constant([[1, 2], [3, 4], [3, 5]], dtype=tf.float32)
b1, b2, b3 = tf.split(a, num_or_size_splits=3, axis=0)
b1 =
[[1. 2.]]
b2 =
[[3. 4.]]
b3 =
[[3. 5.]]
- tf.slice(input_, begin, size, name=None)
把矩阵input_
,沿着每个纬度指定开始位置begin
截取size
大小的内容,这里-1
代表到当前纬度结尾处,和python中list分块[begin:end]
是类似的
a = tf.constant([[1, 2], [3, 4], [3, 5]], dtype=tf.float32)
b = tf.slice(a, [1, 0], [2, -1])
b =
[[3. 4.]
[3. 5.]]
- tf.expand_dims(input, axis=None, name=None, dim=None)
把矩阵input
的纬度增加1,axis
为曾加的纬度位置,例如当前a的shape为[2]
,最后一个位置增加一个纬度shape变成[2,1]
a = tf.constant([1, 2], dtype=tf.float32)
b = tf.expand_dims(a, -1)
b =
[[1.]
[2.]]
- squeeze(input, axis=None, name=None, squeeze_dims=None)
把矩阵input
中size大小为1
的纬度减去,axis
为减少的纬度位置,例如当前a的shape为[2,1]
,最后一个位置减去一个纬度shape变成[2]
a = tf.constant([[1], [2]], dtype=tf.float32)
b = tf.squeeze(a, 1)
b =
[1. 2.]
- tf.reshape(tensor, shape, name=None)
把矩阵tensor
改变成指定shape
的样子,需要注意的是,改变shape之前和之后的元素个数总和应该一样
a = tf.constant([[1, 2, 3, 4]], dtype=tf.float32)
b = tf.reshape(a, [2, 2])
b =
[[1. 2.]
[3. 4.]]
- tf.transpose(a, perm=None, name=”transpose”, conjugate=False)
调整矩阵a
纬度的顺序,按照perm
中指定顺序排列,列如把a的行列交换顺序,也就是经常所见的二维矩阵转置过程
a = tf.constant([[1, 2], [3, 4], [3, 5]], dtype=tf.float32)
b = tf.transpose(a, perm=[1, 0])
b =
[[1. 3. 3.]
[2. 4. 5.]]
- tf.random_shuffle(value, seed=None, name=None)
把矩阵value
中的元素顺序打乱,需要注意的是被打乱顺序的纬度索引为0
a = tf.constant([[1, 2], [3, 4], [3, 5]], dtype=tf.float32)
b = tf.random_shuffle(a)
b =
[[3. 4.]
[1. 2.]
[3. 5.]]