器→工具, 工具软件, 数据, 术→技巧

Scikit-Learn体系化学习之内置数据集

钱魏Way · · 182 次浏览
!文章内容如有错误或排版问题,请提交反馈,非常感谢!

sklearn.datasets 是 scikit-learn 中用于加载和生成数据集的工具模块,内置了多种经典数据集和灵活的数据生成方法,适合快速实验和算法验证。

内置数据集类型

模块中的数据集分为三类,通过不同函数加载:

小型内置数据集(Toy Datasets)

特点:数据量小,直接集成在库中,无需下载。

鸢尾花数据集(Iris)

数据集背景

  • 来源:由统计学家 Ronald Fisher 于 1936 年发表,是最早用于模式识别研究的公开数据集之一。
  • 用途:主要用于分类任务,尤其是监督学习中的多分类问题(如逻辑回归、决策树、SVM)。
  • 特点:数据简单清晰、无缺失值、类别平衡,适合快速验证模型。

数据结构

  • 样本数量:150 个样本(每个类别 50 个)。
  • 特征数量:4 个数值型特征(单位均为厘米)。
  • 目标类别:3 种鸢尾花(每个类别 50 个样本)。

特征说明

每个样本包含以下 4 个形态测量值:

  • 花萼长度(sepal length)
  • 花萼宽度(sepal width)
  • 花瓣长度(petal length)
  • 花瓣宽度(petal width)

目标类别

  • 0: 山鸢尾(Iris setosa)
  • 1: 变色鸢尾(Iris versicolor)
  • 2: 维吉尼亚鸢尾(Iris virginica)

数据示例

以下是数据集的片段展示(部分示例值):

sepal length sepal width petal length petal width target
5.1 3.5 1.4 0.2 0
7.0 3.2 4.7 1.4 1
6.3 3.3 6.0 2.5 2

加载方法

使用 Scikit-learn 的 load_iris 函数加载数据:

from sklearn.datasets import load_iris

# 加载数据集
iris = load_iris()
X = iris.data  # 特征矩阵 (150, 4)
y = iris.target  # 目标标签 (150,)

# 查看特征名和类别名
print("特征名:", iris.feature_names)
# 输出:['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

print("类别名:", iris.target_names)
# 输出:['setosa' 'versicolor' 'virginica']

数据标准化与模型训练

由于特征量纲一致(均为厘米),标准化非必需,但部分模型(如 SVM、KNN)可能受益:

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# 训练逻辑回归模型
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)

# 评估准确率
print("测试集准确率:", model.score(X_test, y_test))
# 输出:测试集准确率:1.0(完美分类)

主成分分析(PCA)降维

可通过 PCA 将 4 维特征降至 2 维,便于可视化:

from sklearn.decomposition import PCA

pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='viridis')
plt.xlabel("主成分 1")
plt.ylabel("主成分 2")
plt.show()

关键特点

  • 类别线性可分性:Iris setosa与其他两类线性可分,但versicolor和virginica存在部分重叠。
  • 特征相关性:花瓣长度与宽度高度相关(相关系数>0.9),可能需处理多重共线性。

手写数字数据集(Digits)

Scikit-learn中的手写数字数据集(Digits)是一个经典的多分类数据集,广泛用于图像识别任务的入门和实践。

数据集概述

  • 用途:多分类任务(识别0-9的手写数字)。
  • 来源:从美国国家标准与技术研究院(NIST)的原始数据中提取并预处理。
  • 特点
    • 低分辨率灰度图像(8×8像素),适合快速训练和验证模型。
    • 数据已标准化(居中、去噪),无需复杂预处理。
    • 类别平衡,适合作为图像分类的基准测试集。

数据结构

  • 样本数量:1797个样本。
  • 特征数量:64个特征(对应8×8像素的灰度值,取值范围0-16)。
  • 目标类别:10类(数字0-9,每类约180个样本)。

特征说明

每个样本是一个展平的 8×8 像素矩阵,每个像素的灰度值为 0(白色)到 16(黑色)。

例如,数字 “7” 的像素矩阵可能如下(简化):

[[0. 0. 5. 13. 9. 1. 0. 0.]
 [0. 0. 13. 15. 10. 15. 5. 0.]
 [0. 3. 15. 2. 0. 11. 8. 0.]
 [0. 4. 12. 0. 0. 8. 8. 0.]
 [0. 5. 8. 0. 0. 9. 8. 0.]
 [0. 4. 11. 0. 1. 12. 7. 0.]
 [0. 2. 14. 5. 10. 12. 0. 0.]
 [0. 0. 6. 13. 10. 0. 0. 0.]]

