2023-07-26 20:53:08 +08:00

39 lines
1.4 KiB
Python

from .models import create_model
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Front:
def __init__(self, model) -> None:
self.model = model
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = model.to(self.device)
def __call__(self, inp1, inp2):
inp1 = torch.from_numpy(inp1).float().to(self.device) / 255.0
inp1 = inp1.unsqueeze(0)
inp2 = torch.from_numpy(inp2).float().to(self.device) / 255.0
inp2 = inp2.unsqueeze(0)
out = self.model(inp1, inp2)
out = out.sigmoid()
out = out.cpu().detach().numpy()[0,0]
return out
def get_model(name):
try:
model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=1, # model output channels (number of classes in your datasets)
siam_encoder=True, # whether to use a siamese encoder
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.)
)
except:
return None
return Front(model)