一尘不染

比较包含NaN的numpy数组

python

对于我的单元测试,我想检查两个数组是否相同。简化示例:

a = np.array([1, 2, np.NaN])
b = np.array([1, 2, np.NaN])
if np.all(a==b):
    print 'arrays are equal'

这是行不通的,因为nan != nan。最好的进行方法是什么?


阅读 146

收藏
2020-12-20

共1个答案

一尘不染

或者您可以使用numpy.testing.assert_equalnumpy.testing.assert_array_equaltry/except

In : import numpy as np

In : def nan_equal(a,b):
...:     try:
...:         np.testing.assert_equal(a,b)
...:     except AssertionError:
...:         return False
...:     return True

In : a=np.array([1, 2, np.NaN])

In : b=np.array([1, 2, np.NaN])

In : nan_equal(a,b)
Out: True

In : a=np.array([1, 2, np.NaN])

In : b=np.array([3, 2, np.NaN])

In : nan_equal(a,b)
Out: False

编辑

由于您正在使用它进行单元测试,因此裸露assert(而不是将其包装成get True/False)可能更自然。

2020-12-20