Tensorflow新手在这里.我正在尝试建立一个RNN.我的输入数据是一组大小为instance_size的矢量实例,表示每个时间步的一组粒子的(x,y)位置. (由于实例已经具有语义内容,因此它们不需要嵌入.)目标是学习在下一步预测粒子的位置.
在RNN tutorial之后并略微调整包含的RNN代码,我创建了一个或多或少像这样的模型(省略一些细节):
inputs, self._input_data = tf.placeholder(tf.float32, [batch_size, num_steps, instance_size])
self._targets = tf.placeholder(tf.float32, [batch_size, num_steps, instance_size])
with tf.variable_scope("lstm_cell", reuse=True):
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size, forget_bias=0.0)
if is_training and config.keep_prob < 1:
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
lstm_cell, output_keep_prob=config.keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
self._initial_state = cell.zero_state(batch_size, tf.float32)
from tensorflow.models.rnn import rnn
inputs = [tf.squeeze(input_, [1])
for input_ in tf.split(1, num_steps, inputs)]
outputs, state = rnn.rnn(cell, inputs, initial_state=self._initial_state)
output = tf.reshape(tf.concat(1, outputs), [-1, hidden_size])
softmax_w = tf.get_variable("softmax_w", [hidden_size, instance_size])
softmax_b = tf.get_variable("softmax_b", [instance_size])
logits = tf.matmul(output, softmax_w) + softmax_b
loss = position_squared_error_loss(
tf.reshape(logits, [-1]),
tf.reshape(self._targets, [-1]),
)
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = state
然后我创建一个saver = tf.train.Saver(),迭代数据以使用给定的run_epoch()方法训练它,并用saver.save()写出参数.到现在为止还挺好.
但是,我如何实际使用经过训练的模型?教程此时停止.从the docs on tf.train.Saver.restore()
开始,为了回读变量,我需要设置与我保存变量时运行的完全相同的图,或选择性地恢复特定变量.无论哪种方式,这意味着我的新模型将需要输入大小batch_size x num_steps x instance_size.但是,我现在想要的是在大小为num_steps x instance_size的输入上对模型进行单个正向传递,并读出单个instance_size大小的结果(下一个时间步的预测);换句话说,我想创建一个接受不同尺寸张量的模型,而不是我训练过的张量.我可以通过将现有模型传递给我的预期数据batch_size次来对其进行处理,但这似乎不是最佳做法.最好的方法是什么?
最佳答案 您必须创建一个具有相同结构但使用batch_size = 1的新图形,并使用tf.train.Saver.restore()导入已保存的变量.您可以在ptb_word_lm.py:
https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/rnn/ptb/ptb_word_lm.py中查看它们如何定义具有可变批量大小的多个模型
因此,您可以使用单独的文件,在其中使用所需的batch_size实例化图形,然后还原已保存的变量.然后你可以执行你的图表.