from .RCNN import RCNN from .DVCA import DVCA from .DPFCN import DPFCN from . import encoders # from . import utils # from . import losses # from . import datasets # from .__version__ import __version__ from typing import Optional import torch def create_model( arch: str, encoder_name: str = "resnet34", encoder_weights: Optional[str] = "imagenet", in_channels: int = 3, classes: int = 1, **kwargs, ) -> torch.nn.Module: """Models wrapper. Allows to create any model just with parametes """ archs = [DVCA, DPFCN, RCNN] archs_dict = {a.__name__.lower(): a for a in archs} try: model_class = archs_dict[arch.lower()] except KeyError: raise KeyError("Wrong architecture type `{}`. Available options are: {}".format( arch, list(archs_dict.keys()), )) return model_class( encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=in_channels, classes=classes, **kwargs, )