机器学习, 法→原理

机器学习常用损失函数

钱魏Way · · 5 次浏览

机器学习中的损失函数用于衡量模型预测值与真实值之间的差异,是模型优化的关键。以下按任务类型分类介绍常见损失函数:

回归任务损失函数

均方误差 (Mean Squared Error, MSE / L2 Loss)

公式

$$\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i – \hat{y}_i)^2$$

导数(梯度)

$$\frac{\partial \text{MSE}}{\partial \hat{y}_i} = -2(y_i – \hat{y}_i)$$

特性

  • 凸函数:具有唯一全局最小值
  • 惩罚特性:对较大误差给予二次方惩罚
  • 最优解:最小化MSE得到条件期望$\mathbb{E}[y|x]$

优点

  • 处处可导,优化稳定高效
  • 对高斯噪声假设下的最大似然估计

缺点

  • 对异常值敏感
  • 损失单位与目标单位不一致

平均绝对误差 (Mean Absolute Error, MAE / L1 Loss)

公式

$$\text{MAE} = \frac{1}{n} \sum_{i=1}^{n} |y_i – \hat{y}_i|$$

导数(梯度)

$$\frac{\partial \text{MAE}}{\partial \hat{y}_i} = \begin{cases} 1, & \hat{y}_i < y_i \\ -1, & \hat{y}_i > y_i \\ \text{未定义}, & \hat{y}_i = y_i \end{cases}$$

实际使用中常采用次梯度:$\text{sign}(y_i – \hat{y}_i)$

特性

  • 鲁棒性:对异常值不敏感
  • 最优解:最小化MAE得到条件中位数
  • 分位数特性:当$\tau = 0.5$时是分位数损失的特例

优点

  • 对异常值具有鲁棒性
  • 损失单位与目标单位一致

缺点

  • 在零点不可导
  • 优化效率相对较低

Huber Loss

公式

$$L_\delta(y, \hat{y}) = \begin{cases} \frac{1}{2}(y – \hat{y})^2, & \text{当} |y – \hat{y}| \leq \delta \\ \delta|y – \hat{y}| – \frac{1}{2}\delta^2, & \text{当 } |y – \hat{y}| > \delta \end{cases}$$

其中$\delta$是超参数,通常通过交叉验证选择。

导数(梯度)

$$\frac{\partial L_\delta}{\partial \hat{y}} = \begin{cases} -(y – \hat{y}), & |y – \hat{y}| \leq \delta \\ -\delta \cdot \text{sign}(y – \hat{y}), & |y – \hat{y}| > \delta \end{cases}$$

特性

  • 平滑过渡:在小误差区域表现如MSE,在大误差区域表现如MAE
  • 可导性:处处一阶连续可导
  • 超参数:$\delta$控制从二次到线性的转换点

优点

  • 兼具MSE和MAE的优点
  • 对异常值鲁棒且优化高效

缺点

  • 需要调整超参数$\delta$
  • 计算复杂度略高于MSE和MAE

分位数损失 (Quantile Loss)

公式

$$L_\tau(y, \hat{y}) = \begin{cases} \tau \cdot |y – \hat{y}|, & y \geq \hat{y} \\ (1 – \tau) \cdot |y – \hat{y}|, & y < \hat{y} \end{cases}$$

或等价表示为:

$$L_\tau(y, \hat{y}) = (y – \hat{y}) \cdot (\tau – \mathbb{I}_{y < \hat{y}})$$

其中$\tau \in (0, 1)$是目标分位数,$\mathbb{I}$是指示函数。

特殊情形

  • 当$\tau = 0.5$时,退化为MAE
  • 当$\tau = 0.9$时,高估惩罚为1,低估惩罚为0.9

导数(梯度)

$$\frac{\partial L_\tau}{\partial \hat{y}} = \begin{cases} -\tau, & y \geq \hat{y} \\ 1 – \tau, & y < \hat{y} \end{cases}$$

