This commit is contained in:
copper 2023-05-23 11:49:52 +08:00
parent 804a3064f6
commit dab53f5dfa
18 changed files with 740 additions and 488 deletions

239
env.yaml Normal file
View File

@ -0,0 +1,239 @@
name: cveo_ss
channels:
- conda-forge
- defaults
dependencies:
- appdirs=1.4.4=pyh9f0ad1d_0
- astroid=2.8.6=py37h03978a9_1
- blosc=1.21.1=h74325e0_3
- boost-cpp=1.78.0=h9f4b32c_1
- brotli=1.0.9=h8ffe710_7
- brotli-bin=1.0.9=h8ffe710_7
- brotlipy=0.7.0=py37hcc03f2d_1004
- bzip2=1.0.8=h8ffe710_4
- c-blosc2=2.3.1=hdf67494_0
- ca-certificates=2022.6.15.1=h5b45459_0
- cairo=1.16.0=hb19e0ff_1008
- certifi=2022.6.15.1=pyhd8ed1ab_0
- cffi=1.15.1=py37hd8e9650_0
- cfitsio=3.470=h0af3d06_7
- charls=2.2.0=h39d44d4_0
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- cloudpickle=2.2.0=pyhd8ed1ab_0
- colorama=0.4.5=pyhd8ed1ab_0
- coverage=6.4.4=py37hcc03f2d_0
- cryptography=3.4.7=py37h20c650d_0
- curl=7.83.1=h789b8ee_0
- cycler=0.11.0=pyhd8ed1ab_0
- cytoolz=0.12.0=py37hcc03f2d_0
- dask-core=2021.9.1=pyhd8ed1ab_0
- decorator=5.1.0=pyhd8ed1ab_0
- draco=1.5.3=h5362a0b_0
- exiv2=0.27.5=h02b4549_0
- expat=2.4.8=h39d44d4_0
- fontconfig=2.14.0=hce3cb01_0
- fonttools=4.37.1=py37hcc03f2d_0
- freeglut=3.2.2=h0e60522_1
- freetype=2.12.1=h546665d_0
- freexl=1.0.6=ha8e266a_0
- fsspec=2022.8.2=pyhd8ed1ab_0
- future=0.18.2=py37h03978a9_3
- gdal=3.3.1=py37hb11e9a8_2
- geos=3.9.1=h39d44d4_2
- geotiff=1.6.0=ha8a8a2d_6
- gettext=0.19.8.1=ha2e2712_1008
- giflib=5.2.1=h8d14728_2
- gsl=2.7=hdfb1a43_0
- hdf4=4.2.15=h0e5069d_4
- hdf5=1.10.6=nompi_h5268f04_1114
- httplib2=0.20.4=pyhd8ed1ab_0
- icu=68.2=h0e60522_0
- idna=3.3=pyhd8ed1ab_0
- imagecodecs=2021.8.26=py37h91eda04_1
- imageio=2.21.3=pyhfa7a67d_0
- intel-openmp=2022.1.0=h57928b3_3787
- isort=5.10.1=pyhd8ed1ab_0
- jasper=2.0.33=h77af90b_0
- jbig=2.1=h8d14728_2003
- jinja2=3.1.2=pyhd8ed1ab_1
- jpeg=9e=h8ffe710_2
- jsoncpp=1.9.4=h2d74725_3
- jxrlib=1.1=h8ffe710_2
- kealib=1.4.14=h96bfa42_2
- kiwisolver=1.4.4=py37h8c56517_0
- krb5=1.19.3=h1176d77_0
- laszip=3.4.3=h6538335_1
- laz-perf=3.0.0=h2d74725_0
- lazy-object-proxy=1.7.1=py37hcc03f2d_1
- lcms2=2.12=h2a16943_0
- lerc=3.0=h0e60522_0
- libaec=1.0.6=h39d44d4_0
- libblas=3.9.0=16_win64_mkl
- libbrotlicommon=1.0.9=h8ffe710_7
- libbrotlidec=1.0.9=h8ffe710_7
- libbrotlienc=1.0.9=h8ffe710_7
- libcblas=3.9.0=16_win64_mkl
- libclang=11.1.0=default_h5c34c98_1
- libcurl=7.83.1=h789b8ee_0
- libdeflate=1.8=h8ffe710_0
- libffi=3.4.2=h8ffe710_5
- libgdal=3.3.1=h7e75cf7_2
- libglib=2.72.1=h3be07f2_0
- libiconv=1.16=he774522_0
- libkml=1.3.0=hf2ab4e4_1015
- liblapack=3.9.0=16_win64_mkl
- liblapacke=3.9.0=16_win64_mkl
- libnetcdf=4.8.1=nompi_hf689e7d_100
- libopencv=4.5.3=py37h6700db3_1
- libpng=1.6.37=h1d00b33_4
- libpq=13.5=hfcc5ef8_1
- libprotobuf=3.16.0=h7755175_0
- librttopo=1.1.0=hb340de5_6
- libspatialindex=1.9.3=h39d44d4_4
- libspatialite=5.0.1=h762a7f4_6
- libsqlite=3.39.3=hcfcfb64_0
- libssh2=1.10.0=h680486a_3
- libtiff=4.3.0=hd413186_2
- libwebp=1.2.4=h8ffe710_0
- libwebp-base=1.2.4=h8ffe710_0
- libxcb=1.13=hcd874cb_1004
- libxml2=2.9.14=hf5bbc77_4
- libxslt=1.1.35=h34f844d_0
- libzip=1.9.2=hfed4ece_1
- libzlib=1.2.12=hcfcfb64_3
- libzopfli=1.0.3=h0e60522_0
- locket=1.0.0=pyhd8ed1ab_0
- lz4-c=1.9.3=h8ffe710_1
- m2w64-gcc-libgfortran=5.3.0=6
- m2w64-gcc-libs=5.3.0=7
- m2w64-gcc-libs-core=5.3.0=7
- m2w64-gmp=6.1.0=2
- m2w64-libwinpthread-git=5.0.0.4634.697f757=2
- markupsafe=2.1.1=py37hcc03f2d_1
- matplotlib-base=3.5.3=py37hbaab90a_2
- mccabe=0.6.1=py_1
- mkl=2022.1.0=h6a75c08_874
- mock=4.0.3=py37h03978a9_3
- msys2-conda-epoch=20160418=1
- munkres=1.1.4=pyh9f0ad1d_0
- networkx=2.7.1=pyhd8ed1ab_0
- nitro=2.7.dev6=h39d44d4_5
- nose2=0.9.2=py_0
- numpy=1.21.2=py37h940b05c_0
- opencv=4.5.3=py37h03978a9_1
- openjpeg=2.4.0=hb211442_1
- openssl=1.1.1q=h8ffe710_0
- owslib=0.27.2=pyhd8ed1ab_1
- packaging=21.3=pyhd8ed1ab_0
- pandas=1.3.5=py37h9386db6_0
- partd=1.3.0=pyhd8ed1ab_0
- pcre=8.45=h0e60522_0
- pdal=2.3.0=hde8ebe7_6
- pip=21.2.4=pyhd8ed1ab_0
- pixman=0.40.0=h8ffe710_0
- pkg-config=0.29.2=h2bf4dc2_1008
- pkgconfig=1.5.5=py37h03978a9_2
- platformdirs=2.5.2=pyhd8ed1ab_1
- plotly=5.10.0=pyhd8ed1ab_0
- pooch=1.6.0=pyhd8ed1ab_0
- poppler=21.03.0=h9ff6ed8_0
- poppler-data=0.4.11=hd8ed1ab_0
- postgresql=13.5=h1c22c4f_1
- proj=8.0.1=h1cfcee9_0
- psycopg2=2.9.2=py37hd8e9650_0
- pthread-stubs=0.4=hcd874cb_1001
- py-opencv=4.5.3=py37h4038f58_1
- pycparser=2.20=pyh9f0ad1d_2
- pygments=2.10.0=pyhd8ed1ab_0
- pylint=2.11.1=pyhd8ed1ab_0
- pyopenssl=21.0.0=pyhd8ed1ab_0
- pyparsing=3.0.9=pyhd8ed1ab_0
- pyproj=3.2.1=py37h9f67652_0
- pyqt=5.12.3=py37h03978a9_8
- pyqt-impl=5.12.3=py37hf2a7229_8
- pyqt5-sip=4.19.18=py37hf2a7229_8
- pyqtads=3.8.2=py37hf2a7229_0
- pyqtchart=5.12=py37hf2a7229_8
- pyqtwebengine=5.12.1=py37hf2a7229_8
- pyqtwebkit=5.212=py37h9e7b984_2
- pysocks=1.7.1=py37h03978a9_5
- python=3.7.10=h7840368_101_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.7=2_cp37m
- pytz=2021.1=pyhd8ed1ab_0
- pywavelets=1.1.1=py37hec80d1f_3
- pywin32=303=py37hcc03f2d_0
- pyyaml=5.4.1=py37hcc03f2d_1
- qca=2.2.1=hd7ce7fb_3
- qgis=3.18.3=py37h3dc7164_2
- qjson=0.9.0=hd7ce7fb_1006
- qscintilla2=2.11.2=py37h9e7b984_6
- qt=5.12.9=h5909a2a_4
- qtkeychain=0.12.0=h552f0f6_0
- qtlocation=5.12.9=ha925a31_0
- qtwebkit=5.212=hb258684_1
- qwt=6.1.6=h552f0f6_0
- qwtpolar=1.1.1=hd7ce7fb_7
- requests=2.28.1=pyhd8ed1ab_1
- scikit-image=0.18.3=py37h9386db6_0
- scipy=1.7.1=py37hb6553fb_0
- setuptools=59.8.0=py37h03978a9_1
- six=1.16.0=pyh6c4a22f_0
- snappy=1.1.9=h82413e6_1
- sqlite=3.39.3=hcfcfb64_0
- tbb=2021.5.0=h91493d7_2
- tenacity=8.0.1=pyhd8ed1ab_0
- tifffile=2021.11.2=pyhd8ed1ab_0
- tiledb=2.3.4=h78dabda_0
- tk=8.6.12=h8ffe710_0
- toml=0.10.2=pyhd8ed1ab_0
- tomli=2.0.1=pyhd8ed1ab_0
- toolz=0.12.0=pyhd8ed1ab_0
- typed-ast=1.5.4=py37hcc03f2d_0
- typing-extensions=4.3.0=hd8ed1ab_0
- typing_extensions=4.3.0=pyha770c72_0
- ucrt=10.0.20348.0=h57928b3_0
- unicodedata2=14.0.0=py37hcc03f2d_1
- urllib3=1.26.11=pyhd8ed1ab_0
- vc=14.2=hb210afc_7
- vs2015_runtime=14.29.30139=h890b9b1_7
- wheel=0.37.1=pyhd8ed1ab_0
- win_inet_pton=1.1.0=py37h03978a9_4
- wrapt=1.13.3=py37hcc03f2d_1
- xerces-c=3.2.3=h0e60522_5
- xorg-libxau=1.0.9=hcd874cb_0
- xorg-libxdmcp=1.1.3=hcd874cb_0
- xz=5.2.6=h8d14728_0
- yaml=0.2.5=he774522_0
- zfp=0.5.5=h0e60522_8
- zlib=1.2.12=hcfcfb64_3
- zlib-ng=2.0.6=h8ffe710_0
- zstd=1.5.2=h7755175_4
- pip:
- attrs==21.4.0
- autopep8==2.0.0
- conda-pack==0.6.0
- cython==0.29.24
- efficientnet-pytorch==0.7.1
- filelock==3.12.0
- huggingface-hub==0.14.1
- importlib-metadata==4.8.1
- joblib==1.1.0
- munch==2.5.0
- nuitka==0.8.3
- opencv-python==4.5.3.56
- ordered-set==4.1.0
- pathlib==1.0.1
- pillow==6.2.2
- pretrainedmodels==0.7.4
- pycodestyle==2.9.1
- pycryptodome==3.14.1
- rios==0.0.0.0.dev20200902
- scikit-learn==1.0.2
- sklearn==0.0
- threadpoolctl==3.1.0
- timm==0.6.13
- torch==1.13.1
- torchvision==0.14.1
- tqdm==4.65.0
- zipp==3.8.1

