2023-02-10 21:51:03 +08:00

30 lines
772 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch.nn as nn
import torch
class BCL(nn.Module):
"""
batch-balanced contrastive loss
no-change1
change-1
"""
def __init__(self, margin=2.0):
super(BCL, self).__init__()
self.margin = margin
def forward(self, distance, label):
label[label==255] = 1
mask = (label != 255).float()
distance = distance * mask
pos_num = torch.sum((label==1).float())+0.0001
neg_num = torch.sum((label==-1).float())+0.0001
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
loss_2 = torch.sum((1-label) / 2 * mask *
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
) / neg_num
loss = loss_1 + loss_2
return loss