Scan
复发(Recurrence)的一种常用形式,可以用于循环(looping)
Reduction和map是scan的特例
可以根据一些输出序列scan一个函数(function),每一步都会生成一个输出
可以查看之前k步的输出
给定一个初始状态z=0,可以通过scan函数z + x(i)计算一个列表的和sum(a_list)
通常一个for循环可以用scan()操作符进行实现
使用scan的优点:
迭代次数为符号图的一部分
最大限度地减少GPU传输(如果用到了GPU)
通过序列步长计算梯度
运行速率比python内置的for循环稍微快些
可以通过检测需要的实际内存量,来降低整体内存使用量
例子:对应元素计算tanh(x(t).dot(W) + b)
import theano
import theano.tensor as T
import numpy as np
# 定义张量变量
X = T.matrix('X')
W = T.matrix('W')
b_sym = T.vector('b_sym')
results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
compute_elementwise = theano.function([X, W, b_sym], results)
# 测试
x = np.eye(2, dtype=theano.config.floatX)
w = np.ones((2, 2), dtype=theano.config.floatX)
b = np.ones((2), dtype=theano.config.floatX)
b[1] = 2
compute_elementwise(x, w, b)
# 和numpy相比较
np.tanh(x.dot(w) + b)
例子: 计算序列x(t) = tanh(x(t-1).dot(W) + y(t).dot(U) + p(T-t).dot(V))
import theano
import theano.tensor as T
import numpy as np
# 定义张量变量
X = T.vector('X')
W = T.matrix('W')
b_sym = T.vector('b_sym')
U, Y, V, P = T.matrices('U', 'Y', 'V', 'P')
result, update = theano.scan(lambda y, p, x_tml: T.tanh(T.dot(x_tml, W) + T.dot(y, U) + T.dot(p, V)),
sequences=[Y, P[::-1]], outputs_info=[X])
compute_seq = theano.function(inputs=[X, W, Y, U, P, V], outputs=result)
# 测试
x = np.zeros((2), dtype=theano.config.floatX)
x[1] = 1
w = np.ones((2, 2), dtype=theano.config.floatX)
y = np.ones((5, 2), dtype=theano.config.floatX)
y[0, :] = -3
u = np.ones((2, 2), dtype=theano.config.floatX)
p = np.ones((5, 2), dtype=theano.config.floatX)
p[0, :] = 3
v = np.ones((2, 2), dtype=theano.config.floatX)
print(compute_seq(x, w, y, u, p, v))
# 与Numpy对比
x_res = np.zeros((5, 2), dtype=theano.config.floatX)
x_res[0] = np.tanh(x.dot(w) + y[0].dot(u) + p[4].dot(v))
for i in range(1, 5):
x_res[i] = np.tanh(x_res[i - 1].dot(w) + y[i].dot(u) + p[4-i].dot(v))
print(x_res)
例子: 计算X的行范式
import theano
import theano.tensor as T
import numpy as np
# 定义张量变量
X = T.matrix('X')
results, updates = theano.scan(lambda x_i: T.sqrt((x_i ** 2)).sum(), sequences=[X])
compute_norm_lines = theano.function(inputs=[X], outputs=results)
# 测试
x = np.diag(np.arange(1, 6, dtype=theano.config.floatX), 1)
print(compute_norm_lines(x))
# 和Numpy对比
print(np.sqrt((x ** 2).sum(1)))
例子: 计算X的列范式
import theano
import theano.tensor as T
import numpy as np
# 定义张量变量
X = T.matrix("X")
results, updates = theano.scan(lambda x_i: T.sqrt((x_i ** 2).sum()), sequences=[X.T])
compute_norm_cols = theano.function(inputs=[X], outputs=results)
# 测试
x = np.diag(np.arange(1, 6, dtype=theano.config.floatX), 1)
print(compute_norm_cols(x))
# 和Numpy对比
print(np.sqrt((x ** 2).sum(0)))