diff --git a/example/images/000000000144.json b/example/images/000000000144.json index 043e62b..8acc9f0 100644 --- a/example/images/000000000144.json +++ b/example/images/000000000144.json @@ -1,813 +1,12 @@ { "info": { "description": "ISAT", - "folder": "C:/Users/lg/PycharmProjects/ISAT_with_segment_anything/example/images", + "folder": "/home/super/PycharmProjects/ISAT_with_segment_anything/example/images", "name": "000000000144.jpg", "width": 640, "height": 480, "depth": 3, "note": "" }, - "objects": [ - { - "category": "fence", - "group": "1", - "segmentation": [ - [ - 20, - 239 - ], - [ - 17, - 240 - ], - [ - 15, - 241 - ], - [ - 13, - 242 - ], - [ - 11, - 244 - ], - [ - 10, - 244 - ], - [ - 9, - 245 - ], - [ - 6, - 249 - ], - [ - 5, - 251 - ], - [ - 4, - 255 - ], - [ - 0, - 261 - ], - [ - 0, - 280 - ], - [ - 1, - 277 - ], - [ - 6, - 268 - ], - [ - 9, - 259 - ], - [ - 13, - 254 - ], - [ - 13, - 252 - ], - [ - 15, - 250 - ], - [ - 18, - 248 - ], - [ - 20, - 247 - ], - [ - 24, - 247 - ], - [ - 26, - 248 - ], - [ - 30, - 252 - ], - [ - 33, - 256 - ], - [ - 35, - 260 - ], - [ - 38, - 266 - ], - [ - 41, - 273 - ], - [ - 43, - 280 - ], - [ - 50, - 294 - ], - [ - 53, - 297 - ], - [ - 54, - 297 - ], - [ - 59, - 301 - ], - [ - 61, - 303 - ], - [ - 62, - 316 - ], - [ - 63, - 341 - ], - [ - 63, - 382 - ], - [ - 64, - 405 - ], - [ - 66, - 414 - ], - [ - 71, - 413 - ], - [ - 72, - 402 - ], - [ - 72, - 399 - ], - [ - 71, - 359 - ], - [ - 69, - 324 - ], - [ - 70, - 300 - ], - [ - 73, - 299 - ], - [ - 78, - 294 - ], - [ - 80, - 291 - ], - [ - 81, - 289 - ], - [ - 83, - 283 - ], - [ - 85, - 278 - ], - [ - 86, - 272 - ], - [ - 91, - 258 - ], - [ - 94, - 252 - ], - [ - 98, - 248 - ], - [ - 105, - 248 - ], - [ - 107, - 249 - ], - [ - 111, - 253 - ], - [ - 116, - 261 - ], - [ - 118, - 265 - ], - [ - 121, - 271 - ], - [ - 128, - 288 - ], - [ - 130, - 291 - ], - [ - 135, - 296 - ], - [ - 138, - 298 - ], - [ - 140, - 299 - ], - [ - 143, - 300 - ], - [ - 148, - 300 - ], - [ - 151, - 299 - ], - [ - 153, - 298 - ], - [ - 156, - 296 - ], - [ - 159, - 293 - ], - [ - 161, - 290 - ], - [ - 164, - 284 - ], - [ - 179, - 255 - ], - [ - 180, - 253 - ], - [ - 180, - 251 - ], - [ - 181, - 249 - ], - [ - 183, - 247 - ], - [ - 186, - 248 - ], - [ - 187, - 259 - ], - [ - 188, - 282 - ], - [ - 188, - 303 - ], - [ - 190, - 320 - ], - [ - 190, - 343 - ], - [ - 191, - 355 - ], - [ - 192, - 357 - ], - [ - 196, - 358 - ], - [ - 198, - 357 - ], - [ - 200, - 353 - ], - [ - 200, - 341 - ], - [ - 198, - 321 - ], - [ - 197, - 300 - ], - [ - 196, - 279 - ], - [ - 195, - 252 - ], - [ - 195, - 250 - ], - [ - 196, - 248 - ], - [ - 200, - 248 - ], - [ - 200, - 249 - ], - [ - 202, - 251 - ], - [ - 207, - 260 - ], - [ - 208, - 262 - ], - [ - 208, - 264 - ], - [ - 210, - 266 - ], - [ - 212, - 273 - ], - [ - 215, - 279 - ], - [ - 217, - 281 - ], - [ - 218, - 283 - ], - [ - 219, - 287 - ], - [ - 221, - 290 - ], - [ - 225, - 294 - ], - [ - 227, - 295 - ], - [ - 230, - 296 - ], - [ - 235, - 296 - ], - [ - 239, - 295 - ], - [ - 241, - 294 - ], - [ - 245, - 290 - ], - [ - 246, - 288 - ], - [ - 247, - 286 - ], - [ - 248, - 283 - ], - [ - 248, - 279 - ], - [ - 245, - 281 - ], - [ - 244, - 283 - ], - [ - 244, - 285 - ], - [ - 243, - 287 - ], - [ - 238, - 292 - ], - [ - 230, - 292 - ], - [ - 228, - 291 - ], - [ - 224, - 287 - ], - [ - 222, - 284 - ], - [ - 219, - 278 - ], - [ - 214, - 268 - ], - [ - 214, - 266 - ], - [ - 207, - 252 - ], - [ - 205, - 249 - ], - [ - 201, - 245 - ], - [ - 198, - 243 - ], - [ - 194, - 242 - ], - [ - 188, - 242 - ], - [ - 183, - 243 - ], - [ - 180, - 244 - ], - [ - 176, - 246 - ], - [ - 174, - 249 - ], - [ - 172, - 252 - ], - [ - 170, - 256 - ], - [ - 169, - 260 - ], - [ - 168, - 262 - ], - [ - 166, - 264 - ], - [ - 165, - 266 - ], - [ - 164, - 268 - ], - [ - 161, - 276 - ], - [ - 156, - 286 - ], - [ - 153, - 290 - ], - [ - 152, - 291 - ], - [ - 151, - 291 - ], - [ - 149, - 293 - ], - [ - 147, - 293 - ], - [ - 144, - 293 - ], - [ - 142, - 293 - ], - [ - 140, - 292 - ], - [ - 138, - 290 - ], - [ - 136, - 287 - ], - [ - 133, - 281 - ], - [ - 125, - 263 - ], - [ - 123, - 259 - ], - [ - 120, - 253 - ], - [ - 117, - 248 - ], - [ - 112, - 243 - ], - [ - 110, - 242 - ], - [ - 107, - 241 - ], - [ - 98, - 241 - ], - [ - 96, - 242 - ], - [ - 94, - 243 - ], - [ - 89, - 248 - ], - [ - 85, - 256 - ], - [ - 83, - 261 - ], - [ - 81, - 269 - ], - [ - 78, - 274 - ], - [ - 76, - 282 - ], - [ - 74, - 287 - ], - [ - 73, - 289 - ], - [ - 67, - 295 - ], - [ - 63, - 295 - ], - [ - 61, - 294 - ], - [ - 55, - 288 - ], - [ - 53, - 284 - ], - [ - 50, - 278 - ], - [ - 49, - 273 - ], - [ - 41, - 254 - ], - [ - 40, - 252 - ], - [ - 38, - 249 - ], - [ - 35, - 246 - ], - [ - 35, - 245 - ], - [ - 33, - 243 - ], - [ - 30, - 241 - ], - [ - 28, - 240 - ], - [ - 24, - 239 - ] - ], - "area": 4485.0, - "layer": 1.0, - "bbox": [ - 0.0, - 239.0, - 248.0, - 414.0 - ], - "iscrowd": 0, - "note": "" - } - ] + "objects": [] } \ No newline at end of file diff --git a/example/images/isat.yaml b/example/images/isat.yaml index cf05e28..bec54bb 100644 --- a/example/images/isat.yaml +++ b/example/images/isat.yaml @@ -1,4 +1,4 @@ -contour_mode: external +contour_mode: all label: - color: '#000000' name: __background__ diff --git a/isat.yaml b/isat.yaml index cf05e28..bec54bb 100644 --- a/isat.yaml +++ b/isat.yaml @@ -1,4 +1,4 @@ -contour_mode: external +contour_mode: all label: - color: '#000000' name: __background__ diff --git a/main.py b/main.py index 6791910..cea8269 100644 --- a/main.py +++ b/main.py @@ -12,5 +12,5 @@ if __name__ == '__main__': app = QtWidgets.QApplication(['']) mainwindow = MainWindow() mainwindow.show() - sys.exit(app.exec_()) + sys.exit(app.exec()) diff --git a/segment_any/segment_any.py b/segment_any/segment_any.py index ff80abe..9f8a2f5 100644 --- a/segment_any/segment_any.py +++ b/segment_any/segment_any.py @@ -16,6 +16,7 @@ class SegAny: self.model_type = "vit_h" else: raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint)) + torch.cuda.empty_cache() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' sam = sam_model_registry[self.model_type](checkpoint=checkpoint) diff --git a/segment_anything/__init__.py b/segment_anything/__init__.py index 34383d8..d576507 100644 --- a/segment_anything/__init__.py +++ b/segment_anything/__init__.py @@ -11,5 +11,6 @@ from .build_sam import ( build_sam_vit_b, sam_model_registry, ) +from .build_sam_baseline import sam_model_registry_baseline from .predictor import SamPredictor from .automatic_mask_generator import SamAutomaticMaskGenerator diff --git a/segment_anything/automatic_mask_generator.py b/segment_anything/automatic_mask_generator.py index d5a8c96..427ebeb 100644 --- a/segment_anything/automatic_mask_generator.py +++ b/segment_anything/automatic_mask_generator.py @@ -134,7 +134,7 @@ class SamAutomaticMaskGenerator: self.output_mode = output_mode @torch.no_grad() - def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]: """ Generates masks for the given image. @@ -160,7 +160,7 @@ class SamAutomaticMaskGenerator: """ # Generate masks - mask_data = self._generate_masks(image) + mask_data = self._generate_masks(image, multimask_output) # Filter small disconnected regions and holes in masks if self.min_mask_region_area > 0: @@ -194,7 +194,7 @@ class SamAutomaticMaskGenerator: return curr_anns - def _generate_masks(self, image: np.ndarray) -> MaskData: + def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData: orig_size = image.shape[:2] crop_boxes, layer_idxs = generate_crop_boxes( orig_size, self.crop_n_layers, self.crop_overlap_ratio @@ -203,7 +203,7 @@ class SamAutomaticMaskGenerator: # Iterate over image crops data = MaskData() for crop_box, layer_idx in zip(crop_boxes, layer_idxs): - crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output) data.cat(crop_data) # Remove duplicate masks between crops @@ -228,6 +228,7 @@ class SamAutomaticMaskGenerator: crop_box: List[int], crop_layer_idx: int, orig_size: Tuple[int, ...], + multimask_output: bool = True, ) -> MaskData: # Crop the image and calculate embeddings x0, y0, x1, y1 = crop_box @@ -242,7 +243,7 @@ class SamAutomaticMaskGenerator: # Generate masks for this crop in batches data = MaskData() for (points,) in batch_iterator(self.points_per_batch, points_for_image): - batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output) data.cat(batch_data) del batch_data self.predictor.reset_image() @@ -269,6 +270,7 @@ class SamAutomaticMaskGenerator: im_size: Tuple[int, ...], crop_box: List[int], orig_size: Tuple[int, ...], + multimask_output: bool = True, ) -> MaskData: orig_h, orig_w = orig_size @@ -279,7 +281,7 @@ class SamAutomaticMaskGenerator: masks, iou_preds, _ = self.predictor.predict_torch( in_points[:, None, :], in_labels[:, None], - multimask_output=True, + multimask_output=multimask_output, return_logits=True, ) diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py index 37cd245..b280cf4 100644 --- a/segment_anything/build_sam.py +++ b/segment_anything/build_sam.py @@ -8,7 +8,7 @@ import torch from functools import partial -from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer +from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer def build_sam_vit_h(checkpoint=None): @@ -84,7 +84,7 @@ def _build_sam( input_image_size=(image_size, image_size), mask_in_chans=16, ), - mask_decoder=MaskDecoder( + mask_decoder=MaskDecoderHQ( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, @@ -95,13 +95,19 @@ def _build_sam( transformer_dim=prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, + vit_dim=encoder_embed_dim, ), pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) - sam.eval() + # sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: state_dict = torch.load(f) - sam.load_state_dict(state_dict) + info = sam.load_state_dict(state_dict, strict=False) + print(info) + for n, p in sam.named_parameters(): + if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: + p.requires_grad = False + return sam diff --git a/segment_anything/build_sam_baseline.py b/segment_anything/build_sam_baseline.py new file mode 100644 index 0000000..8f14970 --- /dev/null +++ b/segment_anything/build_sam_baseline.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry_baseline = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam \ No newline at end of file diff --git a/segment_anything/modeling/__init__.py b/segment_anything/modeling/__init__.py index 38e9062..71172d2 100644 --- a/segment_anything/modeling/__init__.py +++ b/segment_anything/modeling/__init__.py @@ -6,6 +6,7 @@ from .sam import Sam from .image_encoder import ImageEncoderViT +from .mask_decoder_hq import MaskDecoderHQ from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder from .transformer import TwoWayTransformer diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py index 66351d9..7048651 100644 --- a/segment_anything/modeling/image_encoder.py +++ b/segment_anything/modeling/image_encoder.py @@ -108,12 +108,15 @@ class ImageEncoderViT(nn.Module): if self.pos_embed is not None: x = x + self.pos_embed + interm_embeddings=[] for blk in self.blocks: x = blk(x) + if blk.window_size == 0: + interm_embeddings.append(x) x = self.neck(x.permute(0, 3, 1, 2)) - return x + return x, interm_embeddings class Block(nn.Module): diff --git a/segment_anything/modeling/mask_decoder.py b/segment_anything/modeling/mask_decoder.py index 5d2fdb0..242ecb7 100644 --- a/segment_anything/modeling/mask_decoder.py +++ b/segment_anything/modeling/mask_decoder.py @@ -75,6 +75,8 @@ class MaskDecoder(nn.Module): sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. diff --git a/segment_anything/modeling/mask_decoder_hq.py b/segment_anything/modeling/mask_decoder_hq.py new file mode 100644 index 0000000..1e365e3 --- /dev/null +++ b/segment_anything/modeling/mask_decoder_hq.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Modified by HQ-SAM team +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoderHQ(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + vit_dim: int = 1024, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + # HQ-SAM parameters + self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token + self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token + self.num_mask_tokens = self.num_mask_tokens + 1 + + # three conv fusion layers for obtaining HQ-Feature + self.compress_vit_feat = nn.Sequential( + nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)) + + self.embedding_encoder = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + ) + self.embedding_maskfeature = nn.Sequential( + nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), + LayerNorm2d(transformer_dim // 4), + nn.GELU(), + nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)) + + + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + interm_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the ViT image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + hq_features=hq_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + # mask with highest score + mask_slice = slice(1,self.num_mask_tokens-1) + iou_pred = iou_pred[:, mask_slice] + iou_pred, max_iou_idx = torch.max(iou_pred,dim=1) + iou_pred = iou_pred.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + iou_pred = iou_pred[:,mask_slice] + masks_sam = masks[:,mask_slice] + + masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)] + if hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + hq_features: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + + upscaled_embedding_sam = self.output_upscaling(src) + upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + if i < self.num_mask_tokens - 1: + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + else: + hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :])) + + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding_sam.shape + + masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) + masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) + masks = torch.cat([masks_sam,masks_sam_hq],dim=1) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/segment_anything/modeling/sam.py b/segment_anything/modeling/sam.py index 8074cff..b928dfd 100644 --- a/segment_anything/modeling/sam.py +++ b/segment_anything/modeling/sam.py @@ -50,11 +50,11 @@ class Sam(nn.Module): def device(self) -> Any: return self.pixel_mean.device - @torch.no_grad() def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool, + hq_token_only: bool =False, ) -> List[Dict[str, torch.Tensor]]: """ Predicts masks end-to-end from provided images and prompts. @@ -95,10 +95,11 @@ class Sam(nn.Module): to subsequent iterations of prediction. """ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) - image_embeddings = self.image_encoder(input_images) + image_embeddings, interm_embeddings = self.image_encoder(input_images) + interm_embeddings = interm_embeddings[0] # early layer outputs = [] - for image_record, curr_embedding in zip(batched_input, image_embeddings): + for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: @@ -114,6 +115,8 @@ class Sam(nn.Module): sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0), ) masks = self.postprocess_masks( low_res_masks, diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 8a6e6d8..31458fb 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -7,7 +7,7 @@ import numpy as np import torch -from segment_anything.modeling import Sam +from .modeling import Sam from typing import Optional, Tuple @@ -49,10 +49,12 @@ class SamPredictor: "RGB", "BGR", ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + # import pdb;pdb.set_trace() if image_format != self.model.image_format: image = image[..., ::-1] # Transform the image to the form expected by the model + # import pdb;pdb.set_trace() input_image = self.transform.apply_image(image) input_image_torch = torch.as_tensor(input_image, device=self.device) input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] @@ -86,7 +88,7 @@ class SamPredictor: self.original_size = original_image_size self.input_size = tuple(transformed_image.shape[-2:]) input_image = self.model.preprocess(transformed_image) - self.features = self.model.image_encoder(input_image) + self.features, self.interm_features = self.model.image_encoder(input_image) self.is_image_set = True def predict( @@ -97,6 +99,7 @@ class SamPredictor: mask_input: Optional[np.ndarray] = None, multimask_output: bool = True, return_logits: bool = False, + hq_token_only: bool =False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. @@ -158,6 +161,7 @@ class SamPredictor: mask_input_torch, multimask_output, return_logits=return_logits, + hq_token_only=hq_token_only, ) masks_np = masks[0].detach().cpu().numpy() @@ -174,6 +178,7 @@ class SamPredictor: mask_input: Optional[torch.Tensor] = None, multimask_output: bool = True, return_logits: bool = False, + hq_token_only: bool =False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. @@ -232,6 +237,8 @@ class SamPredictor: sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, + hq_token_only=hq_token_only, + interm_embeddings=self.interm_features, ) # Upscale the masks to the original image resolution diff --git a/segment_anything/utils/onnx.py b/segment_anything/utils/onnx.py index 3196bdf..8013dc4 100644 --- a/segment_anything/utils/onnx.py +++ b/segment_anything/utils/onnx.py @@ -25,7 +25,8 @@ class SamOnnxModel(nn.Module): def __init__( self, model: Sam, - return_single_mask: bool, + hq_token_only: bool = False, + multimask_output: bool = False, use_stability_score: bool = False, return_extra_metrics: bool = False, ) -> None: @@ -33,7 +34,8 @@ class SamOnnxModel(nn.Module): self.mask_decoder = model.mask_decoder self.model = model self.img_size = model.image_encoder.img_size - self.return_single_mask = return_single_mask + self.hq_token_only = hq_token_only + self.multimask_output = multimask_output self.use_stability_score = use_stability_score self.stability_score_offset = 1.0 self.return_extra_metrics = return_extra_metrics @@ -89,25 +91,12 @@ class SamOnnxModel(nn.Module): masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) return masks - def select_masks( - self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Determine if we should return the multiclick mask or not from the number of points. - # The reweighting is used to avoid control flow. - score_reweight = torch.tensor( - [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] - ).to(iou_preds.device) - score = iou_preds + (num_points - 2.5) * score_reweight - best_idx = torch.argmax(score, dim=1) - masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) - iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) - - return masks, iou_preds @torch.no_grad() def forward( self, image_embeddings: torch.Tensor, + interm_embeddings: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, mask_input: torch.Tensor, @@ -117,11 +106,15 @@ class SamOnnxModel(nn.Module): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self._embed_masks(mask_input, has_mask_input) + vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT + hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features) + masks, scores = self.model.mask_decoder.predict_masks( image_embeddings=image_embeddings, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embedding, dense_prompt_embeddings=dense_embedding, + hq_features=hq_features, ) if self.use_stability_score: @@ -129,8 +122,26 @@ class SamOnnxModel(nn.Module): masks, self.model.mask_threshold, self.stability_score_offset ) - if self.return_single_mask: - masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + if self.multimask_output: + # mask with highest score + mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1) + scores = scores[:, mask_slice] + scores, max_iou_idx = torch.max(scores,dim=1) + scores = scores.unsqueeze(1) + masks_multi = masks[:, mask_slice, :, :] + masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1) + else: + # singale mask output, default + mask_slice = slice(0, 1) + scores = scores[:,mask_slice] + masks_sam = masks[:,mask_slice] + + masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)] + + if self.hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq upscaled_masks = self.mask_postprocessing(masks, orig_im_size) diff --git a/ui/MainWindow.py b/ui/MainWindow.py index 8a4c574..656819d 100644 --- a/ui/MainWindow.py +++ b/ui/MainWindow.py @@ -88,6 +88,8 @@ class Ui_MainWindow(object): self.menuMode.setObjectName("menuMode") self.menuContour_mode = QtWidgets.QMenu(self.menuMode) self.menuContour_mode.setObjectName("menuContour_mode") + self.menuSAM_model = QtWidgets.QMenu(self.menubar) + self.menuSAM_model.setObjectName("menuSAM_model") MainWindow.setMenuBar(self.menubar) self.statusbar = QtWidgets.QStatusBar(MainWindow) self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight) @@ -352,6 +354,7 @@ class Ui_MainWindow(object): self.menubar.addAction(self.menuFile.menuAction()) self.menubar.addAction(self.menuEdit.menuAction()) self.menubar.addAction(self.menuView.menuAction()) + self.menubar.addAction(self.menuSAM_model.menuAction()) self.menubar.addAction(self.menuMode.menuAction()) self.menubar.addAction(self.menuTools.menuAction()) self.menubar.addAction(self.menuAbout.menuAction()) @@ -391,6 +394,7 @@ class Ui_MainWindow(object): self.menuEdit.setTitle(_translate("MainWindow", "Edit")) self.menuMode.setTitle(_translate("MainWindow", "Mode")) self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode")) + self.menuSAM_model.setTitle(_translate("MainWindow", "SAM")) self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar")) self.info_dock.setWindowTitle(_translate("MainWindow", "Info")) self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos")) diff --git a/ui/MainWindow.ui b/ui/MainWindow.ui index a9e74ba..481b280 100644 --- a/ui/MainWindow.ui +++ b/ui/MainWindow.ui @@ -202,9 +202,15 @@ + + + SAM + + + diff --git a/widgets/mainwindow.py b/widgets/mainwindow.py index 613c49f..99bad76 100644 --- a/widgets/mainwindow.py +++ b/widgets/mainwindow.py @@ -34,8 +34,6 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): def __init__(self): super(MainWindow, self).__init__() self.setupUi(self) - self.init_ui() - self.init_segment_anything() self.image_root: str = None self.label_root:str = None @@ -58,41 +56,51 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): self.map_mode = MAPMode.LABEL # 标注目标 self.current_label:Annotation = None + self.use_segment_anything = False + self.init_ui() self.reload_cfg() self.init_connect() self.reset_action() - def init_segment_anything(self): - if os.path.exists('./segment_any/sam_vit_h_4b8939.pth'): - self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_h_4b8939.pth')) - self.segany = SegAny('./segment_any/sam_vit_h_4b8939.pth') - self.use_segment_anything = True - elif os.path.exists('./segment_any/sam_vit_l_0b3195.pth'): - self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_l_0b3195.pth')) - self.segany = SegAny('./segment_any/sam_vit_l_0b3195.pth') - self.use_segment_anything = True - elif os.path.exists('./segment_any/sam_vit_b_01ec64.pth'): - self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_b_01ec64.pth')) - self.segany = SegAny('./segment_any/sam_vit_b_01ec64.pth') - self.use_segment_anything = True - else: - QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format('https://github.com/facebookresearch/segment-anything#model-checkpoints')) + def init_segment_anything(self, model_name, reload=False): + if model_name == '': self.use_segment_anything = False + for name, action in self.pths_actions.items(): + action.setChecked(model_name == name) + return + model_path = os.path.join('segment_any', model_name) + if not os.path.exists(model_path): + QtWidgets.QMessageBox.warning(self, 'Warning', + 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format( + 'https://github.com/facebookresearch/segment-anything#model-checkpoints')) + for name, action in self.pths_actions.items(): + action.setChecked(model_name == name) + self.use_segment_anything = False + return - if self.use_segment_anything: - if self.segany.device != 'cpu': - self.gpu_resource_thread = GPUResource_Thread() - self.gpu_resource_thread.message.connect(self.labelGPUResource.setText) - self.gpu_resource_thread.start() + self.segany = SegAny(model_path) + self.use_segment_anything = True + self.statusbar.showMessage('Use the checkpoint named {}.'.format(model_name), 3000) + for name, action in self.pths_actions.items(): + action.setChecked(model_name==name) + if not reload: + if self.use_segment_anything: + if self.segany.device != 'cpu': + self.gpu_resource_thread = GPUResource_Thread() + self.gpu_resource_thread.message.connect(self.labelGPUResource.setText) + self.gpu_resource_thread.start() + else: + self.labelGPUResource.setText('cpu') else: - self.labelGPUResource.setText('cpu') - else: - self.labelGPUResource.setText('segment anything unused.') + self.labelGPUResource.setText('segment anything unused.') + + if reload and self.current_index is not None: + self.show_image(self.current_index) def init_ui(self): - # + #q self.setting_dialog = SettingDialog(parent=self, mainwindow=self) self.categories_dock_widget = CategoriesDockWidget(mainwindow=self) @@ -144,6 +152,17 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): self.statusbar.addPermanentWidget(self.labelData) # + model_names = sorted([pth for pth in os.listdir('segment_any') if pth.endswith('.pth')]) + self.pths_actions = {} + for model_name in model_names: + action = QtWidgets.QAction(self) + action.setObjectName("actionZoom_in") + action.triggered.connect(functools.partial(self.init_segment_anything, model_name)) + action.setText("{}".format(model_name)) + action.setCheckable(True) + + self.pths_actions[model_name] = action + self.menuSAM_model.addAction(action) self.toolBar.addSeparator() self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self) @@ -213,6 +232,9 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow): self.cfg['mask_alpha'] = mask_alpha self.mask_aplha.setValue(mask_alpha*10) + model_name = self.cfg.get('model_name', '') + self.init_segment_anything(model_name) + self.categories_dock_widget.update_widget() def set_saved_state(self, is_saved:bool):