2023-07-26 20:53:08 +08:00

70 lines
1.9 KiB
Python

from typing import Optional
import torch.nn as nn
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
from ..encoders import get_encoder
from .decoder import DVCADecoder
class DVCA(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_channels: int = 256,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
upsampling: int = 8,
aux_params: Optional[dict] = None,
siam_encoder: bool = True,
fusion_form: str = "concat",
**kwargs
):
super().__init__()
self.siam_encoder = siam_encoder
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
)
if not self.siam_encoder:
self.encoder_non_siam = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
output_stride=8,
)
self.decoder = DVCADecoder(
in_channels=self.encoder.out_channels[-1],
out_channels=decoder_channels,
fusion_form=fusion_form,
)
self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None