2023-04-18 11:04:50 +08:00

197 lines
6.7 KiB
Python

import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from collections import OrderedDict
from pretrainedmodels.models.torchvision_models import pretrained_settings
from ._base import EncoderMixin
from .swin_transformer_model import SwinTransformer
class SwinTransformerEncoder(SwinTransformer, EncoderMixin):
def __init__(self, out_channels, depth=5, **kwargs):
super().__init__(**kwargs)
self._depth = depth
self._out_channels = out_channels
self._in_channels = 3
def get_stages(self):
return [nn.Identity()]
def feature_forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
if self.ape:
# interpolate the position embedding to the corresponding size
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
else:
x = x.flatten(2).transpose(1, 2)
x = self.pos_drop(x)
outs = []
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return outs
def forward(self, x):
stages = self.get_stages()
features = []
for stage in stages:
x = stage(x)
features.append(x)
outs = self.feature_forward(x)
# Note: An additional interpolated feature to accommodate five-stage decoders,\
# the additional feature will be ignored if a decoder with fewer stages is used.
add_feature = F.interpolate(outs[0], scale_factor=2)
features = features + [add_feature] + outs
return features
def load_state_dict(self, state_dict, **kwargs):
new_state_dict = OrderedDict()
if 'state_dict' in state_dict:
_state_dict = state_dict['state_dict']
elif 'model' in state_dict:
_state_dict = state_dict['model']
else:
_state_dict = state_dict
for k, v in _state_dict.items():
if k.startswith('backbone.'):
new_state_dict[k[9:]] = v
else:
new_state_dict[k] = v
# Note: In swin seg model: `attn_mask` is no longer a class attribute for
# multi-scale inputs; a norm layer is added for each output; the head layer
# is removed.
kwargs.update({'strict': False})
super().load_state_dict(new_state_dict, **kwargs)
# https://github.com/microsoft/Swin-Transformer
new_settings = {
"Swin-T": {
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth",
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_tiny_patch4_window7_512x512.pth"
},
"Swin-S": {
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth",
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_small_patch4_window7_512x512.pth"
},
"Swin-B": {
"imagenet": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth",
"imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth",
"ADE20k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.1/upernet_swin_base_patch4_window7_512x512.pth"
},
"Swin-L": {
"imagenet-22k": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth"
},
}
pretrained_settings = deepcopy(pretrained_settings)
for model_name, sources in new_settings.items():
if model_name not in pretrained_settings:
pretrained_settings[model_name] = {}
for source_name, source_url in sources.items():
pretrained_settings[model_name][source_name] = {
"url": source_url,
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
swin_transformer_encoders = {
"Swin-T": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-T"],
"params": {
"embed_dim": 96,
"out_channels": (3, 96, 96, 192, 384, 768),
"depths": [2, 2, 6, 2],
"num_heads": [3, 6, 12, 24],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
},
"Swin-S": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-S"],
"params": {
"embed_dim": 96,
"out_channels": (3, 96, 96, 192, 384, 768),
"depths": [2, 2, 18, 2],
"num_heads": [3, 6, 12, 24],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
},
"Swin-B": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-B"],
"params": {
"embed_dim": 128,
"out_channels": (3, 128, 128, 256, 512, 1024),
"depths": [2, 2, 18, 2],
"num_heads": [4, 8, 16, 32],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
},
"Swin-L": {
"encoder": SwinTransformerEncoder,
"pretrained_settings": pretrained_settings["Swin-L"],
"params": {
"embed_dim": 192,
"out_channels": (3, 192, 192, 384, 768, 1536),
"depths": [2, 2, 18, 2],
"num_heads": [6, 12, 24, 48],
"window_size": 7,
"ape": False,
"drop_path_rate": 0.3,
"patch_norm": True,
"use_checkpoint": False
}
}
}
if __name__ == "__main__":
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input = torch.randn(1, 3, 256, 256).to(device)
model = SwinTransformerEncoder(2, window_size=8)
# print(model)
res = model.forward(input)
for i in res:
print(i.shape)