74 lines
2.2 KiB
Python
74 lines
2.2 KiB
Python
from typing import Optional, Union
|
|
|
|
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
|
|
from ..encoders import get_encoder
|
|
from .decoder import RCNNDecoder
|
|
|
|
|
|
class RCNN(SegmentationModel):
|
|
|
|
def __init__(
|
|
self,
|
|
encoder_name: str = "resnet34",
|
|
encoder_depth: int = 5,
|
|
encoder_weights: Optional[str] = "imagenet",
|
|
decoder_pyramid_channels: int = 256,
|
|
decoder_segmentation_channels: int = 128,
|
|
decoder_merge_policy: str = "add",
|
|
decoder_dropout: float = 0.2,
|
|
in_channels: int = 3,
|
|
classes: int = 1,
|
|
activation: Optional[str] = None,
|
|
upsampling: int = 4,
|
|
aux_params: Optional[dict] = None,
|
|
siam_encoder: bool = True,
|
|
fusion_form: str = "concat",
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
|
|
self.siam_encoder = siam_encoder
|
|
|
|
self.encoder = get_encoder(
|
|
encoder_name,
|
|
in_channels=in_channels,
|
|
depth=encoder_depth,
|
|
weights=encoder_weights,
|
|
)
|
|
|
|
if not self.siam_encoder:
|
|
self.encoder_non_siam = get_encoder(
|
|
encoder_name,
|
|
in_channels=in_channels,
|
|
depth=encoder_depth,
|
|
weights=encoder_weights,
|
|
)
|
|
|
|
self.decoder = RCNNDecoder(
|
|
encoder_channels=self.encoder.out_channels,
|
|
encoder_depth=encoder_depth,
|
|
pyramid_channels=decoder_pyramid_channels,
|
|
segmentation_channels=decoder_segmentation_channels,
|
|
dropout=decoder_dropout,
|
|
merge_policy=decoder_merge_policy,
|
|
fusion_form=fusion_form,
|
|
)
|
|
|
|
self.segmentation_head = SegmentationHead(
|
|
in_channels=self.decoder.out_channels,
|
|
out_channels=classes,
|
|
activation=activation,
|
|
kernel_size=1,
|
|
upsampling=upsampling,
|
|
)
|
|
|
|
if aux_params is not None:
|
|
self.classification_head = ClassificationHead(
|
|
in_channels=self.encoder.out_channels[-1], **aux_params
|
|
)
|
|
else:
|
|
self.classification_head = None
|
|
|
|
self.name = "rcnn-{}".format(encoder_name)
|
|
self.initialize()
|