180 lines
5.6 KiB
Python
180 lines
5.6 KiB
Python
"""
|
|
BSD 3-Clause License
|
|
|
|
Copyright (c) Soumith Chintala 2016,
|
|
All rights reserved.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are met:
|
|
|
|
* Redistributions of source code must retain the above copyright notice, this
|
|
list of conditions and the following disclaimer.
|
|
|
|
* Redistributions in binary form must reproduce the above copyright notice,
|
|
this list of conditions and the following disclaimer in the documentation
|
|
and/or other materials provided with the distribution.
|
|
|
|
* Neither the name of the copyright holder nor the names of its
|
|
contributors may be used to endorse or promote products derived from
|
|
this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from ..base import Decoder
|
|
|
|
__all__ = ["DVCADecoder"]
|
|
|
|
|
|
class DVCADecoder(Decoder):
|
|
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36), fusion_form="concat"):
|
|
super().__init__()
|
|
|
|
# adjust encoder channels according to fusion form
|
|
if fusion_form in self.FUSION_DIC["2to2_fusion"]:
|
|
in_channels = in_channels * 2
|
|
|
|
self.aspp = nn.Sequential(
|
|
# ASPP(in_channels, out_channels, atrous_rates),
|
|
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
self.out_channels = out_channels
|
|
self.fusion_form = fusion_form
|
|
|
|
def forward(self, *features):
|
|
x = self.fusion(features[0][-1], features[1][-1], self.fusion_form)
|
|
x = self.aspp(x)
|
|
return x
|
|
|
|
|
|
class ASPPConv(nn.Sequential):
|
|
def __init__(self, in_channels, out_channels, dilation):
|
|
super().__init__(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
bias=False,
|
|
),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
|
|
class ASPPSeparableConv(nn.Sequential):
|
|
def __init__(self, in_channels, out_channels, dilation):
|
|
super().__init__(
|
|
SeparableConv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
bias=False,
|
|
),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
|
|
class ASPPPooling(nn.Sequential):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
size = x.shape[-2:]
|
|
for mod in self:
|
|
x = mod(x)
|
|
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
|
|
|
|
|
|
class ASPP(nn.Module):
|
|
def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
|
|
super(ASPP, self).__init__()
|
|
modules = []
|
|
modules.append(
|
|
nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
)
|
|
)
|
|
|
|
rate1, rate2, rate3 = tuple(atrous_rates)
|
|
ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
|
|
|
|
modules.append(ASPPConvModule(in_channels, out_channels, rate1))
|
|
modules.append(ASPPConvModule(in_channels, out_channels, rate2))
|
|
modules.append(ASPPConvModule(in_channels, out_channels, rate3))
|
|
modules.append(ASPPPooling(in_channels, out_channels))
|
|
|
|
self.convs = nn.ModuleList(modules)
|
|
|
|
self.project = nn.Sequential(
|
|
nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
)
|
|
|
|
def forward(self, x):
|
|
res = []
|
|
for conv in self.convs:
|
|
res.append(conv(x))
|
|
res = torch.cat(res, dim=1)
|
|
return self.project(res)
|
|
|
|
|
|
class SeparableConv2d(nn.Sequential):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
bias=True,
|
|
):
|
|
dephtwise_conv = nn.Conv2d(
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=in_channels,
|
|
bias=False,
|
|
)
|
|
pointwise_conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
bias=bias,
|
|
)
|
|
super().__init__(dephtwise_conv, pointwise_conv)
|