目标类别

  • 0~9:对应手写数字 0 到 9。

加载方法

使用 Scikit-learn 的 load_digits 函数加载数据:

from sklearn.datasets import load_digits

# 加载数据集
digits = load_digits()
X = digits.data  # 特征矩阵 (1797, 64)
y = digits.target  # 目标标签 (1797,)
images = digits.images  # 原始图像矩阵 (1797, 8, 8)

# 查看特征名和类别名
print("目标类别:", digits.target_names.tolist())  # [0, 1, 2, ..., 9]

数据可视化

显示单个数字图像

import matplotlib.pyplot as plt

# 显示前 10 个样本
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(images[i], cmap='binary')  # 使用黑白配色
    ax.set_title(f"Label: {y[i]}")
    ax.axis('off')
plt.show()

预处理与建模示例

数据标准化与分类

虽然像素值范围较小(0-16),但某些模型(如 SVM)可能受益于标准化:

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.svm import SVC

# 将像素值归一化到 [0,1]
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, stratify=y)

# 训练 SVM 模型
model = SVC(kernel='rbf', gamma=0.001, C=10)
model.fit(X_train, y_train)

# 评估准确率
print("测试集准确率:", model.score(X_test, y_test))  # 通常可达 ~0.98

主成分分析(PCA)降维

from sklearn.decomposition import PCA

# 降至 2 维可视化
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='tab10', alpha=0.6)
plt.colorbar(ticks=range(10), label='数字')
plt.xlabel("主成分 1")
plt.ylabel("主成分 2")
plt.title("PCA 降维可视化")
plt.show()

关键特点

  • 低分辨率限制:8×8 像素可能导致部分数字(如 4 和 9)边界模糊,分类难度增加。
  • 适用场景:
    • 验证传统机器学习模型(如 SVM、KNN、决策树)的性能。
    • 教学演示图像分类流程(特征提取、模型训练、评估)。
  • 扩展性不足:对于深度学习模型(如 CNN),建议使用更高分辨率的 MNIST 数据集(28×28 像素)。

威斯康星乳腺癌数据集(BreastCancer)

Scikit-learn 中的威斯康星乳腺癌数据集(BreastCancer Wisconsin Dataset)是一个经典的二分类数据集,用于预测乳腺肿瘤的恶性或良性。

数据集背景

  • 来源:由威斯康星大学医院收集,基于乳腺肿块的细针穿刺(FNA)数字化图像分析。
  • 历史意义:自 1990 年代起广泛用于机器学习研究,是医疗诊断领域的重要基准数据集。
  • 伦理声明:数据已匿名化,不包含患者个人信息,适合学术研究。

数据结构

  • 样本数量:569 个样本(良性 357 例,恶性 212 例)。
  • 特征数量:30 个数值特征(由 10 个原始特征分别计算均值、标准差和最大值)。
  • 目标变量:二分类标签,0 表示恶性(malignant),1 表示良性(benign)。

特征详细说明

原始测量特征(共 10 个)

每个特征通过数字化图像分析得到,描述细胞核的形态学特性:

  • 半径(Radius):从中心到边缘的平均距离。
  • 纹理(Texture):灰度值的标准差,反映图像局部对比度。
  • 周长(Perimeter):细胞核边界的长度。
  • 面积(Area):细胞核覆盖的像素面积。
  • 平滑度(Smoothness):周长变化的局部差异,量化边界的光滑程度。
  • 紧密度(Compactness):计算公式为 周长²/面积−1,值越大表示形状越不规则。
  • 凹陷度(Concavity):轮廓凹陷部分的严重程度,反映边界凹陷的深度。
  • 凹点数量(Concave Points):轮廓中凹陷区域的数量。
  • 对称性(Symmetry):细胞核的对称程度。
  • 分形维度(Fractal Dimension):边界复杂度的度量,值越接近 1 表示边界越接近平滑曲线。

统计量扩展

每个原始特征计算了以下 3 种统计量,形成 30 维特征:

  • 均值(Mean):所有细胞核测量值的平均。
  • 标准差(Standard Error):测量值的离散程度。
  • 最大值(Worst):最异常(通常最大)的单个细胞核测量值。