View File

@ -1 +1 @@
vd4FiYncytyziGH9GNCAA8hGGr1/79Xmphtc5+PHPJDpxvqj1hP7+985QMojYO4M5Qn/aqEAvFgeDN3CA8x1YAK8SdCgSXSBJpRBK8wqPQjBY1ak96QfdPCrTLunr+xuPxK3Gxe772adTTsee2+ot7WePYUsC4y4NcS5+rlP1if87xtYqVeSwx3c64cOmAGP
IaqFuRlbPMtYTReB0p+cxn8sffVeOjbq+d46I2texZIVIeLbwCxJ7w3mqlramQy3p0totEfoSkjIrQV1GjtrOOWsuYRqv5ZZ5A+/PdTd7ZU8WlMAl7sknJGJFWvciG1VL9n9XtJUG+CJg4oLYkdwR5WePYUsC4y4NcS5+rlP1if87xtYqVeSwx3c64cOmAGP

View File

@ -8,7 +8,7 @@ from rscder.utils.project import Project
from rscder.utils.geomath import geo2imageRC, imageRC2geo
import math
from .packages import get_model
import numpy as np
class BasicAICD(AlgFrontend):
@staticmethod
@ -62,9 +62,9 @@ class BasicAICD(AlgFrontend):
for i in range(xblocks +1):
block_xy1 = (start1x + i * cell_size[0], start1y+j * cell_size[1])
block_xy2 = (start2x + i * cell_size[0], start2y+j * cell_size[1])
block_xy = (i * cell_size[0], j * cell_size[1])
block_xy1 = [start1x + i * cell_size[0], start1y+j * cell_size[1]]
block_xy2 = [start2x + i * cell_size[0], start2y+j * cell_size[1]]
block_xy = [i * cell_size[0], j * cell_size[1]]
if block_xy1[1] > end1y or block_xy2[1] > end2y:
break

