Theano - 循环

Scan

  • 复发(Recurrence)的一种常用形式,可以用于循环(looping)

  • Reduction和map是scan的特例

  • 可以根据一些输出序列scan一个函数(function),每一步都会生成一个输出

  • 可以查看之前k步的输出

  • 给定一个初始状态z=0,可以通过scan函数z + x(i)计算一个列表的和sum(a_list)

  • 通常一个for循环可以用scan()操作符进行实现

使用scan的优点:

  • 迭代次数为符号图的一部分

  • 最大限度地减少GPU传输(如果用到了GPU)

  • 通过序列步长计算梯度

  • 运行速率比python内置的for循环稍微快些

  • 可以通过检测需要的实际内存量,来降低整体内存使用量

例子:对应元素计算tanh(x(t).dot(W) + b)

import theano
import theano.tensor as T
import numpy as np

# 定义张量变量
X = T.matrix('X')
W = T.matrix('W')
b_sym = T.vector('b_sym')

results, updates = theano.scan(lambda v: T.tanh(T.dot(v, W) + b_sym), sequences=X)
compute_elementwise = theano.function([X, W, b_sym], results)

# 测试
x = np.eye(2, dtype=theano.config.floatX)
w = np.ones((2, 2), dtype=theano.config.floatX)
b = np.ones((2), dtype=theano.config.floatX)
b[1] = 2

compute_elementwise(x, w, b)
# 和numpy相比较
np.tanh(x.dot(w) + b)

例子: 计算序列x(t) = tanh(x(t-1).dot(W) + y(t).dot(U) + p(T-t).dot(V))

import theano
import theano.tensor as T
import numpy as np
# 定义张量变量
X = T.vector('X')
W = T.matrix('W')
b_sym = T.vector('b_sym')
U, Y, V, P = T.matrices('U', 'Y', 'V', 'P')

result, update = theano.scan(lambda y, p, x_tml: T.tanh(T.dot(x_tml, W) + T.dot(y, U) + T.dot(p, V)),
                             sequences=[Y, P[::-1]], outputs_info=[X])

compute_seq = theano.function(inputs=[X, W, Y, U, P, V], outputs=result)

# 测试
x = np.zeros((2), dtype=theano.config.floatX)
x[1] = 1
w = np.ones((2, 2), dtype=theano.config.floatX)
y = np.ones((5, 2), dtype=theano.config.floatX)
y[0, :] = -3
u = np.ones((2, 2), dtype=theano.config.floatX)
p = np.ones((5, 2), dtype=theano.config.floatX)
p[0, :] = 3
v = np.ones((2, 2), dtype=theano.config.floatX)

print(compute_seq(x, w, y, u, p, v))

# 与Numpy对比
x_res = np.zeros((5, 2), dtype=theano.config.floatX)
x_res[0] = np.tanh(x.dot(w) + y[0].dot(u) + p[4].dot(v))
for i in range(1, 5):
    x_res[i] = np.tanh(x_res[i - 1].dot(w) + y[i].dot(u) + p[4-i].dot(v))
print(x_res)

例子: 计算X的行范式

import theano
import theano.tensor as T
import numpy as np

# 定义张量变量
X = T.matrix('X')
results, updates = theano.scan(lambda x_i: T.sqrt((x_i ** 2)).sum(), sequences=[X])
compute_norm_lines = theano.function(inputs=[X], outputs=results)

# 测试
x = np.diag(np.arange(1, 6, dtype=theano.config.floatX), 1)
print(compute_norm_lines(x))

# 和Numpy对比
print(np.sqrt((x ** 2).sum(1)))

例子: 计算X的列范式

import theano
import theano.tensor as T
import numpy as np

# 定义张量变量
X = T.matrix("X")
results, updates = theano.scan(lambda x_i: T.sqrt((x_i ** 2).sum()), sequences=[X.T])
compute_norm_cols = theano.function(inputs=[X], outputs=results)

# 测试
x = np.diag(np.arange(1, 6, dtype=theano.config.floatX), 1)
print(compute_norm_cols(x))

# 和Numpy对比
print(np.sqrt((x ** 2).sum(0)))
    原文作者:xiao蜗牛
    原文地址: https://segmentfault.com/a/1190000009965404
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