2023-07-26 20:53:08 +08:00

34 lines
1.1 KiB
Python

import torch
class Decoder(torch.nn.Module):
# TODO: support learnable fusion modules
def __init__(self):
super().__init__()
self.FUSION_DIC = {"2to1_fusion": ["sum", "diff", "abs_diff"],
"2to2_fusion": ["concat"]}
def fusion(self, x1, x2, fusion_form="concat"):
"""Specify the form of feature fusion"""
if fusion_form == "concat":
x = torch.cat([x1, x2], dim=1)
elif fusion_form == "sum":
x = x1 + x2
elif fusion_form == "diff":
x = x2 - x1
elif fusion_form == "abs_diff":
x = torch.abs(x1 - x2)
else:
raise ValueError('the fusion form "{}" is not defined'.format(fusion_form))
return x
def aggregation_layer(self, fea1, fea2, fusion_form="concat", ignore_original_img=True):
"""aggregate features from siamese or non-siamese branches"""
start_idx = 1 if ignore_original_img else 0
aggregate_fea = [self.fusion(fea1[idx], fea2[idx], fusion_form)
for idx in range(start_idx, len(fea1))]
return aggregate_fea