View File

@ -1,39 +1,39 @@
from .models import create_model
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Front:
def __init__(self, model) -> None:
self.model = model
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = model.to(self.device)
def __call__(self, inp1, inp2):
inp1 = torch.from_numpy(inp1).to(self.device)
inp1 = torch.transpose(inp1, 0, 2).transpose(1, 2).unsqueeze(0)
inp2 = torch.from_numpy(inp2).to(self.device)
inp2 = torch.transpose(inp2, 0, 2).transpose(1, 2).unsqueeze(0)
out = self.model(inp1, inp2)
out = out.sigmoid()
out = out.cpu().detach().numpy()[0,0]
return out
def get_model(name):
try:
model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=1, # model output channels (number of classes in your datasets)
siam_encoder=True, # whether to use a siamese encoder
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.)
)
except:
return None
from .models import create_model
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Front:
def __init__(self, model) -> None:
self.model = model
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = model.to(self.device)
def __call__(self, inp1, inp2):
inp1 = torch.from_numpy(inp1).float().to(self.device) / 255.0
inp1 = inp1.unsqueeze(0)
inp2 = torch.from_numpy(inp2).float().to(self.device) / 255.0
inp2 = inp2.unsqueeze(0)
out = self.model(inp1, inp2)
out = out.sigmoid()
out = out.cpu().detach().numpy()[0,0]
return out
def get_model(name):
try:
model = create_model(name, encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=1, # model output channels (number of classes in your datasets)
siam_encoder=True, # whether to use a siamese encoder
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.)
)
except:
return None
return Front(model)

