import torch cross_entropy_loss = torch.nn.CrossEntropyLoss() # Input: f_q (BxCxS) and sampled features from H(G_enc(x)) # Input: f_k (BxCxS) are sampled features from H(G_enc(G(x)) # Input: tau is the temperature used in PatchNCE loss. # Output: PatchNCE loss def PatchNCELoss(f_q, f_k, tau=0.07): # batch size, channel size, and number of sample locations B, C, S = f_q.shape # calculate v * v+: BxSx1 l_pos = (f_k * f_q).sum(dim=1)[:, :, None] # calculate v * v-: BxSxS l_neg = torch.bmm(f_q.transpose(1, 2), f_k) # The diagonal entries are not negatives. Remove them. identity_matrix = torch.eye(S)[None, :, :] l_neg.masked_fill_(identity_matrix, -float('inf')) # calculate logits: (B)x(S)x(S+1) logits = torch.cat((l_pos, l_neg), dim=2) / tau # return PatchNCE loss predictions = logits.flatten(0, 1) targets = torch.zeros(B * S, dtype=torch.long) return cross_entropy_loss(predictions, targets)
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter