求这组数据的最大值?
[[0.06251886 0.2645436 0.04882399 0.09480914 0.04890436 0.15327263
0.0369646 0.22686356 0.0089916 0.05430767]]
这时候就是用tf.argmax的最好时候,测试代码
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
sess = tf.Session()
a = tf.constant([1.,2.,3.,0.,9.,])
b = tf.constant([[1,2,3],
[3,2,1],
[4,5,6],
[6,5,4]])
col_max0 = sess.run(tf.argmax(a, 0))
print (col_max0)
# 4
col_max = sess.run(tf.argmax(b, 0) ) #当axis=0时返回每一列的最大值的位置索引
print (col_max)
# [3 2 2]
row_max = sess.run(tf.argmax(b, 1) ) #当axis=1时返回每一行中的最大值的位置索引
print (row_max)
# [2 0 2 0]