python – 查找列表是否包含特定的numpy数组

import numpy as np

a = np.eye(2)
b = np.array([1,1],[0,1])

my_list = [a, b]

my_list中的a返回true,但是my_list中的b返回“ValueError:具有多个元素的数组的真值是不明确的.使用a.any()或a.all()”.我可以通过首先将数组转换为字符串或列表来解决这个问题,但是有更好的(更多Pythonic)方法吗?

最佳答案 问题是在numpy中,==运算符返回一个数组:

>>> a == b
array([[ True, False],
       [ True,  True]], dtype=bool)

您使用 .array_equal() 将数组与纯布尔值进行比较.

>>> any(np.array_equal(a, x) for x in my_list)
True
>>> any(np.array_equal(b, x) for x in my_list)
True
>>> any(np.array_equal(np.array([a, a]), x) for x in my_list)
False
>>> any(np.array_equal(np.array([[0,0],[0,0]]), x) for x in my_list)
False
点赞