View File

@ -48,8 +48,8 @@ class DVCADecoder(Decoder):
in_channels = in_channels * 2
self.aspp = nn.Sequential(
ASPP(in_channels, out_channels, atrous_rates),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
# ASPP(in_channels, out_channels, atrous_rates),
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)

View File

@ -1,42 +1,42 @@
from .RCNN import RCNN
from .DVCA import DVCA
from .DPFCN import DPFCN
from . import encoders
from . import utils
from . import losses
from . import datasets
from .__version__ import __version__
from typing import Optional
import torch
def create_model(
arch: str,
encoder_name: str = "resnet34",
encoder_weights: Optional[str] = "imagenet",
in_channels: int = 3,
classes: int = 1,
**kwargs,
) -> torch.nn.Module:
"""Models wrapper. Allows to create any model just with parametes
"""
archs = [DVCA, DPFCN, RCNN]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
model_class = archs_dict[arch.lower()]
except KeyError:
raise KeyError("Wrong architecture type `{}`. Available options are: {}".format(
arch, list(archs_dict.keys()),
))
return model_class(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
**kwargs,
)
from .RCNN import RCNN
from .DVCA import DVCA
from .DPFCN import DPFCN
from . import encoders
# from . import utils
# from . import losses
# from . import datasets
# from .__version__ import __version__
from typing import Optional
import torch
def create_model(
arch: str,
encoder_name: str = "resnet34",
encoder_weights: Optional[str] = "imagenet",
in_channels: int = 3,
classes: int = 1,
**kwargs,
) -> torch.nn.Module:
"""Models wrapper. Allows to create any model just with parametes
"""
archs = [DVCA, DPFCN, RCNN]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
model_class = archs_dict[arch.lower()]
except KeyError:
raise KeyError("Wrong architecture type `{}`. Available options are: {}".format(
arch, list(archs_dict.keys()),
))
return model_class(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=in_channels,
classes=classes,
**kwargs,
)

View File

