Python numpy.squeeze函数方法的使用

numpy.squeeze 函数用于去掉 NumPy 数组中维度为 1 的轴。这在处理数组时很有用,尤其是在对高维数据进行操作时,可以简化数组的形状。本文主要介绍一下NumPy中squeeze方法的使用。

numpy.squeeze

numpy.squeeze(a, axis=None)     [source]

从数组shape中删除一维条目。

参数 :

a :array_like

输入数据。

axisNoneint 

或 int类型的tuple, 可选

1.7.0版中的新功能。

 选择形状中一维条目的子集。 

如果选择的形状输入

大于一个的轴,

则会引发错误。

返回值 :

squeezed :ndarray

输入数组,

但删除了长度为1的全部或部分维度。

 可能始终是其本身或a的视图。 

请注意,如果所有轴都受到squeeze

则结果为0d数组,而不是标量。

Raises :

ValueError

如果axis不为None

并且被压缩的轴的长度不为1

例子

1)去掉所有维度为 1 的轴

import numpy as np

# 创建一个具有维度 (1, 3, 1) 的数组
arr = np.array([[[1]], [[2]], [[3]]])
print("原始数组形状:", arr.shape)  # 输出: (3, 1, 1)

# 使用 squeeze 去掉维度为 1 的轴
squeezed_arr = np.squeeze(arr)
print("去掉维度后的数组形状:", squeezed_arr.shape)  # 输出: (3,)
print(squeezed_arr)  # 输出: [1 2 3]

2)指定轴去掉

import numpy as np

# 创建一个具有维度 (1, 3, 1) 的数组
arr = np.array([[[1]], [[2]], [[3]]])
print("原始数组形状:", arr.shape)  # 输出: (3, 1, 1)

# 使用 squeeze 指定轴去掉
squeezed_arr = np.squeeze(arr, axis=1)
print("去掉指定维度后的数组形状:", squeezed_arr.shape)  # 输出: (3, 1)
print(squeezed_arr)  # 输出: [[1]
                     #        [2]
                     #        [3]]

3)尝试去掉不存在的轴

import numpy as np

# 创建一个具有维度 (3, 2) 的数组
arr = np.array([[1, 2], [3, 4], [5, 6]])
print("原始数组形状:", arr.shape)  # 输出: (3, 2)

# 尝试去掉一个不存在的轴
squeezed_arr = np.squeeze(arr, axis=0)  # 轴 0 不是 1,保持不变
print("去掉不存在维度后的数组形状:", squeezed_arr.shape)  # 输出: (3, 2)
print(squeezed_arr)  # 输出: [[1 2]
                     #        [3 4]
                     #        [5 6]]

4)处理其他形状的数组

import numpy as np

# 创建一个具有维度 (1, 1) 的数组
x = np.array([[1234]])
print("原始数组形状:", x.shape)  # 输出: (1, 1)

# 去掉所有维度为 1 的轴
squeezed_x = np.squeeze(x)
print("去掉维度后的数组:", squeezed_x)  # 输出: 1234
print("去掉维度后的数组形状:", squeezed_x.shape)  # 输出: ()

# 访问单个元素
element = np.squeeze(x)[()]
print("单个元素:", element)  # 输出: 1234

推荐阅读
cjavapy编程之路首页