特性

  • 不对称惩罚:对正负误差给予不同权重
  • 分位数预测:最小化得到条件分布的$\tau$-分位数
  • 预测区间:可用两个不同$\tau$的模型构建预测区间

应用场景

  • 金融风险价值(VaR)计算
  • 库存管理中的安全库存设定
  • 医疗预后中的风险区间估计

对比总结

特性 MSE MAE Huber Quantile
公式 $\frac{1}{n}\sum(y-\hat{y})^2$ $\frac{1}{n}\sum\|y-\hat{y}\|$ 分段函数 $\tau\|y-\hat{y}\|$ (if $y \geq \hat{y}$)
对异常值 敏感 鲁棒 鲁棒(通过δ控制) 鲁棒
可导性 处处可导 零点不可导 处处可导 零点不可导
最优解 条件均值 条件中位数 条件均值与中位数间 条件分位数
计算效率
超参数 δ τ
输出解释 点估计 点估计 点估计 区间估计

选择指南

  • 数据干净、重视精度 → MSE
    • 信号处理、物理实验等
  • 存在异常值、需要鲁棒性 → MAE 或 Huber
    • MAE:简单直接,不调参
    • Huber:平衡鲁棒性与效率,需调δ
  • 需要不确定性量化 → Quantile Loss
    • 金融风险评估、医疗预后、供应链管理
  • 实践建议:
    • 从MSE开始基准测试
    • 检查残差分布,如有重尾则考虑MAE/Huber
    • 根据业务需求选择是否使用分位数损失

这些损失函数在实际应用中常常组合使用或进行改进,例如在梯度提升树中,分位数损失被广泛用于构建预测区间。理解它们的数学特性和适用场景,有助于在实际问题中选择合适的损失函数。

分类任务损失函数

交叉熵损失(Cross-Entropy Loss)

交叉熵损失是分类任务中最核心的损失函数,衡量模型预测概率分布与真实分布之间的差异。

二分类交叉熵(Binary Cross-Entropy)

公式:

$$L_{\text{BCE}} = -\frac{1}{N} \sum_{i=1}^{N} [ y_i \log(\hat{y}_i) + (1 – y_i) \log(1 – \hat{y}_i)]$$

其中:

  • $y_i \in \{0,1\}$:真实标签
  • $\hat{y}_i \in [0,1]$:模型预测的正类概率
  • $N$:样本数量

导数:

$$\frac{\partial L_{\text{BCE}}}{\partial \hat{y}_i} = -\frac{y_i}{\hat{y}_i} + \frac{1-y_i}{1-\hat{y}_i}$$

激活函数:通常与Sigmoid 函数结合:$$\sigma(z) = \frac{1}{1 + e^{-z}}$$

多分类交叉熵(Categorical Cross-Entropy)

公式:

$$L_{\text{CCE}} = -\frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c})$$

其中:

  • $y_{i,c} \in \{0,1\}$:真实标签的one-hot编码
  • $\hat{y}_{i,c} \in [0,1]$:模型预测的类别c的概率
  • $C$:类别数量
  • 约束:$\sum_{c=1}^{C} \hat{y}_{i,c} = 1$

导数:当与Softmax结合时,梯度简化为:

$$\frac{\partial L_{\text{CCE}}}{\partial z_c} = \hat{y}_c – y_c$$

其中$z_c$是类别c的logit值。

激活函数:必须与Softmax 函数结合:

$$\text{Softmax}(z_c) = \frac{e^{z_c}}{\sum_{j=1}^{C} e^{z_j}}$$

特点与应用场景

  • 信息论基础:最小化交叉熵等价于最小化KL散度,衡量两个分布间的差异
  • 概率解释:在贝叶斯框架下,交叉熵损失对应最大似然估计
  • 梯度友好:与Softmax结合时梯度计算高效,避免梯度消失问题
  • 应用场景:
    • 神经网络分类任务(图像分类、文本分类等)
    • 语言模型(如GPT系列)
    • 推荐系统

