2023-02-10 21:51:03 +08:00

109 lines
3.2 KiB
Python

from rscder.utils.icons import IconInstance
from ai_method.basic_cd import AIMethodDialog, AI_METHOD, AlgFrontend
import os
from PyQt5.QtWidgets import QFormLayout, QLabel, QLineEdit, QPushButton, QDialogButtonBox, QWidget
from PyQt5.QtCore import Qt
class STAParams(AlgFrontend):
options = dict(
train_dir= dict(
label = '训练集根目录',
opt='--dataroot',
dtype='str',
default=None
),
val_dir = dict(
label = '验证集根目录',
opt='--val_dataroot',
dtype='str',
default=None
),
# test_dir = dict(
# label = '测试集根目录',
# opt='--dataroot',
# dtype='str',
# default=None
# )
)
@staticmethod
def get_widget(parent=None):
widget = QWidget(parent)
form = QFormLayout(widget)
# form.setFormAlignment()
# for key in STAParams.options:
train_dir_label = QLabel('训练集根目录')
train_dir_data = QLineEdit()
train_dir_data.setObjectName('train_dir')
form.addRow(train_dir_label, train_dir_data)
val_dir_label = QLabel('验证集根目录')
val_dir_data = QLineEdit()
val_dir_data.setObjectName('val_dir')
form.addRow(val_dir_label, val_dir_data)
test_dir_label = QLabel('测试集根目录')
test_dir_data = QLineEdit()
# test_dir_data.tex
test_dir_data.setObjectName('test_dir')
form.addRow(test_dir_label, test_dir_data)
widget.setLayout(form)
return widget
@staticmethod
def get_params(widget:QWidget=None):
if widget is None:
return None
opt = []
for key in STAParams.options:
comp:QLineEdit = widget.findChild(QLineEdit, name=key)
if comp is None:
if STAParams.options[key]['default'] is not None:
opt.append(STAParams.options[key]['opt'])
opt.append(STAParams.options[key]['default'])
continue
opt.append(STAParams.options[key]['opt'])
opt.append(comp.text())
return opt
class STAMethod(AIMethodDialog):
ENV = 'torch1121cu113'
setting_widget = STAParams
stages = [ ('train', '训练'), ('test', '测试'), ('predict_batch', '批量预测') ]
name = 'STA Net'
@property
def workdir(self):
return os.path.abspath(os.path.join(os.path.dirname(__file__), 'STA_net'))
def stage_script(self, stage):
if stage == 'train':
return os.path.join(self.workdir, 'train.py')
elif stage == 'test':
return os.path.join(self.workdir, 'test.py')
else:
return None
@AI_METHOD.register
class STANet(AlgFrontend):
@staticmethod
def get_name():
return 'STA Net'
@staticmethod
def get_icon():
return IconInstance().AI_DETECT
@staticmethod
def get_widget(parent=None):
return STAMethod(parent)