class LabelSmoothingLoss(torch.nn.Module): def __init__(self, smoothing: float = 0.1, reduction="mean", weight=None): super(LabelSmoothingLoss, self).__init__() self.smoothing = smoothing self.reduction = reduction self.weight = weight def reduce_loss(self, loss): return loss.mean() if self.reduction == 'mean' else loss.sum() \ if self.reduction == 'sum' else loss def linear_combination(self, x, y): return self.smoothing * x + (1 - self.smoothing) * y def forward(self, preds, target): assert 0 <= self.smoothing < 1 if self.weight is not None: self.weight = self.weight.to(preds.device) n = preds.size(-1) log_preds = F.log_softmax(preds, dim=-1) loss = self.reduce_loss(-log_preds.sum(dim=-1)) nll = F.nll_loss( log_preds, target, reduction=self.reduction, weight=self.weight ) return self.linear_combination(loss / n, nll)
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter