from rscder.utils.icons import IconInstance from ai_method.basic_cd import AIMethodDialog, AI_METHOD, AlgFrontend import os from PyQt5.QtWidgets import QFormLayout, QLabel, QLineEdit, QPushButton, QDialogButtonBox, QWidget from PyQt5.QtCore import Qt class STAParams(AlgFrontend): options = dict( train_dir= dict( label = '训练集根目录', opt='--dataroot', dtype='str', default=None ), val_dir = dict( label = '验证集根目录', opt='--val_dataroot', dtype='str', default=None ), # test_dir = dict( # label = '测试集根目录', # opt='--dataroot', # dtype='str', # default=None # ) ) @staticmethod def get_widget(parent=None): widget = QWidget(parent) form = QFormLayout(widget) # form.setFormAlignment() # for key in STAParams.options: train_dir_label = QLabel('训练集根目录') train_dir_data = QLineEdit() train_dir_data.setObjectName('train_dir') form.addRow(train_dir_label, train_dir_data) val_dir_label = QLabel('验证集根目录') val_dir_data = QLineEdit() val_dir_data.setObjectName('val_dir') form.addRow(val_dir_label, val_dir_data) test_dir_label = QLabel('测试集根目录') test_dir_data = QLineEdit() # test_dir_data.tex test_dir_data.setObjectName('test_dir') form.addRow(test_dir_label, test_dir_data) widget.setLayout(form) return widget @staticmethod def get_params(widget:QWidget=None): if widget is None: return None opt = [] for key in STAParams.options: comp:QLineEdit = widget.findChild(QLineEdit, name=key) if comp is None: if STAParams.options[key]['default'] is not None: opt.append(STAParams.options[key]['opt']) opt.append(STAParams.options[key]['default']) continue opt.append(STAParams.options[key]['opt']) opt.append(comp.text()) return opt class STAMethod(AIMethodDialog): ENV = 'torch1121cu113' setting_widget = STAParams stages = [ ('train', '训练'), ('test', '测试'), ('predict_batch', '批量预测') ] name = 'STA Net' @property def workdir(self): return os.path.abspath(os.path.join(os.path.dirname(__file__), 'STA_net')) def stage_script(self, stage): if stage == 'train': return os.path.join(self.workdir, 'train.py') elif stage == 'test': return os.path.join(self.workdir, 'test.py') else: return None @AI_METHOD.register class STANet(AlgFrontend): @staticmethod def get_name(): return 'STA Net' @staticmethod def get_icon(): return IconInstance().AI_DETECT @staticmethod def get_widget(parent=None): return STAMethod(parent)