from .models import create_model import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class Front: def __init__(self, model) -> None: self.model = model self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.model = model.to(self.device) def __call__(self, inp1, inp2): inp1 = torch.from_numpy(inp1).float().to(self.device) / 255.0 inp1 = inp1.unsqueeze(0) inp2 = torch.from_numpy(inp2).float().to(self.device) / 255.0 inp2 = inp2.unsqueeze(0) out = self.model(inp1, inp2) out = out.sigmoid() out = out.cpu().detach().numpy()[0,0] return out def get_model(name): try: model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=1, # model output channels (number of classes in your datasets) siam_encoder=True, # whether to use a siamese encoder fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.) ) except: return None return Front(model)