合页损失(Hinge Loss)

合页损失主要用于支持向量机(SVM),追求最大间隔分类。

公式

二分类合页损失:

$$L_{\text{Hinge}} = \frac{1}{N} \sum_{i=1}^{N} \max(0, 1 – y_i f(x_i))$$

其中:

  • $y_i \in \{-1, +1\}$:真实标签
  • $f(x_i)$:模型决策函数值(未归一化)

多分类合页损失(Crammer-Singer形式):

$$L_{\text{MultiHinge}} = \frac{1}{N} \sum_{i=1}^{N} \sum_{j \neq y_i} \max(0, 1 – f_{y_i}(x_i) + f_j(x_i))$$

几何解释与特性

间隔概念:

  • 当$y_i f(x_i) \geq 1$时,损失为0(样本被正确分类且离决策边界足够远)
  • 当$0 < y_i f(x_i) < 1$时,产生线性惩罚(样本在间隔内)
  • 当$y_i f(x_i) \leq 0$时,样本被错误分类

优化特性:

  • 凸函数:保证全局最优解
  • 次梯度优化:在不可导点(间隔=1处)使用次梯度
  • 稀疏性:只有支持向量(间隔内的样本)影响模型

SVM对偶问题:

  • 原始SVM优化问题:$\min_{w,b} \frac{1}{2} \|w\|^2 + C \sum_{i=1}^{N} \xi_i$
  • 约束:$y_i(w^T x_i + b) \geq 1 – \xi_i, \quad \xi_i \geq 0$,其中$\xi_i = \max(0, 1 – y_i f(x_i))$

应用场景

  • 传统SVM分类器
  • 线性可分的分类问题
  • 需要明确决策边界解释的场景

指数损失(Exponential Loss)

指数损失是AdaBoost算法的核心,对错误分类的样本施加指数级惩罚。

公式

$$L_{\text{Exp}} = \frac{1}{N} \sum_{i=1}^{N} \exp(-y_i f(x_i))$$

其中$y_i \in \{-1, +1\}$

AdaBoost算法框架

AdaBoost通过加权组合弱分类器构建强分类器:

$$F(x) = \sum_{t=1}^{T} \alpha_t h_t(x)$$

每轮迭代中:

  • 更新样本权重:$w_i^{(t+1)} = w_i^{(t)} \exp(-\alpha_t y_i h_t(x_i))$
  • 选择使加权误差最小的弱分类器 $h_t$
  • 计算分类器权重:$\alpha_t = \frac{1}{2} \ln (\frac{1-\epsilon_t}{\epsilon_t})$

特性分析

  • 指数增长惩罚:误分类样本的损失呈指数增长,迫使模型聚焦于难例
  • 与交叉熵关系:可视为交叉熵损失的近似(通过泰勒展开)
  • 对异常值敏感:误分类的异常点会产生极大的损失值

理论保证:

当弱分类器仅比随机猜测略好时,AdaBoost的训练误差以指数速率下降:

$$\frac{1}{N} \sum_{i=1}^{N} \mathbb{I}(y_i \neq F(x_i)) \leq \exp (-2 \sum_{t=1}^{T} \gamma_t^2)$$

其中$\gamma_t = \frac{1}{2} – \epsilon_t$

应用场景

  • AdaBoost及其变体
  • 集成学习基学习器的组合
  • 特征选择

Focal Loss

Focal Loss是为解决类别不平衡问题设计的,特别针对目标检测中前景-背景极度不平衡的场景。

公式推导

  • 标准交叉熵损失:$CE(p, y) = \begin{cases} -\log(p), & y = 1 \\ -\log(1-p), & y = 0 \end{cases}$
  • 可统一写为:$CE(p_t) = -\log(p_t)$,其中$p_t = \begin{cases} p, & y = 1 \\ 1-p, & y = 0 \end{cases}$
  • Focal Loss定义:$FL(p_t) = -\alpha_t (1 – p_t)^\gamma \log(p_t)$,其中:
    • $\gamma \geq 0$:聚焦参数(focusing parameter)
    • $\alpha_t \in [0,1]$:类别平衡权重

