SAM-HQ支持,界面更新,添加模型选择

This commit is contained in:
yatengLG 2023-07-11 13:50:17 +08:00
parent db139d2c02
commit f7edb986d4
19 changed files with 473 additions and 866 deletions

View File

@ -1,813 +1,12 @@
{
"info": {
"description": "ISAT",
"folder": "C:/Users/lg/PycharmProjects/ISAT_with_segment_anything/example/images",
"folder": "/home/super/PycharmProjects/ISAT_with_segment_anything/example/images",
"name": "000000000144.jpg",
"width": 640,
"height": 480,
"depth": 3,
"note": ""
},
"objects": [
{
"category": "fence",
"group": "1",
"segmentation": [
[
20,
239
],
[
17,
240
],
[
15,
241
],
[
13,
242
],
[
11,
244
],
[
10,
244
],
[
9,
245
],
[
6,
249
],
[
5,
251
],
[
4,
255
],
[
0,
261
],
[
0,
280
],
[
1,
277
],
[
6,
268
],
[
9,
259
],
[
13,
254
],
[
13,
252
],
[
15,
250
],
[
18,
248
],
[
20,
247
],
[
24,
247
],
[
26,
248
],
[
30,
252
],
[
33,
256
],
[
35,
260
],
[
38,
266
],
[
41,
273
],
[
43,
280
],
[
50,
294
],
[
53,
297
],
[
54,
297
],
[
59,
301
],
[
61,
303
],
[
62,
316
],
[
63,
341
],
[
63,
382
],
[
64,
405
],
[
66,
414
],
[
71,
413
],
[
72,
402
],
[
72,
399
],
[
71,
359
],
[
69,
324
],
[
70,
300
],
[
73,
299
],
[
78,
294
],
[
80,
291
],
[
81,
289
],
[
83,
283
],
[
85,
278
],
[
86,
272
],
[
91,
258
],
[
94,
252
],
[
98,
248
],
[
105,
248
],
[
107,
249
],
[
111,
253
],
[
116,
261
],
[
118,
265
],
[
121,
271
],
[
128,
288
],
[
130,
291
],
[
135,
296
],
[
138,
298
],
[
140,
299
],
[
143,
300
],
[
148,
300
],
[
151,
299
],
[
153,
298
],
[
156,
296
],
[
159,
293
],
[
161,
290
],
[
164,
284
],
[
179,
255
],
[
180,
253
],
[
180,
251
],
[
181,
249
],
[
183,
247
],
[
186,
248
],
[
187,
259
],
[
188,
282
],
[
188,
303
],
[
190,
320
],
[
190,
343
],
[
191,
355
],
[
192,
357
],
[
196,
358
],
[
198,
357
],
[
200,
353
],
[
200,
341
],
[
198,
321
],
[
197,
300
],
[
196,
279
],
[
195,
252
],
[
195,
250
],
[
196,
248
],
[
200,
248
],
[
200,
249
],
[
202,
251
],
[
207,
260
],
[
208,
262
],
[
208,
264
],
[
210,
266
],
[
212,
273
],
[
215,
279
],
[
217,
281
],
[
218,
283
],
[
219,
287
],
[
221,
290
],
[
225,
294
],
[
227,
295
],
[
230,
296
],
[
235,
296
],
[
239,
295
],
[
241,
294
],
[
245,
290
],
[
246,
288
],
[
247,
286
],
[
248,
283
],
[
248,
279
],
[
245,
281
],
[
244,
283
],
[
244,
285
],
[
243,
287
],
[
238,
292
],
[
230,
292
],
[
228,
291
],
[
224,
287
],
[
222,
284
],
[
219,
278
],
[
214,
268
],
[
214,
266
],
[
207,
252
],
[
205,
249
],
[
201,
245
],
[
198,
243
],
[
194,
242
],
[
188,
242
],
[
183,
243
],
[
180,
244
],
[
176,
246
],
[
174,
249
],
[
172,
252
],
[
170,
256
],
[
169,
260
],
[
168,
262
],
[
166,
264
],
[
165,
266
],
[
164,
268
],
[
161,
276
],
[
156,
286
],
[
153,
290
],
[
152,
291
],
[
151,
291
],
[
149,
293
],
[
147,
293
],
[
144,
293
],
[
142,
293
],
[
140,
292
],
[
138,
290
],
[
136,
287
],
[
133,
281
],
[
125,
263
],
[
123,
259
],
[
120,
253
],
[
117,
248
],
[
112,
243
],
[
110,
242
],
[
107,
241
],
[
98,
241
],
[
96,
242
],
[
94,
243
],
[
89,
248
],
[
85,
256
],
[
83,
261
],
[
81,
269
],
[
78,
274
],
[
76,
282
],
[
74,
287
],
[
73,
289
],
[
67,
295
],
[
63,
295
],
[
61,
294
],
[
55,
288
],
[
53,
284
],
[
50,
278
],
[
49,
273
],
[
41,
254
],
[
40,
252
],
[
38,
249
],
[
35,
246
],
[
35,
245
],
[
33,
243
],
[
30,
241
],
[
28,
240
],
[
24,
239
]
],
"area": 4485.0,
"layer": 1.0,
"bbox": [
0.0,
239.0,
248.0,
414.0
],
"iscrowd": 0,
"note": ""
}
]
"objects": []
}

