nn.NLLLoss() nn.CrossEntropyLoss() nn.KLDivLoss()的区别

1
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1)
<torch._C.Generator at 0x1ea98f1deb0>
1
2
input = torch.randn(3, 4)  # 输出假设为三个样本,四种类别
input
tensor([[ 0.6614,  0.2669,  0.0617,  0.6213],
        [-0.4519, -0.1661, -1.5228,  0.3817],
        [-1.0276, -0.5631, -0.8923, -0.0583]])

$$\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$

1
2
softmax = nn.Softmax(dim=1)  # dim=1按着行向量相加为1
softmax(input)
tensor([[0.5820, 0.1406, 0.1137, 0.1637],
        [0.6070, 0.2923, 0.0541, 0.0466],
        [0.0815, 0.3453, 0.5165, 0.0567]])

$$\log(\text{Softmax}(x))$$
$$\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)$$

1
torch.log(softmax(input))
tensor([[-0.5413, -1.9622, -2.1739, -1.8095],
        [-0.4993, -1.2300, -2.9168, -3.0653],
        [-2.5078, -1.0633, -0.6606, -2.8699]])
1
F.log_softmax(input, dim=1)  # function中的直接计算
tensor([[-0.5413, -1.9622, -2.1739, -1.8095],
        [-0.4993, -1.2300, -2.9168, -3.0653],
        [-2.5078, -1.0633, -0.6606, -2.8699]])
1
target = torch.tensor([0, 2, 3])

$$\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)
= -x[class] + \log\left(\sum_j \exp(x[j])\right)$$

1
(0.8356 + 2.0189 + 2.9673) / 3  # 三个样本的目标分别为0, 2, 3,所以三个样本计算的交叉熵的损失为
1.9405999999999999
1
loss = nn.NLLLoss()  # NLLLoss损失的设置
1
loss(torch.log(softmax(input)), target)  # 计算的结果与上面的结果相同
tensor(2.1093)
1
loss = nn.CrossEntropyLoss()
1
loss(input, target)
tensor(2.1093)

$$\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)$$
$$l(x,y) = L := { l_1,\dots,l_N }, \quad
l_n = y_n \cdot \left( \log y_n - x_n \right)$$

1
loss = nn.KLDivLoss(reduction='batchmean')
1
2
3
target = torch.tensor([[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]], dtype=torch.float)
1
loss(F.log_softmax(input, dim=1), target)  # KL的结果是在目标为one_hot的时候计算的,结果同上面两个计算相同
tensor(2.1093)

总上nn.CrossEntropyLoss()就是把Softmax-Log-NLLLoss合并为了一步计算。NLLLoss()就是在log似然的基础上直接计算熵,其target是类别的索引数字。KLDivLoss()的计算为Softmax->Log->目标类别由索引转为one-hot->KLDivLoss,计算结果同CrossEntropyLoss()相同。

[1] https://blog.csdn.net/qq_22210253/article/details/85229988/