核心机制

调制因子$(1-p_t)^\gamma$的作用:

  • 当样本易分类($p_t \to 1$)时,$(1-p_t)^\gamma \to 0$,损失权重降低
  • 当样本难分类($p_t \to 0$)时,$(1-p_t)^\gamma \to 1$,损失权重基本不变

参数影响分析:

  • γ=0:退化为标准交叉熵损失
  • γ增大:易分类样本的损失贡献被进一步抑制
  • α参数:用于调节类别不平衡,通常对稀有类别设置较大的α

在RetinaNet中的应用

RetinaNet采用Focal Loss解决单阶段检测器的类别不平衡问题:

RetinaNet架构特点:

  • 特征金字塔网络(FPN):提取多尺度特征
  • 两个任务特定子网络:
    • 分类子网络:预测每个锚框的类别概率
    • 回归子网络:预测边界框偏移量
  • Focal Loss应用:仅用于分类子网络,回归任务仍使用Smooth L1损失

性能提升:

  • 在COCO数据集上,RetinaNet首次使单阶段检测器达到两阶段检测器的精度水平
  • 有效缓解了简单负样本主导梯度的问题

数学性质

梯度分析:

$$\frac{\partial FL}{\partial z} = \begin{cases} -\alpha (1-p)^\gamma [\gamma p \log(p) + p – 1], & y=1 \\ (1-\alpha) p^\gamma [\gamma (1-p) \log(1-p) + p], & y=0 \end{cases}$$

其中$z$是logit值,$p = \sigma(z)$

与类别平衡交叉熵的关系:

Focal Loss可视为对类别平衡交叉熵的动态加权:

$$FL(p_t) = w_t \cdot CE(p_t)$$

其中权重$w_t = \alpha_t (1-p_t)^\gamma $随预测置信度动态调整

应用场景

  • 目标检测(特别是单阶段检测器如RetinaNet、YOLO变体)
  • 医学图像分析(病灶检测中的类别不平衡)
  • 任何存在极端类别不平衡的分类任务

损失函数对比总结

特性 交叉熵损失 合页损失 指数损失 Focal Loss
公式 $-\sum y \log(\hat{y})$ $\max(0, 1-yf(x))$ $\exp(-yf(x))$ $-\alpha (1-p)^\gamma \log(p)$
输出类型 概率分布 决策函数值 决策函数值 概率分布
优化目标 最小化分布差异 最大化分类间隔 指数级惩罚错分 关注难例,平衡类别
对异常值 敏感 相对鲁棒(线性惩罚) 非常敏感 通过调制因子缓解
梯度特性 平滑,易于优化 分段常数,使用次梯度 指数增长,不稳定 动态调整,关注难例
主要应用 深度学习分类 SVM AdaBoost 类别不平衡任务

选择指南

  • 标准分类任务 → 交叉熵损失
    • 深度学习中的默认选择
    • 与Softmax/Sigmoid自然结合
  • 需要最大间隔解释 → 合页损失
    • 传统SVM分类器
    • 线性可分或近似线性可分数据
  • 集成学习 → 指数损失
    • AdaBoost算法
    • 需要逐步聚焦难例的场景
  • 类别极度不平衡 → Focal Loss
    • 目标检测(前景-背景不平衡)
    • 医学图像分析(病灶检测)
    • 欺诈检测(正样本极少)
  • 实际建议:
    • 深度学习优先使用交叉熵损失
    • 遇到类别不平衡时,先尝试加权交叉熵,效果不佳再考虑Focal Loss
    • 传统机器学习中根据模型特性选择(SVM用合页损失,AdaBoost用指数损失)