示例特征名:mean radius, radius error, worst radius数据加载与探索

from sklearn.datasets import load_breast_cancer

# 加载数据集
cancer = load_breast_cancer()
X = cancer.data  # 特征矩阵 (569, 30)
y = cancer.target  # 目标标签 (569,)

# 查看元数据
print("特征名:", cancer.feature_names)
print("类别名:", cancer.target_names)  # ['malignant' 'benign']
print("数据集描述:", cancer.DESCR)  # 详细统计信息和背景

预处理建议

  • 标准化:特征量纲差异大(如 area 范围在 143-2501,smoothness 在 0.5-0.16),需使用 StandardScaler 或 MinMaxScaler。
  • 类别平衡:良性样本较多,可考虑过采样(如 SMOTE)或在训练时设置类别权重(如 class_weight=’balanced’)。
  • 特征选择:高维特征可能导致过拟合,可使用:
    • 单变量统计(如 ANOVA F 值)
    • 递归特征消除(RFE)
    • 基于模型的特征重要性(如随机森林)

关键挑战与解决方案

  • 特征相关性高:如 mean radius 与 mean perimeter 高度相关(r>0.99),可通过 PCA 降维或手动剔除冗余特征。
  • 类别不平衡:良性样本占比约 63%,使用分层抽样或调整模型参数(如 SVM 的 class_weight)。
  • 过拟合风险:使用正则化(L1/L2)、交叉验证或简化模型复杂度。

扩展数据集

葡萄酒数据集(Wine)

Scikit-learn 中的葡萄酒数据集(Wine Dataset)是一个经典的多分类数据集,用于根据化学成分区分三种不同品种的葡萄酒。

数据集背景

  • 来源:数据源自意大利同一地区种植的葡萄酿造的红酒,但来自三个不同的品种(由研究者 Forina 等人收集)。
  • 用途:多分类任务(识别葡萄酒品种),适合验证分类算法(如 SVM、决策树、随机森林)。
  • 特点:特征均为化学分析结果,无缺失值,适合直接用于机器学习。

数据结构

  • 样本数量:178 个样本。
  • 特征数量:13 个数值型特征(描述葡萄酒的化学成分)。
  • 目标类别:3 种葡萄酒品种,分布如下:
    • 类别 0:59 个样本
    • 类别 1:71 个样本
    • 类别 2:48 个样本

特征详细说明

所有特征均为连续型数值,代表葡萄酒的化学分析指标:

  • Alcohol(酒精含量):体积百分比。
  • Malic acid(苹果酸):苹果酸浓度(g/L)。
  • Ash(灰分):高温灼烧后的无机物残留(g/L)。
  • Alcalinity of ash(灰分碱度):灰分的碱度(以碳酸钾计)。
  • Magnesium(镁):镁含量(mg/L)。
  • Total phenols(总酚):酚类物质总量(以没食子酸计)。
  • Flavanoids(类黄酮):类黄酮含量(mg/L)。
  • Nonflavanoid phenols(非类黄酮酚):非类黄酮酚含量(mg/L)。
  • Proanthocyanins(原花青素):原花青素浓度(mg/L)。
  • Color intensity(颜色强度):颜色深浅(光学测量值)。
  • Hue(色调):葡萄酒的色调(基于光谱分析)。
  • OD280/OD315 of diluted wines(稀释吸光度比值):稀释后OD280与OD315的吸光度比,反映蛋白质含量。
  • Proline(脯氨酸):脯氨酸含量(mg/L),一种氨基酸。

数据加载与探索

from sklearn.datasets import load_wine

# 加载数据集
wine = load_wine()
X = wine.data  # 特征矩阵 (178, 13)
y = wine.target  # 目标标签 (178,)

# 查看元数据
print("特征名:", wine.feature_names)
print("类别名:", wine.target_names)  # ['class_0' 'class_1' 'class_2']
print("数据集描述:", wine.DESCR)  # 包含详细统计信息

预处理建议

  • 标准化:特征量纲差异大(如 proline 范围在 278-1680,magnesium 在 70-162),需使用 StandardScaler。
  • 类别平衡:类别分布略微不均衡(类别 2 样本较少),可采用分层抽样或调整类别权重。

