From f7edb986d4f9b254c73e5fdb533d54cc3daa1a64 Mon Sep 17 00:00:00 2001
From: yatengLG <767624851@qq.com>
Date: Tue, 11 Jul 2023 13:50:17 +0800
Subject: [PATCH] =?UTF-8?q?SAM-HQ=E6=94=AF=E6=8C=81=EF=BC=8C=E7=95=8C?=
=?UTF-8?q?=E9=9D=A2=E6=9B=B4=E6=96=B0=EF=BC=8C=E6=B7=BB=E5=8A=A0=E6=A8=A1?=
=?UTF-8?q?=E5=9E=8B=E9=80=89=E6=8B=A9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
example/images/000000000144.json | 805 +------------------
example/images/isat.yaml | 2 +-
isat.yaml | 2 +-
main.py | 2 +-
segment_any/segment_any.py | 1 +
segment_anything/__init__.py | 1 +
segment_anything/automatic_mask_generator.py | 14 +-
segment_anything/build_sam.py | 14 +-
segment_anything/build_sam_baseline.py | 107 +++
segment_anything/modeling/__init__.py | 1 +
segment_anything/modeling/image_encoder.py | 5 +-
segment_anything/modeling/mask_decoder.py | 2 +
segment_anything/modeling/mask_decoder_hq.py | 232 ++++++
segment_anything/modeling/sam.py | 9 +-
segment_anything/predictor.py | 11 +-
segment_anything/utils/onnx.py | 47 +-
ui/MainWindow.py | 4 +
ui/MainWindow.ui | 6 +
widgets/mainwindow.py | 74 +-
19 files changed, 473 insertions(+), 866 deletions(-)
create mode 100644 segment_anything/build_sam_baseline.py
create mode 100644 segment_anything/modeling/mask_decoder_hq.py
diff --git a/example/images/000000000144.json b/example/images/000000000144.json
index 043e62b..8acc9f0 100644
--- a/example/images/000000000144.json
+++ b/example/images/000000000144.json
@@ -1,813 +1,12 @@
{
"info": {
"description": "ISAT",
- "folder": "C:/Users/lg/PycharmProjects/ISAT_with_segment_anything/example/images",
+ "folder": "/home/super/PycharmProjects/ISAT_with_segment_anything/example/images",
"name": "000000000144.jpg",
"width": 640,
"height": 480,
"depth": 3,
"note": ""
},
- "objects": [
- {
- "category": "fence",
- "group": "1",
- "segmentation": [
- [
- 20,
- 239
- ],
- [
- 17,
- 240
- ],
- [
- 15,
- 241
- ],
- [
- 13,
- 242
- ],
- [
- 11,
- 244
- ],
- [
- 10,
- 244
- ],
- [
- 9,
- 245
- ],
- [
- 6,
- 249
- ],
- [
- 5,
- 251
- ],
- [
- 4,
- 255
- ],
- [
- 0,
- 261
- ],
- [
- 0,
- 280
- ],
- [
- 1,
- 277
- ],
- [
- 6,
- 268
- ],
- [
- 9,
- 259
- ],
- [
- 13,
- 254
- ],
- [
- 13,
- 252
- ],
- [
- 15,
- 250
- ],
- [
- 18,
- 248
- ],
- [
- 20,
- 247
- ],
- [
- 24,
- 247
- ],
- [
- 26,
- 248
- ],
- [
- 30,
- 252
- ],
- [
- 33,
- 256
- ],
- [
- 35,
- 260
- ],
- [
- 38,
- 266
- ],
- [
- 41,
- 273
- ],
- [
- 43,
- 280
- ],
- [
- 50,
- 294
- ],
- [
- 53,
- 297
- ],
- [
- 54,
- 297
- ],
- [
- 59,
- 301
- ],
- [
- 61,
- 303
- ],
- [
- 62,
- 316
- ],
- [
- 63,
- 341
- ],
- [
- 63,
- 382
- ],
- [
- 64,
- 405
- ],
- [
- 66,
- 414
- ],
- [
- 71,
- 413
- ],
- [
- 72,
- 402
- ],
- [
- 72,
- 399
- ],
- [
- 71,
- 359
- ],
- [
- 69,
- 324
- ],
- [
- 70,
- 300
- ],
- [
- 73,
- 299
- ],
- [
- 78,
- 294
- ],
- [
- 80,
- 291
- ],
- [
- 81,
- 289
- ],
- [
- 83,
- 283
- ],
- [
- 85,
- 278
- ],
- [
- 86,
- 272
- ],
- [
- 91,
- 258
- ],
- [
- 94,
- 252
- ],
- [
- 98,
- 248
- ],
- [
- 105,
- 248
- ],
- [
- 107,
- 249
- ],
- [
- 111,
- 253
- ],
- [
- 116,
- 261
- ],
- [
- 118,
- 265
- ],
- [
- 121,
- 271
- ],
- [
- 128,
- 288
- ],
- [
- 130,
- 291
- ],
- [
- 135,
- 296
- ],
- [
- 138,
- 298
- ],
- [
- 140,
- 299
- ],
- [
- 143,
- 300
- ],
- [
- 148,
- 300
- ],
- [
- 151,
- 299
- ],
- [
- 153,
- 298
- ],
- [
- 156,
- 296
- ],
- [
- 159,
- 293
- ],
- [
- 161,
- 290
- ],
- [
- 164,
- 284
- ],
- [
- 179,
- 255
- ],
- [
- 180,
- 253
- ],
- [
- 180,
- 251
- ],
- [
- 181,
- 249
- ],
- [
- 183,
- 247
- ],
- [
- 186,
- 248
- ],
- [
- 187,
- 259
- ],
- [
- 188,
- 282
- ],
- [
- 188,
- 303
- ],
- [
- 190,
- 320
- ],
- [
- 190,
- 343
- ],
- [
- 191,
- 355
- ],
- [
- 192,
- 357
- ],
- [
- 196,
- 358
- ],
- [
- 198,
- 357
- ],
- [
- 200,
- 353
- ],
- [
- 200,
- 341
- ],
- [
- 198,
- 321
- ],
- [
- 197,
- 300
- ],
- [
- 196,
- 279
- ],
- [
- 195,
- 252
- ],
- [
- 195,
- 250
- ],
- [
- 196,
- 248
- ],
- [
- 200,
- 248
- ],
- [
- 200,
- 249
- ],
- [
- 202,
- 251
- ],
- [
- 207,
- 260
- ],
- [
- 208,
- 262
- ],
- [
- 208,
- 264
- ],
- [
- 210,
- 266
- ],
- [
- 212,
- 273
- ],
- [
- 215,
- 279
- ],
- [
- 217,
- 281
- ],
- [
- 218,
- 283
- ],
- [
- 219,
- 287
- ],
- [
- 221,
- 290
- ],
- [
- 225,
- 294
- ],
- [
- 227,
- 295
- ],
- [
- 230,
- 296
- ],
- [
- 235,
- 296
- ],
- [
- 239,
- 295
- ],
- [
- 241,
- 294
- ],
- [
- 245,
- 290
- ],
- [
- 246,
- 288
- ],
- [
- 247,
- 286
- ],
- [
- 248,
- 283
- ],
- [
- 248,
- 279
- ],
- [
- 245,
- 281
- ],
- [
- 244,
- 283
- ],
- [
- 244,
- 285
- ],
- [
- 243,
- 287
- ],
- [
- 238,
- 292
- ],
- [
- 230,
- 292
- ],
- [
- 228,
- 291
- ],
- [
- 224,
- 287
- ],
- [
- 222,
- 284
- ],
- [
- 219,
- 278
- ],
- [
- 214,
- 268
- ],
- [
- 214,
- 266
- ],
- [
- 207,
- 252
- ],
- [
- 205,
- 249
- ],
- [
- 201,
- 245
- ],
- [
- 198,
- 243
- ],
- [
- 194,
- 242
- ],
- [
- 188,
- 242
- ],
- [
- 183,
- 243
- ],
- [
- 180,
- 244
- ],
- [
- 176,
- 246
- ],
- [
- 174,
- 249
- ],
- [
- 172,
- 252
- ],
- [
- 170,
- 256
- ],
- [
- 169,
- 260
- ],
- [
- 168,
- 262
- ],
- [
- 166,
- 264
- ],
- [
- 165,
- 266
- ],
- [
- 164,
- 268
- ],
- [
- 161,
- 276
- ],
- [
- 156,
- 286
- ],
- [
- 153,
- 290
- ],
- [
- 152,
- 291
- ],
- [
- 151,
- 291
- ],
- [
- 149,
- 293
- ],
- [
- 147,
- 293
- ],
- [
- 144,
- 293
- ],
- [
- 142,
- 293
- ],
- [
- 140,
- 292
- ],
- [
- 138,
- 290
- ],
- [
- 136,
- 287
- ],
- [
- 133,
- 281
- ],
- [
- 125,
- 263
- ],
- [
- 123,
- 259
- ],
- [
- 120,
- 253
- ],
- [
- 117,
- 248
- ],
- [
- 112,
- 243
- ],
- [
- 110,
- 242
- ],
- [
- 107,
- 241
- ],
- [
- 98,
- 241
- ],
- [
- 96,
- 242
- ],
- [
- 94,
- 243
- ],
- [
- 89,
- 248
- ],
- [
- 85,
- 256
- ],
- [
- 83,
- 261
- ],
- [
- 81,
- 269
- ],
- [
- 78,
- 274
- ],
- [
- 76,
- 282
- ],
- [
- 74,
- 287
- ],
- [
- 73,
- 289
- ],
- [
- 67,
- 295
- ],
- [
- 63,
- 295
- ],
- [
- 61,
- 294
- ],
- [
- 55,
- 288
- ],
- [
- 53,
- 284
- ],
- [
- 50,
- 278
- ],
- [
- 49,
- 273
- ],
- [
- 41,
- 254
- ],
- [
- 40,
- 252
- ],
- [
- 38,
- 249
- ],
- [
- 35,
- 246
- ],
- [
- 35,
- 245
- ],
- [
- 33,
- 243
- ],
- [
- 30,
- 241
- ],
- [
- 28,
- 240
- ],
- [
- 24,
- 239
- ]
- ],
- "area": 4485.0,
- "layer": 1.0,
- "bbox": [
- 0.0,
- 239.0,
- 248.0,
- 414.0
- ],
- "iscrowd": 0,
- "note": ""
- }
- ]
+ "objects": []
}
\ No newline at end of file
diff --git a/example/images/isat.yaml b/example/images/isat.yaml
index cf05e28..bec54bb 100644
--- a/example/images/isat.yaml
+++ b/example/images/isat.yaml
@@ -1,4 +1,4 @@
-contour_mode: external
+contour_mode: all
label:
- color: '#000000'
name: __background__
diff --git a/isat.yaml b/isat.yaml
index cf05e28..bec54bb 100644
--- a/isat.yaml
+++ b/isat.yaml
@@ -1,4 +1,4 @@
-contour_mode: external
+contour_mode: all
label:
- color: '#000000'
name: __background__
diff --git a/main.py b/main.py
index 6791910..cea8269 100644
--- a/main.py
+++ b/main.py
@@ -12,5 +12,5 @@ if __name__ == '__main__':
app = QtWidgets.QApplication([''])
mainwindow = MainWindow()
mainwindow.show()
- sys.exit(app.exec_())
+ sys.exit(app.exec())
diff --git a/segment_any/segment_any.py b/segment_any/segment_any.py
index ff80abe..9f8a2f5 100644
--- a/segment_any/segment_any.py
+++ b/segment_any/segment_any.py
@@ -16,6 +16,7 @@ class SegAny:
self.model_type = "vit_h"
else:
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
+ torch.cuda.empty_cache()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
diff --git a/segment_anything/__init__.py b/segment_anything/__init__.py
index 34383d8..d576507 100644
--- a/segment_anything/__init__.py
+++ b/segment_anything/__init__.py
@@ -11,5 +11,6 @@ from .build_sam import (
build_sam_vit_b,
sam_model_registry,
)
+from .build_sam_baseline import sam_model_registry_baseline
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py
index d5a8c96..427ebeb 100644
--- a/segment_anything/automatic_mask_generator.py
+++ b/segment_anything/automatic_mask_generator.py
@@ -134,7 +134,7 @@ class SamAutomaticMaskGenerator:
self.output_mode = output_mode
@torch.no_grad()
- def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
@@ -160,7 +160,7 @@ class SamAutomaticMaskGenerator:
"""
# Generate masks
- mask_data = self._generate_masks(image)
+ mask_data = self._generate_masks(image, multimask_output)
# Filter small disconnected regions and holes in masks
if self.min_mask_region_area > 0:
@@ -194,7 +194,7 @@ class SamAutomaticMaskGenerator:
return curr_anns
- def _generate_masks(self, image: np.ndarray) -> MaskData:
+ def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData:
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
@@ -203,7 +203,7 @@ class SamAutomaticMaskGenerator:
# Iterate over image crops
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
- crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output)
data.cat(crop_data)
# Remove duplicate masks between crops
@@ -228,6 +228,7 @@ class SamAutomaticMaskGenerator:
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
+ multimask_output: bool = True,
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
@@ -242,7 +243,7 @@ class SamAutomaticMaskGenerator:
# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
- batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output)
data.cat(batch_data)
del batch_data
self.predictor.reset_image()
@@ -269,6 +270,7 @@ class SamAutomaticMaskGenerator:
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
+ multimask_output: bool = True,
) -> MaskData:
orig_h, orig_w = orig_size
@@ -279,7 +281,7 @@ class SamAutomaticMaskGenerator:
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
- multimask_output=True,
+ multimask_output=multimask_output,
return_logits=True,
)
diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py
index 37cd245..b280cf4 100644
--- a/segment_anything/build_sam.py
+++ b/segment_anything/build_sam.py
@@ -8,7 +8,7 @@ import torch
from functools import partial
-from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
+from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
@@ -84,7 +84,7 @@ def _build_sam(
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
- mask_decoder=MaskDecoder(
+ mask_decoder=MaskDecoderHQ(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
@@ -95,13 +95,19 @@ def _build_sam(
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
+ vit_dim=encoder_embed_dim,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
- sam.eval()
+ # sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
- sam.load_state_dict(state_dict)
+ info = sam.load_state_dict(state_dict, strict=False)
+ print(info)
+ for n, p in sam.named_parameters():
+ if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
+ p.requires_grad = False
+
return sam
diff --git a/segment_anything/build_sam_baseline.py b/segment_anything/build_sam_baseline.py
new file mode 100644
index 0000000..8f14970
--- /dev/null
+++ b/segment_anything/build_sam_baseline.py
@@ -0,0 +1,107 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
+
+
+def build_sam_vit_h(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1280,
+ encoder_depth=32,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[7, 15, 23, 31],
+ checkpoint=checkpoint,
+ )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1024,
+ encoder_depth=24,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[5, 11, 17, 23],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam_vit_b(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=768,
+ encoder_depth=12,
+ encoder_num_heads=12,
+ encoder_global_attn_indexes=[2, 5, 8, 11],
+ checkpoint=checkpoint,
+ )
+
+
+sam_model_registry_baseline = {
+ "default": build_sam_vit_h,
+ "vit_h": build_sam_vit_h,
+ "vit_l": build_sam_vit_l,
+ "vit_b": build_sam_vit_b,
+}
+
+
+def _build_sam(
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ encoder_global_attn_indexes,
+ checkpoint=None,
+):
+ prompt_embed_dim = 256
+ image_size = 1024
+ vit_patch_size = 16
+ image_embedding_size = image_size // vit_patch_size
+ sam = Sam(
+ image_encoder=ImageEncoderViT(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=encoder_num_heads,
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ ),
+ prompt_encoder=PromptEncoder(
+ embed_dim=prompt_embed_dim,
+ image_embedding_size=(image_embedding_size, image_embedding_size),
+ input_image_size=(image_size, image_size),
+ mask_in_chans=16,
+ ),
+ mask_decoder=MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ ),
+ pixel_mean=[123.675, 116.28, 103.53],
+ pixel_std=[58.395, 57.12, 57.375],
+ )
+ sam.eval()
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ sam.load_state_dict(state_dict)
+ return sam
\ No newline at end of file
diff --git a/segment_anything/modeling/__init__.py b/segment_anything/modeling/__init__.py
index 38e9062..71172d2 100644
--- a/segment_anything/modeling/__init__.py
+++ b/segment_anything/modeling/__init__.py
@@ -6,6 +6,7 @@
from .sam import Sam
from .image_encoder import ImageEncoderViT
+from .mask_decoder_hq import MaskDecoderHQ
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer
diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py
index 66351d9..7048651 100644
--- a/segment_anything/modeling/image_encoder.py
+++ b/segment_anything/modeling/image_encoder.py
@@ -108,12 +108,15 @@ class ImageEncoderViT(nn.Module):
if self.pos_embed is not None:
x = x + self.pos_embed
+ interm_embeddings=[]
for blk in self.blocks:
x = blk(x)
+ if blk.window_size == 0:
+ interm_embeddings.append(x)
x = self.neck(x.permute(0, 3, 1, 2))
- return x
+ return x, interm_embeddings
class Block(nn.Module):
diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py
index 5d2fdb0..242ecb7 100644
--- a/segment_anything/modeling/mask_decoder.py
+++ b/segment_anything/modeling/mask_decoder.py
@@ -75,6 +75,8 @@ class MaskDecoder(nn.Module):
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
+ hq_token_only: bool,
+ interm_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
diff --git a/segment_anything/modeling/mask_decoder_hq.py b/segment_anything/modeling/mask_decoder_hq.py
new file mode 100644
index 0000000..1e365e3
--- /dev/null
+++ b/segment_anything/modeling/mask_decoder_hq.py
@@ -0,0 +1,232 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Modified by HQ-SAM team
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoderHQ(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ vit_dim: int = 1024,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ transformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+ )
+
+ # HQ-SAM parameters
+ self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
+ self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token
+ self.num_mask_tokens = self.num_mask_tokens + 1
+
+ # three conv fusion layers for obtaining HQ-Feature
+ self.compress_vit_feat = nn.Sequential(
+ nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim),
+ nn.GELU(),
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
+
+ self.embedding_encoder = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ nn.GELU(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ )
+ self.embedding_maskfeature = nn.Sequential(
+ nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
+ LayerNorm2d(transformer_dim // 4),
+ nn.GELU(),
+ nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
+
+
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ hq_token_only: bool,
+ interm_embeddings: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ """
+ vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
+ hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
+
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ hq_features=hq_features,
+ )
+
+ # Select the correct mask or masks for output
+ if multimask_output:
+ # mask with highest score
+ mask_slice = slice(1,self.num_mask_tokens-1)
+ iou_pred = iou_pred[:, mask_slice]
+ iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
+ iou_pred = iou_pred.unsqueeze(1)
+ masks_multi = masks[:, mask_slice, :, :]
+ masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
+ else:
+ # singale mask output, default
+ mask_slice = slice(0, 1)
+ iou_pred = iou_pred[:,mask_slice]
+ masks_sam = masks[:,mask_slice]
+
+ masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]
+ if hq_token_only:
+ masks = masks_hq
+ else:
+ masks = masks_sam + masks_hq
+ # Prepare output
+ return masks, iou_pred
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ hq_features: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+
+ upscaled_embedding_sam = self.output_upscaling(src)
+ upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
+
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ if i < self.num_mask_tokens - 1:
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ else:
+ hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
+
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding_sam.shape
+
+ masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
+ masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
+ masks = torch.cat([masks_sam,masks_sam_hq],dim=1)
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/segment_anything/modeling/sam.py b/segment_anything/modeling/sam.py
index 8074cff..b928dfd 100644
--- a/segment_anything/modeling/sam.py
+++ b/segment_anything/modeling/sam.py
@@ -50,11 +50,11 @@ class Sam(nn.Module):
def device(self) -> Any:
return self.pixel_mean.device
- @torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
+ hq_token_only: bool =False,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
@@ -95,10 +95,11 @@ class Sam(nn.Module):
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
- image_embeddings = self.image_encoder(input_images)
+ image_embeddings, interm_embeddings = self.image_encoder(input_images)
+ interm_embeddings = interm_embeddings[0] # early layer
outputs = []
- for image_record, curr_embedding in zip(batched_input, image_embeddings):
+ for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
@@ -114,6 +115,8 @@ class Sam(nn.Module):
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
+ hq_token_only=hq_token_only,
+ interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
)
masks = self.postprocess_masks(
low_res_masks,
diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py
index 8a6e6d8..31458fb 100644
--- a/segment_anything/predictor.py
+++ b/segment_anything/predictor.py
@@ -7,7 +7,7 @@
import numpy as np
import torch
-from segment_anything.modeling import Sam
+from .modeling import Sam
from typing import Optional, Tuple
@@ -49,10 +49,12 @@ class SamPredictor:
"RGB",
"BGR",
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ # import pdb;pdb.set_trace()
if image_format != self.model.image_format:
image = image[..., ::-1]
# Transform the image to the form expected by the model
+ # import pdb;pdb.set_trace()
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
@@ -86,7 +88,7 @@ class SamPredictor:
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
- self.features = self.model.image_encoder(input_image)
+ self.features, self.interm_features = self.model.image_encoder(input_image)
self.is_image_set = True
def predict(
@@ -97,6 +99,7 @@ class SamPredictor:
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
+ hq_token_only: bool =False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict masks for the given input prompts, using the currently set image.
@@ -158,6 +161,7 @@ class SamPredictor:
mask_input_torch,
multimask_output,
return_logits=return_logits,
+ hq_token_only=hq_token_only,
)
masks_np = masks[0].detach().cpu().numpy()
@@ -174,6 +178,7 @@ class SamPredictor:
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
+ hq_token_only: bool =False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
@@ -232,6 +237,8 @@ class SamPredictor:
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
+ hq_token_only=hq_token_only,
+ interm_embeddings=self.interm_features,
)
# Upscale the masks to the original image resolution
diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py
index 3196bdf..8013dc4 100644
--- a/segment_anything/utils/onnx.py
+++ b/segment_anything/utils/onnx.py
@@ -25,7 +25,8 @@ class SamOnnxModel(nn.Module):
def __init__(
self,
model: Sam,
- return_single_mask: bool,
+ hq_token_only: bool = False,
+ multimask_output: bool = False,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
@@ -33,7 +34,8 @@ class SamOnnxModel(nn.Module):
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
- self.return_single_mask = return_single_mask
+ self.hq_token_only = hq_token_only
+ self.multimask_output = multimask_output
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
self.return_extra_metrics = return_extra_metrics
@@ -89,25 +91,12 @@ class SamOnnxModel(nn.Module):
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks
- def select_masks(
- self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Determine if we should return the multiclick mask or not from the number of points.
- # The reweighting is used to avoid control flow.
- score_reweight = torch.tensor(
- [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
- ).to(iou_preds.device)
- score = iou_preds + (num_points - 2.5) * score_reweight
- best_idx = torch.argmax(score, dim=1)
- masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
- iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
-
- return masks, iou_preds
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
+ interm_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
@@ -117,11 +106,15 @@ class SamOnnxModel(nn.Module):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
+ vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
+ hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
+
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
+ hq_features=hq_features,
)
if self.use_stability_score:
@@ -129,8 +122,26 @@ class SamOnnxModel(nn.Module):
masks, self.model.mask_threshold, self.stability_score_offset
)
- if self.return_single_mask:
- masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
+ if self.multimask_output:
+ # mask with highest score
+ mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
+ scores = scores[:, mask_slice]
+ scores, max_iou_idx = torch.max(scores,dim=1)
+ scores = scores.unsqueeze(1)
+ masks_multi = masks[:, mask_slice, :, :]
+ masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
+ else:
+ # singale mask output, default
+ mask_slice = slice(0, 1)
+ scores = scores[:,mask_slice]
+ masks_sam = masks[:,mask_slice]
+
+ masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
+
+ if self.hq_token_only:
+ masks = masks_hq
+ else:
+ masks = masks_sam + masks_hq
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
diff --git a/ui/MainWindow.py b/ui/MainWindow.py
index 8a4c574..656819d 100644
--- a/ui/MainWindow.py
+++ b/ui/MainWindow.py
@@ -88,6 +88,8 @@ class Ui_MainWindow(object):
self.menuMode.setObjectName("menuMode")
self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
self.menuContour_mode.setObjectName("menuContour_mode")
+ self.menuSAM_model = QtWidgets.QMenu(self.menubar)
+ self.menuSAM_model.setObjectName("menuSAM_model")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
@@ -352,6 +354,7 @@ class Ui_MainWindow(object):
self.menubar.addAction(self.menuFile.menuAction())
self.menubar.addAction(self.menuEdit.menuAction())
self.menubar.addAction(self.menuView.menuAction())
+ self.menubar.addAction(self.menuSAM_model.menuAction())
self.menubar.addAction(self.menuMode.menuAction())
self.menubar.addAction(self.menuTools.menuAction())
self.menubar.addAction(self.menuAbout.menuAction())
@@ -391,6 +394,7 @@ class Ui_MainWindow(object):
self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
self.menuMode.setTitle(_translate("MainWindow", "Mode"))
self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
+ self.menuSAM_model.setTitle(_translate("MainWindow", "SAM"))
self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))
diff --git a/ui/MainWindow.ui b/ui/MainWindow.ui
index a9e74ba..481b280 100644
--- a/ui/MainWindow.ui
+++ b/ui/MainWindow.ui
@@ -202,9 +202,15 @@
+
+
diff --git a/widgets/mainwindow.py b/widgets/mainwindow.py
index 613c49f..99bad76 100644
--- a/widgets/mainwindow.py
+++ b/widgets/mainwindow.py
@@ -34,8 +34,6 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super(MainWindow, self).__init__()
self.setupUi(self)
- self.init_ui()
- self.init_segment_anything()
self.image_root: str = None
self.label_root:str = None
@@ -58,41 +56,51 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.map_mode = MAPMode.LABEL
# 标注目标
self.current_label:Annotation = None
+ self.use_segment_anything = False
+ self.init_ui()
self.reload_cfg()
self.init_connect()
self.reset_action()
- def init_segment_anything(self):
- if os.path.exists('./segment_any/sam_vit_h_4b8939.pth'):
- self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_h_4b8939.pth'))
- self.segany = SegAny('./segment_any/sam_vit_h_4b8939.pth')
- self.use_segment_anything = True
- elif os.path.exists('./segment_any/sam_vit_l_0b3195.pth'):
- self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_l_0b3195.pth'))
- self.segany = SegAny('./segment_any/sam_vit_l_0b3195.pth')
- self.use_segment_anything = True
- elif os.path.exists('./segment_any/sam_vit_b_01ec64.pth'):
- self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_b_01ec64.pth'))
- self.segany = SegAny('./segment_any/sam_vit_b_01ec64.pth')
- self.use_segment_anything = True
- else:
- QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format('https://github.com/facebookresearch/segment-anything#model-checkpoints'))
+ def init_segment_anything(self, model_name, reload=False):
+ if model_name == '':
self.use_segment_anything = False
+ for name, action in self.pths_actions.items():
+ action.setChecked(model_name == name)
+ return
+ model_path = os.path.join('segment_any', model_name)
+ if not os.path.exists(model_path):
+ QtWidgets.QMessageBox.warning(self, 'Warning',
+ 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
+ 'https://github.com/facebookresearch/segment-anything#model-checkpoints'))
+ for name, action in self.pths_actions.items():
+ action.setChecked(model_name == name)
+ self.use_segment_anything = False
+ return
- if self.use_segment_anything:
- if self.segany.device != 'cpu':
- self.gpu_resource_thread = GPUResource_Thread()
- self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
- self.gpu_resource_thread.start()
+ self.segany = SegAny(model_path)
+ self.use_segment_anything = True
+ self.statusbar.showMessage('Use the checkpoint named {}.'.format(model_name), 3000)
+ for name, action in self.pths_actions.items():
+ action.setChecked(model_name==name)
+ if not reload:
+ if self.use_segment_anything:
+ if self.segany.device != 'cpu':
+ self.gpu_resource_thread = GPUResource_Thread()
+ self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
+ self.gpu_resource_thread.start()
+ else:
+ self.labelGPUResource.setText('cpu')
else:
- self.labelGPUResource.setText('cpu')
- else:
- self.labelGPUResource.setText('segment anything unused.')
+ self.labelGPUResource.setText('segment anything unused.')
+
+ if reload and self.current_index is not None:
+ self.show_image(self.current_index)
def init_ui(self):
- #
+ #q
self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
@@ -144,6 +152,17 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.statusbar.addPermanentWidget(self.labelData)
#
+ model_names = sorted([pth for pth in os.listdir('segment_any') if pth.endswith('.pth')])
+ self.pths_actions = {}
+ for model_name in model_names:
+ action = QtWidgets.QAction(self)
+ action.setObjectName("actionZoom_in")
+ action.triggered.connect(functools.partial(self.init_segment_anything, model_name))
+ action.setText("{}".format(model_name))
+ action.setCheckable(True)
+
+ self.pths_actions[model_name] = action
+ self.menuSAM_model.addAction(action)
self.toolBar.addSeparator()
self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
@@ -213,6 +232,9 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.cfg['mask_alpha'] = mask_alpha
self.mask_aplha.setValue(mask_alpha*10)
+ model_name = self.cfg.get('model_name', '')
+ self.init_segment_anything(model_name)
+
self.categories_dock_widget.update_widget()
def set_saved_state(self, is_saved:bool):