这些损失函数代表了分类任务中不同的设计哲学和优化目标,理解它们的数学特性和适用场景对于构建有效的分类模型至关重要。

其他任务损失函数

对比损失(Contrastive Loss)

对比损失用于度量学习,目标是学习一个嵌入空间,使得相似样本对距离小,不相似样本对距离大。

公式与定义

核心公式:$$L = \frac{1}{2N} \sum_{i=1}^N [ y_i d_i^2 + (1-y_i) \max(0, m – d_i)^2]$$

其中:

  • $d_i = \| f(x_i^{(1)}) – f(x_i^{(2)}) \|_2$:样本对在嵌入空间中的欧氏距离
  • $y_i \in \{0,1\}$:样本对标签(1表示相似,0表示不相似)
  • $m > 0$:边界参数(margin),控制不相似样本对的最小距离
  • $N$:样本对数量
  • $f(\cdot)$:嵌入函数(编码器)

几何解释与优化目标

相似样本对($y=1$):

  • 损失项:$d^2$)
  • 目标:最小化距离d
  • 梯度:$\frac{\partial L}{\partial d} = d$

不相似样本对(\(y=0\)):

  • 损失项:$\max(0, m – d)^2$
  • 条件1:当$d < m$时,损失为$(m-d)^2$,梯度为$2(d-m)$
  • 条件2:当$d \geq m$时,损失为0
  • 目标:将距离推至至少为m

训练策略与技巧

样本对构建:

  • 正样本对:同一类别的样本、数据增强的同一图像
  • 负样本对:不同类别的样本

边界参数m的选择:

  • 太小:约束不足,嵌入空间区分度不够
  • 太大:优化困难,可能导致训练不稳定
  • 经验值:通常在5-2.0之间,需通过实验调整

梯度计算:

$$\frac{\partial L}{\partial f(x^{(1)})} = \begin{cases} \frac{f(x^{(1)}) – f(x^{(2)})}{\|f(x^{(1)}) – f(x^{(2)})\|} \cdot d, & y=1 \\ \frac{f(x^{(2)}) – f(x^{(1)})}{\|f(x^{(1)}) – f(x^{(2)})\|} \cdot \max(0, m-d), & y=0 \text{ 且 } d < m \\ 0, & y=0 \text{ 且 } d \geq m \end{cases}$$

应用场景

  • 人脸验证(如Siamese Networks)
  • 签名验证
  • 图像检索
  • 少样本学习

变体与扩展

  • 增强对比损失:加入温度参数控制相似度分布的尖锐程度
  • 多重对比损失:同时考虑多个负样本

三元组损失(Triplet Loss)

三元组损失是度量学习的经典方法,通过同时考虑锚点、正样本、负样本的三元组关系学习嵌入表示。

公式与定义

核心公式:

$$L = \sum_{i=1}^N \max(0, \|f(a_i) – f(p_i)\|_2^2 – \|f(a_i) – f(n_i)\|_2^2 + m)$$

其中:

  • $a_i$:锚点样本(Anchor)
  • $p_i$:正样本(Positive,与锚点同类)
  • $n_i$:负样本(Negative,与锚点不同类)
  • $m > 0$:边界参数
  • $f(\cdot)$:嵌入函数

优化目标与几何解释

不等式约束:$\|f(a) – f(p)\|_2^2 + m < \|f(a) – f(n)\|_2^2$

损失函数的三部分:

  • Anchor-Positive距离:$d_{ap} = \|f(a) – f(p)\|_2$
  • Anchor-Negative距离:$d_{an} = \|f(a) – f(n)\|_2$
  • 相对距离差:$d_{ap}^2 – d_{an}^2 + m$

三元组类型:

  • 简单三元组:$d_{ap} + m < d_{an}$,损失为0
  • 困难三元组:$d_{an} < d_{ap} + m$,产生正损失
  • 半困难三元组:$d_{ap} < d_{an} < d_{ap} + m$

采样策略(关键因素)

