在sklearn,我习惯拥有一个可以运行的模型然后预测.但是,使用TensorFlow时,我在调用预测时无法从拟合中加载学习参数.归结为我不知道如何在会话之间重用变量的值.例如,
import tensorflow as tf
x = tf.Variable(0.0)
# fit code
with tf.Session() as sess1:
sess1.run(tf.global_variables_initializer())
sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0
# predict code
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
print(sess2.run(x)) # want this to be 1.0, but is 0.0
我可以想到一个解决方法,但它似乎真的很hacky,如果我想重用几个变量会很烦人:
import tensorflow as tf
x = tf.Variable(0.0)
# fit code
with tf.Session() as sess1:
sess1.run(tf.global_variables_initializer())
sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0
learned_x = sess1.run(x) # remember value of learned x at end of session
# predict code
with tf.Session() as sess2:
sess2.run(tf.global_variables_initializer())
sess2.run(tf.assign(x, learned_x))
print(sess2.run(x)) # prints 1.0
如何在不写入磁盘的情况下在会话之间重用变量(即使用tf.train.Saver)?我上面写的解决方法是正确的方法吗?
最佳答案 要模仿sklearn的模型,只需将会话包装到一个类中,以便您可以在方法之间共享它.
class Model:
def __init__(self):
self.graph = self.build_graph()
self.session = tf.Session()
self.session.run(tf.global_variables_initializer())
def build_graph(self):
return {'x': tf.Variable(0.0)}
def fit(self):
self.session.run(tf.assign(self.graph['x'], 1.0))
def predict(self):
print(self.session.run(self.graph['x']))
def close(self):
tf.reset_default_graph()
self.session.close()
m = Model()
m.fit()
m.predict()
m.close()
确保手动关闭会话并相应地处理异常.