NumPy(Numerical Python的缩写)是一个开源的Python科学计算库。使用NumPy,就可以很自然地使用数组和矩阵。NumPy包含很多实用的数学函数,涵盖线性代数运算、傅里叶变换和随机数生成等功能。本文主要介绍一下NumPy中testing.assert_array_equal方法的使用。

numpy.testing.assert_array_equal

numpy.testing.assert_array_equal(x, y, err_msg='', verbose=True)    [source]

如果两个array_like对象不相等,则引发AssertionError。

给定两个array_like对象,检查形状是否相等以及这些对象的所有元素是否相等(但是请参阅注释,了解标量的特殊处理)。当形状不匹配或值冲突时引发异常。与numpy中的标准用法相反,将nan作为数字进行比较,如果两个对象的nan位于相同的位置,则不会产生断言。

建议使用浮点数验证是否相等时要格外小心。

参数 :

x :array_like

实际检查对象。

y :array_like

期望的预期对象。

err_msgstr, 可选

发生故障时要打印的错误消息。

verbosebool, 可选

如果为True,则冲突值将附加到错误消息中。

Raises :

AssertionError

如果实际对象与期望对象不相等。

Notes

当x和y中的一个是标量,另一个是array_like时,该函数将检查array_like对象的每个元素是否等于标量。

例子

第一个断言不会引发异常:

>>> np.testing.assert_array_equal([1.0,2.33333,np.nan],
...                               [np.exp(0),2.33333, np.nan])

声明由于浮点数值不精确而失败:

>>> np.testing.assert_array_equal([1.0,np.pi,np.nan],
...                               [1, np.sqrt(np.pi)**2, np.nan])
Traceback (most recent call last):
    ...
AssertionError:\nArrays are not equal
\nMismatched elements: 1 / 3 (33.3%)
Max absolute difference: 4.4408921e-16
Max relative difference: 1.41357986e-16
 x: array([1.      , 3.141593,      nan])
 y: array([1.      , 3.141593,      nan])

在这些情况下,使用assert_allclose或nulp(浮点值的数量)函数之一:

>>> np.testing.assert_allclose([1.0,np.pi,np.nan],
...                            [1, np.sqrt(np.pi)**2, np.nan],
...                            rtol=1e-10, atol=0)

如注释部分所述,assert_array_equal具有标量的特殊处理。在这里,测试检查x中的每个值是否为3:

>>> x = np.full((2, 5), fill_value=3)
>>> np.testing.assert_array_equal(x, 3)

推荐文档

相关文档

大家感兴趣的内容

随机列表