我有一个嵌入矩阵e定义如下
e = tf.get_variable(name="embedding", shape=[n_e, d],
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
其中n_e表示实体数,d表示潜在维数.对于这个例子,假设d = 10.
训练:
optimizer = tf.train.GradientDescentOptimizer(0.01)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
训练后保存模型.
稍后,添加新实体(例如,2),产生n_e_new.现在我想重新训练模型,但是保留已经训练的实体的嵌入,即仅重新训练delta(2个新实体).
我加载了保存的e和
init_e = np.zeros((n_e_new, d), dtype=np.float32)
r = list(range(n_e_new - 2))
init_e[r, :] = # load e from saved model
e = tf.get_variable(name="embedding", initializer=init_e)
gather_e = tf.nn.embedding_lookup(e, [n_e, n_e+1])
训练:
optimizer = tf.train.GradientDescentOptimizer(0.01)
grads_and_vars = optimizer.compute_gradients(loss, gather_e)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
我在compute_gradients上遇到错误:
NotImplementedError :(‘试图优化不支持的类型’,)
我知道第二个参数gather_e到compute_gradients不是变量,但无法弄清楚如何实现这种部分训练/更新.
P.S – 我也看过this post,但似乎也找不到解决方案.
编辑:
代码示例(根据@meruf建议的方法):
if new_data_available:
e = tf.get_variable(name="embedding", shape=[n_e_new, 1, d],
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
e_old = tf.get_variable(name="embedding_old", initializer=<load e from saved model>, trainable=False)
e_new = tf.concat([e_old, e], 0)
else:
e = tf.get_variable(name="embedding", shape=[n_e, d],
initializer=tf.contrib.layers.xavier_initializer(uniform=False))
查询如下:
if new_data_available:
var_p = tf.nn.embedding_lookup(e_new, indices)
else:
var_p = tf.nn.embedding_lookup(e, indices)
loss = #some operations on var_p and other variabes that are a result of the lookup above
问题是当new_data_available为true时,e和e_new都不会在每个纪元期间发生变化.他们保持不变.
最佳答案 您不应该在优化程序级别更改代码.你可以很容易地告诉tensorflow哪个变量是可训练的.
我们来看看tf.getVariable()defination,
tf.get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None
)
这里可训练的参数表示如果参数是可训练的.如果您不想训练参数,则将其设为false.
为你的情况下做2组变量.一个是trainable = True而另一个是trainable = false.
假设您有100个预训练变量和10个新变量来训练.现在将预先训练的变量加载到A,将新变量加载到B.
注意:
有关实现的详细信息,您应该查看tf.cond
函数以了解运行时决策.主要用于查找.因为现在你的新B嵌入的索引从0开始.但是你可能已经从数据集或程序中预先训练的嵌入1的#中分配了它们.因此,在tensorflow中,您可以采取运行时决策
伪代码
if index_number is >= number of pretrained embedding
index_number = index_number - number of pretrained embedding
look_up on B matrix
else
look_up on A matrix
An Ipython Notebook of the example. (slightly different than the example given here.)
更新:
我们来看看我的意思,
首先加载库
import tensorflow as tf
声明占位符
y_ = tf.placeholder(tf.float32, [None, 2])
x = tf.placeholder(tf.int32, [None])
z = tf.placeholder(tf.bool, []) # is the example in the x contains new data or not
创建网络
e = tf.get_variable(name="embedding", shape=[5,10],initializer=tf.contrib.layers.xavier_initializer(uniform=False))
e_old = tf.get_variable(name="embedding1", shape=[5,10],initializer=tf.contrib.layers.xavier_initializer(uniform=False),trainable=False)
out = tf.cond(z,lambda : e, lambda : e_old)
lookup = tf.nn.embedding_lookup(out,x)
W = tf.get_variable(name="weight", shape=[10,2],initializer=tf.contrib.layers.xavier_initializer(uniform=False))
l = tf.nn.relu(tf.matmul(lookup,W))
y = tf.nn.softmax(l)
计算损失
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
优化损失
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
加载并运行图表
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
打印初始值
我们正在打印这些值,以便我们稍后可以检查我们的值是否发生变化.
e_out_tf,e_out_old_tf = sess.run([e,e_old])
print("New Data ", e_out_tf)
print("Old Data", e_out_old_tf)
('New Data ', array([[-0.38952214, -0.37217963, 0.11370762, -0.13024905, 0.11420489,
-0.09138191, 0.13781562, -0.1624797 , -0.27410012, -0.5404499 ],
[-0.0065698 , 0.04728106, 0.53637034, -0.13864517, -0.36171854,
0.40325132, 0.7172644 , -0.28067762, -0.0258827 , -0.5615116 ],
[-0.17240004, 0.3765518 , 0.4658525 , 0.16545495, -0.37515178,
-0.39557686, -0.50662124, -0.06570222, -0.3605038 , 0.13746035],
[ 0.19647208, -0.16588202, 0.5739292 , 0.43803877, -0.05350745,
0.71350956, 0.39937392, -0.45939735, 0.09050641, -0.18077391],
[-0.05588558, 0.7295865 , 0.42288807, 0.57227516, 0.7268311 ,
-0.1194113 , 0.28589466, 0.09422033, -0.10094754, 0.3942643 ]],
dtype=float32))
('Old Data', array([[ 0.5308224 , -0.14003026, -0.7685277 , 0.06644323, -0.02585996,
-0.1713268 , 0.04987739, 0.01220775, 0.33571896, 0.19891626],
[ 0.3288728 , -0.09298109, 0.14795913, 0.21343362, 0.14123142,
-0.19770677, 0.7366793 , 0.38711038, 0.37526497, 0.440099 ],
[-0.29200613, 0.4852043 , 0.55407804, -0.13675605, -0.2815263 ,
-0.00703347, 0.31396288, -0.7152872 , 0.0844975 , 0.4210107 ],
[ 0.5046112 , 0.3085646 , 0.19497707, -0.5193338 , -0.0429871 ,
-0.5231836 , -0.38976955, -0.2300536 , -0.00906788, -0.1689194 ],
[-0.1231837 , 0.54029703, 0.45702592, -0.07886257, -0.6420077 ,
-0.24090563, -0.02165782, -0.44103763, -0.20914222, 0.40911582]],
dtype=float32))
测试用例
现在我们将测试我们的理论是否
1.不可训练的变量变化与否
2.可训练变量与否.
我们声明了一个额外的占位符z来指示我们的输入是否包含新数据或旧数据.
这里,索引0包含如果z为True则可训练的新数据.
feed_dict={x: [0],z:True}
lookup_tf = sess.run([lookup], feed_dict=feed_dict)
检查该值是否与上述值匹配.
print(lookup_tf)
[array([[-0.38952214, -0.37217963, 0.11370762, -0.13024905, 0.11420489,
-0.09138191, 0.13781562, -0.1624797 , -0.27410012, -0.5404499 ]],
dtype=float32)]
我们将发送z = True以指示您要查找的嵌入.
因此,在发送批处理时,请确保批处理仅包含旧数据或新数据.
feed_dict={x: [0], y_: [[0,1]], z:True}
_, = sess.run([train_step], feed_dict=feed_dict)
lookup_tf = sess.run([lookup], feed_dict=feed_dict)
训练结束后,让我们检查它是否正常.
print(lookup_tf)
[array([[-0.559212 , -0.362611 , 0.06011545, -0.02056453, 0.26133284,
-0.24933788, 0.18598196, -0.00602196, -0.12775017, -0.6666256 ]],
dtype=float32)]
请参阅索引0包含可训练的新数据,并且由于SGD更新而从之前的值更改.
让我们尝试相反的事情
feed_dict={x: [0], y_: [[0,1]], z:False}
lookup_tf = sess.run([lookup], feed_dict=feed_dict)
print(lookup_tf)
_, = sess.run([train_step], feed_dict=feed_dict)
lookup_tf = sess.run([lookup], feed_dict=feed_dict)
print(lookup_tf)
[array([[ 0.5308224 , -0.14003026, -0.7685277 , 0.06644323, -0.02585996,
-0.1713268 , 0.04987739, 0.01220775, 0.33571896, 0.19891626]],
dtype=float32)]
[array([[ 0.5308224 , -0.14003026, -0.7685277 , 0.06644323, -0.02585996,
-0.1713268 , 0.04987739, 0.01220775, 0.33571896, 0.19891626]],
dtype=float32)]