我希望能够获得提交特定类型错误的实例的索引(或数组掩码),而不仅仅是获取混淆矩阵.因此,例如,我想看到当它属于0级等时预测为2级的实例.
我可以使用列表推导轻松获取数组掩码:
import numpy as np
y_true, y_pred = np.array([0, 1, 0, 2, 1, 1]), np.array([0, 0, 0, 2, 1, 2])
np.array([[np.logical_and(y_true==r, y_pred==c) for c in xrange(3)] for r in xrange(3)])
这会产生:
[[[ True False True False False False]
[False False False False False False]
[False False False False False False]]
[[False True False False False False]
[False False False False True False]
[False False False False False True]]
[[False False False False False False]
[False False False False False False]
[False False False True False False]]]
(为了得到索引,我可以使用np.where()).以上对应于混淆矩阵:
[[2 0 0]
[1 1 1]
[0 0 1]]
但是,我想问一下是否有一个numpy-thonic单行程来帮助我取消嵌套列表理解?
最佳答案 要添加其中一个令人困惑的花哨索引解决方案,您还可以:
>>> y_true = np.array([0, 1, 0, 2, 1, 1])
>>> y_pred = np.array([0, 0, 0, 2, 1, 2])
>>> out = np.zeros((3, 3, len(y_true)), dtype=np.bool)
>>> out[y_true, y_pred, np.arange(len(y_true))] = True
>>> out
array([[[ True, False, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False]],
[[False, True, False, False, False, False],
[False, False, False, False, True, False],
[False, False, False, False, False, True]],
[[False, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, True, False, False]]], dtype=bool)
你可以得到混淆矩阵在最后一个轴上总结上面的矩阵,但是如果这就是你所追求的,那么用np.bincount直接构建它可能更好:
>>> np.bincount(y_pred + 3*y_true, minlength=9).reshape(3,3)
array([[2, 0, 0],
[1, 1, 1],
[0, 0, 1]], dtype=int64)
SciPy的sparse_coo矩阵将重复索引加在一起,因此以下内容也适用:
>>> sps.coo_matrix((np.ones_like(y_true, dtype=np.intp),
--- (y_true, y_pred)), shape=(3, 3)).A
array([[2, 0, 0],
[1, 1, 1],
[0, 0, 1]], dtype=int64)