关键特点

  • 化学特征可解释性:特征直接反映葡萄酒的化学成分,便于分析关键区分因素(如 proline 和 flavanoids)。
  • 线性可分性:部分类别在 PCA 降维后呈现线性可分趋势,适合线性模型(如逻辑回归)。
  • 小样本挑战:数据量较小(178 样本),需谨慎避免过拟合(如使用正则化或交叉验证)。

扩展与注意事项

  • 扩展数据集:UCI 提供更详细的Wine Quality Dataset,包含红葡萄酒和白葡萄酒的感官评分。
  • 注意事项
    • 避免在预处理前划分数据集以防止数据泄露。
    • 对于高维特征,建议结合特征选择(如 ANOVA F 值)或降维(如 PCA)提升模型效率。

葡萄酒数据集凭借其清晰的化学特征和适中的规模,成为多分类任务的理想选择,尤其适合探索特征分析与模型解释的结合应用。

糖尿病数据集(Diabetes)

Scikit-learn 中的糖尿病数据集(Diabetes Dataset)是一个经典的回归任务数据集,用于预测糖尿病患者的疾病进展。

数据集背景

  • 来源:数据源自 1990 年代的研究,旨在探索糖尿病患者的生理指标与疾病发展的关联。
  • 用途:回归任务(预测糖尿病进展),适合验证回归模型(如线性回归、岭回归、随机森林回归)。
  • 特点:
    • 特征未标准化,需预处理。
    • 无缺失值,均为数值型特征。
    • 目标变量为连续值,反映疾病进展程度。

数据结构

  • 样本数量:442 个患者样本。
  • 特征数量:10个数值型特征(基线生理指标)。
  • 目标变量:连续值(范围25-346),表示糖尿病在基线后一年内的疾病进展。

特征详细说明

所有特征均经过标准化处理(均值=0,标准差=1),但原始测量指标如下:

  • Age(年龄):患者的年龄(岁)。
  • Sex(性别):0表示女性,1表示男性。
  • BMI(身体质量指数):体重(kg)除以身高(m)的平方。
  • BP(平均血压):平均动脉压(mmHg)。
  • S1(血清测定1):总胆固醇(TC,mg/dL)。
  • S2(血清测定2):低密度脂蛋白(LDL,mg/dL)。
  • S3(血清测定3):高密度脂蛋白(HDL,mg/dL)。
  • S4(血清测定4):总胆固醇/HDL比值。
  • S5(血清测定5):可能为甘油三酯(对数转换值)。
  • S6(血清测定6):血糖水平(mg/dL)。

目标变量

  • 疾病进展指标:基于基线后一年内患者的疾病发展评估,值越大表示病情恶化越严重。

数据加载与探索

from sklearn.datasets import load_diabetes

# 加载数据集
diabetes = load_diabetes()
X = diabetes.data  # 特征矩阵 (442, 10)
y = diabetes.target  # 目标变量 (442,)

# 查看元数据
print("特征名:", diabetes.feature_names)
print("数据集描述:", diabetes.DESCR)  # 包含详细统计信息

预处理建议

  • 标准化:虽然特征已中心化,但不同特征的尺度可能仍需调整(如使用StandardScaler)。
  • 特征工程:探索交互项(如 BMI × 血压)或多项式特征。
  • 特征选择:通过相关系数或 Lasso 回归筛选重要特征。

关键挑战与解决方案

  • 低解释方差:R² 通常较低(约 4-0.5),表明需更复杂模型(如梯度提升树)或引入额外特征。
  • 共线性问题:血清指标(S1-S6)可能存在相关性,可通过 VIF 检测或使用正则化模型(如岭回归)。
  • 小样本限制:使用交叉验证(如 5 折)优化超参数,避免过拟合。

大型下载数据集(Real-World Datasets)

特点:数据量较大,首次使用需从网络下载,缓存到本地。

使用 fetch_* 函数加载(首次使用需下载):

加州房价数据集(California Housing)

Scikit-learn 中的加州房价数据集(California Housing Dataset)是一个经典的回归任务数据集,用于预测美国加利福尼亚州各区域的房屋价格中位数。

数据集概述

  • 来源:基于 1990 年美国人口普查数据,由美国人口普查局(S. Census Bureau)收集。
  • 用途:回归任务(预测房价),适合验证回归模型(如线性回归、梯度提升树、神经网络)。
  • 特点:
    • 包含地理和社会经济特征。
    • 数据规模较大(约 2 万样本),适合训练复杂模型。
    • 无缺失值,特征已部分预处理。

