Python 机器学习 数据集分布可视化

Python 的机器学习项目中,可视化是理解数据、模型和预测结果的重要工具。通过可视化可以观察数据集的分布情况,了解数据的特征和规律,可以评估模型的性能,发现模型的优缺点,分析预测结果,解释模型的预测过程。可视化数据集的分布和预测结果是整个过程中一个重要的步骤。通常可视化可以用Seaborn实现,它是基于 Matplotlib 的高级绘图库,提供了一些更高级的绘图功能。

1、加载数据集

load_iris()是scikit-learn库中的一个函数,用于加载一个著名的数据集,即鸢尾花(Iris)数据集。数据集通常用于机器学习和统计分类技术的示例、测试和实验。鸢尾花数据集包含了三种鸢尾花(Iris setosa、Iris virginica和Iris versicolor)的150个样本。每个样本有四个特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度,这些特征的单位都是厘米。目标变量是花的种类。数据集常用于分类算法的教学和测试,特别是对于新手来说,它是理解机器学习概念的一个很好的入门数据集。可以用于各种分类算法,包括最简单的如K-近邻(KNN)算法,以及更复杂的如支持向量机(SVM)和神经网络。

from sklearn.datasets import load_iris

# 加载数据集
iris = load_iris()

# 特征矩阵,iris.data包含了150个样本的四个特征值
x = iris.data
print(x)

# 目标向量,iris.target包含了相应的种类标签(0, 1, 2分别代表三种不同的鸢尾花)。
y = iris.target
print(y)

# 特征名称
feature_names = iris.feature_names
print(feature_names)

# 目标名称(花的种类)
target_names = iris.target_names
print(target_names)

2、seaborn.lmplot()的使用

Seaborn 的 lmplot() 函数是用于绘制线性回归模型的强大工具,它结合了 regplot() 和 FacetGrid。这个函数适用于绘制数据集中变量间线性关系的图形,尤其是探索两个连续变量(或一个连续和一个分类变量)之间的关系。lmplot() 是一个功能强大的工具,适用于探索和呈现变量间的线性关系,特别是在数据集包含分类变量时。常用参数如下,

参数

描述

x

数据框架中的变量名称,将在 x 轴上绘制。

y

数据框架中的变量名称,将在 y 轴上绘制。

hue

数据框架中的分类变量,不同类别以不同颜色显示。

data

数据源,通常是 Pandas 的 DataFrame。

palette

设置不同类别的颜色。

col

用于在不同列展示数据框架中一个额外分类变量的不同级别。

row

用于在不同行展示数据框架中一个额外分类变量的不同级别。

markers

指定不同类别的数据点的标记类型。

fit_reg

布尔值,控制是否绘制回归模型

(对于 KNN,通常为 False)。

scatter_kws

传递额外的关键字参数到底层 Matplotlib 函数,

控制散点的样式。

line_kws

传递额外的关键字参数到底层 Matplotlib 函数,

控制线条的样式。

height

每个面板的高度大小。

aspect

每个面板的宽高比。

使用代码:

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_iris


# 加载鸢尾花数据集
#iris = sns.load_dataset('iris') #加载报错,可以直接使用sklearn.datasets的load_iris来加载数据集


# 加载鸢尾花数据集
iris = load_iris()

# 创建DataFrame
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)


# 使用 lmplot() 函数绘制图表
sns.lmplot(
    x="sepal length (cm)",  # x 轴变量
    y="petal length (cm)",  # y 轴变量
    hue="species",     # 数据分类变量
    data=iris_df,         # 数据源
    palette="Set1",    # 为不同的 species 设置不同的颜色
    markers=["o", "s", "D"],  # 为不同的 species 设置不同的标记
    height=5,          # 图表高度
    aspect=1.5,        # 图表宽高比
    fit_reg=False,     # 关闭线性回归拟合线,因为我们更关注数据分布
    scatter_kws={"s": 50, "alpha": 0.8}  # 设置散点的大小和透明度
)

# 添加标题
plt.title("cjavapy")
plt.draw()

# 显示图形
plt.show()

3、数据集分布可视化

使用 Seaborn 的 lmplot() 创建的鸢尾花数据集的散点图。展示数据的分布而非探究变量间的线性关系。可视化是理解数据、模型和预测结果的重要工具。散点图用于展示数据点的分布情况,适用于数值型数据。

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_iris

# 加载鸢尾花数据集
iris = load_iris()

# 特征矩阵,iris.data包含了150个样本的四个特征值
x = iris.data
print(x)

# 目标向量,iris.target包含了相应的种类标签(0, 1, 2分别代表三种不同的鸢尾花)。
y = iris.target
print(y)

# 特征名称
feature_names = iris.feature_names
print(feature_names)

# 目标名称(花的种类)
target_names = iris.target_names
print(target_names)

iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)
iris_df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
# 使用 lmplot 绘制数据分布图
sns.lmplot(x="sepal length (cm)", y="petal length (cm)", hue="species", data=iris_df,
           palette="Set1", markers=["o", "s", "D"], height=5, aspect=1.5, fit_reg=False)

# 添加标题
plt.title("cjavapy")

plt.draw()
# 显示图形
plt.show()

推荐阅读
cjavapy编程之路首页