34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import torch
|
|
|
|
|
|
class Decoder(torch.nn.Module):
|
|
# TODO: support learnable fusion modules
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"],
|
|
"2to2_fusion": ["concat"]}
|
|
|
|
def fusion(self, x1, x2, fusion_form="concat"):
|
|
"""Specify the form of feature fusion"""
|
|
if fusion_form == "concat":
|
|
x = torch.cat([x1, x2], dim=1)
|
|
elif fusion_form == "sum":
|
|
x = x1 + x2
|
|
elif fusion_form == "diff":
|
|
x = x2 - x1
|
|
elif fusion_form == "abs_diff":
|
|
x = torch.abs(x1 - x2)
|
|
else:
|
|
raise ValueError('the fusion form "{}" is not defined'.format(fusion_form))
|
|
|
|
return x
|
|
|
|
def aggregation_layer(self, fea1, fea2, fusion_form="concat", ignore_original_img=True):
|
|
"""aggregate features from siamese or non-siamese branches"""
|
|
|
|
start_idx = 1 if ignore_original_img else 0
|
|
aggregate_fea = [self.fusion(fea1[idx], fea2[idx], fusion_form)
|
|
for idx in range(start_idx, len(fea1))]
|
|
|
|
return aggregate_fea
|