in cpu device also ok

This commit is contained in:
copper 2023-09-12 15:52:49 +08:00
parent 042d134557
commit f4b994d30b
13 changed files with 36 additions and 32 deletions

3
.gitignore vendored
View File

@ -278,4 +278,5 @@ $RECYCLE.BIN/
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
# pytorch pth
*.pth
*.pth
install/

View File

@ -4,9 +4,11 @@
from PyQt5 import QtWidgets
import os
os.environ['SAM_ANN_BASE_DIR'] = os.path.dirname(__file__)
from sam_ann.widgets.mainwindow import MainWindow
import sys
os.environ['SAM_ANN_BASE_DIR'] = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(os.environ['SAM_ANN_BASE_DIR'], 'lib'))
from sam_ann import MainWindow
if __name__ == '__main__':

View File

@ -0,0 +1 @@
from .widgets.mainwindow import MainWindow

View File

@ -40,7 +40,7 @@ class Annotation:
print('Warning: Except image has 2 or 3 ndim, but get {}.'.format(image.ndim))
del image
self.objects:List[Object,...] = []
self.objects:List[Object] = []
def load_annotation(self):
if os.path.exists(self.label_path):

View File

@ -3,9 +3,9 @@ from enum import Enum
import os
BASE_DIR = os.environ['SAM_ANN_BASE_DIR']
DEFAULT_CONFIG_FILE = 'default.yaml'
CONFIG_FILE = 'isat.yaml'
DEFAULT_CONFIG_FILE = os.path.join(BASE_DIR, 'sam_ann', 'default.yaml')
CONFIG_FILE = os.path.join(BASE_DIR, 'sam_ann', 'isat.yaml')
CHECKPOINTS = os.path.join(BASE_DIR, 'checkpoints')
def load_config(file):
with open(file, 'rb')as f:
cfg = yaml.load(f.read(), Loader=yaml.FullLoader)

View File

@ -88,7 +88,7 @@ def build_sam_vit_t(checkpoint=None):
mobile_sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
mobile_sam.load_state_dict(state_dict)
return mobile_sam
@ -152,7 +152,7 @@ def _build_sam(
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
sam.load_state_dict(state_dict)
return sam

View File

@ -8,37 +8,37 @@ import torch
import numpy as np
import timm
import os
from sam_ann.configs import CHECKPOINTS
class SegAny:
def __init__(self, checkpoint):
if 'mobile_sam' in os.path.basename(checkpoint):
if 'mobile_sam' in checkpoint:
# mobile sam
from mobile_sam import sam_model_registry, SamPredictor
from sam_ann.mobile_sam import sam_model_registry, SamPredictor
print('- mobile sam!')
self.model_type = "vit_t"
elif 'sam_hq_vit' in os.path.basename(checkpoint):
elif 'sam_hq_vit' in checkpoint:
# sam hq
from segment_anything_hq import sam_model_registry, SamPredictor
from sam_ann.segment_anything_hq import sam_model_registry, SamPredictor
print('- sam hq!')
if 'vit_b' in os.path.basename(checkpoint):
if 'vit_b' in checkpoint:
self.model_type = "vit_b"
elif 'vit_l' in os.path.basename(checkpoint):
elif 'vit_l' in checkpoint:
self.model_type = "vit_l"
elif 'vit_h' in os.path.basename(checkpoint):
elif 'vit_h' in checkpoint:
self.model_type = "vit_h"
elif 'vit_tiny' in os.path.basename(checkpoint):
elif 'vit_tiny' in checkpoint:
self.model_type = "vit_tiny"
else:
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
elif 'sam_vit' in os.path.basename(checkpoint):
elif 'sam_vit' in checkpoint:
# sam
from segment_anything import sam_model_registry, SamPredictor
from sam_ann.segment_anything import sam_model_registry, SamPredictor
print('- sam!')
if 'vit_b' in os.path.basename(checkpoint):
if 'vit_b' in checkpoint:
self.model_type = "vit_b"
elif 'vit_l' in os.path.basename(checkpoint):
elif 'vit_l' in checkpoint:
self.model_type = "vit_l"
elif 'vit_h' in os.path.basename(checkpoint):
elif 'vit_h' in checkpoint:
self.model_type = "vit_h"
else:
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
@ -47,7 +47,7 @@ class SegAny:
torch.cuda.empty_cache()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
sam = sam_model_registry[self.model_type](checkpoint=os.path.join(CHECKPOINTS, checkpoint))
sam.to(device=self.device)
self.predictor_with_point_prompt = SamPredictor(sam)
self.image = None

View File

@ -102,6 +102,6 @@ def _build_sam(
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
sam.load_state_dict(state_dict)
return sam

View File

@ -89,7 +89,7 @@ def build_sam_vit_t(checkpoint=None):
mobile_sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
info = mobile_sam.load_state_dict(state_dict, strict=False)
print(info)
for n, p in mobile_sam.named_parameters():
@ -157,7 +157,7 @@ def _build_sam(
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
info = sam.load_state_dict(state_dict, strict=False)
print(info)
for n, p in sam.named_parameters():

View File

@ -88,7 +88,7 @@ def build_sam_vit_t(checkpoint=None):
mobile_sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
mobile_sam.load_state_dict(state_dict)
return mobile_sam
@ -151,6 +151,6 @@ def _build_sam(
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
state_dict = torch.load(f, map_location=torch.device('cpu'))
sam.load_state_dict(state_dict)
return sam

View File

@ -18,7 +18,7 @@ from sam_ann.widgets.ISAT_to_COCO_dialog import ISATtoCOCODialog
from sam_ann.widgets.ISAT_to_LABELME_dialog import ISATtoLabelMeDialog
from sam_ann.widgets.COCO_to_ISAT_dialog import COCOtoISATDialog
from sam_ann.widgets.canvas import AnnotationScene, AnnotationView
from sam_ann.configs import STATUSMode, MAPMode, load_config, save_config, CONFIG_FILE, DEFAULT_CONFIG_FILE
from sam_ann.configs import STATUSMode, MAPMode, load_config, save_config, CONFIG_FILE, DEFAULT_CONFIG_FILE, CHECKPOINTS
from sam_ann.annotation import Object, Annotation
from sam_ann.widgets.polygon import Polygon
from sam_ann.configs import BASE_DIR
@ -77,7 +77,7 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
for name, action in self.pths_actions.items():
action.setChecked(model_name == name)
return
model_path = os.path.join('segment_any', model_name)
model_path = os.path.join(CHECKPOINTS, model_name)
if not os.path.exists(model_path):
QtWidgets.QMessageBox.warning(self, 'Warning',
'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
@ -189,7 +189,7 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.statusbar.addPermanentWidget(self.labelData)
#
model_names = sorted([pth for pth in os.listdir(os.path.join(BASE_DIR, 'checkpoints')) if pth.endswith('.pth') or pth.endswith('.pt')])
model_names = sorted([pth for pth in os.listdir(CHECKPOINTS) if pth.endswith('.pth') or pth.endswith('.pt')])
self.pths_actions = {}
for model_name in model_names:
action = QtWidgets.QAction(self)