109 lines
3.2 KiB
Python
109 lines
3.2 KiB
Python
from rscder.utils.icons import IconInstance
|
|
from ai_method.basic_cd import AIMethodDialog, AI_METHOD, AlgFrontend
|
|
import os
|
|
from PyQt5.QtWidgets import QFormLayout, QLabel, QLineEdit, QPushButton, QDialogButtonBox, QWidget
|
|
from PyQt5.QtCore import Qt
|
|
|
|
class STAParams(AlgFrontend):
|
|
|
|
options = dict(
|
|
train_dir= dict(
|
|
label = '训练集根目录',
|
|
opt='--dataroot',
|
|
dtype='str',
|
|
default=None
|
|
),
|
|
val_dir = dict(
|
|
label = '验证集根目录',
|
|
opt='--val_dataroot',
|
|
dtype='str',
|
|
default=None
|
|
),
|
|
# test_dir = dict(
|
|
# label = '测试集根目录',
|
|
# opt='--dataroot',
|
|
# dtype='str',
|
|
# default=None
|
|
# )
|
|
)
|
|
|
|
@staticmethod
|
|
def get_widget(parent=None):
|
|
widget = QWidget(parent)
|
|
form = QFormLayout(widget)
|
|
# form.setFormAlignment()
|
|
# for key in STAParams.options:
|
|
|
|
train_dir_label = QLabel('训练集根目录')
|
|
train_dir_data = QLineEdit()
|
|
train_dir_data.setObjectName('train_dir')
|
|
form.addRow(train_dir_label, train_dir_data)
|
|
|
|
val_dir_label = QLabel('验证集根目录')
|
|
val_dir_data = QLineEdit()
|
|
val_dir_data.setObjectName('val_dir')
|
|
form.addRow(val_dir_label, val_dir_data)
|
|
|
|
test_dir_label = QLabel('测试集根目录')
|
|
test_dir_data = QLineEdit()
|
|
# test_dir_data.tex
|
|
test_dir_data.setObjectName('test_dir')
|
|
form.addRow(test_dir_label, test_dir_data)
|
|
|
|
widget.setLayout(form)
|
|
return widget
|
|
|
|
@staticmethod
|
|
def get_params(widget:QWidget=None):
|
|
if widget is None:
|
|
return None
|
|
opt = []
|
|
for key in STAParams.options:
|
|
comp:QLineEdit = widget.findChild(QLineEdit, name=key)
|
|
if comp is None:
|
|
if STAParams.options[key]['default'] is not None:
|
|
opt.append(STAParams.options[key]['opt'])
|
|
opt.append(STAParams.options[key]['default'])
|
|
continue
|
|
|
|
opt.append(STAParams.options[key]['opt'])
|
|
opt.append(comp.text())
|
|
|
|
return opt
|
|
|
|
|
|
class STAMethod(AIMethodDialog):
|
|
|
|
ENV = 'torch1121cu113'
|
|
setting_widget = STAParams
|
|
stages = [ ('train', '训练'), ('test', '测试'), ('predict_batch', '批量预测') ]
|
|
name = 'STA Net'
|
|
|
|
@property
|
|
def workdir(self):
|
|
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'STA_net'))
|
|
|
|
def stage_script(self, stage):
|
|
if stage == 'train':
|
|
return os.path.join(self.workdir, 'train.py')
|
|
elif stage == 'test':
|
|
return os.path.join(self.workdir, 'test.py')
|
|
else:
|
|
return None
|
|
|
|
|
|
@AI_METHOD.register
|
|
class STANet(AlgFrontend):
|
|
|
|
@staticmethod
|
|
def get_name():
|
|
return 'STA Net'
|
|
|
|
@staticmethod
|
|
def get_icon():
|
|
return IconInstance().AI_DETECT
|
|
|
|
@staticmethod
|
|
def get_widget(parent=None):
|
|
return STAMethod(parent)
|
|
|