diff --git a/.gitignore b/.gitignore index 3c5acb4..7269ba6 100644 --- a/.gitignore +++ b/.gitignore @@ -278,4 +278,5 @@ $RECYCLE.BIN/ # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) # pytorch pth -*.pth \ No newline at end of file +*.pth +install/ \ No newline at end of file diff --git a/main.py b/main.py index 4398988..07a2a7b 100644 --- a/main.py +++ b/main.py @@ -4,9 +4,11 @@ from PyQt5 import QtWidgets import os -os.environ['SAM_ANN_BASE_DIR'] = os.path.dirname(__file__) -from sam_ann.widgets.mainwindow import MainWindow import sys +os.environ['SAM_ANN_BASE_DIR'] = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(os.environ['SAM_ANN_BASE_DIR'], 'lib')) +from sam_ann import MainWindow + if __name__ == '__main__': diff --git a/sam_ann/__init__.py b/sam_ann/__init__.py index e69de29..6f0a62e 100644 --- a/sam_ann/__init__.py +++ b/sam_ann/__init__.py @@ -0,0 +1 @@ +from .widgets.mainwindow import MainWindow \ No newline at end of file diff --git a/sam_ann/annotation.py b/sam_ann/annotation.py index 1992622..003d00c 100644 --- a/sam_ann/annotation.py +++ b/sam_ann/annotation.py @@ -40,7 +40,7 @@ class Annotation: print('Warning: Except image has 2 or 3 ndim, but get {}.'.format(image.ndim)) del image - self.objects:List[Object,...] = [] + self.objects:List[Object] = [] def load_annotation(self): if os.path.exists(self.label_path): diff --git a/sam_ann/configs.py b/sam_ann/configs.py index f51ed17..e9738d4 100644 --- a/sam_ann/configs.py +++ b/sam_ann/configs.py @@ -3,9 +3,9 @@ from enum import Enum import os BASE_DIR = os.environ['SAM_ANN_BASE_DIR'] -DEFAULT_CONFIG_FILE = 'default.yaml' -CONFIG_FILE = 'isat.yaml' - +DEFAULT_CONFIG_FILE = os.path.join(BASE_DIR, 'sam_ann', 'default.yaml') +CONFIG_FILE = os.path.join(BASE_DIR, 'sam_ann', 'isat.yaml') +CHECKPOINTS = os.path.join(BASE_DIR, 'checkpoints') def load_config(file): with open(file, 'rb')as f: cfg = yaml.load(f.read(), Loader=yaml.FullLoader) diff --git a/default.yaml b/sam_ann/default.yaml similarity index 100% rename from default.yaml rename to sam_ann/default.yaml diff --git a/isat.yaml b/sam_ann/isat.yaml similarity index 100% rename from isat.yaml rename to sam_ann/isat.yaml diff --git a/sam_ann/mobile_sam/build_sam.py b/sam_ann/mobile_sam/build_sam.py index 9a52c50..f2e4e94 100644 --- a/sam_ann/mobile_sam/build_sam.py +++ b/sam_ann/mobile_sam/build_sam.py @@ -88,7 +88,7 @@ def build_sam_vit_t(checkpoint=None): mobile_sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) mobile_sam.load_state_dict(state_dict) return mobile_sam @@ -152,7 +152,7 @@ def _build_sam( sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) sam.load_state_dict(state_dict) return sam diff --git a/sam_ann/segment_any/segment_any.py b/sam_ann/segment_any/segment_any.py index 5c1f7aa..6db6192 100644 --- a/sam_ann/segment_any/segment_any.py +++ b/sam_ann/segment_any/segment_any.py @@ -8,37 +8,37 @@ import torch import numpy as np import timm import os - +from sam_ann.configs import CHECKPOINTS class SegAny: def __init__(self, checkpoint): - if 'mobile_sam' in os.path.basename(checkpoint): + if 'mobile_sam' in checkpoint: # mobile sam - from mobile_sam import sam_model_registry, SamPredictor + from sam_ann.mobile_sam import sam_model_registry, SamPredictor print('- mobile sam!') self.model_type = "vit_t" - elif 'sam_hq_vit' in os.path.basename(checkpoint): + elif 'sam_hq_vit' in checkpoint: # sam hq - from segment_anything_hq import sam_model_registry, SamPredictor + from sam_ann.segment_anything_hq import sam_model_registry, SamPredictor print('- sam hq!') - if 'vit_b' in os.path.basename(checkpoint): + if 'vit_b' in checkpoint: self.model_type = "vit_b" - elif 'vit_l' in os.path.basename(checkpoint): + elif 'vit_l' in checkpoint: self.model_type = "vit_l" - elif 'vit_h' in os.path.basename(checkpoint): + elif 'vit_h' in checkpoint: self.model_type = "vit_h" - elif 'vit_tiny' in os.path.basename(checkpoint): + elif 'vit_tiny' in checkpoint: self.model_type = "vit_tiny" else: raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint)) - elif 'sam_vit' in os.path.basename(checkpoint): + elif 'sam_vit' in checkpoint: # sam - from segment_anything import sam_model_registry, SamPredictor + from sam_ann.segment_anything import sam_model_registry, SamPredictor print('- sam!') - if 'vit_b' in os.path.basename(checkpoint): + if 'vit_b' in checkpoint: self.model_type = "vit_b" - elif 'vit_l' in os.path.basename(checkpoint): + elif 'vit_l' in checkpoint: self.model_type = "vit_l" - elif 'vit_h' in os.path.basename(checkpoint): + elif 'vit_h' in checkpoint: self.model_type = "vit_h" else: raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint)) @@ -47,7 +47,7 @@ class SegAny: torch.cuda.empty_cache() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - sam = sam_model_registry[self.model_type](checkpoint=checkpoint) + sam = sam_model_registry[self.model_type](checkpoint=os.path.join(CHECKPOINTS, checkpoint)) sam.to(device=self.device) self.predictor_with_point_prompt = SamPredictor(sam) self.image = None diff --git a/sam_ann/segment_anything/build_sam.py b/sam_ann/segment_anything/build_sam.py index 37cd245..9039785 100644 --- a/sam_ann/segment_anything/build_sam.py +++ b/sam_ann/segment_anything/build_sam.py @@ -102,6 +102,6 @@ def _build_sam( sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) sam.load_state_dict(state_dict) return sam diff --git a/sam_ann/segment_anything_hq/build_sam.py b/sam_ann/segment_anything_hq/build_sam.py index bd04173..19d60b2 100644 --- a/sam_ann/segment_anything_hq/build_sam.py +++ b/sam_ann/segment_anything_hq/build_sam.py @@ -89,7 +89,7 @@ def build_sam_vit_t(checkpoint=None): mobile_sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) info = mobile_sam.load_state_dict(state_dict, strict=False) print(info) for n, p in mobile_sam.named_parameters(): @@ -157,7 +157,7 @@ def _build_sam( sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) info = sam.load_state_dict(state_dict, strict=False) print(info) for n, p in sam.named_parameters(): diff --git a/sam_ann/segment_anything_hq/build_sam_baseline.py b/sam_ann/segment_anything_hq/build_sam_baseline.py index b1d34d7..710048b 100644 --- a/sam_ann/segment_anything_hq/build_sam_baseline.py +++ b/sam_ann/segment_anything_hq/build_sam_baseline.py @@ -88,7 +88,7 @@ def build_sam_vit_t(checkpoint=None): mobile_sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) mobile_sam.load_state_dict(state_dict) return mobile_sam @@ -151,6 +151,6 @@ def _build_sam( sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, map_location=torch.device('cpu')) sam.load_state_dict(state_dict) return sam \ No newline at end of file diff --git a/sam_ann/widgets/mainwindow.py b/sam_ann/widgets/mainwindow.py index c2b53f1..fb2afad 100644 --- a/sam_ann/widgets/mainwindow.py +++ b/sam_ann/widgets/mainwindow.py @@ -18,7 +18,7 @@ from sam_ann.widgets.ISAT_to_COCO_dialog import ISATtoCOCODialog from sam_ann.widgets.ISAT_to_LABELME_dialog import ISATtoLabelMeDialog from sam_ann.widgets.COCO_to_ISAT_dialog import COCOtoISATDialog from sam_ann.widgets.canvas import AnnotationScene, AnnotationView -from sam_ann.configs import STATUSMode, MAPMode, load_config, save_config, CONFIG_FILE, DEFAULT_CONFIG_FILE +from sam_ann.configs import STATUSMode, MAPMode, load_config, save_config, CONFIG_FILE, DEFAULT_CONFIG_FILE, CHECKPOINTS from sam_ann.annotation import Object, Annotation from sam_ann.widgets.polygon import Polygon from sam_ann.configs import BASE_DIR @@ -77,7 +77,7 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): for name, action in self.pths_actions.items(): action.setChecked(model_name == name) return - model_path = os.path.join('segment_any', model_name) + model_path = os.path.join(CHECKPOINTS, model_name) if not os.path.exists(model_path): QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format( @@ -189,7 +189,7 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): self.statusbar.addPermanentWidget(self.labelData) # - model_names = sorted([pth for pth in os.listdir(os.path.join(BASE_DIR, 'checkpoints')) if pth.endswith('.pth') or pth.endswith('.pt')]) + model_names = sorted([pth for pth in os.listdir(CHECKPOINTS) if pth.endswith('.pth') or pth.endswith('.pt')]) self.pths_actions = {} for model_name in model_names: action = QtWidgets.QAction(self)