参见英文答案 >
Tensorflow.strided_slice missing argument ‘strides’? 2个
我已经从教程
here中的链接下载了CIFAR10代码,并且我正在尝试运行该教程.我用命令运行它
python cifar10_train.py
它启动正常并按预期下载数据文件.当它尝试打开输入文件时,它失败并带有以下跟踪:
Traceback (most recent call last):
File "cifar10_train.py", line 120, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 43, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "cifar10_train.py", line 116, in main
train()
File "cifar10_train.py", line 63, in train
images, labels = cifar10.distorted_inputs()
File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10.py", line 157, in distorted_inputs
batch_size=FLAGS.batch_size)
File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 161, in distorted_inputs
read_input = read_cifar10(filename_queue)
File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 87, in read_cifar10
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
TypeError: strided_slice() takes at least 4 arguments (3 given)
果然,当我调查代码时,在cifar10_input.py中调用strided_slice()只有3个参数:
tf.strided_slice(record_bytes, [0], [label_bytes])
而tensorflow文档确实表明必须至少有4个参数.
出了什么问题?我已经下载了最新的张量流(0.12),我正在运行cifar代码的主分支.
最佳答案 在
github的一些讨论后,我已经进行了以下更改,似乎使它工作:
在cifar10_input.py中
- result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
+ result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
- depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]), [result.depth, result.height, result.width])
+ depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width])
然后在cifar10_input.py和cifar10.py中我不得不搜索“不赞成”,无论我在哪里找到它,都要根据我在api指南中读到的内容替换它(希望正确).这方面的例子:
- tf.contrib.deprecated.image_summary('images', images)
+ tf.summary.image('images', images)
和
- tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
- tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
+ tf.summary.histogram(tensor_name + '/activations', x)
+ tf.summary.scalar(tensor_name + '/sparsity',
现在好像很开心.我会看看它是否完成正常,如果我在上面输入的更改给出了所需的诊断输出.
我仍然希望听到更接近代码的人的确切答案.