tf.argmax(data, axis=None)
用tensorflow 做 mnist分类时,用到这个接口,于是就研究了下这个接口的用法:
如果是一维数组呢?
data = tf.constant([1,2,3])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
print(sess.run(tf.argmax(data)))
>>> 2
>>> 2
这个很好理解,因为data是一维数组,axis只能为0(如果是1就会报错),结果返回数组中最大值的下标,所以是2
如果是二维数组呢?
data = tf.constant([[1,2,3]])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
print(sess.run(tf.argmax(data, 1)))
>>> [0,0,0]
>>> [2]
是不是有点晕了?
我是这么理解的:
Axis = 0时:
只有data[0] = [1,2,3], 按照对应位置比较,因为只有data[0],1的对应位置为空,所以1是最大值,2的对应位置为空所以2是最大值,3的对应位置为空所以3是最大值。
而argmax函数返回的是最大值的索引,因为1, 2,3 都属于data[0],所以返回值是 [0, 0, 0].
Axis = 1时:
Data[0][0] = 1
Data[0][1] = 2
Data[0][2] = 3
1和2 和 3比较显然是 3最大,3的索引为2,所以返回[2]
再看一个二维数组,可能就明白了:
data = tf.constant([[1,2,3], [4,5,6]])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
print(sess.run(tf.argmax(data, 1)))
>>> [1, 1, 1]
>>> [2, 2]
Axis = 0时:
Data[0] = [1,2,3]
Data[1] = [4,5,6]
对应位置比较:4 > 1, 5>2, 6>3, 所以返回 4,5,6所在的索引位置[1,1,1]
Axis = 1时:
Data[0][0] = 1
Data[0][1] = 2
Data[0][2] = 3
对应位置比较 3最大,3的索引为2
Data[1][0] = 4
Data[1][1] = 5
Data[1][2] = 6
对应位置比较6最大,6的索引为2
所以最后返回[2,2].
同样如果是三维数组:
data = tf.constant([[[1,2,3]],
[[7, 1,9]]])
with tf.Session() as sess:
print(sess.run(tf.argmax(data, 0)))
print(sess.run(tf.argmax(data, 1)))
print(sess.run(tf.argmax(data, 2)))
同样步骤分析:
Axis = 0时:
Data[0] = [[1, 2, 3]]
Data[1] = [[7, 1, 9]]
对应位置比较 7>1, 2 >1, 9> 3, 7属于索引1,2属于索引0,9属于索引1,所以返回[[1, 0,1]].
Axis = 1时:
Data[0][0] = [1,2,3]
1 2 3,对应位置分别为空,所以1,2,3在对应位置都是最大,1,2,3,都属于索引为0,返回[0,0,0]。
Data[1][0] = [7, 1,9]
7 1 9,对应位置分别为空,所以7,1,9在对应位置都是最大,7,1,9,都属于索引为0,返回[0,0,0]
所以最后返回[[0,0,0],[0,0,0]]。
Axis = 2时:
Data[0][0][0] = 1
Data[0][0][1] = 2
Data[0][0][2] = 3
3比较最大,3所在的索引为2,返回 2,
Data[1][0][0] = 7
Data[1][0][1] = 1
Data[1][0][2] = 9
9 最大,9所在的索引为2,返回2
所以最后返回[[2],[2]].
如果是四维或者更高维度,都是按照同样的方法。
Original: https://blog.csdn.net/xufeng930325/article/details/122895016
Author: xufengzxcvbnm
Title: tf.argmax()的详细用法

windows下完全离线安装Anaconda+Tensorflow

ADAS相关名词解释

《三英战吕布》 – 图像模板匹配 【Python-Open_CV系列(八)】

python3 语音合成 pyttsx3 介绍 windows, 树莓派

浅谈 USB Audio(1)—— Feedback端点作用

python离线录音转文字软件_语音转文字工具(音频转文字助手)V2.1 最新版

kaldi中文语音识别(一):multi_cn

机器学习算法一之基于K均值聚类算法实现数据聚类及二维图像像素分割

还没有女朋友的朋友们,你们有福了,学会CycleGAN把男朋友变成女朋友

windows使用conda命令安装tensorflow-gpu,并查看程序是否调用GPU

MobileNet系列(4):MobileNetv3网络详解

Python OpenCv 实现实时人脸识别及面部距离测量

深度学习YOLOv4模型简单部署,flask框架搭建以及web开发

ROS图像的Deeplab v3+实时语义分割(ROS+Pytorch)
