numpy.expand_dims 是一个用于在指定位置插入新的轴(即维度)的函数。这个函数的主要用途是将一个数组的形状扩展,以便与其他数组在特定维度上进行操作,尤其是在广播和矩阵操作时使用。本文主要介绍一下NumPy中expand_dims方法的使用。

numpy.expand_dims

numpy.expand_dims(a, axis)     [source]

Expand数组的shape。

插入一个新轴,该轴将出现在expand数组shape的轴位置上。

参数 :

a :array_like

输入数组。

axis :int 或  int类型的tuple

在扩展轴上放置新轴的位置。 

从1.13.0版开始不推荐使用:

传递一个将axis>a.ndim视为axis == a.ndim的轴,

并传递axis<-a .ndim-1将被视为axis == 0。 

不建议使用此行为。 在版本1.18.0中更改:

现在支持轴元组。

 如上所述,超出范围的轴现在被禁止,

并引发AxisError

返回值 :

result :ndarray

视图a随维数增加。

例子

1)在第 0 维插入新的轴

import numpy as np

arr = np.array([1, 2, 3])
print("原数组形状:", arr.shape)

# 在第 0 维插入新轴
expanded_arr = np.expand_dims(arr, axis=0)
print("新数组形状:", expanded_arr.shape)
print(expanded_arr)

2)在最后一维插入新轴

import numpy as np

arr = np.array([1, 2, 3])
print("原数组形状:", arr.shape)

# 在最后一维插入新轴
expanded_arr = np.expand_dims(arr, axis=-1)
print("新数组形状:", expanded_arr.shape)
print(expanded_arr)

3)对多维数组进行扩展

import numpy as np

arr = np.array([[1, 2], [3, 4]])
print("原数组形状:", arr.shape)

# 在第 1 维插入新轴
expanded_arr = np.expand_dims(arr, axis=1)
print("新数组形状:", expanded_arr.shape)
print(expanded_arr)

4)在 axis=1 插入一个新轴

import numpy as np

x = np.array([1, 2])
print("原数组形状:", x.shape)  # 输出: (2,)

# 在 axis=1 插入一个新轴
y = np.expand_dims(x, axis=1)
print("新数组:", y)
print("新数组形状:", y.shape)  # 输出: (2, 1)

5)在 axis=(0, 1) 插入新轴

import numpy as np

x = np.array([1, 2])

# 在 axis=(0, 1) 插入新轴
y = np.expand_dims(x, axis=(0, 1))
print("新数组:", y)
print("新数组形状:", y.shape)  # 输出: (1, 1, 2)

推荐文档

相关文档

大家感兴趣的内容

随机列表