数据结构

  • 样本数量:20,640 个区域(样本)。
  • 特征数量:8 个数值型特征。
  • 目标变量:连续值,表示区域内房屋价格中位数(单位:10 万美元)。

特征详细说明

所有特征均为数值型,描述地理和人口统计信息:

  • MedInc(收入中位数):区域内家庭收入中位数(单位:万美元)。
  • HouseAge(房龄中位数):房屋年龄中位数(单位:年)。
  • AveRooms(平均房间数):每户平均房间数。
  • AveBedrms(平均卧室数):每户平均卧室数。
  • Population(人口):区域内人口总数。
  • AveOccup(平均入住率):每户平均居住人数。
  • Latitude(纬度):区域中心的纬度。
  • Longitude(经度):区域中心的经度。

目标变量

  • MedHouseVal(房价中位数):标化后的房屋价格中位数(范围 0-5,对应 0-50 万美元)。

数据加载与探索

from sklearn.datasets import fetch_california_housing

# 加载数据集
housing = fetch_california_housing()
X = housing.data  # 特征矩阵 (20640, 8)
y = housing.target  # 目标变量 (20640,)

# 查看元数据
print("特征名:", housing.feature_names)
print("数据集描述:", housing.DESCR)  # 包含详细统计信息

预处理建议

  • 标准化/归一化:特征量纲差异大(如 Population 范围 0-35,682,MedInc 0-15),需使用 StandardScaler。
  • 地理特征工程:
    • 将经纬度转换为区域编码(如 K-Means 聚类)。
    • 计算到海岸线或城市的距离(需外部数据)。
  • 处理偏态分布:对右偏特征(如 Population)进行对数转换。

关键挑战与解决方案

  • 高维度与稀疏性:地理特征(经纬度)直接使用效果有限,可转换为空间聚类或距离特征。
  • 非线性关系:使用树模型(如 GBDT)或引入多项式特征(如 MedInc²)。
  • 数据时效性:1990 年数据已过时,可通过外部数据(如 Zillow API)获取最新房价。

扩展与注意事项

  • 扩展数据集:
    • 通过 fetch_openml(name=’california_housing’) 获取更详细的版本(包含更多特征)。
    • 结合外部数据(如犯罪率、学校评分)提升预测精度。
  • 注意事项:
    • 避免将经纬度直接输入线性模型(需转换为空间特征)。
    • 目标变量已截断(最高为 0),需注意模型预测值的范围限制。

Olivetti 人脸数据集

Scikit-learn 中的 Olivetti 人脸数据集(Olivetti Faces Dataset)是一个经典的人脸识别数据集,常用于图像分类、降维和聚类任务。

数据集概述

  • 来源:由剑桥大学 AT&T 实验室收集,包含 40 位志愿者的面部图像。
  • 用途:人脸识别、图像降维(如 PCA、LDA)、小样本学习。
  • 特点:
      • 小样本但高维度:每位志愿者仅 10 张图像,但每张图像为 64×64 像素(4096 维)。

    灰度图像,背景统一,适合快速验证算法。

  • 同一人物的图像包含不同表情、光照和面部细节变化。

数据结构

  • 样本数量:400 张图像(40 人 × 10 张)。
  • 特征数量:4096 个特征(64×64 像素的灰度值,范围 0-1)。
  • 目标变量:40 个类别(0-39,每位志愿者对应一个类别)。

数据示例

图像内容

  • 人物:40 位不同志愿者(20-50 岁,混合性别)。
  • 拍摄条件:
  • 表情:自然、微笑、惊讶等。
  • 姿态:轻微头部转动。
  • 光照:部分图像存在明暗变化。
  • 细节:是否佩戴眼镜(部分人物)。

数据格式

  • 每张图像被展平为一维向量(64×64=4096维),例如:
# 加载数据集后,获取第一张图像
image_sample = olivetti.images[0] # 形状(64,64)
flattened_sample = olivetti.data[0] # 形状(4096,)

数据加载与探索

from sklearn.datasets import fetch_olivetti_faces

# 加载数据集(首次使用需下载)
olivetti = fetch_olivetti_faces()
X = olivetti.data # 特征矩阵(400,4096)
y = olivetti.target # 目标标签(400,)
images = olivetti.images # 原始图像矩阵(400,64,64)

# 查看元数据
print("目标类别数:", len(np.unique(y))) # 40
print("图像分辨率:", images[0].shape) # (64,64)

数据可视化

显示同一人物的多张图像

import matplotlib.pyplot as plt

# 选择第5位志愿者的所有图像(标签为4)
person_id = 4
person_images = images[y == person_id]

# 绘制前5张图像
fig, axes = plt.subplots(1,5, figsize=(12,3))
for i, ax in enumerate(axes):
    ax.imshow(person_images[i], cmap='gray')
    ax.axis('off')
    ax.set_title(f"Person {person_id}")
plt.show()

预处理建议

  • 标准化:像素值已归一化到[0,1],无需额外处理。
  • 降维:使用PCA或LDA减少4096维特征,提升计算效率。
  • 数据增强:对小样本(每人10张)应用旋转、平移或噪声注入(需手动扩展)。

关键挑战与解决方案

  • 小样本问题:每人仅10张图像,可结合迁移学习(如VGG-Face预训练模型)。
  • 高维度灾难:使用PCA将4096维降至50-100维,再输入传统分类器。
  • 类内差异:同一人物的表情和姿态变化需模型具备鲁棒性,可引入数据增强或注意力机制。

注意事项

  • 数据规模限制:400样本仅适合教学和小规模实验,工业级应用需更大数据集(如LFW、CelebA)。
  • 图像分辨率:64×64像素较低,无法捕捉细节,可尝试超分辨率重建预处理。
  • 隐私保护:尽管数据已匿名,实际应用中需遵守隐私法规。

Olivetti人脸数据集为入门级人脸识别任务提供了便捷的实验平台,尤其适合验证特征提取与降维技术的有效性。其简洁的结构和明确的类别划分,使其成为图像处理与模式识别教学的经典案例。

20类新闻文本数据集(20Newsgroups)

Scikit-learn中的20类新闻文本数据集(20Newsgroups)是一个经典的文本分类数据集,广泛用于自然语言处理(NLP)任务。

数据集概述

  • 来源:收集自20个不同主题的新闻组(Usenet讨论组),涵盖技术、科学、娱乐等多个领域。
  • 用途:文本分类、主题建模、信息检索、迁移学习(如预训练模型微调)。
  • 特点:
    • 纯英文文本,包含邮件头(发件人、主题等)和正文。
    • 原始文本未预处理,需进行分词、去噪等操作。
    • 提供训练集和测试集的预设划分,适合基准测试。

数据结构

  • 样本数量:共约18,846篇文档(实际数量因版本和参数而异)。
  • 类别数量:20个主题,每个主题约600-1,000篇文档。
  • 数据划分:
    • 训练集:约11,314篇文档。
    • 测试集:约7,532篇文档。
  • 类别示例:graphics, rec.sport.baseball, sci.med, talk.politics.mideast, alt.atheism。

数据加载与探索

from sklearn.datasets import fetch_20newsgroups

# 加载训练集和测试集(全类别)
news_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42)
news_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=42)

X_train, y_train = news_train.data, news_train.target
X_test, y_test = news_test.data, news_test.target

# 查看类别名
print("类别名:", news_train.target_names)

参数说明

    • subset: ‘train’(训练集)、’test’(测试集)或’all’(合并所有数据)。
    • categories: 指定加载的类别列表(默认加载全部20类)。

remove: 去除邮件头中的特定部分(如 (‘headers’, ‘footers’, ‘quotes’))。

数据示例

# 打印第一篇训练文档的文本和类别
print("文本内容:\n", X_train[0][:500]) # 显示前500字符
print("类别ID:", y_train[0])
print("类别名:", news_train.target_names[y_train[0]])

输出示例:

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is
...
类别ID:6
类别名:rec.autos

预处理与特征工程

预处理步骤

  • 去除噪声:
    • 删除邮件头、签名、引用(remove=(‘headers’, ‘footers’, ‘quotes’))。
    • 移除标点、数字、特殊字符。
  • 分词与词干提取:
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize

stemmer = PorterStemmer()
def preprocess(text):
    tokens = word_tokenize(text.lower()) # 小写化并分词
    tokens = [stemmer.stem(t) for t in tokens if t.isalpha()] # 仅保留字母词并词干提取
    return ' '.join(tokens)
  • 停用词过滤:使用 corpus.stopwords 或 sklearn 内置停用词表。
  • 向量化(文本转数值特征)
from sklearn.feature_extraction.text import TfidfVectorizer

# 使用 TF-IDF 转换(限制最大特征数)
vectorizer = TfidfVectorizer(max_features=10_000, stop_words='english')
X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

挑战与解决方案

  • 类别不平衡:使用 class_weight=’balanced’ 或过采样(如 SMOTE)。
  • 高维稀疏特征:使用 TruncatedSVD 或 NMF 降维。
  • 文本噪声:结合正则表达式和 NLP 工具包(如 nltk、spaCy)清洗数据。

注意事项

  • 版本差异:不同 Scikit-learn 版本的数据划分可能略有不同,建议固定随机种子。
  • 计算资源:TF-IDF 转换后特征维度高(数万维),需注意内存消耗。
  • 语言限制:仅支持英文,处理其他语言需额外预处理。

20类新闻文本数据集为文本分类任务提供了标准化的实验平台,尤其适合探索从原始文本到分类模型的端到端流程。其多样的主题和真实的文本噪声,使其成为 NLP 研究和教学的理想选择。

MNIST 手写数字数据集

Scikit-learn 中的 MNIST 手写数字数据集(MNIST Dataset)是一个经典的图像分类数据集,广泛用于验证机器学习模型(尤其是深度学习模型)在图像识别任务中的性能。

数据集概述

  • 来源:由美国国家标准与技术研究院(NIST)收集,包含手写数字的灰度图像。
  • 用途:图像分类(识别 0-9 手写数字)、降维、生成对抗网络(GAN)训练。
  • 特点:
    • 高分辨率(相比 Scikit-learn 内置的 Digits 数据集):28×28 像素。
    • 样本量大,包含训练集和测试集的官方划分。
    • 数据已标准化(居中、去噪),适合直接使用。

数据结构

  • 样本数量:70,000 张图像。
    • 训练集:60,000 张。
    • 测试集:10,000 张。
  • 特征数量:784 个特征(28×28 像素展平后的灰度值,范围 0-255)。
  • 目标变量:10 类(数字 0-9)。

数据加载与探索

MNIST 未直接集成在 Scikit-learn 的 datasets 模块中,但可通过 fetch_openml 加载:

from sklearn.datasets import fetch_openml

# 加载MNIST数据集(首次使用需下载)
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
X = mnist.data  # 特征矩阵 (70000, 784)
y = mnist.target.astype(int)  # 目标标签 (70000,)

# 查看数据形状
print("特征维度:", X.shape)  # (70000, 784)
print("目标类别:", np.unique(y))  # [0 1 2 3 4 5 6 7 8 9]

数据划分

# 官方训练集和测试集划分
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

数据可视化

显示单个数字图像

import matplotlib.pyplot as plt

# 显示第5个训练样本
digit_image = X_train[4].reshape(28, 28)
plt.imshow(digit_image, cmap='binary')
plt.title(f"Label: {y_train[4]}")
plt.axis('off')
plt.show()

预处理建议

归一化:将像素值从0-255缩放到0-1:

X_train_scaled = X_train / 255.0
X_test_scaled = X_test / 255.0

数据增强(针对深度学习):随机旋转、平移、缩放(需使用 tensorflow 或 keras 工具)。

关键挑战与解决方案

    相似数字混淆(如 4/9、5/3):

    • 使用数据增强生成更多样本。
    • 改进模型结构(如引入残差连接)。
  • 过拟合:
    • 添加 Dropout 层或 L2 正则化。
  • 计算资源限制:
    • 使用轻量级模型(如 LeNet-5)或迁移学习。

扩展数据集

  • Fashion-MNIST:替代 MNIST 的 10 类服装图像数据集(相同格式)。
  • EMNIST:扩展 MNIST,包含字母和数字。
  • KMNIST:日文手写字符数据集。

注意事项

  • 数据泄露:使用官方划分的训练集和测试集,避免自定义划分。
  • 性能天花板:传统模型(如 SVM)准确率约 95-98%,CNN 可达 5% 以上,剩余错误多为模糊书写。

MNIST 手写数字数据集是机器学习领域的“Hello World”,凭借其简洁的结构和丰富的应用场景,成为算法验证和教育的最佳选择。尽管现代研究已转向更复杂的数据集(如 CIFAR-10、ImageNet),MNIST 仍为理解图像处理基础提供了不可替代的价值。