随机采样:

  • 随机选择锚点、正样本、负样本
  • 问题:大量三元组满足约束,梯度很小,收敛慢

困难三元组挖掘(Hard Triplet Mining):

  • 离线挖掘:每轮训练后,计算所有三元组,选择损失最大的
  • 在线挖掘:在批次内动态选择困难三元组
  • 半困难挖掘:选择$d_{ap} < d_{an} < d_{ap} + m$的三元组

批次内采样策略:

  • Batch Hard:选择批次内最困难的正负样本
  • Batch All:考虑批次内所有有效三元组
  • Batch Semi-Hard:选择满足$d_{ap} < d_{an} < d_{ap} + m$的三元组

梯度计算

$$\frac{\partial L}{\partial f(a)} = 2[(f(n) – f(p)) \cdot \mathbb{I}(d_{ap}^2 – d_{an}^2 + m > 0)]$$

$$\frac{\partial L}{\partial f(p)} = 2(f(p) – f(a)) \cdot \mathbb{I}(d_{ap}^2 – d_{an}^2 + m > 0)$$

$$\frac{\partial L}{\partial f(n)} = 2(f(a) – f(n)) \cdot \mathbb{I}(d_{ap}^2 – d_{an}^2 + m > 0)$$

其中$\mathbb{I}(\cdot)$是指示函数。

参数设置与优化技巧

边界参数 \(m\) 的选择:

  • 常用值:2-1.0
  • 太大:可能导致训练不稳定
  • 太小:嵌入空间区分度不足

嵌入空间归一化:

  • 通常对嵌入向量进行L2归一化:$f(x) \leftarrow \frac{f(x)}{\|f(x)\|_2}$
  • 好处:距离限制在固定范围,简化超参数选择

应用场景:

  • 人脸识别(如FaceNet)
  • 行人重识别
  • 商品图像检索
  • 语音识别中的说话人验证

对比:三元组损失 vs 对比损失

特性 三元组损失 对比损失
样本结构 三元组(锚点、正、负) 样本对(正对、负对)
约束类型 相对距离约束 绝对距离约束
梯度信息 同时比较正负样本 分别处理正负对
计算复杂度 较高(需构造三元组) 较低

Dice Loss

Dice Loss源自Dice系数,是一种用于评估分割结果重叠度的指标,特别适用于医学图像分割等类别不平衡场景。

Dice系数定义

二分类Dice系数:

$$\text{Dice} = \frac{2|X \cap Y|}{|X| + |Y|} = \frac{2TP}{2TP + FP + FN}$$

其中:

  • X:预测分割区域
  • Y:真实分割区域
  • TP:真阳性,FP:假阳性,FN:假阴性

连续概率形式:

设预测概率图P,真实二值标签G:

$$\text{Dice} = \frac{2\sum_{i=1}^N p_i g_i}{\sum_{i=1}^N p_i + \sum_{i=1}^N g_i}$$

其中$p_i \in [0,1]$是第i个像素的预测概率,$g_i \in \{0,1\}$是真实标签。

Dice Loss公式

基本形式:

$$L_{\text{Dice}} = 1 – \frac{2\sum_{i=1}^N p_i g_i + \epsilon}{\sum_{i=1}^N p_i + \sum_{i=1}^N g_i + \epsilon}$$

其中$\epsilon$是平滑项(通常取$10^{-6}$),防止分母为零。

多分类扩展:

  • 宏平均Dice Loss:$L_{\text{macro}} = 1 – \frac{1}{C} \sum_{c=1}^C \frac{2\sum_{i=1}^N p_{i,c} g_{i,c} + \epsilon}{\sum_{i=1}^N p_{i,c} + \sum_{i=1}^N g_{i,c} + \epsilon}$
  • 加权Dice Loss:$L_{\text{weighted}} = 1 – \sum_{c=1}^C w_c \frac{2\sum_{i=1}^N p_{i,c} g_{i,c} + \epsilon}{\sum_{i=1}^N p_{i,c} + \sum_{i=1}^N g_{i,c} + \epsilon}$