@ -1,113 +1,113 @@
import functools
import torch
import torch.utils.model_zoo as model_zoo
from ._preprocessing import preprocess_input
from .densenet import densenet_encoders
from .dpn import dpn_encoders
from .efficientnet import efficient_net_encoders
from .inceptionresnetv2 import inceptionresnetv2_encoders
from .inceptionv4 import inceptionv4_encoders
from .mobilenet import mobilenet_encoders
from .resnet import resnet_encoders
from .senet import senet_encoders
from .timm_efficientnet import timm_efficientnet_encoders
from .timm_gernet import timm_gernet_encoders
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
from .timm_regnet import timm_regnet_encoders
from .timm_res2net import timm_res2net_encoders
from .timm_resnest import timm_resnest_encoders
from .timm_sknet import timm_sknet_encoders
from .timm_universal import TimmUniversalEncoder
from .vgg import vgg_encoders
from .xception import xception_encoders
from .swin_transformer import swin_transformer_encoders
from .mit_encoder import mit_encoders
# from .hrnet import hrnet_encoders
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
encoders = {}
encoders.update(resnet_encoders)
encoders.update(dpn_encoders)
encoders.update(vgg_encoders)
encoders.update(senet_encoders)
encoders.update(densenet_encoders)
encoders.update(inceptionresnetv2_encoders)
encoders.update(inceptionv4_encoders)
encoders.update(efficient_net_encoders)
encoders.update(mobilenet_encoders)
encoders.update(xception_encoders)
encoders.update(timm_efficientnet_encoders)
encoders.update(timm_resnest_encoders)
encoders.update(timm_res2net_encoders)
encoders.update(timm_regnet_encoders)
encoders.update(timm_sknet_encoders)
encoders.update(timm_mobilenetv3_encoders)
encoders.update(timm_gernet_encoders)
encoders.update(swin_transformer_encoders)
encoders.update(mit_encoders)
# encoders.update(hrnet_encoders)
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
if name.startswith("tu-"):
name = name[3:]
encoder = TimmUniversalEncoder(
name=name,
in_channels=in_channels,
depth=depth,
output_stride=output_stride,
pretrained=weights is not None,
**kwargs
)
return encoder
try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
params = encoders[name]["params"]
params.update(depth=depth)
encoder = Encoder(**params)
if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights, name, list(encoders[name]["pretrained_settings"].keys()),
))
encoder.load_state_dict(model_zoo.load_url(settings["url"], map_location=torch.device(DEVICE)))
encoder.set_in_channels(in_channels, pretrained=weights is not None)
if output_stride != 32:
encoder.make_dilated(output_stride)
return encoder
def get_encoder_names():
return list(encoders.keys())
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
settings = encoders[encoder_name]["pretrained_settings"]
if pretrained not in settings.keys():
raise ValueError("Available pretrained options {}".format(settings.keys()))
formatted_settings = {}
formatted_settings["input_space"] = settings[pretrained].get("input_space")
formatted_settings["input_range"] = settings[pretrained].get("input_range")
formatted_settings["mean"] = settings[pretrained].get("mean")
formatted_settings["std"] = settings[pretrained].get("std")
return formatted_settings
def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
return functools.partial(preprocess_input, **params)
import functools
import torch
import torch.utils.model_zoo as model_zoo
from ._preprocessing import preprocess_input
# from .densenet import densenet_encoders
# from .dpn import dpn_encoders
# from .efficientnet import efficient_net_encoders
# from .inceptionresnetv2 import inceptionresnetv2_encoders
# from .inceptionv4 import inceptionv4_encoders
# from .mobilenet import mobilenet_encoders
from .resnet import resnet_encoders
# from .senet import senet_encoders
# from .timm_efficientnet import timm_efficientnet_encoders
# from .timm_gernet import timm_gernet_encoders
# from .timm_mobilenetv3 import timm_mobilenetv3_encoders
# from .timm_regnet import timm_regnet_encoders
# from .timm_res2net import timm_res2net_encoders
# from .timm_resnest import timm_resnest_encoders
# from .timm_sknet import timm_sknet_encoders
# from .timm_universal import TimmUniversalEncoder
# from .vgg import vgg_encoders
# from .xception import xception_encoders
# from .swin_transformer import swin_transformer_encoders
# from .mit_encoder import mit_encoders
# from .hrnet import hrnet_encoders
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
encoders = {}
encoders.update(resnet_encoders)
# encoders.update(dpn_encoders)
# encoders.update(vgg_encoders)
# encoders.update(senet_encoders)
# encoders.update(densenet_encoders)
# encoders.update(inceptionresnetv2_encoders)
# encoders.update(inceptionv4_encoders)
# encoders.update(efficient_net_encoders)
# encoders.update(mobilenet_encoders)
# encoders.update(xception_encoders)
# encoders.update(timm_efficientnet_encoders)
# encoders.update(timm_resnest_encoders)
# encoders.update(timm_res2net_encoders)
# encoders.update(timm_regnet_encoders)
# encoders.update(timm_sknet_encoders)
# encoders.update(timm_mobilenetv3_encoders)
# encoders.update(timm_gernet_encoders)
# encoders.update(swin_transformer_encoders)
# encoders.update(mit_encoders)
# encoders.update(hrnet_encoders)
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
if name.startswith("tu-"):
name = name[3:]
# encoder = TimmUniversalEncoder(
# name=name,
# in_channels=in_channels,
# depth=depth,
# output_stride=output_stride,
# pretrained=weights is not None,
# **kwargs
# )
# return encoder
try:
Encoder = encoders[name]["encoder"]
except KeyError:
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
params = encoders[name]["params"]
params.update(depth=depth)
encoder = Encoder(**params)
if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights, name, list(encoders[name]["pretrained_settings"].keys()),
))
encoder.load_state_dict(model_zoo.load_url(settings["url"], map_location=torch.device(DEVICE)))
encoder.set_in_channels(in_channels, pretrained=weights is not None)
if output_stride != 32:
encoder.make_dilated(output_stride)
return encoder
def get_encoder_names():
return list(encoders.keys())
def get_preprocessing_params(encoder_name, pretrained="imagenet"):
settings = encoders[encoder_name]["pretrained_settings"]
if pretrained not in settings.keys():
raise ValueError("Available pretrained options {}".format(settings.keys()))
formatted_settings = {}
formatted_settings["input_space"] = settings[pretrained].get("input_space")
formatted_settings["input_range"] = settings[pretrained].get("input_range")
formatted_settings["mean"] = settings[pretrained].get("mean")
formatted_settings["std"] = settings[pretrained].get("std")
return formatted_settings
def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
params = get_preprocessing_params(encoder_name, pretrained=pretrained)
return functools.partial(preprocess_input, **params)

View File

