我有3个张量
X形(1,c,h,w),假设(1,20,40,50)
Fx形状(num,w,N),假设(1000,50,10)
Fy形状(num,N,h),假设(1000,10,40)
我想做的是Fy *(X * Fx)(*表示matmul)
X * Fx形状(num,c,h,N),假设(1000,20,40,10)
Fy *(X * Fx)形状(num,c,N,N),假设(1000,20,10,10)
我正在使用tf.tile和tf.expand_dims来完成它
但我认为它使用了大量内存(平铺复制数据对吗?),而且速度慢
试着找到更好的方式,更快,并使用小内存来完成
# X: (1, c, h, w)
# Fx: (num, w, N)
# Fy: (num, N, h)
X = tf.tile(X, [tf.shape(Fx)[0], 1, 1, 1]) # (num, c, h, w)
Fx_ex = tf.expand_dims(Fx, axis=1) # (num, 1, w, N)
Fx_ex = tf.tile(Fx_ex, [1, c, 1, 1]) # (num, c, w, N)
tmp = tf.matmul(X, Fxt_ex) # (num, c, h, N)
Fy_ex = tf.expand_dims(Fy, axis=1) # (num, 1, N, h)
Fy_ex = tf.tile(Fy_ex, [1, c, 1, 1]) # (num, c, N, h)
res = tf.matmul(Fy_ex, tmp) # (num, c, N, N)
最佳答案 我想是一个
mythical einsum
的案例:
>>> import numpy as np
>>> X = np.random.rand(1, 20, 40, 50)
>>> Fx = np.random.rand(100, 50, 10)
>>> Fy = np.random.rand(100, 10, 40)
>>> np.einsum('nMh,uchw,nwN->ncMN', Fy, X, Fx).shape
(100, 20, 10, 10)
它应该在tf和numpy中的工作方式几乎相同(在某些tf版本中不允许使用大写索引,我看到了).虽然如果你以前从未见过这种表示法,这肯定超过了不可读性的正则表达式.