import torch import torch.nn as nn from typing import List from collections import OrderedDict from . import _utils as utils class EncoderMixin: """Add encoder functionality such as: - output channels specification of feature tensors (produced by encoder) - patching first convolution for arbitrary input channels """ @property def out_channels(self): """Return channels dimensions for each tensor of forward output of encoder""" return self._out_channels[: self._depth + 1] def set_in_channels(self, in_channels, pretrained=True): """Change first convolution channels""" if in_channels == 3: return self._in_channels = in_channels if self._out_channels[0] == 3: self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained) def get_stages(self): """Method should be overridden in encoder""" raise NotImplementedError def make_dilated(self, output_stride): if output_stride == 16: stage_list=[5,] dilation_list=[2,] elif output_stride == 8: stage_list=[4, 5] dilation_list=[2, 4] else: raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride)) stages = self.get_stages() for stage_indx, dilation_rate in zip(stage_list, dilation_list): utils.replace_strides_with_dilation( module=stages[stage_indx], dilation_rate=dilation_rate, )