@ -8,6 +8,7 @@ from PyQt5.QtGui import QIcon
from PyQt5.QtCore import Qt
from osgeo import gdal, gdal_array
import numpy as np
import threading
class ExportDialog(QDialog):
TXT='*.txt'
@ -122,12 +123,12 @@ class ExportPlugin(BasicPlugin):
def export_rgb_action(self):
dialog = ExportDialog(self.mainwindow, ExportDialog.TIF)
btn = QPushButton('选择样式文件')
style_path = ''
self.style_path = ''
def select_style():
select_file = QFileDialog.getOpenFileName(self, '选择样式文件', '*.*')
select_file = QFileDialog.getOpenFileName(self.mainwindow, '选择样式文件', '*.*')
if select_file[0]:
# style_path.setText(select_file[0])
style_path = select_file[0]
self.style_path = select_file[0]
btn.setText(select_file[0])
btn.clicked.connect(select_style)
dialog.layout().addWidget(btn)
@ -137,22 +138,26 @@ class ExportPlugin(BasicPlugin):
[1, 255, 0, 0]
]
if style_path is '':
if self.style_path is '':
style = default_style
else:
style = np.loadtxt(style_path, comments='#', delimiter=',')
style = np.loadtxt(self.style_path, comments='#', delimiter=',')
style = style.tolist()
if dialog.exec_():
result = dialog.result_layer.path
result = dialog.result_layer.result_path['cmi']
out = dialog.out_path
self.render(style, result, out)
t = threading.Thread(target=self.render, args=(style, result, out)) # 创建线程并执行RENDER函数
# self.render(style, result, out)
t.start()
def render(self, style:list, path, out):
self.send_message.emit('正在导出' )
data = gdal_array.LoadFile(path)
if len(data.shape) == 3:
data = data[..., 0]
data = data[0]
data = data / 255.0 # scale to 0-1 range for gdal_array.SaveArray()
def get_color(v):
first_color = []
second_color = []
@ -166,7 +171,7 @@ class ExportPlugin(BasicPlugin):
second_value = s[0]
second_color = s[1:]
break
if second_value == -1:
if second_value == 1:
return np.array(style[-1][1:])
first_dis = (v - first_value) / (second_value - first_value)
@ -174,7 +179,7 @@ class ExportPlugin(BasicPlugin):
first_color = np.array(first_color)
second_color = np.array(second_color)
color = first_color* first_dis + second_color * second_dis
return color
return np.floor(color)
get_color = np.frompyfunc(get_color, nin=1, nout=1)
@ -195,4 +200,6 @@ class ExportPlugin(BasicPlugin):
for i in range(3):
out_ds.GetRasterBand(i+1).WriteArray(rgbs[..., i])
del out_ds
del out_ds
self.send_message.emit('导出成功,结果保存至:'+out ) # 发送消息到控制台中显示。 创建完成消息后,关闭窗口。)

View File

@ -1,140 +1,141 @@
from misc import AlgFrontend
from misc.utils import format_now
from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5 import QtCore, QtGui, QtWidgets
from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
from rscder.utils.icons import IconInstance
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
def adaptiveMedianDeNoise(count, original):
# 初始窗口大小
startWindow = 3
# 卷积范围
c = count // 2
rows, cols = original.shape
newI = np.zeros(original.shape)
# median =
for i in range(c, rows - c):
for j in range(c, cols - c):
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
else:
while True:
startWindow = startWindow + 2
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma or startWindow > count:
break
if mi < median < ma or startWindow > count:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
return newI
@FILTER.register
class AdaptiveFilter(AlgFrontend):
@staticmethod
def get_name():
return '自适应滤波'
@staticmethod
def get_icon():
return IconInstance().ARITHMETIC2
@staticmethod
def get_widget(parent=None):
widget = QtWidgets.QWidget(parent)
x_size_input = QtWidgets.QLineEdit(widget)
x_size_input.setText('3')
x_size_input.setValidator(QtGui.QIntValidator())
x_size_input.setObjectName('xinput')
# y_size_input = QtWidgets.QLineEdit(widget)
# y_size_input.setValidator(QtGui.QIntValidator())
# y_size_input.setObjectName('yinput')
# y_size_input.setText('3')
size_label = QtWidgets.QLabel(widget)
size_label.setText('窗口大小:')
# time_label = QtWidgets.QLabel(widget)
# time_label.setText('X')
hlayout1 = QtWidgets.QHBoxLayout()
hlayout1.addWidget(size_label)
hlayout1.addWidget(x_size_input)
# hlayout1.addWidget(time_label)
# hlayout1.addWidget(y_size_input)
widget.setLayout(hlayout1)
return widget
@staticmethod
def get_params(widget:QtWidgets.QWidget=None):
if widget is None:
return dict(x_size=3)
x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput')
# y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput')
if x_input is None:
return dict(x_size=3)
x_size = int(x_input.text())
# y_size = int(y_input.text())
return dict(x_size=x_size)
@staticmethod
def run_alg(pth, x_size, *args, **kargs):
x_size = int(x_size)
# y_size = int(y_size)
# pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
name = os.path.splitext(os.path.basename(pth))[0]
out_path = os.path.join(Project().other_path, '{}_{}_{}.tif'.format(name, AdaptiveFilter.get_name(), format_now()))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = adaptiveMedianDeNoise(x_size, data)
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
return out_path
from misc import AlgFrontend
from misc.utils import format_now
from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5 import QtCore, QtGui, QtWidgets
from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
from rscder.utils.icons import IconInstance
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
def adaptiveMedianDeNoise(count, original):
# 初始窗口大小
startWindow = 3
# 卷积范围
c = count // 2
rows, cols = original.shape
newI = np.zeros(original.shape)
# median =
for i in range(c, rows - c):
for j in range(c, cols - c):
startWindow = 3
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
else:
while True:
startWindow = startWindow + 2
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma or startWindow >= count:
break
if mi < median < ma or startWindow > count:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
return newI
@FILTER.register
class AdaptiveFilter(AlgFrontend):
@staticmethod
def get_name():
return '自适应滤波'
@staticmethod
def get_icon():
return IconInstance().ARITHMETIC2
@staticmethod
def get_widget(parent=None):
widget = QtWidgets.QWidget(parent)
x_size_input = QtWidgets.QLineEdit(widget)
x_size_input.setText('9')
x_size_input.setValidator(QtGui.QIntValidator(6, 21))
x_size_input.setObjectName('xinput')
# y_size_input = QtWidgets.QLineEdit(widget)
# y_size_input.setValidator(QtGui.QIntValidator())
# y_size_input.setObjectName('yinput')
# y_size_input.setText('3')
size_label = QtWidgets.QLabel(widget)
size_label.setText('窗口大小:')
# time_label = QtWidgets.QLabel(widget)
# time_label.setText('X')
hlayout1 = QtWidgets.QHBoxLayout()
hlayout1.addWidget(size_label)
hlayout1.addWidget(x_size_input)
# hlayout1.addWidget(time_label)
# hlayout1.addWidget(y_size_input)
widget.setLayout(hlayout1)
return widget
@staticmethod
def get_params(widget:QtWidgets.QWidget=None):
if widget is None:
return dict(x_size=9)
x_input = widget.findChild(QtWidgets.QLineEdit, 'xinput')
# y_input = widget.findChild(QtWidgets.QLineEdit, 'yinput')
if x_input is None:
return dict(x_size=9)
x_size = int(x_input.text())
# y_size = int(y_input.text())
return dict(x_size=x_size)
@staticmethod
def run_alg(pth, x_size, *args, **kargs):
x_size = int(x_size)
# y_size = int(y_size)
# pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
name = os.path.splitext(os.path.basename(pth))[0]
out_path = os.path.join(Project().other_path, '{}_{}_{}.tif'.format(name, AdaptiveFilter.get_name(), format_now()))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = adaptiveMedianDeNoise(x_size, data)
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
return out_path

