334 lines
14 KiB
Python
334 lines
14 KiB
Python
import os
|
||
import torch
|
||
from collections import OrderedDict
|
||
from abc import ABC, abstractmethod
|
||
from torch.optim import lr_scheduler
|
||
|
||
|
||
def get_scheduler(optimizer, opt):
|
||
"""Return a learning rate scheduler
|
||
|
||
Parameters:
|
||
optimizer -- the optimizer of the network
|
||
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
||
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
||
|
||
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
|
||
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
|
||
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
||
See https://pytorch.org/docs/stable/optim.html for more details.
|
||
"""
|
||
if opt.lr_policy == 'linear':
|
||
def lambda_rule(epoch):
|
||
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
||
return lr_l
|
||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||
elif opt.lr_policy == 'step':
|
||
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
||
elif opt.lr_policy == 'plateau':
|
||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
||
elif opt.lr_policy == 'cosine':
|
||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
||
else:
|
||
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
||
return scheduler
|
||
|
||
|
||
class BaseModel(ABC):
|
||
"""This class is an abstract base class (ABC) for models.
|
||
To create a subclass, you need to implement the following five functions:
|
||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||
-- <forward>: produce intermediate results.
|
||
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||
"""
|
||
|
||
def __init__(self, opt):
|
||
"""Initialize the BaseModel class.
|
||
|
||
Parameters:
|
||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||
|
||
When creating your custom class, you need to implement your own initialization.
|
||
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
||
Then, you need to define four lists:
|
||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||
-- self.model_names (str list): specify the images that you want to display and save.
|
||
-- self.visual_names (str list): define networks used in our training.
|
||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||
"""
|
||
self.opt = opt
|
||
self.gpu_ids = opt.gpu_ids
|
||
self.isTrain = opt.isTrain
|
||
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
||
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
||
torch.backends.cudnn.benchmark = True
|
||
self.loss_names = []
|
||
self.model_names = []
|
||
self.visual_names = []
|
||
self.visual_features = []
|
||
self.optimizers = []
|
||
self.image_paths = []
|
||
self.metric = 0 # used for learning rate policy 'plateau'
|
||
self.istest = True if opt.phase == 'test' else False # 如果是测试,该模式下,没有标注样本;
|
||
|
||
@staticmethod
|
||
def modify_commandline_options(parser, is_train):
|
||
"""Add new model-specific options, and rewrite default values for existing options.
|
||
|
||
Parameters:
|
||
parser -- original option parser
|
||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||
|
||
Returns:
|
||
the modified parser.
|
||
"""
|
||
return parser
|
||
|
||
@abstractmethod
|
||
def set_input(self, input):
|
||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||
|
||
Parameters:
|
||
input (dict): includes the data itself and its metadata information.
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def forward(self):
|
||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def optimize_parameters(self):
|
||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||
pass
|
||
|
||
def setup(self, opt):
|
||
"""Load and print networks; create schedulers
|
||
|
||
Parameters:
|
||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||
"""
|
||
if self.isTrain:
|
||
self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
||
if not self.isTrain or opt.continue_train:
|
||
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
||
self.load_networks(load_suffix)
|
||
self.print_networks(opt.verbose)
|
||
|
||
def eval(self):
|
||
"""Make models eval mode during test time"""
|
||
for name in self.model_names:
|
||
if isinstance(name, str):
|
||
net = getattr(self, 'net' + name)
|
||
net.eval()
|
||
|
||
|
||
def train(self):
|
||
"""Make models train mode during train time"""
|
||
for name in self.model_names:
|
||
if isinstance(name, str):
|
||
net = getattr(self, 'net' + name)
|
||
net.train()
|
||
|
||
def test(self):
|
||
"""Forward function used in test time.
|
||
|
||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||
It also calls <compute_visuals> to produce additional visualization results
|
||
"""
|
||
with torch.no_grad():
|
||
self.forward()
|
||
self.compute_visuals()
|
||
|
||
def compute_visuals(self):
|
||
"""Calculate additional output images for visdom and HTML visualization"""
|
||
pass
|
||
|
||
def get_image_paths(self):
|
||
""" Return image paths that are used to load current data"""
|
||
return self.image_paths
|
||
|
||
def update_learning_rate(self):
|
||
"""Update learning rates for all the networks; called at the end of every epoch"""
|
||
for scheduler in self.schedulers:
|
||
if self.opt.lr_policy == 'plateau':
|
||
scheduler.step(self.metric)
|
||
else:
|
||
scheduler.step()
|
||
|
||
lr = self.optimizers[0].param_groups[0]['lr']
|
||
print('learning rate = %.7f' % lr)
|
||
|
||
def get_current_visuals(self):
|
||
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
||
visual_ret = OrderedDict()
|
||
for name in self.visual_names:
|
||
if isinstance(name, str):
|
||
visual_ret[name] = getattr(self, name)
|
||
return visual_ret
|
||
|
||
def get_current_losses(self):
|
||
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
||
errors_ret = OrderedDict()
|
||
for name in self.loss_names:
|
||
if isinstance(name, str):
|
||
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
||
return errors_ret
|
||
|
||
def save_networks(self, epoch):
|
||
"""Save all the networks to the disk.
|
||
|
||
Parameters:
|
||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||
"""
|
||
for name in self.model_names:
|
||
if isinstance(name, str):
|
||
save_filename = '%s_net_%s.pth' % (epoch, name)
|
||
save_path = os.path.join(self.save_dir, save_filename)
|
||
net = getattr(self, 'net' + name)
|
||
|
||
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
||
# torch.save(net.module.cpu().state_dict(), save_path)
|
||
torch.save(net.cpu().state_dict(), save_path)
|
||
net.cuda(self.gpu_ids[0])
|
||
else:
|
||
torch.save(net.cpu().state_dict(), save_path)
|
||
|
||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
||
key = keys[i]
|
||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||
(key == 'running_mean' or key == 'running_var'):
|
||
if getattr(module, key) is None:
|
||
state_dict.pop('.'.join(keys))
|
||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||
(key == 'num_batches_tracked'):
|
||
state_dict.pop('.'.join(keys))
|
||
else:
|
||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||
|
||
def get_visual(self, name):
|
||
visual_ret = {}
|
||
visual_ret[name] = getattr(self, name)
|
||
return visual_ret
|
||
|
||
def pred_large(self, A, B, input_size=256, stride=0):
|
||
"""
|
||
输入前后时相的大图,获得预测结果
|
||
假定预测结果中心部分为准确,边缘padding = (input_size-stride)/2
|
||
:param A: tensor, N*C*H*W
|
||
:param B: tensor, N*C*H*W
|
||
:param input_size: int, 输入网络的图像size
|
||
:param stride: int, 预测时的跨步
|
||
:return: pred, tensor, N*1*H*W
|
||
"""
|
||
import math
|
||
import numpy as np
|
||
n, c, h, w = A.shape
|
||
assert A.shape == B.shape
|
||
# 分块数量
|
||
n_h = math.ceil((h - input_size) / stride) + 1
|
||
n_w = math.ceil((w - input_size) / stride) + 1
|
||
# 重新计算长宽
|
||
new_h = (n_h - 1) * stride + input_size
|
||
new_w = (n_w - 1) * stride + input_size
|
||
print("new_h: ", new_h)
|
||
print("new_w: ", new_w)
|
||
print("n_h: ", n_h)
|
||
print("n_w: ", n_w)
|
||
new_A = torch.zeros([n, c, new_h, new_w], dtype=torch.float32)
|
||
new_B = torch.zeros([n, c, new_h, new_w], dtype=torch.float32)
|
||
new_A[:, :, :h, :w] = A
|
||
new_B[:, :, :h, :w] = B
|
||
new_pred = torch.zeros([n, 1, new_h, new_w], dtype=torch.uint8)
|
||
del A
|
||
del B
|
||
#
|
||
for i in range(0, new_h - input_size + 1, stride):
|
||
for j in range(0, new_w - input_size + 1, stride):
|
||
left = j
|
||
right = input_size + j
|
||
top = i
|
||
bottom = input_size + i
|
||
patch_A = new_A[:, :, top:bottom, left:right]
|
||
patch_B = new_B[:, :, top:bottom, left:right]
|
||
# print(left,' ',right,' ', top,' ', bottom)
|
||
self.A = patch_A.to(self.device)
|
||
self.B = patch_B.to(self.device)
|
||
with torch.no_grad():
|
||
patch_pred = self.forward()
|
||
new_pred[:, :, top:bottom, left:right] = patch_pred.detach().cpu()
|
||
pred = new_pred[:, :, :h, :w]
|
||
return pred
|
||
|
||
def load_networks(self, epoch):
|
||
"""Load all the networks from the disk.
|
||
|
||
Parameters:
|
||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||
"""
|
||
for name in self.model_names:
|
||
if isinstance(name, str):
|
||
load_filename = '%s_net_%s.pth' % (epoch, name)
|
||
load_path = os.path.join(self.save_dir, load_filename)
|
||
net = getattr(self, 'net' + name)
|
||
# if isinstance(net, torch.nn.DataParallel):
|
||
# net = net.module
|
||
# net = net.module # 适配保存的module
|
||
print('loading the model from %s' % load_path)
|
||
# if you are using PyTorch newer than 0.4 (e.g., built from
|
||
# GitHub source), you can remove str() on self.device
|
||
state_dict = torch.load(load_path, map_location=str(self.device))
|
||
|
||
if hasattr(state_dict, '_metadata'):
|
||
del state_dict._metadata
|
||
# patch InstanceNorm checkpoints prior to 0.4
|
||
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
||
# print(key)
|
||
net.load_state_dict(state_dict,strict=False)
|
||
|
||
def print_networks(self, verbose):
|
||
"""Print the total number of parameters in the network and (if verbose) network architecture
|
||
|
||
Parameters:
|
||
verbose (bool) -- if verbose: print the network architecture
|
||
"""
|
||
print('---------- Networks initialized -------------')
|
||
for name in self.model_names:
|
||
if isinstance(name, str):
|
||
net = getattr(self, 'net' + name)
|
||
num_params = 0
|
||
for param in net.parameters():
|
||
num_params += param.numel()
|
||
if verbose:
|
||
print(net)
|
||
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
||
print('-----------------------------------------------')
|
||
|
||
def set_requires_grad(self, nets, requires_grad=False):
|
||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||
Parameters:
|
||
nets (network list) -- a list of networks
|
||
requires_grad (bool) -- whether the networks require gradients or not
|
||
"""
|
||
if not isinstance(nets, list):
|
||
nets = [nets]
|
||
for net in nets:
|
||
if net is not None:
|
||
for param in net.parameters():
|
||
param.requires_grad = requires_grad
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
|
||
A = torch.rand([1,3,512,512],dtype=torch.float32)
|
||
B = torch.rand([1,3,512,512],dtype=torch.float32)
|