70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
from typing import Optional
|
|
|
|
import torch.nn as nn
|
|
|
|
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
|
from ..encoders import get_encoder
|
|
from .decoder import DVCADecoder
|
|
|
|
|
|
class DVCA(SegmentationModel):
|
|
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_name: str = "resnet34",
|
|
encoder_depth: int = 5,
|
|
encoder_weights: Optional[str] = "imagenet",
|
|
decoder_channels: int = 256,
|
|
in_channels: int = 3,
|
|
classes: int = 1,
|
|
activation: Optional[str] = None,
|
|
upsampling: int = 8,
|
|
aux_params: Optional[dict] = None,
|
|
siam_encoder: bool = True,
|
|
fusion_form: str = "concat",
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.siam_encoder = siam_encoder
|
|
|
|
self.encoder = get_encoder(
|
|
encoder_name,
|
|
in_channels=in_channels,
|
|
depth=encoder_depth,
|
|
weights=encoder_weights,
|
|
output_stride=8,
|
|
)
|
|
|
|
if not self.siam_encoder:
|
|
self.encoder_non_siam = get_encoder(
|
|
encoder_name,
|
|
in_channels=in_channels,
|
|
depth=encoder_depth,
|
|
weights=encoder_weights,
|
|
output_stride=8,
|
|
)
|
|
|
|
self.decoder = DVCADecoder(
|
|
in_channels=self.encoder.out_channels[-1],
|
|
out_channels=decoder_channels,
|
|
fusion_form=fusion_form,
|
|
)
|
|
|
|
self.segmentation_head = SegmentationHead(
|
|
in_channels=self.decoder.out_channels,
|
|
out_channels=classes,
|
|
activation=activation,
|
|
kernel_size=1,
|
|
upsampling=upsampling,
|
|
)
|
|
|
|
if aux_params is not None:
|
|
self.classification_head = ClassificationHead(
|
|
in_channels=self.encoder.out_channels[-1], **aux_params
|
|
)
|
|
else:
|
|
self.classification_head = None
|
|
|