30 lines
772 B
Python
30 lines
772 B
Python
import torch.nn as nn
|
||
import torch
|
||
|
||
|
||
class BCL(nn.Module):
|
||
"""
|
||
batch-balanced contrastive loss
|
||
no-change,1
|
||
change,-1
|
||
"""
|
||
|
||
def __init__(self, margin=2.0):
|
||
super(BCL, self).__init__()
|
||
self.margin = margin
|
||
|
||
def forward(self, distance, label):
|
||
label[label==255] = 1
|
||
mask = (label != 255).float()
|
||
distance = distance * mask
|
||
pos_num = torch.sum((label==1).float())+0.0001
|
||
neg_num = torch.sum((label==-1).float())+0.0001
|
||
|
||
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
|
||
loss_2 = torch.sum((1-label) / 2 * mask *
|
||
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
|
||
) / neg_num
|
||
loss = loss_1 + loss_2
|
||
return loss
|
||
|