其中$w_c$是类别权重,可基于类别频率设置。

数学特性与梯度分析

梯度计算:

$$\frac{\partial L_{\text{Dice}}}{\partial p_j} = -\frac{2g_j(S_p + S_g) – 2S_{pg}}{(S_p + S_g)^2}$$

其中:

  • $S_p = \sum_i p_i + \epsilon$
  • $S_g = \sum_i g_i + \epsilon$
  • $S_{pg} = \sum_i p_i g_i + \epsilon$

梯度特性:

  • 分母中的平方项使梯度在预测接近0或1时较小
  • 对假阴性(FN)和假阳性(FP)同等惩罚
  • 梯度幅度与当前Dice系数相关

与交叉熵损失的比较

特性 Dice Loss 交叉熵损失
优化目标 最大化重叠区域 最小化概率分布差异
类别不平衡 鲁棒性好 需要加权或Focal Loss
梯度特性 基于区域统计 逐像素梯度
对FP/FN 同等惩罚 依赖预测概率
训练稳定性 可能不稳定 通常更稳定

改进变体

  • Soft Dice Loss:使用平滑的预测概率,避免硬阈值:$L_{\text{SoftDice}} = 1 – \frac{2\sum_i p_i g_i + \epsilon}{\sum_i (p_i^2 + g_i^2) + \epsilon}$
  • Generalized Dice Loss:为每个类别分配权重,处理类别不平衡:$L_{\text{GDL}} = 1 – 2\frac{\sum_{c=1}^C w_c \sum_i p_{i,c} g_{i,c}}{\sum_{c=1}^C w_c \sum_i (p_{i,c} + g_{i,c})}$,其中$w_c = 1/(\sum_i g_{i,c})^2$
  • Dice + BCE组合损失:结合Dice Loss和二元交叉熵,兼顾全局和局部信息:$L = \alpha L_{\text{Dice}} + (1-\alpha) L_{\text{BCE}}$,通常$\alpha = 0.5$。

应用场景

  • 医学图像分割(肿瘤、器官分割)
  • 遥感图像分割
  • 任何存在极端类别不平衡的分割任务

KL散度(Kullback-Leibler Divergence)

KL散度用于衡量两个概率分布的差异,在生成模型、变分推断、模型蒸馏中有广泛应用。

定义与公式

离散分布

设P和Q是离散概率分布:$D_{KL}(P \| Q) = \sum_{x \in \mathcal{X}} P(x) \log \frac{P(x)}{Q(x)}$

连续分布

设p(x)和q(x)是概率密度函数:$D_{KL}(P \| Q) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} dx$

信息论解释

与交叉熵、熵的关系:$D_{KL}(P \| Q) = H(P, Q) – H(P)$

其中:

  • $H(P) = -\sum P(x) \log P(x)$:分布 P的熵
  • $H(P, Q) = -\sum P(x) \log Q(x)$:P和 Q的交叉熵

性质:

  • 非负性:$D_{KL}(P \| Q) \geq 0$,等号成立当且仅当P = Q几乎处处成立
  • 不对称性:$D_{KL}(P \| Q) \neq D_{KL}(Q \| P)$,不是真正的距离度量
  • 不满足三角不等式

在变分自编码器(VAE)中的应用

VAE的损失函数包含重构损失和KL散度正则项:

VAE损失函数:$L_{\text{VAE}} = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] – D_{KL}(q_\phi(z|x) \| p(z))$

各项含义:

  • 重构项:$\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]$,鼓励解码器重构输入
  • KL散度项:$D_{KL}(q_\phi(z|x) \| p(z))$,约束后验分布$q_\phi(z|x)$接近先验p(z)(通常为标准正态分布)

具体计算(高斯假设下):

设:

  • 先验:$p(z) = \mathcal{N}(0, I)$
  • 后验:$q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \sigma_\phi^2(x) I)$