View File

@ -1,128 +1,128 @@
from misc import AlgFrontend
from misc.utils import format_now
from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5 import QtCore, QtGui, QtWidgets
from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
from rscder.utils.icons import IconInstance
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
def adaptiveMedianDeNoise(count, original):
# 初始窗口大小
startWindow = 3
# 卷积范围
c = count // 2
rows, cols = original.shape
newI = np.zeros(original.shape)
# median =
for i in range(c, rows - c):
for j in range(c, cols - c):
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
else:
while True:
startWindow = startWindow + 2
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma or startWindow > count:
break
if mi < median < ma or startWindow > count:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
return newI
@FILTER.register
class AdaptiveNPFilter(AlgFrontend):
@staticmethod
def get_name():
return '自动滤波(无参自适应滤波)'
@staticmethod
def get_icon():
return IconInstance().ARITHMETIC2
@staticmethod
def get_widget(parent=None):
# widget = QtWidgets.QWidget(parent)
# x_size_input = QtWidgets.QLineEdit(widget)
# x_size_input.setText('3')
# x_size_input.setValidator(QtGui.QIntValidator())
# x_size_input.setObjectName('xinput')
# # y_size_input = QtWidgets.QLineEdit(widget)
# # y_size_input.setValidator(QtGui.QIntValidator())
# # y_size_input.setObjectName('yinput')
# # y_size_input.setText('3')
# size_label = QtWidgets.QLabel(widget)
# size_label.setText('窗口大小:')
# # time_label = QtWidgets.QLabel(widget)
# # time_label.setText('X')
# hlayout1 = QtWidgets.QHBoxLayout()
# hlayout1.addWidget(size_label)
# hlayout1.addWidget(x_size_input)
# # hlayout1.addWidget(time_label)
# # hlayout1.addWidget(y_size_input)
# widget.setLayout(hlayout1)
return None
@staticmethod
def get_params(widget:QtWidgets.QWidget=None):
return dict()
@staticmethod
def run_alg(pth, x_size, *args, **kargs):
# x_size = int(x_size)
# y_size = int(y_size)
# pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
name = os.path.splitext(os.path.basename(pth))[0]
out_path = os.path.join(Project().other_path, '{}_{}_{}.tif'.format(name, AdaptiveNPFilter.get_name(), format_now()))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = adaptiveMedianDeNoise(5, data)
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
return out_path
from misc import AlgFrontend
from misc.utils import format_now
from osgeo import gdal, gdal_array
from skimage.filters import rank
from skimage.morphology import rectangle
from filter_collection import FILTER
from PyQt5.QtWidgets import QDialog, QAction
from PyQt5 import QtCore, QtGui, QtWidgets
from rscder.utils.project import PairLayer, Project, RasterLayer, ResultPointLayer
from rscder.utils.icons import IconInstance
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
def adaptiveMedianDeNoise(count, original):
# 初始窗口大小
startWindow = 3
# 卷积范围
c = count // 2
rows, cols = original.shape
newI = np.zeros(original.shape)
# median =
for i in range(c, rows - c):
for j in range(c, cols - c):
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
else:
while True:
startWindow = startWindow + 2
k = int(startWindow / 2)
median = np.median(original[i - k:i + k + 1, j - k:j + k + 1])
mi = np.min(original[i - k:i + k + 1, j - k:j + k + 1])
ma = np.max(original[i - k:i + k + 1, j - k:j + k + 1])
if mi < median < ma or startWindow >= count:
break
if mi < median < ma or startWindow >= count:
if mi < original[i, j] < ma:
newI[i, j] = original[i, j]
else:
newI[i, j] = median
return newI
@FILTER.register
class AdaptiveNPFilter(AlgFrontend):
@staticmethod
def get_name():
return '自动滤波(无参自适应滤波)'
@staticmethod
def get_icon():
return IconInstance().ARITHMETIC2
@staticmethod
def get_widget(parent=None):
# widget = QtWidgets.QWidget(parent)
# x_size_input = QtWidgets.QLineEdit(widget)
# x_size_input.setText('3')
# x_size_input.setValidator(QtGui.QIntValidator())
# x_size_input.setObjectName('xinput')
# # y_size_input = QtWidgets.QLineEdit(widget)
# # y_size_input.setValidator(QtGui.QIntValidator())
# # y_size_input.setObjectName('yinput')
# # y_size_input.setText('3')
# size_label = QtWidgets.QLabel(widget)
# size_label.setText('窗口大小:')
# # time_label = QtWidgets.QLabel(widget)
# # time_label.setText('X')
# hlayout1 = QtWidgets.QHBoxLayout()
# hlayout1.addWidget(size_label)
# hlayout1.addWidget(x_size_input)
# # hlayout1.addWidget(time_label)
# # hlayout1.addWidget(y_size_input)
# widget.setLayout(hlayout1)
return None
@staticmethod
def get_params(widget:QtWidgets.QWidget=None):
return dict()
@staticmethod
def run_alg(pth, *args, **kargs):
# x_size = int(x_size)
# y_size = int(y_size)
# pth = layer.path
if pth is None:
return
ds = gdal.Open(pth)
band_count = ds.RasterCount
name = os.path.splitext(os.path.basename(pth))[0]
out_path = os.path.join(Project().other_path, '{}_{}_{}.tif'.format(name, AdaptiveNPFilter.get_name(), format_now()))
out_ds = gdal.GetDriverByName('GTiff').Create(out_path, ds.RasterXSize, ds.RasterYSize, band_count, ds.GetRasterBand(1).DataType)
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
for i in range(band_count):
band = ds.GetRasterBand(i+1)
data = band.ReadAsArray()
data = adaptiveMedianDeNoise(5, data)
out_band = out_ds.GetRasterBand(i+1)
out_band.WriteArray(data)
out_ds.FlushCache()
del out_ds
del ds
return out_path

