交叉熵的一些理解、疑问
概念
交叉熵损失函数是一种常用于分类问题的损失函数,其公式如下:
CrossEntropy(yi,yi^)=−k=1∑Kyi,klog(yi,k^)
其中,yi代表样本i的真实标签,yi^代表样本i属于每个类别的概率分布,K是类别数。交叉熵函数可以通过比较真实标签和模型预测标签的差异来衡量模型的错误程度。当模型的输出与真实标签完全匹配时,交叉熵为0;当模型的输出与真实标签不匹配时,交叉熵会增加。在机器学习中,我们通常通过最小化交叉熵损失函数来训练分类模型。
原理
交叉熵损失函数的含义可以从信息论的角度来解释。在信息论中,熵是衡量信息量的一个指标,表示随机变量不确定性的度量。对于一个离散型随机变量X,其熵的计算公式为:
H(X)=−i=1∑np(xi)logp(xi)
其中,p(xi)表示X取值xi的概率。
当我们将交叉熵应用于分类问题时,我们可以将真实标签看作一个分布p,将模型预测标签看作一个分布q。假设有K个类,那么真实标签和模型预测标签都是K维的概率向量,每一维表示该样本属于该类别的概率。交叉熵可以用来衡量真实标签p和模型预测标签q之间的距离,其具体计算公式为:
CrossEntropy(p,q)=−i=1∑Kpilogqi
交叉熵越小,表示模型的输出和真实标签越接近,即模型的预测越准确。在训练神经网络时,我们希望最小化交叉熵损失函数来调整网络参数,使得模型能够更好地预测样本的类别。
pytorch中实现
torch.nn.CrossEntropyLoss相当于softmax + log + nllloss
使用nll loss时,可以这样操作
1 2 3 4 5 6 7
| nllloss = nn.NLLLoss() predict = torch.Tensor([[2, 3, 1], [3, 7, 9]]) predict = torch.log(torch.softmax(predict, dim=-1)) label = torch.tensor([1, 2]) nllloss(predict, label) # output: tensor(0.2684)
|
而使用torch.nn.CrossEntropyLoss可以省去softmax + log
1 2 3 4 5 6 7
| cross_loss = nn.CrossEntropyLoss()
predict = torch.Tensor([[2, 3, 1], [3, 7, 9]]) label = torch.tensor([1, 2]) cross_loss(predict, label) # output: tensor(0.2684)
|
疑问
1. 为什么样本分布和真实分布越接近,交叉熵就会越小,且二者完全一致时,交叉熵最小,且等于真实分布的熵
使用分解KL散度证明: 连续概率分布
在机器学习中,交叉熵(Cross Entropy)是一种衡量两个概率分布之间差异的指标。当样本分布和真实分布越接近时,它们的交叉熵就越小。
首先,我们定义样本分布为p(x),真实分布为q(x),其中x表示概率分布的变量。交叉熵可以用以下公式表示:
H(p, q) = -Σp(x) * log(q(x))
现在我们来推导为什么样本分布和真实分布越接近,交叉熵越小。
根据交叉熵的定义,我们可以将其展开为:
H(p,q)=−x∑p(x)log(q(x))=−x∑p(x)[log(q(x))−log(p(x))+log(p(x))]=−x∑p(x)log(q(x))+x∑p(x)log(p(x))−x∑p(x)log(p(x))=−x∑p(x)[ log(q(x))−log(p(x)) ]+H(p)=x∑p(x)[ log(p(x))−log(q(x)) ]+H(p)=DKL(p∣∣q)+H(p)
其中,H(p)表示真实分布p(x)的熵,DKL(p∣∣q)表示真实分布p(x)相对于样本分布q(x)的KL散度。
熵 (信息论) - 维基百科,自由的百科全书 (wikipedia.org)
(KL散度)相对熵 - 维基百科,自由的百科全书 (wikipedia.org)
KL散度的公式如下:
DKL(p∣∣q)=x∑p(x)log(q(x)p(x))
其中,p(x)和q(x)分别表示真实分布和样本分布在样本点x处的概率。DKL(p∣∣q)表示真实分布p(x)相对于样本分布q(x)的KL散度,用来衡量两个分布之间的差异。
KL散度是一个非对称的度量,即DKL(p∣∣q)=DKL(q∣∣p)。这意味着交换真实分布和样本分布的位置会得到不同的结果。因此,DKL(p∣∣q)表示当使用样本分布q(x)来近似真实分布p(x)时,产生的信息损失量。
需要注意的是,KL散度始终大于等于0,当且仅当p(x)=q(x)时,KL散度才为0。当样本分布和真实分布越接近时,KL散度逐渐减小,达到最小值0表示两个分布完全一致。
可以看出,当p(x)>0且q(x)=0时,log(q(x)p(x))会趋近于负无穷,从而DKL(p∣∣q)也会趋近于正无穷。这意味着,当样本分布q(x)在真实分布p(x)的非零概率处取值为0时,真实分布p(x)相对于样本分布q(x)的KL散度就会变成无穷大。
此外,KL散度还有一个重要性质,即KL散度具有非负性,即DKL(p∣∣q)≥0。
这是由于log(x)是一个凸函数,因此根据Jensen不等式,
x∑p(x)log(q(x)p(x))⩾log(x∑p(x)q(x)p(x))=log(1)=0
由于KL散度始终大于等于0,当真实分布p(x)和样本分布q(x)完全一致时,KL散度为0,因此交叉熵达到最小值。
而当样本分布q(x)与真实分布p(x)不同步时,KL散度大于0,导致交叉熵也大于0。当样本分布和真实分布越接近时,KL散度逐渐减小,从而交叉熵也逐渐减小。
举个例子来说明。假设我们有一个二分类问题,样本分布为[0.7,0.3],真实分布为[0.8,0.2],表示两个类别的概率分布。计算它们之间的交叉熵如下:
H(p,q)=−[0.7log(0.8)+0.3log(0.2)]≈0.3567
可以看到,交叉熵是一个正值,表示样本分布和真实分布之间的差异程度。如果样本分布和真实分布完全一致,交叉熵将为0,表示两个分布完全相同。
因此,当样本分布和真实分布越接近时,交叉熵越小,当两者完全一致时,交叉熵最小且等于真实分布的熵。
Reference
- AIchatOS (yqcloud.top)
- 琴生不等式 - 维基百科,自由的百科全书 (wikipedia.org)
- 熵 (信息论) - 维基百科,自由的百科全书 (wikipedia.org)
- (KL散度)相对熵 - 维基百科,自由的百科全书 (wikipedia.org)
- 详解torch.nn.NLLLOSS - 知乎 (zhihu.com)