39 lines
1.4 KiB
Python
39 lines
1.4 KiB
Python
from .models import create_model
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
class Front:
|
|
|
|
def __init__(self, model) -> None:
|
|
self.model = model
|
|
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
self.model = model.to(self.device)
|
|
|
|
def __call__(self, inp1, inp2):
|
|
inp1 = torch.from_numpy(inp1).float().to(self.device) / 255.0
|
|
inp1 = inp1.unsqueeze(0)
|
|
inp2 = torch.from_numpy(inp2).float().to(self.device) / 255.0
|
|
inp2 = inp2.unsqueeze(0)
|
|
|
|
out = self.model(inp1, inp2)
|
|
out = out.sigmoid()
|
|
out = out.cpu().detach().numpy()[0,0]
|
|
return out
|
|
|
|
def get_model(name):
|
|
|
|
try:
|
|
model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
|
|
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
|
|
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
|
|
classes=1, # model output channels (number of classes in your datasets)
|
|
siam_encoder=True, # whether to use a siamese encoder
|
|
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.)
|
|
)
|
|
except:
|
|
return None
|
|
|
|
return Front(model) |