2023-07-26 20:53:08 +08:00

74 lines
2.2 KiB
Python

from typing import Optional, Union
from ..base import ClassificationHead, SegmentationHead, SegmentationModel
from ..encoders import get_encoder
from .decoder import RCNNDecoder
class RCNN(SegmentationModel):
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_pyramid_channels: int = 256,
decoder_segmentation_channels: int = 128,
decoder_merge_policy: str = "add",
decoder_dropout: float = 0.2,
in_channels: int = 3,
classes: int = 1,
activation: Optional[str] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
siam_encoder: bool = True,
fusion_form: str = "concat",
**kwargs
):
super().__init__()
self.siam_encoder = siam_encoder
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)
if not self.siam_encoder:
self.encoder_non_siam = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)
self.decoder = RCNNDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
pyramid_channels=decoder_pyramid_channels,
segmentation_channels=decoder_segmentation_channels,
dropout=decoder_dropout,
merge_policy=decoder_merge_policy,
fusion_form=fusion_form,
)
self.segmentation_head = SegmentationHead(
in_channels=self.decoder.out_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None
self.name = "rcnn-{}".format(encoder_name)
self.initialize()