PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    
    def __init__(self, weight=None, 
                 gamma=2., reduction='none'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            target_tensor, 
            weight=self.weight,
            reduction = self.reduction
        )
import torch
cross_entropy_loss = torch.nn.CrossEntropyLoss()

# Input: f_q (BxCxS) and sampled features from H(G_enc(x))
# Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))
# Input: tau is the temperature used in PatchNCE loss.
# Output: PatchNCE loss
def PatchNCELoss(f_q, f_k, tau=0.07):
    # batch size, channel size, and number of sample locations
    B, C, S = f_q.shape

    # calculate v * v+: BxSx1
    l_pos = (f_k * f_q).sum(dim=1)[:, :, None]

    # calculate v * v-: BxSxS
    l_neg = torch.bmm(f_q.transpose(1, 2), f_k)

    # The diagonal entries are not negatives. Remove them.
    identity_matrix = torch.eye(S)[None, :, :]
    l_neg.masked_fill_(identity_matrix, -float('inf'))

    # calculate logits: (B)x(S)x(S+1)
    logits = torch.cat((l_pos, l_neg), dim=2) / tau

    # return PatchNCE loss
    predictions = logits.flatten(0, 1)
    targets = torch.zeros(B * S, dtype=torch.long)
    return cross_entropy_loss(predictions, targets)

Similiar Collections

Python strftime reference pandas.Period.strftime python - Formatting Quarter time in pandas columns - Stack Overflow python - Pandas: Change day - Stack Overflow python - Check if multiple columns exist in a df - Stack Overflow Pandas DataFrame apply() - sending arguments examples python - How to filter a dataframe of dates by a particular month/day? - Stack Overflow python - replace a value in the entire pandas data frame - Stack Overflow python - Replacing blank values (white space) with NaN in pandas - Stack Overflow python - get list from pandas dataframe column - Stack Overflow python - How to drop rows of Pandas DataFrame whose value in a certain column is NaN - Stack Overflow python - How to drop rows of Pandas DataFrame whose value in a certain column is NaN - Stack Overflow python - How to lowercase a pandas dataframe string column if it has missing values? - Stack Overflow How to Convert Integers to Strings in Pandas DataFrame - Data to Fish How to Convert Integers to Strings in Pandas DataFrame - Data to Fish create a dictionary of two pandas Dataframe columns? - Stack Overflow python - ValueError: No axis named node2 for object type <class 'pandas.core.frame.DataFrame'> - Stack Overflow Python Pandas iterate over rows and access column names - Stack Overflow python - Creating dataframe from a dictionary where entries have different lengths - Stack Overflow python - Deleting DataFrame row in Pandas based on column value - Stack Overflow python - How to check if a column exists in Pandas - Stack Overflow python - Import pandas dataframe column as string not int - Stack Overflow python - What is the most efficient way to create a dictionary of two pandas Dataframe columns? - Stack Overflow Python Loop through Excel sheets, place into one df - Stack Overflow python - How do I get the row count of a Pandas DataFrame? - Stack Overflow python - How to save a new sheet in an existing excel file, using Pandas? - Stack Overflow Python Loop through Excel sheets, place into one df - Stack Overflow How do I select a subset of a DataFrame? — pandas 1.2.4 documentation python - Delete column from pandas DataFrame - Stack Overflow python - Convert list of dictionaries to a pandas DataFrame - Stack Overflow How to Add or Insert Row to Pandas DataFrame? - Python Examples python - Check if a value exists in pandas dataframe index - Stack Overflow python - Set value for particular cell in pandas DataFrame using index - Stack Overflow python - Pandas Dataframe How to cut off float decimal points without rounding? - Stack Overflow python - Pandas: Change day - Stack Overflow python - Clean way to convert quarterly periods to datetime in pandas - Stack Overflow Pandas - Number of Months Between Two Dates - Stack Overflow python - MonthEnd object result in <11 * MonthEnds> instead of number - Stack Overflow python - Extracting the first day of month of a datetime type column in pandas - Stack Overflow