则KL散度有闭合形式:

$$D_{KL} = \frac{1}{2} \sum_{j=1}^J ( \sigma_j^2 + \mu_j^2 – 1 – \log \sigma_j^2)$$

其中J是隐变量维度。

在模型蒸馏中的应用

知识蒸馏框架:

  • 教师模型(复杂、高精度):分布$P^T$
  • 学生模型(简单、高效):分布$P^S$
  • 蒸馏损失:$L_{\text{distill}} = D_{KL}(P^T \| P^S)$

温度缩放:

使用带温度参数的Softmax:

$$p_i^T = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}$$

其中T > 0是温度参数,提高T会软化概率分布,传递更多暗知识。

蒸馏损失函数:

$$L = \alpha \cdot T^2 \cdot D_{KL}(P^T \| P^S) + (1-\alpha) \cdot L_{\text{task}}$$

其中$\alpha$是蒸馏权重。

与其他散度的比较

散度度量 公式 对称性 三角不等式 主要应用
KL散度 $D_{KL}(P\|Q)$ 不对称 不满足 VAE、模型蒸馏
JS散度 $\frac{1}{2}D_{KL}(P\|M)+\frac{1}{2}D_{KL}(Q\|M)$ 对称 不满足 GAN的早期变体
Wasserstein距离 $\inf_{\gamma \in \Pi(P,Q)} \mathbb{E}_{(x,y)\sim\gamma}[\|x-y\|]$ 对称 满足 WGAN
总变分距离 $\sup_A |P(A)-Q(A)|$ P(A)-Q(A) ) 对称

梯度分析

KL散度的梯度:

$$\frac{\partial D_{KL}(P \| Q)}{\partial \theta} = -\mathbb{E}_{x \sim P} [ \frac{\partial \log q_\theta(x)}{\partial \theta}]$$

其中$q_\theta$是以$\theta$为参数的分布。

在VAE中的梯度:

  • 重构项梯度:通过重参数化技巧计算
  • KL散度项梯度:通常有解析解,可直接求导

应用场景

  • 变分推断(VAE、贝叶斯神经网络)
  • 模型蒸馏(知识迁移)
  • 强化学习(策略优化中的信任区域方法)
  • 信息论(率失真理论、最小描述长度)
  • 自然语言处理(语言模型评估)

总结对比

损失函数 主要用途 关键公式 核心特点
对比损失 度量学习 $L = \frac{1}{2N}\sum [y d^2 + (1-y)\max(0,m-d)^2]$ 拉近正对,推开负对
三元组损失 度量学习 $L = \sum \max(0, d_{ap}^2 – d_{an}^2 + m)$ 相对距离比较,困难样本挖掘
Dice Loss 图像分割 $L = 1 – \frac{2\sum p_i g_i}{\sum p_i + \sum g_i}$ 直接优化重叠区域,处理类别不平衡
KL散度 分布匹配 $D_{KL}(P\|Q) = \sum P\log\frac{P}{Q}$ 衡量分布差异,非对称,VAE核心

选择指南

  • 度量学习任务:
    • 简单相似性学习 → 对比损失
    • 高质量嵌入学习 → 三元组损失(配合困难样本挖掘)
  • 图像分割任务:
    • 类别平衡 → 交叉熵损失
    • 类别不平衡 → Dice Loss 或 Dice+BCE组合
    • 小目标分割 → Focal Loss 或 Generalized Dice Loss
  • 生成模型:
    • 变分自编码器 → KL散度(作为正则项)
    • 知识蒸馏 → KL散度(带温度参数)
  • 实际建议:
    • 从任务标准损失开始,针对问题特性选择改进
    • 多任务可结合不同损失(如分割中的Dice+BCE)
    • 注意数值稳定性,添加平滑项防止除零

这些损失函数代表了特定任务中的核心优化目标,理解它们的数学特性和适用场景对于解决复杂机器学习问题至关重要。

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注