View File

@ -1,4 +1,4 @@
contour_mode: external
contour_mode: all
label:
- color: '#000000'
name: __background__

View File

@ -1,4 +1,4 @@
contour_mode: external
contour_mode: all
label:
- color: '#000000'
name: __background__

View File

@ -12,5 +12,5 @@ if __name__ == '__main__':
app = QtWidgets.QApplication([''])
mainwindow = MainWindow()
mainwindow.show()
sys.exit(app.exec_())
sys.exit(app.exec())

View File

@ -16,6 +16,7 @@ class SegAny:
self.model_type = "vit_h"
else:
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
torch.cuda.empty_cache()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)

View File

@ -11,5 +11,6 @@ from .build_sam import (
build_sam_vit_b,
sam_model_registry,
)
from .build_sam_baseline import sam_model_registry_baseline
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator

View File

@ -134,7 +134,7 @@ class SamAutomaticMaskGenerator:
self.output_mode = output_mode
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
@ -160,7 +160,7 @@ class SamAutomaticMaskGenerator:
"""
# Generate masks
mask_data = self._generate_masks(image)
mask_data = self._generate_masks(image, multimask_output)
# Filter small disconnected regions and holes in masks
if self.min_mask_region_area > 0:
@ -194,7 +194,7 @@ class SamAutomaticMaskGenerator:
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData:
orig_size = image.shape[:2]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
@ -203,7 +203,7 @@ class SamAutomaticMaskGenerator:
# Iterate over image crops
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output)
data.cat(crop_data)
# Remove duplicate masks between crops
@ -228,6 +228,7 @@ class SamAutomaticMaskGenerator:
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
multimask_output: bool = True,
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
@ -242,7 +243,7 @@ class SamAutomaticMaskGenerator:
# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output)
data.cat(batch_data)
del batch_data
self.predictor.reset_image()
@ -269,6 +270,7 @@ class SamAutomaticMaskGenerator:
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
multimask_output: bool = True,
) -> MaskData:
orig_h, orig_w = orig_size
@ -279,7 +281,7 @@ class SamAutomaticMaskGenerator:
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
multimask_output=True,
multimask_output=multimask_output,
return_logits=True,
)

View File

@ -8,7 +8,7 @@ import torch
from functools import partial
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
@ -84,7 +84,7 @@ def _build_sam(
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
mask_decoder=MaskDecoderHQ(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
@ -95,13 +95,19 @@ def _build_sam(
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
vit_dim=encoder_embed_dim,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
# sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
info = sam.load_state_dict(state_dict, strict=False)
print(info)
for n, p in sam.named_parameters():
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
p.requires_grad = False
return sam

View File

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from functools import partial
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
build_sam = build_sam_vit_h
def build_sam_vit_l(checkpoint=None):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
sam_model_registry_baseline = {
"default": build_sam_vit_h,
"vit_h": build_sam_vit_h,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam

View File

@ -6,6 +6,7 @@
from .sam import Sam
from .image_encoder import ImageEncoderViT
from .mask_decoder_hq import MaskDecoderHQ
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer

View File

@ -108,12 +108,15 @@ class ImageEncoderViT(nn.Module):
if self.pos_embed is not None:
x = x + self.pos_embed
interm_embeddings=[]
for blk in self.blocks:
x = blk(x)
if blk.window_size == 0:
interm_embeddings.append(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
return x, interm_embeddings
class Block(nn.Module):

View File

@ -75,6 +75,8 @@ class MaskDecoder(nn.Module):
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
hq_token_only: bool,
interm_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.

View File

@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Modified by HQ-SAM team
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple, Type
from .common import LayerNorm2d
class MaskDecoderHQ(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
vit_dim: int = 1024,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLP(
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
# HQ-SAM parameters
self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token
self.num_mask_tokens = self.num_mask_tokens + 1
# three conv fusion layers for obtaining HQ-Feature
self.compress_vit_feat = nn.Sequential(
nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim),
nn.GELU(),
nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
self.embedding_encoder = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
nn.GELU(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
)
self.embedding_maskfeature = nn.Sequential(
nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
LayerNorm2d(transformer_dim // 4),
nn.GELU(),
nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
hq_token_only: bool,
interm_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
hq_features=hq_features,
)
# Select the correct mask or masks for output
if multimask_output:
# mask with highest score
mask_slice = slice(1,self.num_mask_tokens-1)
iou_pred = iou_pred[:, mask_slice]
iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
iou_pred = iou_pred.unsqueeze(1)
masks_multi = masks[:, mask_slice, :, :]
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
else:
# singale mask output, default
mask_slice = slice(0, 1)
iou_pred = iou_pred[:,mask_slice]
masks_sam = masks[:,mask_slice]
masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]
if hq_token_only:
masks = masks_hq
else:
masks = masks_sam + masks_hq
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
hq_features: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding_sam = self.output_upscaling(src)
upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
if i < self.num_mask_tokens - 1:
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
else:
hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding_sam.shape
masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
masks = torch.cat([masks_sam,masks_sam_hq],dim=1)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x

View File

@ -50,11 +50,11 @@ class Sam(nn.Module):
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
hq_token_only: bool =False,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
@ -95,10 +95,11 @@ class Sam(nn.Module):
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
image_embeddings, interm_embeddings = self.image_encoder(input_images)
interm_embeddings = interm_embeddings[0] # early layer
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
@ -114,6 +115,8 @@ class Sam(nn.Module):
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
hq_token_only=hq_token_only,
interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
)
masks = self.postprocess_masks(
low_res_masks,

View File

@ -7,7 +7,7 @@
import numpy as np
import torch
from segment_anything.modeling import Sam
from .modeling import Sam
from typing import Optional, Tuple
@ -49,10 +49,12 @@ class SamPredictor:
"RGB",
"BGR",
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
# import pdb;pdb.set_trace()
if image_format != self.model.image_format:
image = image[..., ::-1]
# Transform the image to the form expected by the model
# import pdb;pdb.set_trace()
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
@ -86,7 +88,7 @@ class SamPredictor:
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features = self.model.image_encoder(input_image)
self.features, self.interm_features = self.model.image_encoder(input_image)
self.is_image_set = True
def predict(
@ -97,6 +99,7 @@ class SamPredictor:
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
hq_token_only: bool =False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict masks for the given input prompts, using the currently set image.
@ -158,6 +161,7 @@ class SamPredictor:
mask_input_torch,
multimask_output,
return_logits=return_logits,
hq_token_only=hq_token_only,
)
masks_np = masks[0].detach().cpu().numpy()
@ -174,6 +178,7 @@ class SamPredictor:
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
hq_token_only: bool =False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
@ -232,6 +237,8 @@ class SamPredictor:
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
hq_token_only=hq_token_only,
interm_embeddings=self.interm_features,
)
# Upscale the masks to the original image resolution

View File

@ -25,7 +25,8 @@ class SamOnnxModel(nn.Module):
def __init__(
self,
model: Sam,
return_single_mask: bool,
hq_token_only: bool = False,
multimask_output: bool = False,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
@ -33,7 +34,8 @@ class SamOnnxModel(nn.Module):
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
self.return_single_mask = return_single_mask
self.hq_token_only = hq_token_only
self.multimask_output = multimask_output
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
self.return_extra_metrics = return_extra_metrics
@ -89,25 +91,12 @@ class SamOnnxModel(nn.Module):
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks
def select_masks(
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Determine if we should return the multiclick mask or not from the number of points.
# The reweighting is used to avoid control flow.
score_reweight = torch.tensor(
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
).to(iou_preds.device)
score = iou_preds + (num_points - 2.5) * score_reweight
best_idx = torch.argmax(score, dim=1)
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
return masks, iou_preds
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
interm_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
@ -117,11 +106,15 @@ class SamOnnxModel(nn.Module):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
hq_features=hq_features,
)
if self.use_stability_score:
@ -129,8 +122,26 @@ class SamOnnxModel(nn.Module):
masks, self.model.mask_threshold, self.stability_score_offset
)
if self.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
if self.multimask_output:
# mask with highest score
mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
scores = scores[:, mask_slice]
scores, max_iou_idx = torch.max(scores,dim=1)
scores = scores.unsqueeze(1)
masks_multi = masks[:, mask_slice, :, :]
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
else:
# singale mask output, default
mask_slice = slice(0, 1)
scores = scores[:,mask_slice]
masks_sam = masks[:,mask_slice]
masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
if self.hq_token_only:
masks = masks_hq
else:
masks = masks_sam + masks_hq
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)

View File

@ -88,6 +88,8 @@ class Ui_MainWindow(object):
self.menuMode.setObjectName("menuMode")
self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
self.menuContour_mode.setObjectName("menuContour_mode")
self.menuSAM_model = QtWidgets.QMenu(self.menubar)
self.menuSAM_model.setObjectName("menuSAM_model")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
@ -352,6 +354,7 @@ class Ui_MainWindow(object):
self.menubar.addAction(self.menuFile.menuAction())
self.menubar.addAction(self.menuEdit.menuAction())
self.menubar.addAction(self.menuView.menuAction())
self.menubar.addAction(self.menuSAM_model.menuAction())
self.menubar.addAction(self.menuMode.menuAction())
self.menubar.addAction(self.menuTools.menuAction())
self.menubar.addAction(self.menuAbout.menuAction())
@ -391,6 +394,7 @@ class Ui_MainWindow(object):
self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
self.menuMode.setTitle(_translate("MainWindow", "Mode"))
self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
self.menuSAM_model.setTitle(_translate("MainWindow", "SAM"))
self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))

View File

@ -202,9 +202,15 @@
</widget>
<addaction name="menuContour_mode"/>
</widget>
<widget class="QMenu" name="menuSAM_model">
<property name="title">
<string>SAM</string>
</property>
</widget>
<addaction name="menuFile"/>
<addaction name="menuEdit"/>
<addaction name="menuView"/>
<addaction name="menuSAM_model"/>
<addaction name="menuMode"/>
<addaction name="menuTools"/>
<addaction name="menuAbout"/>

View File

@ -34,8 +34,6 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self):
super(MainWindow, self).__init__()
self.setupUi(self)
self.init_ui()
self.init_segment_anything()
self.image_root: str = None
self.label_root:str = None
@ -58,41 +56,51 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.map_mode = MAPMode.LABEL
# 标注目标
self.current_label:Annotation = None
self.use_segment_anything = False
self.init_ui()
self.reload_cfg()
self.init_connect()
self.reset_action()
def init_segment_anything(self):
if os.path.exists('./segment_any/sam_vit_h_4b8939.pth'):
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_h_4b8939.pth'))
self.segany = SegAny('./segment_any/sam_vit_h_4b8939.pth')
self.use_segment_anything = True
elif os.path.exists('./segment_any/sam_vit_l_0b3195.pth'):
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_l_0b3195.pth'))
self.segany = SegAny('./segment_any/sam_vit_l_0b3195.pth')
self.use_segment_anything = True
elif os.path.exists('./segment_any/sam_vit_b_01ec64.pth'):
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_b_01ec64.pth'))
self.segany = SegAny('./segment_any/sam_vit_b_01ec64.pth')
self.use_segment_anything = True
else:
QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format('https://github.com/facebookresearch/segment-anything#model-checkpoints'))
def init_segment_anything(self, model_name, reload=False):
if model_name == '':
self.use_segment_anything = False
for name, action in self.pths_actions.items():
action.setChecked(model_name == name)
return
model_path = os.path.join('segment_any', model_name)
if not os.path.exists(model_path):
QtWidgets.QMessageBox.warning(self, 'Warning',
'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
'https://github.com/facebookresearch/segment-anything#model-checkpoints'))
for name, action in self.pths_actions.items():
action.setChecked(model_name == name)
self.use_segment_anything = False
return
if self.use_segment_anything:
if self.segany.device != 'cpu':
self.gpu_resource_thread = GPUResource_Thread()
self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
self.gpu_resource_thread.start()
self.segany = SegAny(model_path)
self.use_segment_anything = True
self.statusbar.showMessage('Use the checkpoint named {}.'.format(model_name), 3000)
for name, action in self.pths_actions.items():
action.setChecked(model_name==name)
if not reload:
if self.use_segment_anything:
if self.segany.device != 'cpu':
self.gpu_resource_thread = GPUResource_Thread()
self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
self.gpu_resource_thread.start()
else:
self.labelGPUResource.setText('cpu')
else:
self.labelGPUResource.setText('cpu')
else:
self.labelGPUResource.setText('segment anything unused.')
self.labelGPUResource.setText('segment anything unused.')
if reload and self.current_index is not None:
self.show_image(self.current_index)
def init_ui(self):
#
#q
self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
@ -144,6 +152,17 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.statusbar.addPermanentWidget(self.labelData)
#
model_names = sorted([pth for pth in os.listdir('segment_any') if pth.endswith('.pth')])
self.pths_actions = {}
for model_name in model_names:
action = QtWidgets.QAction(self)
action.setObjectName("actionZoom_in")
action.triggered.connect(functools.partial(self.init_segment_anything, model_name))
action.setText("{}".format(model_name))
action.setCheckable(True)
self.pths_actions[model_name] = action
self.menuSAM_model.addAction(action)
self.toolBar.addSeparator()
self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
@ -213,6 +232,9 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
self.cfg['mask_alpha'] = mask_alpha
self.mask_aplha.setValue(mask_alpha*10)
model_name = self.cfg.get('model_name', '')
self.init_segment_anything(model_name)
self.categories_dock_widget.update_widget()
def set_saved_state(self, is_saved:bool):