小能豆

NLLLoss 只是一个普通的负函数?

py

我不太理解nn.NLLLoss()

由于下面的代码总是打印,那么和使用负号(-)True之间有什么区别?nn.NLLLoss()

import torch
while 1:
   b = torch.randn(1)
   print(torch.nn.NLLLoss()(b, torch.tensor([0])) == -b[0])

阅读 18

收藏
2024-12-30

共1个答案

小能豆

在您的情况下,每个批次元素只有一个输出值,目标是0nn.NLLLoss损失将选择与目标张量中包含的索引相对应的预测张量的值。这是一个更一般的例子,其中您总共有五个批次元素,每个元素有三个 logit 值:

>>> logits = torch.randn(5, 3, requires_grad=True)
>>> y = torch.tensor([1, 0, 2, 0, 1])
>>> y_hat = torch.softmax(b, -1)

张量yy_hat分别对应于目标张量和估计分布。你可以nn.NLLLoss使用以下命令实现:

>>> -y_hat[torch.arange(len(y_hat)), y]
tensor([-0.2195, -0.1015, -0.3699, -0.5203, -0.1171], grad_fn=<NegBackward>)

与内置函数相比:

>>> F.nll_loss(y_hat, y, reduction='none')
tensor([-0.2195, -0.1015, -0.3699, -0.5203, -0.1171], grad_fn=<NllLossBackward>)

这与独自一人有很大不同-y_hat

2024-12-30