View File

@ -114,12 +114,17 @@ class MainPlugin(BasicPlugin):
pth = layer.path
if pth is None:
return
out_path = alg.run_alg(pth, **p, send_message=self.send_message)
self.send_message.emit('滤波完成')
self.send_message.emit('结果在:'+out_path)
self.send_message.emit('滤波开始')
try:
out_path = alg.run_alg(pth, **p, send_message=self.send_message)
self.send_message.emit('滤波完成')
self.send_message.emit('结果在:'+out_path)
self.alg_ok.emit(layer, out_path)
self.alg_ok.emit(layer, out_path)
except Exception as e:
self.send_message.emit('滤波出现异常:'+str(e))
def run(self, key):
if key not in FILTER:

View File

@ -157,6 +157,6 @@ class FollowPlugin(BasicPlugin):
self.current_widget = FOLLOW[self.combox.currentData()].get_widget(ActionManager().follow_box)
self.layout.addWidget(self.current_widget)
self.layout.addLayout(self.btn_box)
self.layout.addWidget(self.btn_box)

View File

@ -12419,6 +12419,6 @@ Size = Size2i
def AHT(file1, file2, outfile):
return _AHT.AHT(file1, file2, outfile)
return _AHT.LHBA(file1, file2, outfile)

Binary file not shown.

View File

@ -2,7 +2,6 @@ from rscder.utils.icons import IconInstance
from .SH import SH
from .LHBA import LHBA
from .OCD import OCD
from .AHT import AHT
from .ACD import ACD
import numpy as np
from datetime import datetime
@ -686,7 +685,7 @@ class AHTAlg(AlgFrontend):
send_message.emit('图像一提取完成')
# 运算
send_message.emit('开始LHBA计算.....')
send_message.emit('开始AHT计算.....')
time.sleep(0.1)
out_normal_tif = os.path.join(Project().cmi_path, '{}_{}_cmi.tif'.format(
layer_parent.name, int(np.random.rand() * 100000)))

1
test.lic Normal file
View File

@ -0,0 +1 @@
IaqFuRlbPMtYTReB0p+cxn8sffVeOjbq+d46I2texZIVIeLbwCxJ7w3mqlramQy3p0totEfoSkjIrQV1GjtrOOWsuYRqv5ZZ5A+/PdTd7ZU8WlMAl7sknJGJFWvciG1VL9n9XtJUG+CJg4oLYkdwR5WePYUsC4y4NcS5+rlP1if87xtYqVeSwx3c64cOmAGP