生成模拟数据集(Synthetic Datasets)

特点:按需生成可控数据,适用于算法测试。

Scikit-learn(sklearn)提供了多种生成模拟数据集的方法,适用于分类、回归、聚类、降维等任务。这些数据集便于算法测试、模型调试和教学示例。

分类数据集

make_classification

生成适用于分类的多维数据集,可控制类别间分离度、特征类型等。

参数:

  • n_samples: 样本数(默认 100)
  • n_features: 特征数(默认 20)
  • n_informative: 有效特征数(默认 2)
  • n_redundant: 冗余特征数(默认 2,由有效特征线性组合生成)
  • n_classes: 类别数(默认 2)
  • random_state: 随机种子

示例:

from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4, n_classes=2, random_state=42)

make_blobs

生成各向同性高斯分布的聚类数据,适用于聚类和分类。

参数:

  • centers: 中心点数量或坐标
  • cluster_std: 簇的标准差

示例:

from sklearn.datasets import make_blobs
X, y = make_blobs(n_samples=100, centers=3, n_features=2, random_state=42)

非线性数据集

make_circlesmake_moons: 生成环形或半月形数据,测试非线性分类器。

from sklearn.datasets import make_circles, make_moons
X_circle, y_circle = make_circles(n_samples=100, noise=0.1, factor=0.5)
X_moon, y_moon = make_moons(n_samples=100, noise=0.1)

回归数据集

make_regression

生成回归模型数据,可控制噪声、有效特征等。

参数:

  • n_targets: 目标变量数
  • noise: 噪声标准差

示例:

from sklearn.datasets import make_regression
X, y = make_regression(n_samples=100, n_features=3, noise=0.5)

聚类数据集

make_blobs

(同上,常用于 K-means 等聚类算法)

make_s_curve

生成 S 型曲线,用于流形学习。

示例:

from sklearn.datasets import make_s_curve
X, t = make_s_curve(n_samples=100, noise=0.1)

多标签与多输出

make_multilabel_classification

生成多标签分类数据。

参数:

  • n_labels: 每个样本的标签数

示例:

from sklearn.datasets import make_multilabel_classification
X, y = make_multilabel_classification(n_samples=100, n_features=5, n_classes=3)

其他数据集

make_sparse_uncorrelated

生成稀疏且无相关性的回归数据。

make_biclusters

生成双聚类数据。

使用场景与建议

  • 测试模型: 快速验证算法在特定数据分布下的表现。
  • 教学演示: 直观展示不同算法(如 SVM、决策树)如何处理线性/非线性数据。
  • 参数调整: 通过调整噪声、特征数量等,研究模型鲁棒性。

数据集对象结构

加载后的数据集返回一个 Bunch 对象(类似字典),包含以下键:

  • data:特征数据(NumPy 数组或 SciPy 稀疏矩阵)。
  • target:标签数据(分类/回归目标)。
  • feature_names:特征名称列表(如鸢尾花的[‘sepal length(cm)’,…])。
  • target_names:类别名称列表(如鸢尾花的[‘setosa’,’versicolor’,’virginica’])。
  • DESCR:数据集的完整描述文本。

示例:转换为 Pandas DataFrame

import pandas as pd

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target #添加标签列
df['species'] = df['target'].apply(lambda x: iris.target_names[x])
print(df.head())

核心 API 方法

加载函数

  • load_*:加载小型内置数据集(如 load_iris)。
  • fetch_*:下载大型数据集(如 fetch_california_housing)。
  • make_*:生成模拟数据(如 make_moons 生成月牙形聚类数据)。

参数说明

  • return_X_y:直接返回 (data, target) 而非 Bunch 对象。
X, y = load_iris(return_X_y=True)
  • as_frame:将 data 和 target 转为 Pandas DataFrame(要求数据集支持)。
housing = fetch_california_housing(as_frame=True)
X_df = housing.frame # 获取完整 DataFrame

注意事项

  • 数据集缓存路径:下载的数据集默认保存在 ~/scikit_learn_data,可通过 SKLEARN_DATA 环境变量修改。
  • 伦理数据集:部分数据集(如波士顿房价)因伦理争议已弃用,建议使用替代数据。
  • 数据预处理:生成的数据默认无缺失值,但真实数据集可能需要处理缺失或标准化。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注