import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, weight=None, gamma=2., reduction='none'): nn.Module.__init__(self) self.weight = weight self.gamma = gamma self.reduction = reduction def forward(self, input_tensor, target_tensor): log_prob = F.log_softmax(input_tensor, dim=-1) prob = torch.exp(log_prob) return F.nll_loss( ((1 - prob) ** self.gamma) * log_prob, target_tensor, weight=self.weight, reduction = self.reduction )
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter