in cpu device also ok
This commit is contained in:
parent
042d134557
commit
f4b994d30b
3
.gitignore
vendored
3
.gitignore
vendored
@ -278,4 +278,5 @@ $RECYCLE.BIN/
|
||||
|
||||
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
|
||||
# pytorch pth
|
||||
*.pth
|
||||
*.pth
|
||||
install/
|
6
main.py
6
main.py
@ -4,9 +4,11 @@
|
||||
|
||||
from PyQt5 import QtWidgets
|
||||
import os
|
||||
os.environ['SAM_ANN_BASE_DIR'] = os.path.dirname(__file__)
|
||||
from sam_ann.widgets.mainwindow import MainWindow
|
||||
import sys
|
||||
os.environ['SAM_ANN_BASE_DIR'] = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(os.environ['SAM_ANN_BASE_DIR'], 'lib'))
|
||||
from sam_ann import MainWindow
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -0,0 +1 @@
|
||||
from .widgets.mainwindow import MainWindow
|
@ -40,7 +40,7 @@ class Annotation:
|
||||
print('Warning: Except image has 2 or 3 ndim, but get {}.'.format(image.ndim))
|
||||
del image
|
||||
|
||||
self.objects:List[Object,...] = []
|
||||
self.objects:List[Object] = []
|
||||
|
||||
def load_annotation(self):
|
||||
if os.path.exists(self.label_path):
|
||||
|
@ -3,9 +3,9 @@ from enum import Enum
|
||||
import os
|
||||
BASE_DIR = os.environ['SAM_ANN_BASE_DIR']
|
||||
|
||||
DEFAULT_CONFIG_FILE = 'default.yaml'
|
||||
CONFIG_FILE = 'isat.yaml'
|
||||
|
||||
DEFAULT_CONFIG_FILE = os.path.join(BASE_DIR, 'sam_ann', 'default.yaml')
|
||||
CONFIG_FILE = os.path.join(BASE_DIR, 'sam_ann', 'isat.yaml')
|
||||
CHECKPOINTS = os.path.join(BASE_DIR, 'checkpoints')
|
||||
def load_config(file):
|
||||
with open(file, 'rb')as f:
|
||||
cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||
|
@ -88,7 +88,7 @@ def build_sam_vit_t(checkpoint=None):
|
||||
mobile_sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
mobile_sam.load_state_dict(state_dict)
|
||||
return mobile_sam
|
||||
|
||||
@ -152,7 +152,7 @@ def _build_sam(
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
sam.load_state_dict(state_dict)
|
||||
return sam
|
||||
|
||||
|
@ -8,37 +8,37 @@ import torch
|
||||
import numpy as np
|
||||
import timm
|
||||
import os
|
||||
|
||||
from sam_ann.configs import CHECKPOINTS
|
||||
class SegAny:
|
||||
def __init__(self, checkpoint):
|
||||
if 'mobile_sam' in os.path.basename(checkpoint):
|
||||
if 'mobile_sam' in checkpoint:
|
||||
# mobile sam
|
||||
from mobile_sam import sam_model_registry, SamPredictor
|
||||
from sam_ann.mobile_sam import sam_model_registry, SamPredictor
|
||||
print('- mobile sam!')
|
||||
self.model_type = "vit_t"
|
||||
elif 'sam_hq_vit' in os.path.basename(checkpoint):
|
||||
elif 'sam_hq_vit' in checkpoint:
|
||||
# sam hq
|
||||
from segment_anything_hq import sam_model_registry, SamPredictor
|
||||
from sam_ann.segment_anything_hq import sam_model_registry, SamPredictor
|
||||
print('- sam hq!')
|
||||
if 'vit_b' in os.path.basename(checkpoint):
|
||||
if 'vit_b' in checkpoint:
|
||||
self.model_type = "vit_b"
|
||||
elif 'vit_l' in os.path.basename(checkpoint):
|
||||
elif 'vit_l' in checkpoint:
|
||||
self.model_type = "vit_l"
|
||||
elif 'vit_h' in os.path.basename(checkpoint):
|
||||
elif 'vit_h' in checkpoint:
|
||||
self.model_type = "vit_h"
|
||||
elif 'vit_tiny' in os.path.basename(checkpoint):
|
||||
elif 'vit_tiny' in checkpoint:
|
||||
self.model_type = "vit_tiny"
|
||||
else:
|
||||
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
|
||||
elif 'sam_vit' in os.path.basename(checkpoint):
|
||||
elif 'sam_vit' in checkpoint:
|
||||
# sam
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
from sam_ann.segment_anything import sam_model_registry, SamPredictor
|
||||
print('- sam!')
|
||||
if 'vit_b' in os.path.basename(checkpoint):
|
||||
if 'vit_b' in checkpoint:
|
||||
self.model_type = "vit_b"
|
||||
elif 'vit_l' in os.path.basename(checkpoint):
|
||||
elif 'vit_l' in checkpoint:
|
||||
self.model_type = "vit_l"
|
||||
elif 'vit_h' in os.path.basename(checkpoint):
|
||||
elif 'vit_h' in checkpoint:
|
||||
self.model_type = "vit_h"
|
||||
else:
|
||||
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
|
||||
@ -47,7 +47,7 @@ class SegAny:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||
sam = sam_model_registry[self.model_type](checkpoint=os.path.join(CHECKPOINTS, checkpoint))
|
||||
sam.to(device=self.device)
|
||||
self.predictor_with_point_prompt = SamPredictor(sam)
|
||||
self.image = None
|
||||
|
@ -102,6 +102,6 @@ def _build_sam(
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
sam.load_state_dict(state_dict)
|
||||
return sam
|
||||
|
@ -89,7 +89,7 @@ def build_sam_vit_t(checkpoint=None):
|
||||
mobile_sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
info = mobile_sam.load_state_dict(state_dict, strict=False)
|
||||
print(info)
|
||||
for n, p in mobile_sam.named_parameters():
|
||||
@ -157,7 +157,7 @@ def _build_sam(
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
info = sam.load_state_dict(state_dict, strict=False)
|
||||
print(info)
|
||||
for n, p in sam.named_parameters():
|
||||
|
@ -88,7 +88,7 @@ def build_sam_vit_t(checkpoint=None):
|
||||
mobile_sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
mobile_sam.load_state_dict(state_dict)
|
||||
return mobile_sam
|
||||
|
||||
@ -151,6 +151,6 @@ def _build_sam(
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
state_dict = torch.load(f, map_location=torch.device('cpu'))
|
||||
sam.load_state_dict(state_dict)
|
||||
return sam
|
@ -18,7 +18,7 @@ from sam_ann.widgets.ISAT_to_COCO_dialog import ISATtoCOCODialog
|
||||
from sam_ann.widgets.ISAT_to_LABELME_dialog import ISATtoLabelMeDialog
|
||||
from sam_ann.widgets.COCO_to_ISAT_dialog import COCOtoISATDialog
|
||||
from sam_ann.widgets.canvas import AnnotationScene, AnnotationView
|
||||
from sam_ann.configs import STATUSMode, MAPMode, load_config, save_config, CONFIG_FILE, DEFAULT_CONFIG_FILE
|
||||
from sam_ann.configs import STATUSMode, MAPMode, load_config, save_config, CONFIG_FILE, DEFAULT_CONFIG_FILE, CHECKPOINTS
|
||||
from sam_ann.annotation import Object, Annotation
|
||||
from sam_ann.widgets.polygon import Polygon
|
||||
from sam_ann.configs import BASE_DIR
|
||||
@ -77,7 +77,7 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
for name, action in self.pths_actions.items():
|
||||
action.setChecked(model_name == name)
|
||||
return
|
||||
model_path = os.path.join('segment_any', model_name)
|
||||
model_path = os.path.join(CHECKPOINTS, model_name)
|
||||
if not os.path.exists(model_path):
|
||||
QtWidgets.QMessageBox.warning(self, 'Warning',
|
||||
'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
|
||||
@ -189,7 +189,7 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
self.statusbar.addPermanentWidget(self.labelData)
|
||||
|
||||
#
|
||||
model_names = sorted([pth for pth in os.listdir(os.path.join(BASE_DIR, 'checkpoints')) if pth.endswith('.pth') or pth.endswith('.pt')])
|
||||
model_names = sorted([pth for pth in os.listdir(CHECKPOINTS) if pth.endswith('.pth') or pth.endswith('.pt')])
|
||||
self.pths_actions = {}
|
||||
for model_name in model_names:
|
||||
action = QtWidgets.QAction(self)
|
||||
|
Loading…
x
Reference in New Issue
Block a user