SAM-HQ支持,界面更新,添加模型选择
This commit is contained in:
parent
db139d2c02
commit
f7edb986d4
@ -1,813 +1,12 @@
|
||||
{
|
||||
"info": {
|
||||
"description": "ISAT",
|
||||
"folder": "C:/Users/lg/PycharmProjects/ISAT_with_segment_anything/example/images",
|
||||
"folder": "/home/super/PycharmProjects/ISAT_with_segment_anything/example/images",
|
||||
"name": "000000000144.jpg",
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
"depth": 3,
|
||||
"note": ""
|
||||
},
|
||||
"objects": [
|
||||
{
|
||||
"category": "fence",
|
||||
"group": "1",
|
||||
"segmentation": [
|
||||
[
|
||||
20,
|
||||
239
|
||||
],
|
||||
[
|
||||
17,
|
||||
240
|
||||
],
|
||||
[
|
||||
15,
|
||||
241
|
||||
],
|
||||
[
|
||||
13,
|
||||
242
|
||||
],
|
||||
[
|
||||
11,
|
||||
244
|
||||
],
|
||||
[
|
||||
10,
|
||||
244
|
||||
],
|
||||
[
|
||||
9,
|
||||
245
|
||||
],
|
||||
[
|
||||
6,
|
||||
249
|
||||
],
|
||||
[
|
||||
5,
|
||||
251
|
||||
],
|
||||
[
|
||||
4,
|
||||
255
|
||||
],
|
||||
[
|
||||
0,
|
||||
261
|
||||
],
|
||||
[
|
||||
0,
|
||||
280
|
||||
],
|
||||
[
|
||||
1,
|
||||
277
|
||||
],
|
||||
[
|
||||
6,
|
||||
268
|
||||
],
|
||||
[
|
||||
9,
|
||||
259
|
||||
],
|
||||
[
|
||||
13,
|
||||
254
|
||||
],
|
||||
[
|
||||
13,
|
||||
252
|
||||
],
|
||||
[
|
||||
15,
|
||||
250
|
||||
],
|
||||
[
|
||||
18,
|
||||
248
|
||||
],
|
||||
[
|
||||
20,
|
||||
247
|
||||
],
|
||||
[
|
||||
24,
|
||||
247
|
||||
],
|
||||
[
|
||||
26,
|
||||
248
|
||||
],
|
||||
[
|
||||
30,
|
||||
252
|
||||
],
|
||||
[
|
||||
33,
|
||||
256
|
||||
],
|
||||
[
|
||||
35,
|
||||
260
|
||||
],
|
||||
[
|
||||
38,
|
||||
266
|
||||
],
|
||||
[
|
||||
41,
|
||||
273
|
||||
],
|
||||
[
|
||||
43,
|
||||
280
|
||||
],
|
||||
[
|
||||
50,
|
||||
294
|
||||
],
|
||||
[
|
||||
53,
|
||||
297
|
||||
],
|
||||
[
|
||||
54,
|
||||
297
|
||||
],
|
||||
[
|
||||
59,
|
||||
301
|
||||
],
|
||||
[
|
||||
61,
|
||||
303
|
||||
],
|
||||
[
|
||||
62,
|
||||
316
|
||||
],
|
||||
[
|
||||
63,
|
||||
341
|
||||
],
|
||||
[
|
||||
63,
|
||||
382
|
||||
],
|
||||
[
|
||||
64,
|
||||
405
|
||||
],
|
||||
[
|
||||
66,
|
||||
414
|
||||
],
|
||||
[
|
||||
71,
|
||||
413
|
||||
],
|
||||
[
|
||||
72,
|
||||
402
|
||||
],
|
||||
[
|
||||
72,
|
||||
399
|
||||
],
|
||||
[
|
||||
71,
|
||||
359
|
||||
],
|
||||
[
|
||||
69,
|
||||
324
|
||||
],
|
||||
[
|
||||
70,
|
||||
300
|
||||
],
|
||||
[
|
||||
73,
|
||||
299
|
||||
],
|
||||
[
|
||||
78,
|
||||
294
|
||||
],
|
||||
[
|
||||
80,
|
||||
291
|
||||
],
|
||||
[
|
||||
81,
|
||||
289
|
||||
],
|
||||
[
|
||||
83,
|
||||
283
|
||||
],
|
||||
[
|
||||
85,
|
||||
278
|
||||
],
|
||||
[
|
||||
86,
|
||||
272
|
||||
],
|
||||
[
|
||||
91,
|
||||
258
|
||||
],
|
||||
[
|
||||
94,
|
||||
252
|
||||
],
|
||||
[
|
||||
98,
|
||||
248
|
||||
],
|
||||
[
|
||||
105,
|
||||
248
|
||||
],
|
||||
[
|
||||
107,
|
||||
249
|
||||
],
|
||||
[
|
||||
111,
|
||||
253
|
||||
],
|
||||
[
|
||||
116,
|
||||
261
|
||||
],
|
||||
[
|
||||
118,
|
||||
265
|
||||
],
|
||||
[
|
||||
121,
|
||||
271
|
||||
],
|
||||
[
|
||||
128,
|
||||
288
|
||||
],
|
||||
[
|
||||
130,
|
||||
291
|
||||
],
|
||||
[
|
||||
135,
|
||||
296
|
||||
],
|
||||
[
|
||||
138,
|
||||
298
|
||||
],
|
||||
[
|
||||
140,
|
||||
299
|
||||
],
|
||||
[
|
||||
143,
|
||||
300
|
||||
],
|
||||
[
|
||||
148,
|
||||
300
|
||||
],
|
||||
[
|
||||
151,
|
||||
299
|
||||
],
|
||||
[
|
||||
153,
|
||||
298
|
||||
],
|
||||
[
|
||||
156,
|
||||
296
|
||||
],
|
||||
[
|
||||
159,
|
||||
293
|
||||
],
|
||||
[
|
||||
161,
|
||||
290
|
||||
],
|
||||
[
|
||||
164,
|
||||
284
|
||||
],
|
||||
[
|
||||
179,
|
||||
255
|
||||
],
|
||||
[
|
||||
180,
|
||||
253
|
||||
],
|
||||
[
|
||||
180,
|
||||
251
|
||||
],
|
||||
[
|
||||
181,
|
||||
249
|
||||
],
|
||||
[
|
||||
183,
|
||||
247
|
||||
],
|
||||
[
|
||||
186,
|
||||
248
|
||||
],
|
||||
[
|
||||
187,
|
||||
259
|
||||
],
|
||||
[
|
||||
188,
|
||||
282
|
||||
],
|
||||
[
|
||||
188,
|
||||
303
|
||||
],
|
||||
[
|
||||
190,
|
||||
320
|
||||
],
|
||||
[
|
||||
190,
|
||||
343
|
||||
],
|
||||
[
|
||||
191,
|
||||
355
|
||||
],
|
||||
[
|
||||
192,
|
||||
357
|
||||
],
|
||||
[
|
||||
196,
|
||||
358
|
||||
],
|
||||
[
|
||||
198,
|
||||
357
|
||||
],
|
||||
[
|
||||
200,
|
||||
353
|
||||
],
|
||||
[
|
||||
200,
|
||||
341
|
||||
],
|
||||
[
|
||||
198,
|
||||
321
|
||||
],
|
||||
[
|
||||
197,
|
||||
300
|
||||
],
|
||||
[
|
||||
196,
|
||||
279
|
||||
],
|
||||
[
|
||||
195,
|
||||
252
|
||||
],
|
||||
[
|
||||
195,
|
||||
250
|
||||
],
|
||||
[
|
||||
196,
|
||||
248
|
||||
],
|
||||
[
|
||||
200,
|
||||
248
|
||||
],
|
||||
[
|
||||
200,
|
||||
249
|
||||
],
|
||||
[
|
||||
202,
|
||||
251
|
||||
],
|
||||
[
|
||||
207,
|
||||
260
|
||||
],
|
||||
[
|
||||
208,
|
||||
262
|
||||
],
|
||||
[
|
||||
208,
|
||||
264
|
||||
],
|
||||
[
|
||||
210,
|
||||
266
|
||||
],
|
||||
[
|
||||
212,
|
||||
273
|
||||
],
|
||||
[
|
||||
215,
|
||||
279
|
||||
],
|
||||
[
|
||||
217,
|
||||
281
|
||||
],
|
||||
[
|
||||
218,
|
||||
283
|
||||
],
|
||||
[
|
||||
219,
|
||||
287
|
||||
],
|
||||
[
|
||||
221,
|
||||
290
|
||||
],
|
||||
[
|
||||
225,
|
||||
294
|
||||
],
|
||||
[
|
||||
227,
|
||||
295
|
||||
],
|
||||
[
|
||||
230,
|
||||
296
|
||||
],
|
||||
[
|
||||
235,
|
||||
296
|
||||
],
|
||||
[
|
||||
239,
|
||||
295
|
||||
],
|
||||
[
|
||||
241,
|
||||
294
|
||||
],
|
||||
[
|
||||
245,
|
||||
290
|
||||
],
|
||||
[
|
||||
246,
|
||||
288
|
||||
],
|
||||
[
|
||||
247,
|
||||
286
|
||||
],
|
||||
[
|
||||
248,
|
||||
283
|
||||
],
|
||||
[
|
||||
248,
|
||||
279
|
||||
],
|
||||
[
|
||||
245,
|
||||
281
|
||||
],
|
||||
[
|
||||
244,
|
||||
283
|
||||
],
|
||||
[
|
||||
244,
|
||||
285
|
||||
],
|
||||
[
|
||||
243,
|
||||
287
|
||||
],
|
||||
[
|
||||
238,
|
||||
292
|
||||
],
|
||||
[
|
||||
230,
|
||||
292
|
||||
],
|
||||
[
|
||||
228,
|
||||
291
|
||||
],
|
||||
[
|
||||
224,
|
||||
287
|
||||
],
|
||||
[
|
||||
222,
|
||||
284
|
||||
],
|
||||
[
|
||||
219,
|
||||
278
|
||||
],
|
||||
[
|
||||
214,
|
||||
268
|
||||
],
|
||||
[
|
||||
214,
|
||||
266
|
||||
],
|
||||
[
|
||||
207,
|
||||
252
|
||||
],
|
||||
[
|
||||
205,
|
||||
249
|
||||
],
|
||||
[
|
||||
201,
|
||||
245
|
||||
],
|
||||
[
|
||||
198,
|
||||
243
|
||||
],
|
||||
[
|
||||
194,
|
||||
242
|
||||
],
|
||||
[
|
||||
188,
|
||||
242
|
||||
],
|
||||
[
|
||||
183,
|
||||
243
|
||||
],
|
||||
[
|
||||
180,
|
||||
244
|
||||
],
|
||||
[
|
||||
176,
|
||||
246
|
||||
],
|
||||
[
|
||||
174,
|
||||
249
|
||||
],
|
||||
[
|
||||
172,
|
||||
252
|
||||
],
|
||||
[
|
||||
170,
|
||||
256
|
||||
],
|
||||
[
|
||||
169,
|
||||
260
|
||||
],
|
||||
[
|
||||
168,
|
||||
262
|
||||
],
|
||||
[
|
||||
166,
|
||||
264
|
||||
],
|
||||
[
|
||||
165,
|
||||
266
|
||||
],
|
||||
[
|
||||
164,
|
||||
268
|
||||
],
|
||||
[
|
||||
161,
|
||||
276
|
||||
],
|
||||
[
|
||||
156,
|
||||
286
|
||||
],
|
||||
[
|
||||
153,
|
||||
290
|
||||
],
|
||||
[
|
||||
152,
|
||||
291
|
||||
],
|
||||
[
|
||||
151,
|
||||
291
|
||||
],
|
||||
[
|
||||
149,
|
||||
293
|
||||
],
|
||||
[
|
||||
147,
|
||||
293
|
||||
],
|
||||
[
|
||||
144,
|
||||
293
|
||||
],
|
||||
[
|
||||
142,
|
||||
293
|
||||
],
|
||||
[
|
||||
140,
|
||||
292
|
||||
],
|
||||
[
|
||||
138,
|
||||
290
|
||||
],
|
||||
[
|
||||
136,
|
||||
287
|
||||
],
|
||||
[
|
||||
133,
|
||||
281
|
||||
],
|
||||
[
|
||||
125,
|
||||
263
|
||||
],
|
||||
[
|
||||
123,
|
||||
259
|
||||
],
|
||||
[
|
||||
120,
|
||||
253
|
||||
],
|
||||
[
|
||||
117,
|
||||
248
|
||||
],
|
||||
[
|
||||
112,
|
||||
243
|
||||
],
|
||||
[
|
||||
110,
|
||||
242
|
||||
],
|
||||
[
|
||||
107,
|
||||
241
|
||||
],
|
||||
[
|
||||
98,
|
||||
241
|
||||
],
|
||||
[
|
||||
96,
|
||||
242
|
||||
],
|
||||
[
|
||||
94,
|
||||
243
|
||||
],
|
||||
[
|
||||
89,
|
||||
248
|
||||
],
|
||||
[
|
||||
85,
|
||||
256
|
||||
],
|
||||
[
|
||||
83,
|
||||
261
|
||||
],
|
||||
[
|
||||
81,
|
||||
269
|
||||
],
|
||||
[
|
||||
78,
|
||||
274
|
||||
],
|
||||
[
|
||||
76,
|
||||
282
|
||||
],
|
||||
[
|
||||
74,
|
||||
287
|
||||
],
|
||||
[
|
||||
73,
|
||||
289
|
||||
],
|
||||
[
|
||||
67,
|
||||
295
|
||||
],
|
||||
[
|
||||
63,
|
||||
295
|
||||
],
|
||||
[
|
||||
61,
|
||||
294
|
||||
],
|
||||
[
|
||||
55,
|
||||
288
|
||||
],
|
||||
[
|
||||
53,
|
||||
284
|
||||
],
|
||||
[
|
||||
50,
|
||||
278
|
||||
],
|
||||
[
|
||||
49,
|
||||
273
|
||||
],
|
||||
[
|
||||
41,
|
||||
254
|
||||
],
|
||||
[
|
||||
40,
|
||||
252
|
||||
],
|
||||
[
|
||||
38,
|
||||
249
|
||||
],
|
||||
[
|
||||
35,
|
||||
246
|
||||
],
|
||||
[
|
||||
35,
|
||||
245
|
||||
],
|
||||
[
|
||||
33,
|
||||
243
|
||||
],
|
||||
[
|
||||
30,
|
||||
241
|
||||
],
|
||||
[
|
||||
28,
|
||||
240
|
||||
],
|
||||
[
|
||||
24,
|
||||
239
|
||||
]
|
||||
],
|
||||
"area": 4485.0,
|
||||
"layer": 1.0,
|
||||
"bbox": [
|
||||
0.0,
|
||||
239.0,
|
||||
248.0,
|
||||
414.0
|
||||
],
|
||||
"iscrowd": 0,
|
||||
"note": ""
|
||||
}
|
||||
]
|
||||
"objects": []
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
contour_mode: external
|
||||
contour_mode: all
|
||||
label:
|
||||
- color: '#000000'
|
||||
name: __background__
|
||||
|
@ -1,4 +1,4 @@
|
||||
contour_mode: external
|
||||
contour_mode: all
|
||||
label:
|
||||
- color: '#000000'
|
||||
name: __background__
|
||||
|
2
main.py
2
main.py
@ -12,5 +12,5 @@ if __name__ == '__main__':
|
||||
app = QtWidgets.QApplication([''])
|
||||
mainwindow = MainWindow()
|
||||
mainwindow.show()
|
||||
sys.exit(app.exec_())
|
||||
sys.exit(app.exec())
|
||||
|
||||
|
@ -16,6 +16,7 @@ class SegAny:
|
||||
self.model_type = "vit_h"
|
||||
else:
|
||||
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||
|
@ -11,5 +11,6 @@ from .build_sam import (
|
||||
build_sam_vit_b,
|
||||
sam_model_registry,
|
||||
)
|
||||
from .build_sam_baseline import sam_model_registry_baseline
|
||||
from .predictor import SamPredictor
|
||||
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
||||
|
@ -134,7 +134,7 @@ class SamAutomaticMaskGenerator:
|
||||
self.output_mode = output_mode
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generates masks for the given image.
|
||||
|
||||
@ -160,7 +160,7 @@ class SamAutomaticMaskGenerator:
|
||||
"""
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
mask_data = self._generate_masks(image, multimask_output)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
@ -194,7 +194,7 @@ class SamAutomaticMaskGenerator:
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData:
|
||||
orig_size = image.shape[:2]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||
@ -203,7 +203,7 @@ class SamAutomaticMaskGenerator:
|
||||
# Iterate over image crops
|
||||
data = MaskData()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output)
|
||||
data.cat(crop_data)
|
||||
|
||||
# Remove duplicate masks between crops
|
||||
@ -228,6 +228,7 @@ class SamAutomaticMaskGenerator:
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
multimask_output: bool = True,
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
@ -242,7 +243,7 @@ class SamAutomaticMaskGenerator:
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
||||
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
self.predictor.reset_image()
|
||||
@ -269,6 +270,7 @@ class SamAutomaticMaskGenerator:
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
multimask_output: bool = True,
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
@ -279,7 +281,7 @@ class SamAutomaticMaskGenerator:
|
||||
masks, iou_preds, _ = self.predictor.predict_torch(
|
||||
in_points[:, None, :],
|
||||
in_labels[:, None],
|
||||
multimask_output=True,
|
||||
multimask_output=multimask_output,
|
||||
return_logits=True,
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from functools import partial
|
||||
|
||||
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
||||
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
@ -84,7 +84,7 @@ def _build_sam(
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoder(
|
||||
mask_decoder=MaskDecoderHQ(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
@ -95,13 +95,19 @@ def _build_sam(
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
vit_dim=encoder_embed_dim,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
sam.eval()
|
||||
# sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
info = sam.load_state_dict(state_dict, strict=False)
|
||||
print(info)
|
||||
for n, p in sam.named_parameters():
|
||||
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
|
||||
p.requires_grad = False
|
||||
|
||||
return sam
|
||||
|
107
segment_anything/build_sam_baseline.py
Normal file
107
segment_anything/build_sam_baseline.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
|
||||
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
build_sam = build_sam_vit_h
|
||||
|
||||
|
||||
def build_sam_vit_l(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_b(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
sam_model_registry_baseline = {
|
||||
"default": build_sam_vit_h,
|
||||
"vit_h": build_sam_vit_h,
|
||||
"vit_l": build_sam_vit_l,
|
||||
"vit_b": build_sam_vit_b,
|
||||
}
|
||||
|
||||
|
||||
def _build_sam(
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
):
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
sam = Sam(
|
||||
image_encoder=ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
),
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
return sam
|
@ -6,6 +6,7 @@
|
||||
|
||||
from .sam import Sam
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder_hq import MaskDecoderHQ
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
from .transformer import TwoWayTransformer
|
||||
|
@ -108,12 +108,15 @@ class ImageEncoderViT(nn.Module):
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
interm_embeddings=[]
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
if blk.window_size == 0:
|
||||
interm_embeddings.append(x)
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
return x, interm_embeddings
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
@ -75,6 +75,8 @@ class MaskDecoder(nn.Module):
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
hq_token_only: bool,
|
||||
interm_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
232
segment_anything/modeling/mask_decoder_hq.py
Normal file
232
segment_anything/modeling/mask_decoder_hq.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# Modified by HQ-SAM team
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoderHQ(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
vit_dim: int = 1024,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a
|
||||
transformer architecture.
|
||||
|
||||
Arguments:
|
||||
transformer_dim (int): the channel dimension of the transformer
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict
|
||||
when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when
|
||||
upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict
|
||||
mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||
used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
||||
for i in range(self.num_mask_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
||||
)
|
||||
|
||||
# HQ-SAM parameters
|
||||
self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
|
||||
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token
|
||||
self.num_mask_tokens = self.num_mask_tokens + 1
|
||||
|
||||
# three conv fusion layers for obtaining HQ-Feature
|
||||
self.compress_vit_feat = nn.Sequential(
|
||||
nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim),
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
|
||||
|
||||
self.embedding_encoder = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
)
|
||||
self.embedding_maskfeature = nn.Sequential(
|
||||
nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
hq_token_only: bool,
|
||||
interm_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Arguments:
|
||||
image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
"""
|
||||
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
|
||||
hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
|
||||
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
hq_features=hq_features,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
# mask with highest score
|
||||
mask_slice = slice(1,self.num_mask_tokens-1)
|
||||
iou_pred = iou_pred[:, mask_slice]
|
||||
iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
|
||||
iou_pred = iou_pred.unsqueeze(1)
|
||||
masks_multi = masks[:, mask_slice, :, :]
|
||||
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
|
||||
else:
|
||||
# singale mask output, default
|
||||
mask_slice = slice(0, 1)
|
||||
iou_pred = iou_pred[:,mask_slice]
|
||||
masks_sam = masks[:,mask_slice]
|
||||
|
||||
masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]
|
||||
if hq_token_only:
|
||||
masks = masks_hq
|
||||
else:
|
||||
masks = masks_sam + masks_hq
|
||||
# Prepare output
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
hq_features: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
src = src + dense_prompt_embeddings
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
|
||||
upscaled_embedding_sam = self.output_upscaling(src)
|
||||
upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
|
||||
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
if i < self.num_mask_tokens - 1:
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||
else:
|
||||
hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
|
||||
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding_sam.shape
|
||||
|
||||
masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
|
||||
masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
|
||||
masks = torch.cat([masks_sam,masks_sam_hq],dim=1)
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
# Lightly adapted from
|
||||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||
)
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = F.sigmoid(x)
|
||||
return x
|
@ -50,11 +50,11 @@ class Sam(nn.Module):
|
||||
def device(self) -> Any:
|
||||
return self.pixel_mean.device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
batched_input: List[Dict[str, Any]],
|
||||
multimask_output: bool,
|
||||
hq_token_only: bool =False,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Predicts masks end-to-end from provided images and prompts.
|
||||
@ -95,10 +95,11 @@ class Sam(nn.Module):
|
||||
to subsequent iterations of prediction.
|
||||
"""
|
||||
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
||||
image_embeddings = self.image_encoder(input_images)
|
||||
image_embeddings, interm_embeddings = self.image_encoder(input_images)
|
||||
interm_embeddings = interm_embeddings[0] # early layer
|
||||
|
||||
outputs = []
|
||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
||||
for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
|
||||
if "point_coords" in image_record:
|
||||
points = (image_record["point_coords"], image_record["point_labels"])
|
||||
else:
|
||||
@ -114,6 +115,8 @@ class Sam(nn.Module):
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
hq_token_only=hq_token_only,
|
||||
interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
|
||||
)
|
||||
masks = self.postprocess_masks(
|
||||
low_res_masks,
|
||||
|
@ -7,7 +7,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from segment_anything.modeling import Sam
|
||||
from .modeling import Sam
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -49,10 +49,12 @@ class SamPredictor:
|
||||
"RGB",
|
||||
"BGR",
|
||||
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
||||
# import pdb;pdb.set_trace()
|
||||
if image_format != self.model.image_format:
|
||||
image = image[..., ::-1]
|
||||
|
||||
# Transform the image to the form expected by the model
|
||||
# import pdb;pdb.set_trace()
|
||||
input_image = self.transform.apply_image(image)
|
||||
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
||||
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
||||
@ -86,7 +88,7 @@ class SamPredictor:
|
||||
self.original_size = original_image_size
|
||||
self.input_size = tuple(transformed_image.shape[-2:])
|
||||
input_image = self.model.preprocess(transformed_image)
|
||||
self.features = self.model.image_encoder(input_image)
|
||||
self.features, self.interm_features = self.model.image_encoder(input_image)
|
||||
self.is_image_set = True
|
||||
|
||||
def predict(
|
||||
@ -97,6 +99,7 @@ class SamPredictor:
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
hq_token_only: bool =False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
@ -158,6 +161,7 @@ class SamPredictor:
|
||||
mask_input_torch,
|
||||
multimask_output,
|
||||
return_logits=return_logits,
|
||||
hq_token_only=hq_token_only,
|
||||
)
|
||||
|
||||
masks_np = masks[0].detach().cpu().numpy()
|
||||
@ -174,6 +178,7 @@ class SamPredictor:
|
||||
mask_input: Optional[torch.Tensor] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
hq_token_only: bool =False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
@ -232,6 +237,8 @@ class SamPredictor:
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
hq_token_only=hq_token_only,
|
||||
interm_embeddings=self.interm_features,
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
|
@ -25,7 +25,8 @@ class SamOnnxModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
return_single_mask: bool,
|
||||
hq_token_only: bool = False,
|
||||
multimask_output: bool = False,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics: bool = False,
|
||||
) -> None:
|
||||
@ -33,7 +34,8 @@ class SamOnnxModel(nn.Module):
|
||||
self.mask_decoder = model.mask_decoder
|
||||
self.model = model
|
||||
self.img_size = model.image_encoder.img_size
|
||||
self.return_single_mask = return_single_mask
|
||||
self.hq_token_only = hq_token_only
|
||||
self.multimask_output = multimask_output
|
||||
self.use_stability_score = use_stability_score
|
||||
self.stability_score_offset = 1.0
|
||||
self.return_extra_metrics = return_extra_metrics
|
||||
@ -89,25 +91,12 @@ class SamOnnxModel(nn.Module):
|
||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def select_masks(
|
||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine if we should return the multiclick mask or not from the number of points.
|
||||
# The reweighting is used to avoid control flow.
|
||||
score_reweight = torch.tensor(
|
||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
||||
).to(iou_preds.device)
|
||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
||||
|
||||
return masks, iou_preds
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
interm_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
mask_input: torch.Tensor,
|
||||
@ -117,11 +106,15 @@ class SamOnnxModel(nn.Module):
|
||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
||||
|
||||
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
|
||||
hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
|
||||
|
||||
masks, scores = self.model.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embedding,
|
||||
dense_prompt_embeddings=dense_embedding,
|
||||
hq_features=hq_features,
|
||||
)
|
||||
|
||||
if self.use_stability_score:
|
||||
@ -129,8 +122,26 @@ class SamOnnxModel(nn.Module):
|
||||
masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
|
||||
if self.return_single_mask:
|
||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
||||
if self.multimask_output:
|
||||
# mask with highest score
|
||||
mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
|
||||
scores = scores[:, mask_slice]
|
||||
scores, max_iou_idx = torch.max(scores,dim=1)
|
||||
scores = scores.unsqueeze(1)
|
||||
masks_multi = masks[:, mask_slice, :, :]
|
||||
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
|
||||
else:
|
||||
# singale mask output, default
|
||||
mask_slice = slice(0, 1)
|
||||
scores = scores[:,mask_slice]
|
||||
masks_sam = masks[:,mask_slice]
|
||||
|
||||
masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
|
||||
|
||||
if self.hq_token_only:
|
||||
masks = masks_hq
|
||||
else:
|
||||
masks = masks_sam + masks_hq
|
||||
|
||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
||||
|
||||
|
@ -88,6 +88,8 @@ class Ui_MainWindow(object):
|
||||
self.menuMode.setObjectName("menuMode")
|
||||
self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
|
||||
self.menuContour_mode.setObjectName("menuContour_mode")
|
||||
self.menuSAM_model = QtWidgets.QMenu(self.menubar)
|
||||
self.menuSAM_model.setObjectName("menuSAM_model")
|
||||
MainWindow.setMenuBar(self.menubar)
|
||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||
self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
|
||||
@ -352,6 +354,7 @@ class Ui_MainWindow(object):
|
||||
self.menubar.addAction(self.menuFile.menuAction())
|
||||
self.menubar.addAction(self.menuEdit.menuAction())
|
||||
self.menubar.addAction(self.menuView.menuAction())
|
||||
self.menubar.addAction(self.menuSAM_model.menuAction())
|
||||
self.menubar.addAction(self.menuMode.menuAction())
|
||||
self.menubar.addAction(self.menuTools.menuAction())
|
||||
self.menubar.addAction(self.menuAbout.menuAction())
|
||||
@ -391,6 +394,7 @@ class Ui_MainWindow(object):
|
||||
self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
|
||||
self.menuMode.setTitle(_translate("MainWindow", "Mode"))
|
||||
self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
|
||||
self.menuSAM_model.setTitle(_translate("MainWindow", "SAM"))
|
||||
self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
|
||||
self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
|
||||
self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))
|
||||
|
@ -202,9 +202,15 @@
|
||||
</widget>
|
||||
<addaction name="menuContour_mode"/>
|
||||
</widget>
|
||||
<widget class="QMenu" name="menuSAM_model">
|
||||
<property name="title">
|
||||
<string>SAM</string>
|
||||
</property>
|
||||
</widget>
|
||||
<addaction name="menuFile"/>
|
||||
<addaction name="menuEdit"/>
|
||||
<addaction name="menuView"/>
|
||||
<addaction name="menuSAM_model"/>
|
||||
<addaction name="menuMode"/>
|
||||
<addaction name="menuTools"/>
|
||||
<addaction name="menuAbout"/>
|
||||
|
@ -34,8 +34,6 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super(MainWindow, self).__init__()
|
||||
self.setupUi(self)
|
||||
self.init_ui()
|
||||
self.init_segment_anything()
|
||||
|
||||
self.image_root: str = None
|
||||
self.label_root:str = None
|
||||
@ -58,41 +56,51 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
self.map_mode = MAPMode.LABEL
|
||||
# 标注目标
|
||||
self.current_label:Annotation = None
|
||||
self.use_segment_anything = False
|
||||
|
||||
self.init_ui()
|
||||
self.reload_cfg()
|
||||
|
||||
self.init_connect()
|
||||
self.reset_action()
|
||||
|
||||
def init_segment_anything(self):
|
||||
if os.path.exists('./segment_any/sam_vit_h_4b8939.pth'):
|
||||
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_h_4b8939.pth'))
|
||||
self.segany = SegAny('./segment_any/sam_vit_h_4b8939.pth')
|
||||
self.use_segment_anything = True
|
||||
elif os.path.exists('./segment_any/sam_vit_l_0b3195.pth'):
|
||||
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_l_0b3195.pth'))
|
||||
self.segany = SegAny('./segment_any/sam_vit_l_0b3195.pth')
|
||||
self.use_segment_anything = True
|
||||
elif os.path.exists('./segment_any/sam_vit_b_01ec64.pth'):
|
||||
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_b_01ec64.pth'))
|
||||
self.segany = SegAny('./segment_any/sam_vit_b_01ec64.pth')
|
||||
self.use_segment_anything = True
|
||||
else:
|
||||
QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format('https://github.com/facebookresearch/segment-anything#model-checkpoints'))
|
||||
def init_segment_anything(self, model_name, reload=False):
|
||||
if model_name == '':
|
||||
self.use_segment_anything = False
|
||||
for name, action in self.pths_actions.items():
|
||||
action.setChecked(model_name == name)
|
||||
return
|
||||
model_path = os.path.join('segment_any', model_name)
|
||||
if not os.path.exists(model_path):
|
||||
QtWidgets.QMessageBox.warning(self, 'Warning',
|
||||
'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
|
||||
'https://github.com/facebookresearch/segment-anything#model-checkpoints'))
|
||||
for name, action in self.pths_actions.items():
|
||||
action.setChecked(model_name == name)
|
||||
self.use_segment_anything = False
|
||||
return
|
||||
|
||||
if self.use_segment_anything:
|
||||
if self.segany.device != 'cpu':
|
||||
self.gpu_resource_thread = GPUResource_Thread()
|
||||
self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
|
||||
self.gpu_resource_thread.start()
|
||||
self.segany = SegAny(model_path)
|
||||
self.use_segment_anything = True
|
||||
self.statusbar.showMessage('Use the checkpoint named {}.'.format(model_name), 3000)
|
||||
for name, action in self.pths_actions.items():
|
||||
action.setChecked(model_name==name)
|
||||
if not reload:
|
||||
if self.use_segment_anything:
|
||||
if self.segany.device != 'cpu':
|
||||
self.gpu_resource_thread = GPUResource_Thread()
|
||||
self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
|
||||
self.gpu_resource_thread.start()
|
||||
else:
|
||||
self.labelGPUResource.setText('cpu')
|
||||
else:
|
||||
self.labelGPUResource.setText('cpu')
|
||||
else:
|
||||
self.labelGPUResource.setText('segment anything unused.')
|
||||
self.labelGPUResource.setText('segment anything unused.')
|
||||
|
||||
if reload and self.current_index is not None:
|
||||
self.show_image(self.current_index)
|
||||
|
||||
def init_ui(self):
|
||||
#
|
||||
#q
|
||||
self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
|
||||
|
||||
self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
|
||||
@ -144,6 +152,17 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
self.statusbar.addPermanentWidget(self.labelData)
|
||||
|
||||
#
|
||||
model_names = sorted([pth for pth in os.listdir('segment_any') if pth.endswith('.pth')])
|
||||
self.pths_actions = {}
|
||||
for model_name in model_names:
|
||||
action = QtWidgets.QAction(self)
|
||||
action.setObjectName("actionZoom_in")
|
||||
action.triggered.connect(functools.partial(self.init_segment_anything, model_name))
|
||||
action.setText("{}".format(model_name))
|
||||
action.setCheckable(True)
|
||||
|
||||
self.pths_actions[model_name] = action
|
||||
self.menuSAM_model.addAction(action)
|
||||
|
||||
self.toolBar.addSeparator()
|
||||
self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
|
||||
@ -213,6 +232,9 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
||||
self.cfg['mask_alpha'] = mask_alpha
|
||||
self.mask_aplha.setValue(mask_alpha*10)
|
||||
|
||||
model_name = self.cfg.get('model_name', '')
|
||||
self.init_segment_anything(model_name)
|
||||
|
||||
self.categories_dock_widget.update_widget()
|
||||
|
||||
def set_saved_state(self, is_saved:bool):
|
||||
|
Loading…
x
Reference in New Issue
Block a user