SAM-HQ支持,界面更新,添加模型选择
This commit is contained in:
		
							parent
							
								
									db139d2c02
								
							
						
					
					
						commit
						f7edb986d4
					
				@ -1,813 +1,12 @@
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
    "info": {
 | 
					    "info": {
 | 
				
			||||||
        "description": "ISAT",
 | 
					        "description": "ISAT",
 | 
				
			||||||
        "folder": "C:/Users/lg/PycharmProjects/ISAT_with_segment_anything/example/images",
 | 
					        "folder": "/home/super/PycharmProjects/ISAT_with_segment_anything/example/images",
 | 
				
			||||||
        "name": "000000000144.jpg",
 | 
					        "name": "000000000144.jpg",
 | 
				
			||||||
        "width": 640,
 | 
					        "width": 640,
 | 
				
			||||||
        "height": 480,
 | 
					        "height": 480,
 | 
				
			||||||
        "depth": 3,
 | 
					        "depth": 3,
 | 
				
			||||||
        "note": ""
 | 
					        "note": ""
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    "objects": [
 | 
					    "objects": []
 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            "category": "fence",
 | 
					 | 
				
			||||||
            "group": "1",
 | 
					 | 
				
			||||||
            "segmentation": [
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    20,
 | 
					 | 
				
			||||||
                    239
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    17,
 | 
					 | 
				
			||||||
                    240
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    15,
 | 
					 | 
				
			||||||
                    241
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    13,
 | 
					 | 
				
			||||||
                    242
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    11,
 | 
					 | 
				
			||||||
                    244
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    10,
 | 
					 | 
				
			||||||
                    244
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    9,
 | 
					 | 
				
			||||||
                    245
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    6,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    5,
 | 
					 | 
				
			||||||
                    251
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    4,
 | 
					 | 
				
			||||||
                    255
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    0,
 | 
					 | 
				
			||||||
                    261
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    0,
 | 
					 | 
				
			||||||
                    280
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    1,
 | 
					 | 
				
			||||||
                    277
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    6,
 | 
					 | 
				
			||||||
                    268
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    9,
 | 
					 | 
				
			||||||
                    259
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    13,
 | 
					 | 
				
			||||||
                    254
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    13,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    15,
 | 
					 | 
				
			||||||
                    250
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    18,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    20,
 | 
					 | 
				
			||||||
                    247
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    24,
 | 
					 | 
				
			||||||
                    247
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    26,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    30,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    33,
 | 
					 | 
				
			||||||
                    256
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    35,
 | 
					 | 
				
			||||||
                    260
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    38,
 | 
					 | 
				
			||||||
                    266
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    41,
 | 
					 | 
				
			||||||
                    273
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    43,
 | 
					 | 
				
			||||||
                    280
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    50,
 | 
					 | 
				
			||||||
                    294
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    53,
 | 
					 | 
				
			||||||
                    297
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    54,
 | 
					 | 
				
			||||||
                    297
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    59,
 | 
					 | 
				
			||||||
                    301
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    61,
 | 
					 | 
				
			||||||
                    303
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    62,
 | 
					 | 
				
			||||||
                    316
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    63,
 | 
					 | 
				
			||||||
                    341
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    63,
 | 
					 | 
				
			||||||
                    382
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    64,
 | 
					 | 
				
			||||||
                    405
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    66,
 | 
					 | 
				
			||||||
                    414
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    71,
 | 
					 | 
				
			||||||
                    413
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    72,
 | 
					 | 
				
			||||||
                    402
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    72,
 | 
					 | 
				
			||||||
                    399
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    71,
 | 
					 | 
				
			||||||
                    359
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    69,
 | 
					 | 
				
			||||||
                    324
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    70,
 | 
					 | 
				
			||||||
                    300
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    73,
 | 
					 | 
				
			||||||
                    299
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    78,
 | 
					 | 
				
			||||||
                    294
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    80,
 | 
					 | 
				
			||||||
                    291
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    81,
 | 
					 | 
				
			||||||
                    289
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    83,
 | 
					 | 
				
			||||||
                    283
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    85,
 | 
					 | 
				
			||||||
                    278
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    86,
 | 
					 | 
				
			||||||
                    272
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    91,
 | 
					 | 
				
			||||||
                    258
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    94,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    98,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    105,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    107,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    111,
 | 
					 | 
				
			||||||
                    253
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    116,
 | 
					 | 
				
			||||||
                    261
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    118,
 | 
					 | 
				
			||||||
                    265
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    121,
 | 
					 | 
				
			||||||
                    271
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    128,
 | 
					 | 
				
			||||||
                    288
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    130,
 | 
					 | 
				
			||||||
                    291
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    135,
 | 
					 | 
				
			||||||
                    296
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    138,
 | 
					 | 
				
			||||||
                    298
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    140,
 | 
					 | 
				
			||||||
                    299
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    143,
 | 
					 | 
				
			||||||
                    300
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    148,
 | 
					 | 
				
			||||||
                    300
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    151,
 | 
					 | 
				
			||||||
                    299
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    153,
 | 
					 | 
				
			||||||
                    298
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    156,
 | 
					 | 
				
			||||||
                    296
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    159,
 | 
					 | 
				
			||||||
                    293
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    161,
 | 
					 | 
				
			||||||
                    290
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    164,
 | 
					 | 
				
			||||||
                    284
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    179,
 | 
					 | 
				
			||||||
                    255
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    180,
 | 
					 | 
				
			||||||
                    253
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    180,
 | 
					 | 
				
			||||||
                    251
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    181,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    183,
 | 
					 | 
				
			||||||
                    247
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    186,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    187,
 | 
					 | 
				
			||||||
                    259
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    188,
 | 
					 | 
				
			||||||
                    282
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    188,
 | 
					 | 
				
			||||||
                    303
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    190,
 | 
					 | 
				
			||||||
                    320
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    190,
 | 
					 | 
				
			||||||
                    343
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    191,
 | 
					 | 
				
			||||||
                    355
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    192,
 | 
					 | 
				
			||||||
                    357
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    196,
 | 
					 | 
				
			||||||
                    358
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    198,
 | 
					 | 
				
			||||||
                    357
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    200,
 | 
					 | 
				
			||||||
                    353
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    200,
 | 
					 | 
				
			||||||
                    341
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    198,
 | 
					 | 
				
			||||||
                    321
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    197,
 | 
					 | 
				
			||||||
                    300
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    196,
 | 
					 | 
				
			||||||
                    279
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    195,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    195,
 | 
					 | 
				
			||||||
                    250
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    196,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    200,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    200,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    202,
 | 
					 | 
				
			||||||
                    251
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    207,
 | 
					 | 
				
			||||||
                    260
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    208,
 | 
					 | 
				
			||||||
                    262
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    208,
 | 
					 | 
				
			||||||
                    264
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    210,
 | 
					 | 
				
			||||||
                    266
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    212,
 | 
					 | 
				
			||||||
                    273
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    215,
 | 
					 | 
				
			||||||
                    279
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    217,
 | 
					 | 
				
			||||||
                    281
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    218,
 | 
					 | 
				
			||||||
                    283
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    219,
 | 
					 | 
				
			||||||
                    287
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    221,
 | 
					 | 
				
			||||||
                    290
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    225,
 | 
					 | 
				
			||||||
                    294
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    227,
 | 
					 | 
				
			||||||
                    295
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    230,
 | 
					 | 
				
			||||||
                    296
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    235,
 | 
					 | 
				
			||||||
                    296
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    239,
 | 
					 | 
				
			||||||
                    295
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    241,
 | 
					 | 
				
			||||||
                    294
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    245,
 | 
					 | 
				
			||||||
                    290
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    246,
 | 
					 | 
				
			||||||
                    288
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    247,
 | 
					 | 
				
			||||||
                    286
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    248,
 | 
					 | 
				
			||||||
                    283
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    248,
 | 
					 | 
				
			||||||
                    279
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    245,
 | 
					 | 
				
			||||||
                    281
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    244,
 | 
					 | 
				
			||||||
                    283
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    244,
 | 
					 | 
				
			||||||
                    285
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    243,
 | 
					 | 
				
			||||||
                    287
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    238,
 | 
					 | 
				
			||||||
                    292
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    230,
 | 
					 | 
				
			||||||
                    292
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    228,
 | 
					 | 
				
			||||||
                    291
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    224,
 | 
					 | 
				
			||||||
                    287
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    222,
 | 
					 | 
				
			||||||
                    284
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    219,
 | 
					 | 
				
			||||||
                    278
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    214,
 | 
					 | 
				
			||||||
                    268
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    214,
 | 
					 | 
				
			||||||
                    266
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    207,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    205,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    201,
 | 
					 | 
				
			||||||
                    245
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    198,
 | 
					 | 
				
			||||||
                    243
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    194,
 | 
					 | 
				
			||||||
                    242
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    188,
 | 
					 | 
				
			||||||
                    242
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    183,
 | 
					 | 
				
			||||||
                    243
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    180,
 | 
					 | 
				
			||||||
                    244
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    176,
 | 
					 | 
				
			||||||
                    246
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    174,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    172,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    170,
 | 
					 | 
				
			||||||
                    256
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    169,
 | 
					 | 
				
			||||||
                    260
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    168,
 | 
					 | 
				
			||||||
                    262
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    166,
 | 
					 | 
				
			||||||
                    264
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    165,
 | 
					 | 
				
			||||||
                    266
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    164,
 | 
					 | 
				
			||||||
                    268
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    161,
 | 
					 | 
				
			||||||
                    276
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    156,
 | 
					 | 
				
			||||||
                    286
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    153,
 | 
					 | 
				
			||||||
                    290
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    152,
 | 
					 | 
				
			||||||
                    291
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    151,
 | 
					 | 
				
			||||||
                    291
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    149,
 | 
					 | 
				
			||||||
                    293
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    147,
 | 
					 | 
				
			||||||
                    293
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    144,
 | 
					 | 
				
			||||||
                    293
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    142,
 | 
					 | 
				
			||||||
                    293
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    140,
 | 
					 | 
				
			||||||
                    292
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    138,
 | 
					 | 
				
			||||||
                    290
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    136,
 | 
					 | 
				
			||||||
                    287
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    133,
 | 
					 | 
				
			||||||
                    281
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    125,
 | 
					 | 
				
			||||||
                    263
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    123,
 | 
					 | 
				
			||||||
                    259
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    120,
 | 
					 | 
				
			||||||
                    253
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    117,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    112,
 | 
					 | 
				
			||||||
                    243
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    110,
 | 
					 | 
				
			||||||
                    242
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    107,
 | 
					 | 
				
			||||||
                    241
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    98,
 | 
					 | 
				
			||||||
                    241
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    96,
 | 
					 | 
				
			||||||
                    242
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    94,
 | 
					 | 
				
			||||||
                    243
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    89,
 | 
					 | 
				
			||||||
                    248
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    85,
 | 
					 | 
				
			||||||
                    256
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    83,
 | 
					 | 
				
			||||||
                    261
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    81,
 | 
					 | 
				
			||||||
                    269
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    78,
 | 
					 | 
				
			||||||
                    274
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    76,
 | 
					 | 
				
			||||||
                    282
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    74,
 | 
					 | 
				
			||||||
                    287
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    73,
 | 
					 | 
				
			||||||
                    289
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    67,
 | 
					 | 
				
			||||||
                    295
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    63,
 | 
					 | 
				
			||||||
                    295
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    61,
 | 
					 | 
				
			||||||
                    294
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    55,
 | 
					 | 
				
			||||||
                    288
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    53,
 | 
					 | 
				
			||||||
                    284
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    50,
 | 
					 | 
				
			||||||
                    278
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    49,
 | 
					 | 
				
			||||||
                    273
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    41,
 | 
					 | 
				
			||||||
                    254
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    40,
 | 
					 | 
				
			||||||
                    252
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    38,
 | 
					 | 
				
			||||||
                    249
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    35,
 | 
					 | 
				
			||||||
                    246
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    35,
 | 
					 | 
				
			||||||
                    245
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    33,
 | 
					 | 
				
			||||||
                    243
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    30,
 | 
					 | 
				
			||||||
                    241
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    28,
 | 
					 | 
				
			||||||
                    240
 | 
					 | 
				
			||||||
                ],
 | 
					 | 
				
			||||||
                [
 | 
					 | 
				
			||||||
                    24,
 | 
					 | 
				
			||||||
                    239
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            "area": 4485.0,
 | 
					 | 
				
			||||||
            "layer": 1.0,
 | 
					 | 
				
			||||||
            "bbox": [
 | 
					 | 
				
			||||||
                0.0,
 | 
					 | 
				
			||||||
                239.0,
 | 
					 | 
				
			||||||
                248.0,
 | 
					 | 
				
			||||||
                414.0
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            "iscrowd": 0,
 | 
					 | 
				
			||||||
            "note": ""
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
contour_mode: external
 | 
					contour_mode: all
 | 
				
			||||||
label:
 | 
					label:
 | 
				
			||||||
- color: '#000000'
 | 
					- color: '#000000'
 | 
				
			||||||
  name: __background__
 | 
					  name: __background__
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,4 @@
 | 
				
			|||||||
contour_mode: external
 | 
					contour_mode: all
 | 
				
			||||||
label:
 | 
					label:
 | 
				
			||||||
- color: '#000000'
 | 
					- color: '#000000'
 | 
				
			||||||
  name: __background__
 | 
					  name: __background__
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								main.py
									
									
									
									
									
								
							@ -12,5 +12,5 @@ if __name__ == '__main__':
 | 
				
			|||||||
    app = QtWidgets.QApplication([''])
 | 
					    app = QtWidgets.QApplication([''])
 | 
				
			||||||
    mainwindow = MainWindow()
 | 
					    mainwindow = MainWindow()
 | 
				
			||||||
    mainwindow.show()
 | 
					    mainwindow.show()
 | 
				
			||||||
    sys.exit(app.exec_())
 | 
					    sys.exit(app.exec())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -16,6 +16,7 @@ class SegAny:
 | 
				
			|||||||
            self.model_type = "vit_h"
 | 
					            self.model_type = "vit_h"
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
 | 
					            raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
 | 
				
			||||||
 | 
					        torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
					        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
				
			||||||
        sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
 | 
					        sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
 | 
				
			||||||
 | 
				
			|||||||
@ -11,5 +11,6 @@ from .build_sam import (
 | 
				
			|||||||
    build_sam_vit_b,
 | 
					    build_sam_vit_b,
 | 
				
			||||||
    sam_model_registry,
 | 
					    sam_model_registry,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					from .build_sam_baseline import sam_model_registry_baseline
 | 
				
			||||||
from .predictor import SamPredictor
 | 
					from .predictor import SamPredictor
 | 
				
			||||||
from .automatic_mask_generator import SamAutomaticMaskGenerator
 | 
					from .automatic_mask_generator import SamAutomaticMaskGenerator
 | 
				
			||||||
 | 
				
			|||||||
@ -134,7 +134,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        self.output_mode = output_mode
 | 
					        self.output_mode = output_mode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
 | 
					    def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Generates masks for the given image.
 | 
					        Generates masks for the given image.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -160,7 +160,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Generate masks
 | 
					        # Generate masks
 | 
				
			||||||
        mask_data = self._generate_masks(image)
 | 
					        mask_data = self._generate_masks(image, multimask_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Filter small disconnected regions and holes in masks
 | 
					        # Filter small disconnected regions and holes in masks
 | 
				
			||||||
        if self.min_mask_region_area > 0:
 | 
					        if self.min_mask_region_area > 0:
 | 
				
			||||||
@ -194,7 +194,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        return curr_anns
 | 
					        return curr_anns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _generate_masks(self, image: np.ndarray) -> MaskData:
 | 
					    def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData:
 | 
				
			||||||
        orig_size = image.shape[:2]
 | 
					        orig_size = image.shape[:2]
 | 
				
			||||||
        crop_boxes, layer_idxs = generate_crop_boxes(
 | 
					        crop_boxes, layer_idxs = generate_crop_boxes(
 | 
				
			||||||
            orig_size, self.crop_n_layers, self.crop_overlap_ratio
 | 
					            orig_size, self.crop_n_layers, self.crop_overlap_ratio
 | 
				
			||||||
@ -203,7 +203,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        # Iterate over image crops
 | 
					        # Iterate over image crops
 | 
				
			||||||
        data = MaskData()
 | 
					        data = MaskData()
 | 
				
			||||||
        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 | 
					        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 | 
				
			||||||
            crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
 | 
					            crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output)
 | 
				
			||||||
            data.cat(crop_data)
 | 
					            data.cat(crop_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Remove duplicate masks between crops
 | 
					        # Remove duplicate masks between crops
 | 
				
			||||||
@ -228,6 +228,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        crop_box: List[int],
 | 
					        crop_box: List[int],
 | 
				
			||||||
        crop_layer_idx: int,
 | 
					        crop_layer_idx: int,
 | 
				
			||||||
        orig_size: Tuple[int, ...],
 | 
					        orig_size: Tuple[int, ...],
 | 
				
			||||||
 | 
					        multimask_output: bool = True,
 | 
				
			||||||
    ) -> MaskData:
 | 
					    ) -> MaskData:
 | 
				
			||||||
        # Crop the image and calculate embeddings
 | 
					        # Crop the image and calculate embeddings
 | 
				
			||||||
        x0, y0, x1, y1 = crop_box
 | 
					        x0, y0, x1, y1 = crop_box
 | 
				
			||||||
@ -242,7 +243,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        # Generate masks for this crop in batches
 | 
					        # Generate masks for this crop in batches
 | 
				
			||||||
        data = MaskData()
 | 
					        data = MaskData()
 | 
				
			||||||
        for (points,) in batch_iterator(self.points_per_batch, points_for_image):
 | 
					        for (points,) in batch_iterator(self.points_per_batch, points_for_image):
 | 
				
			||||||
            batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
 | 
					            batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output)
 | 
				
			||||||
            data.cat(batch_data)
 | 
					            data.cat(batch_data)
 | 
				
			||||||
            del batch_data
 | 
					            del batch_data
 | 
				
			||||||
        self.predictor.reset_image()
 | 
					        self.predictor.reset_image()
 | 
				
			||||||
@ -269,6 +270,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        im_size: Tuple[int, ...],
 | 
					        im_size: Tuple[int, ...],
 | 
				
			||||||
        crop_box: List[int],
 | 
					        crop_box: List[int],
 | 
				
			||||||
        orig_size: Tuple[int, ...],
 | 
					        orig_size: Tuple[int, ...],
 | 
				
			||||||
 | 
					        multimask_output: bool = True,
 | 
				
			||||||
    ) -> MaskData:
 | 
					    ) -> MaskData:
 | 
				
			||||||
        orig_h, orig_w = orig_size
 | 
					        orig_h, orig_w = orig_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -279,7 +281,7 @@ class SamAutomaticMaskGenerator:
 | 
				
			|||||||
        masks, iou_preds, _ = self.predictor.predict_torch(
 | 
					        masks, iou_preds, _ = self.predictor.predict_torch(
 | 
				
			||||||
            in_points[:, None, :],
 | 
					            in_points[:, None, :],
 | 
				
			||||||
            in_labels[:, None],
 | 
					            in_labels[:, None],
 | 
				
			||||||
            multimask_output=True,
 | 
					            multimask_output=multimask_output,
 | 
				
			||||||
            return_logits=True,
 | 
					            return_logits=True,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,7 @@ import torch
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
 | 
					from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def build_sam_vit_h(checkpoint=None):
 | 
					def build_sam_vit_h(checkpoint=None):
 | 
				
			||||||
@ -84,7 +84,7 @@ def _build_sam(
 | 
				
			|||||||
            input_image_size=(image_size, image_size),
 | 
					            input_image_size=(image_size, image_size),
 | 
				
			||||||
            mask_in_chans=16,
 | 
					            mask_in_chans=16,
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
        mask_decoder=MaskDecoder(
 | 
					        mask_decoder=MaskDecoderHQ(
 | 
				
			||||||
            num_multimask_outputs=3,
 | 
					            num_multimask_outputs=3,
 | 
				
			||||||
            transformer=TwoWayTransformer(
 | 
					            transformer=TwoWayTransformer(
 | 
				
			||||||
                depth=2,
 | 
					                depth=2,
 | 
				
			||||||
@ -95,13 +95,19 @@ def _build_sam(
 | 
				
			|||||||
            transformer_dim=prompt_embed_dim,
 | 
					            transformer_dim=prompt_embed_dim,
 | 
				
			||||||
            iou_head_depth=3,
 | 
					            iou_head_depth=3,
 | 
				
			||||||
            iou_head_hidden_dim=256,
 | 
					            iou_head_hidden_dim=256,
 | 
				
			||||||
 | 
					            vit_dim=encoder_embed_dim,
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
        pixel_mean=[123.675, 116.28, 103.53],
 | 
					        pixel_mean=[123.675, 116.28, 103.53],
 | 
				
			||||||
        pixel_std=[58.395, 57.12, 57.375],
 | 
					        pixel_std=[58.395, 57.12, 57.375],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    sam.eval()
 | 
					    # sam.eval()
 | 
				
			||||||
    if checkpoint is not None:
 | 
					    if checkpoint is not None:
 | 
				
			||||||
        with open(checkpoint, "rb") as f:
 | 
					        with open(checkpoint, "rb") as f:
 | 
				
			||||||
            state_dict = torch.load(f)
 | 
					            state_dict = torch.load(f)
 | 
				
			||||||
        sam.load_state_dict(state_dict)
 | 
					        info = sam.load_state_dict(state_dict, strict=False)
 | 
				
			||||||
 | 
					        print(info)
 | 
				
			||||||
 | 
					    for n, p in sam.named_parameters():
 | 
				
			||||||
 | 
					        if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
 | 
				
			||||||
 | 
					            p.requires_grad = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return sam
 | 
					    return sam
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										107
									
								
								segment_anything/build_sam_baseline.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								segment_anything/build_sam_baseline.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,107 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This source code is licensed under the license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build_sam_vit_h(checkpoint=None):
 | 
				
			||||||
 | 
					    return _build_sam(
 | 
				
			||||||
 | 
					        encoder_embed_dim=1280,
 | 
				
			||||||
 | 
					        encoder_depth=32,
 | 
				
			||||||
 | 
					        encoder_num_heads=16,
 | 
				
			||||||
 | 
					        encoder_global_attn_indexes=[7, 15, 23, 31],
 | 
				
			||||||
 | 
					        checkpoint=checkpoint,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					build_sam = build_sam_vit_h
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build_sam_vit_l(checkpoint=None):
 | 
				
			||||||
 | 
					    return _build_sam(
 | 
				
			||||||
 | 
					        encoder_embed_dim=1024,
 | 
				
			||||||
 | 
					        encoder_depth=24,
 | 
				
			||||||
 | 
					        encoder_num_heads=16,
 | 
				
			||||||
 | 
					        encoder_global_attn_indexes=[5, 11, 17, 23],
 | 
				
			||||||
 | 
					        checkpoint=checkpoint,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build_sam_vit_b(checkpoint=None):
 | 
				
			||||||
 | 
					    return _build_sam(
 | 
				
			||||||
 | 
					        encoder_embed_dim=768,
 | 
				
			||||||
 | 
					        encoder_depth=12,
 | 
				
			||||||
 | 
					        encoder_num_heads=12,
 | 
				
			||||||
 | 
					        encoder_global_attn_indexes=[2, 5, 8, 11],
 | 
				
			||||||
 | 
					        checkpoint=checkpoint,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					sam_model_registry_baseline = {
 | 
				
			||||||
 | 
					    "default": build_sam_vit_h,
 | 
				
			||||||
 | 
					    "vit_h": build_sam_vit_h,
 | 
				
			||||||
 | 
					    "vit_l": build_sam_vit_l,
 | 
				
			||||||
 | 
					    "vit_b": build_sam_vit_b,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _build_sam(
 | 
				
			||||||
 | 
					    encoder_embed_dim,
 | 
				
			||||||
 | 
					    encoder_depth,
 | 
				
			||||||
 | 
					    encoder_num_heads,
 | 
				
			||||||
 | 
					    encoder_global_attn_indexes,
 | 
				
			||||||
 | 
					    checkpoint=None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    prompt_embed_dim = 256
 | 
				
			||||||
 | 
					    image_size = 1024
 | 
				
			||||||
 | 
					    vit_patch_size = 16
 | 
				
			||||||
 | 
					    image_embedding_size = image_size // vit_patch_size
 | 
				
			||||||
 | 
					    sam = Sam(
 | 
				
			||||||
 | 
					        image_encoder=ImageEncoderViT(
 | 
				
			||||||
 | 
					            depth=encoder_depth,
 | 
				
			||||||
 | 
					            embed_dim=encoder_embed_dim,
 | 
				
			||||||
 | 
					            img_size=image_size,
 | 
				
			||||||
 | 
					            mlp_ratio=4,
 | 
				
			||||||
 | 
					            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 | 
				
			||||||
 | 
					            num_heads=encoder_num_heads,
 | 
				
			||||||
 | 
					            patch_size=vit_patch_size,
 | 
				
			||||||
 | 
					            qkv_bias=True,
 | 
				
			||||||
 | 
					            use_rel_pos=True,
 | 
				
			||||||
 | 
					            global_attn_indexes=encoder_global_attn_indexes,
 | 
				
			||||||
 | 
					            window_size=14,
 | 
				
			||||||
 | 
					            out_chans=prompt_embed_dim,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        prompt_encoder=PromptEncoder(
 | 
				
			||||||
 | 
					            embed_dim=prompt_embed_dim,
 | 
				
			||||||
 | 
					            image_embedding_size=(image_embedding_size, image_embedding_size),
 | 
				
			||||||
 | 
					            input_image_size=(image_size, image_size),
 | 
				
			||||||
 | 
					            mask_in_chans=16,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        mask_decoder=MaskDecoder(
 | 
				
			||||||
 | 
					            num_multimask_outputs=3,
 | 
				
			||||||
 | 
					            transformer=TwoWayTransformer(
 | 
				
			||||||
 | 
					                depth=2,
 | 
				
			||||||
 | 
					                embedding_dim=prompt_embed_dim,
 | 
				
			||||||
 | 
					                mlp_dim=2048,
 | 
				
			||||||
 | 
					                num_heads=8,
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            transformer_dim=prompt_embed_dim,
 | 
				
			||||||
 | 
					            iou_head_depth=3,
 | 
				
			||||||
 | 
					            iou_head_hidden_dim=256,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pixel_mean=[123.675, 116.28, 103.53],
 | 
				
			||||||
 | 
					        pixel_std=[58.395, 57.12, 57.375],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    sam.eval()
 | 
				
			||||||
 | 
					    if checkpoint is not None:
 | 
				
			||||||
 | 
					        with open(checkpoint, "rb") as f:
 | 
				
			||||||
 | 
					            state_dict = torch.load(f)
 | 
				
			||||||
 | 
					        sam.load_state_dict(state_dict)
 | 
				
			||||||
 | 
					    return sam
 | 
				
			||||||
@ -6,6 +6,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from .sam import Sam
 | 
					from .sam import Sam
 | 
				
			||||||
from .image_encoder import ImageEncoderViT
 | 
					from .image_encoder import ImageEncoderViT
 | 
				
			||||||
 | 
					from .mask_decoder_hq import MaskDecoderHQ
 | 
				
			||||||
from .mask_decoder import MaskDecoder
 | 
					from .mask_decoder import MaskDecoder
 | 
				
			||||||
from .prompt_encoder import PromptEncoder
 | 
					from .prompt_encoder import PromptEncoder
 | 
				
			||||||
from .transformer import TwoWayTransformer
 | 
					from .transformer import TwoWayTransformer
 | 
				
			||||||
 | 
				
			|||||||
@ -108,12 +108,15 @@ class ImageEncoderViT(nn.Module):
 | 
				
			|||||||
        if self.pos_embed is not None:
 | 
					        if self.pos_embed is not None:
 | 
				
			||||||
            x = x + self.pos_embed
 | 
					            x = x + self.pos_embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        interm_embeddings=[]
 | 
				
			||||||
        for blk in self.blocks:
 | 
					        for blk in self.blocks:
 | 
				
			||||||
            x = blk(x)
 | 
					            x = blk(x)
 | 
				
			||||||
 | 
					            if blk.window_size == 0:
 | 
				
			||||||
 | 
					                interm_embeddings.append(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x = self.neck(x.permute(0, 3, 1, 2))
 | 
					        x = self.neck(x.permute(0, 3, 1, 2))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return x
 | 
					        return x, interm_embeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Block(nn.Module):
 | 
					class Block(nn.Module):
 | 
				
			||||||
 | 
				
			|||||||
@ -75,6 +75,8 @@ class MaskDecoder(nn.Module):
 | 
				
			|||||||
        sparse_prompt_embeddings: torch.Tensor,
 | 
					        sparse_prompt_embeddings: torch.Tensor,
 | 
				
			||||||
        dense_prompt_embeddings: torch.Tensor,
 | 
					        dense_prompt_embeddings: torch.Tensor,
 | 
				
			||||||
        multimask_output: bool,
 | 
					        multimask_output: bool,
 | 
				
			||||||
 | 
					        hq_token_only: bool,
 | 
				
			||||||
 | 
					        interm_embeddings: torch.Tensor,
 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Predict masks given image and prompt embeddings.
 | 
					        Predict masks given image and prompt embeddings.
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										232
									
								
								segment_anything/modeling/mask_decoder_hq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										232
									
								
								segment_anything/modeling/mask_decoder_hq.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,232 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# Modified by HQ-SAM team
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This source code is licensed under the license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.nn import functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import List, Tuple, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .common import LayerNorm2d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MaskDecoderHQ(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        *,
 | 
				
			||||||
 | 
					        transformer_dim: int,
 | 
				
			||||||
 | 
					        transformer: nn.Module,
 | 
				
			||||||
 | 
					        num_multimask_outputs: int = 3,
 | 
				
			||||||
 | 
					        activation: Type[nn.Module] = nn.GELU,
 | 
				
			||||||
 | 
					        iou_head_depth: int = 3,
 | 
				
			||||||
 | 
					        iou_head_hidden_dim: int = 256,
 | 
				
			||||||
 | 
					        vit_dim: int = 1024,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Predicts masks given an image and prompt embeddings, using a
 | 
				
			||||||
 | 
					        transformer architecture.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					          transformer_dim (int): the channel dimension of the transformer
 | 
				
			||||||
 | 
					          transformer (nn.Module): the transformer used to predict masks
 | 
				
			||||||
 | 
					          num_multimask_outputs (int): the number of masks to predict
 | 
				
			||||||
 | 
					            when disambiguating masks
 | 
				
			||||||
 | 
					          activation (nn.Module): the type of activation to use when
 | 
				
			||||||
 | 
					            upscaling masks
 | 
				
			||||||
 | 
					          iou_head_depth (int): the depth of the MLP used to predict
 | 
				
			||||||
 | 
					            mask quality
 | 
				
			||||||
 | 
					          iou_head_hidden_dim (int): the hidden dimension of the MLP
 | 
				
			||||||
 | 
					            used to predict mask quality
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.transformer_dim = transformer_dim
 | 
				
			||||||
 | 
					        self.transformer = transformer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.num_multimask_outputs = num_multimask_outputs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.iou_token = nn.Embedding(1, transformer_dim)
 | 
				
			||||||
 | 
					        self.num_mask_tokens = num_multimask_outputs + 1
 | 
				
			||||||
 | 
					        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.output_upscaling = nn.Sequential(
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
 | 
				
			||||||
 | 
					            LayerNorm2d(transformer_dim // 4),
 | 
				
			||||||
 | 
					            activation(),
 | 
				
			||||||
 | 
					            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
 | 
				
			||||||
 | 
					            activation(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.output_hypernetworks_mlps = nn.ModuleList(
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
 | 
				
			||||||
 | 
					                for i in range(self.num_mask_tokens)
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.iou_prediction_head = MLP(
 | 
				
			||||||
 | 
					            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # HQ-SAM parameters
 | 
				
			||||||
 | 
					        self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
 | 
				
			||||||
 | 
					        self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token
 | 
				
			||||||
 | 
					        self.num_mask_tokens = self.num_mask_tokens + 1
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # three conv fusion layers for obtaining HQ-Feature
 | 
				
			||||||
 | 
					        self.compress_vit_feat = nn.Sequential(
 | 
				
			||||||
 | 
					                                        nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
 | 
				
			||||||
 | 
					                                        LayerNorm2d(transformer_dim),
 | 
				
			||||||
 | 
					                                        nn.GELU(), 
 | 
				
			||||||
 | 
					                                        nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        self.embedding_encoder = nn.Sequential(
 | 
				
			||||||
 | 
					                                        nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
 | 
				
			||||||
 | 
					                                        LayerNorm2d(transformer_dim // 4),
 | 
				
			||||||
 | 
					                                        nn.GELU(),
 | 
				
			||||||
 | 
					                                        nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
 | 
				
			||||||
 | 
					                                    )
 | 
				
			||||||
 | 
					        self.embedding_maskfeature = nn.Sequential(
 | 
				
			||||||
 | 
					                                        nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), 
 | 
				
			||||||
 | 
					                                        LayerNorm2d(transformer_dim // 4),
 | 
				
			||||||
 | 
					                                        nn.GELU(),
 | 
				
			||||||
 | 
					                                        nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        image_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        image_pe: torch.Tensor,
 | 
				
			||||||
 | 
					        sparse_prompt_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        dense_prompt_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        multimask_output: bool,
 | 
				
			||||||
 | 
					        hq_token_only: bool,
 | 
				
			||||||
 | 
					        interm_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Predict masks given image and prompt embeddings.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Arguments:
 | 
				
			||||||
 | 
					          image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
 | 
				
			||||||
 | 
					          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
 | 
				
			||||||
 | 
					          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
 | 
				
			||||||
 | 
					          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
 | 
				
			||||||
 | 
					          multimask_output (bool): Whether to return multiple masks or a single
 | 
				
			||||||
 | 
					            mask.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					          torch.Tensor: batched predicted masks
 | 
				
			||||||
 | 
					          torch.Tensor: batched predictions of mask quality
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
 | 
				
			||||||
 | 
					        hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        masks, iou_pred = self.predict_masks(
 | 
				
			||||||
 | 
					            image_embeddings=image_embeddings,
 | 
				
			||||||
 | 
					            image_pe=image_pe,
 | 
				
			||||||
 | 
					            sparse_prompt_embeddings=sparse_prompt_embeddings,
 | 
				
			||||||
 | 
					            dense_prompt_embeddings=dense_prompt_embeddings,
 | 
				
			||||||
 | 
					            hq_features=hq_features,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Select the correct mask or masks for output
 | 
				
			||||||
 | 
					        if multimask_output:
 | 
				
			||||||
 | 
					            # mask with highest score
 | 
				
			||||||
 | 
					            mask_slice = slice(1,self.num_mask_tokens-1)
 | 
				
			||||||
 | 
					            iou_pred = iou_pred[:, mask_slice]
 | 
				
			||||||
 | 
					            iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
 | 
				
			||||||
 | 
					            iou_pred = iou_pred.unsqueeze(1)
 | 
				
			||||||
 | 
					            masks_multi = masks[:, mask_slice, :, :]
 | 
				
			||||||
 | 
					            masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # singale mask output, default
 | 
				
			||||||
 | 
					            mask_slice = slice(0, 1)
 | 
				
			||||||
 | 
					            iou_pred = iou_pred[:,mask_slice]
 | 
				
			||||||
 | 
					            masks_sam = masks[:,mask_slice]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]
 | 
				
			||||||
 | 
					        if hq_token_only:
 | 
				
			||||||
 | 
					            masks = masks_hq
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            masks = masks_sam + masks_hq
 | 
				
			||||||
 | 
					        # Prepare output
 | 
				
			||||||
 | 
					        return masks, iou_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict_masks(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        image_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        image_pe: torch.Tensor,
 | 
				
			||||||
 | 
					        sparse_prompt_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        dense_prompt_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        hq_features: torch.Tensor,
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					        """Predicts masks. See 'forward' for more details."""
 | 
				
			||||||
 | 
					        # Concatenate output tokens
 | 
				
			||||||
 | 
					        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
 | 
				
			||||||
 | 
					        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
 | 
				
			||||||
 | 
					        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Expand per-image data in batch direction to be per-mask
 | 
				
			||||||
 | 
					        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
 | 
				
			||||||
 | 
					        src = src + dense_prompt_embeddings
 | 
				
			||||||
 | 
					        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
 | 
				
			||||||
 | 
					        b, c, h, w = src.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Run the transformer
 | 
				
			||||||
 | 
					        hs, src = self.transformer(src, pos_src, tokens)
 | 
				
			||||||
 | 
					        iou_token_out = hs[:, 0, :]
 | 
				
			||||||
 | 
					        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Upscale mask embeddings and predict masks using the mask tokens
 | 
				
			||||||
 | 
					        src = src.transpose(1, 2).view(b, c, h, w)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        upscaled_embedding_sam = self.output_upscaling(src)
 | 
				
			||||||
 | 
					        upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hyper_in_list: List[torch.Tensor] = []
 | 
				
			||||||
 | 
					        for i in range(self.num_mask_tokens):
 | 
				
			||||||
 | 
					            if i < self.num_mask_tokens - 1:
 | 
				
			||||||
 | 
					                hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hyper_in = torch.stack(hyper_in_list, dim=1)
 | 
				
			||||||
 | 
					        b, c, h, w = upscaled_embedding_sam.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
 | 
				
			||||||
 | 
					        masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
 | 
				
			||||||
 | 
					        masks = torch.cat([masks_sam,masks_sam_hq],dim=1)
 | 
				
			||||||
 | 
					        # Generate mask quality predictions
 | 
				
			||||||
 | 
					        iou_pred = self.iou_prediction_head(iou_token_out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return masks, iou_pred
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Lightly adapted from
 | 
				
			||||||
 | 
					# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
 | 
				
			||||||
 | 
					class MLP(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_dim: int,
 | 
				
			||||||
 | 
					        hidden_dim: int,
 | 
				
			||||||
 | 
					        output_dim: int,
 | 
				
			||||||
 | 
					        num_layers: int,
 | 
				
			||||||
 | 
					        sigmoid_output: bool = False,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.num_layers = num_layers
 | 
				
			||||||
 | 
					        h = [hidden_dim] * (num_layers - 1)
 | 
				
			||||||
 | 
					        self.layers = nn.ModuleList(
 | 
				
			||||||
 | 
					            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.sigmoid_output = sigmoid_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        for i, layer in enumerate(self.layers):
 | 
				
			||||||
 | 
					            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
 | 
				
			||||||
 | 
					        if self.sigmoid_output:
 | 
				
			||||||
 | 
					            x = F.sigmoid(x)
 | 
				
			||||||
 | 
					        return x
 | 
				
			||||||
@ -50,11 +50,11 @@ class Sam(nn.Module):
 | 
				
			|||||||
    def device(self) -> Any:
 | 
					    def device(self) -> Any:
 | 
				
			||||||
        return self.pixel_mean.device
 | 
					        return self.pixel_mean.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        batched_input: List[Dict[str, Any]],
 | 
					        batched_input: List[Dict[str, Any]],
 | 
				
			||||||
        multimask_output: bool,
 | 
					        multimask_output: bool,
 | 
				
			||||||
 | 
					        hq_token_only: bool =False,
 | 
				
			||||||
    ) -> List[Dict[str, torch.Tensor]]:
 | 
					    ) -> List[Dict[str, torch.Tensor]]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Predicts masks end-to-end from provided images and prompts.
 | 
					        Predicts masks end-to-end from provided images and prompts.
 | 
				
			||||||
@ -95,10 +95,11 @@ class Sam(nn.Module):
 | 
				
			|||||||
                to subsequent iterations of prediction.
 | 
					                to subsequent iterations of prediction.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
 | 
					        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
 | 
				
			||||||
        image_embeddings = self.image_encoder(input_images)
 | 
					        image_embeddings, interm_embeddings = self.image_encoder(input_images)
 | 
				
			||||||
 | 
					        interm_embeddings = interm_embeddings[0] # early layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        outputs = []
 | 
					        outputs = []
 | 
				
			||||||
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
 | 
					        for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
 | 
				
			||||||
            if "point_coords" in image_record:
 | 
					            if "point_coords" in image_record:
 | 
				
			||||||
                points = (image_record["point_coords"], image_record["point_labels"])
 | 
					                points = (image_record["point_coords"], image_record["point_labels"])
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
@ -114,6 +115,8 @@ class Sam(nn.Module):
 | 
				
			|||||||
                sparse_prompt_embeddings=sparse_embeddings,
 | 
					                sparse_prompt_embeddings=sparse_embeddings,
 | 
				
			||||||
                dense_prompt_embeddings=dense_embeddings,
 | 
					                dense_prompt_embeddings=dense_embeddings,
 | 
				
			||||||
                multimask_output=multimask_output,
 | 
					                multimask_output=multimask_output,
 | 
				
			||||||
 | 
					                hq_token_only=hq_token_only,
 | 
				
			||||||
 | 
					                interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            masks = self.postprocess_masks(
 | 
					            masks = self.postprocess_masks(
 | 
				
			||||||
                low_res_masks,
 | 
					                low_res_masks,
 | 
				
			||||||
 | 
				
			|||||||
@ -7,7 +7,7 @@
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from segment_anything.modeling import Sam
 | 
					from .modeling import Sam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -49,10 +49,12 @@ class SamPredictor:
 | 
				
			|||||||
            "RGB",
 | 
					            "RGB",
 | 
				
			||||||
            "BGR",
 | 
					            "BGR",
 | 
				
			||||||
        ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
 | 
					        ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
 | 
				
			||||||
 | 
					        # import pdb;pdb.set_trace()
 | 
				
			||||||
        if image_format != self.model.image_format:
 | 
					        if image_format != self.model.image_format:
 | 
				
			||||||
            image = image[..., ::-1]
 | 
					            image = image[..., ::-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Transform the image to the form expected by the model
 | 
					        # Transform the image to the form expected by the model
 | 
				
			||||||
 | 
					        # import pdb;pdb.set_trace()
 | 
				
			||||||
        input_image = self.transform.apply_image(image)
 | 
					        input_image = self.transform.apply_image(image)
 | 
				
			||||||
        input_image_torch = torch.as_tensor(input_image, device=self.device)
 | 
					        input_image_torch = torch.as_tensor(input_image, device=self.device)
 | 
				
			||||||
        input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
 | 
					        input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
 | 
				
			||||||
@ -86,7 +88,7 @@ class SamPredictor:
 | 
				
			|||||||
        self.original_size = original_image_size
 | 
					        self.original_size = original_image_size
 | 
				
			||||||
        self.input_size = tuple(transformed_image.shape[-2:])
 | 
					        self.input_size = tuple(transformed_image.shape[-2:])
 | 
				
			||||||
        input_image = self.model.preprocess(transformed_image)
 | 
					        input_image = self.model.preprocess(transformed_image)
 | 
				
			||||||
        self.features = self.model.image_encoder(input_image)
 | 
					        self.features, self.interm_features = self.model.image_encoder(input_image)
 | 
				
			||||||
        self.is_image_set = True
 | 
					        self.is_image_set = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(
 | 
					    def predict(
 | 
				
			||||||
@ -97,6 +99,7 @@ class SamPredictor:
 | 
				
			|||||||
        mask_input: Optional[np.ndarray] = None,
 | 
					        mask_input: Optional[np.ndarray] = None,
 | 
				
			||||||
        multimask_output: bool = True,
 | 
					        multimask_output: bool = True,
 | 
				
			||||||
        return_logits: bool = False,
 | 
					        return_logits: bool = False,
 | 
				
			||||||
 | 
					        hq_token_only: bool =False,
 | 
				
			||||||
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
 | 
					    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Predict masks for the given input prompts, using the currently set image.
 | 
					        Predict masks for the given input prompts, using the currently set image.
 | 
				
			||||||
@ -158,6 +161,7 @@ class SamPredictor:
 | 
				
			|||||||
            mask_input_torch,
 | 
					            mask_input_torch,
 | 
				
			||||||
            multimask_output,
 | 
					            multimask_output,
 | 
				
			||||||
            return_logits=return_logits,
 | 
					            return_logits=return_logits,
 | 
				
			||||||
 | 
					            hq_token_only=hq_token_only,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        masks_np = masks[0].detach().cpu().numpy()
 | 
					        masks_np = masks[0].detach().cpu().numpy()
 | 
				
			||||||
@ -174,6 +178,7 @@ class SamPredictor:
 | 
				
			|||||||
        mask_input: Optional[torch.Tensor] = None,
 | 
					        mask_input: Optional[torch.Tensor] = None,
 | 
				
			||||||
        multimask_output: bool = True,
 | 
					        multimask_output: bool = True,
 | 
				
			||||||
        return_logits: bool = False,
 | 
					        return_logits: bool = False,
 | 
				
			||||||
 | 
					        hq_token_only: bool =False,
 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Predict masks for the given input prompts, using the currently set image.
 | 
					        Predict masks for the given input prompts, using the currently set image.
 | 
				
			||||||
@ -232,6 +237,8 @@ class SamPredictor:
 | 
				
			|||||||
            sparse_prompt_embeddings=sparse_embeddings,
 | 
					            sparse_prompt_embeddings=sparse_embeddings,
 | 
				
			||||||
            dense_prompt_embeddings=dense_embeddings,
 | 
					            dense_prompt_embeddings=dense_embeddings,
 | 
				
			||||||
            multimask_output=multimask_output,
 | 
					            multimask_output=multimask_output,
 | 
				
			||||||
 | 
					            hq_token_only=hq_token_only,
 | 
				
			||||||
 | 
					            interm_embeddings=self.interm_features,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Upscale the masks to the original image resolution
 | 
					        # Upscale the masks to the original image resolution
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,8 @@ class SamOnnxModel(nn.Module):
 | 
				
			|||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        model: Sam,
 | 
					        model: Sam,
 | 
				
			||||||
        return_single_mask: bool,
 | 
					        hq_token_only: bool = False,
 | 
				
			||||||
 | 
					        multimask_output: bool = False,
 | 
				
			||||||
        use_stability_score: bool = False,
 | 
					        use_stability_score: bool = False,
 | 
				
			||||||
        return_extra_metrics: bool = False,
 | 
					        return_extra_metrics: bool = False,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
@ -33,7 +34,8 @@ class SamOnnxModel(nn.Module):
 | 
				
			|||||||
        self.mask_decoder = model.mask_decoder
 | 
					        self.mask_decoder = model.mask_decoder
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.img_size = model.image_encoder.img_size
 | 
					        self.img_size = model.image_encoder.img_size
 | 
				
			||||||
        self.return_single_mask = return_single_mask
 | 
					        self.hq_token_only = hq_token_only
 | 
				
			||||||
 | 
					        self.multimask_output = multimask_output
 | 
				
			||||||
        self.use_stability_score = use_stability_score
 | 
					        self.use_stability_score = use_stability_score
 | 
				
			||||||
        self.stability_score_offset = 1.0
 | 
					        self.stability_score_offset = 1.0
 | 
				
			||||||
        self.return_extra_metrics = return_extra_metrics
 | 
					        self.return_extra_metrics = return_extra_metrics
 | 
				
			||||||
@ -89,25 +91,12 @@ class SamOnnxModel(nn.Module):
 | 
				
			|||||||
        masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
 | 
					        masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
 | 
				
			||||||
        return masks
 | 
					        return masks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def select_masks(
 | 
					 | 
				
			||||||
        self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
 | 
					 | 
				
			||||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
					 | 
				
			||||||
        # Determine if we should return the multiclick mask or not from the number of points.
 | 
					 | 
				
			||||||
        # The reweighting is used to avoid control flow.
 | 
					 | 
				
			||||||
        score_reweight = torch.tensor(
 | 
					 | 
				
			||||||
            [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
 | 
					 | 
				
			||||||
        ).to(iou_preds.device)
 | 
					 | 
				
			||||||
        score = iou_preds + (num_points - 2.5) * score_reweight
 | 
					 | 
				
			||||||
        best_idx = torch.argmax(score, dim=1)
 | 
					 | 
				
			||||||
        masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
 | 
					 | 
				
			||||||
        iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return masks, iou_preds
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        image_embeddings: torch.Tensor,
 | 
					        image_embeddings: torch.Tensor,
 | 
				
			||||||
 | 
					        interm_embeddings: torch.Tensor,
 | 
				
			||||||
        point_coords: torch.Tensor,
 | 
					        point_coords: torch.Tensor,
 | 
				
			||||||
        point_labels: torch.Tensor,
 | 
					        point_labels: torch.Tensor,
 | 
				
			||||||
        mask_input: torch.Tensor,
 | 
					        mask_input: torch.Tensor,
 | 
				
			||||||
@ -117,11 +106,15 @@ class SamOnnxModel(nn.Module):
 | 
				
			|||||||
        sparse_embedding = self._embed_points(point_coords, point_labels)
 | 
					        sparse_embedding = self._embed_points(point_coords, point_labels)
 | 
				
			||||||
        dense_embedding = self._embed_masks(mask_input, has_mask_input)
 | 
					        dense_embedding = self._embed_masks(mask_input, has_mask_input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
 | 
				
			||||||
 | 
					        hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        masks, scores = self.model.mask_decoder.predict_masks(
 | 
					        masks, scores = self.model.mask_decoder.predict_masks(
 | 
				
			||||||
            image_embeddings=image_embeddings,
 | 
					            image_embeddings=image_embeddings,
 | 
				
			||||||
            image_pe=self.model.prompt_encoder.get_dense_pe(),
 | 
					            image_pe=self.model.prompt_encoder.get_dense_pe(),
 | 
				
			||||||
            sparse_prompt_embeddings=sparse_embedding,
 | 
					            sparse_prompt_embeddings=sparse_embedding,
 | 
				
			||||||
            dense_prompt_embeddings=dense_embedding,
 | 
					            dense_prompt_embeddings=dense_embedding,
 | 
				
			||||||
 | 
					            hq_features=hq_features,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.use_stability_score:
 | 
					        if self.use_stability_score:
 | 
				
			||||||
@ -129,8 +122,26 @@ class SamOnnxModel(nn.Module):
 | 
				
			|||||||
                masks, self.model.mask_threshold, self.stability_score_offset
 | 
					                masks, self.model.mask_threshold, self.stability_score_offset
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.return_single_mask:
 | 
					        if self.multimask_output:
 | 
				
			||||||
            masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
 | 
					            # mask with highest score
 | 
				
			||||||
 | 
					            mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
 | 
				
			||||||
 | 
					            scores = scores[:, mask_slice]
 | 
				
			||||||
 | 
					            scores, max_iou_idx = torch.max(scores,dim=1)
 | 
				
			||||||
 | 
					            scores = scores.unsqueeze(1)
 | 
				
			||||||
 | 
					            masks_multi = masks[:, mask_slice, :, :]
 | 
				
			||||||
 | 
					            masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # singale mask output, default
 | 
				
			||||||
 | 
					            mask_slice = slice(0, 1)
 | 
				
			||||||
 | 
					            scores = scores[:,mask_slice]
 | 
				
			||||||
 | 
					            masks_sam = masks[:,mask_slice]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.hq_token_only:
 | 
				
			||||||
 | 
					            masks = masks_hq
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            masks = masks_sam + masks_hq
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
 | 
					        upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -88,6 +88,8 @@ class Ui_MainWindow(object):
 | 
				
			|||||||
        self.menuMode.setObjectName("menuMode")
 | 
					        self.menuMode.setObjectName("menuMode")
 | 
				
			||||||
        self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
 | 
					        self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
 | 
				
			||||||
        self.menuContour_mode.setObjectName("menuContour_mode")
 | 
					        self.menuContour_mode.setObjectName("menuContour_mode")
 | 
				
			||||||
 | 
					        self.menuSAM_model = QtWidgets.QMenu(self.menubar)
 | 
				
			||||||
 | 
					        self.menuSAM_model.setObjectName("menuSAM_model")
 | 
				
			||||||
        MainWindow.setMenuBar(self.menubar)
 | 
					        MainWindow.setMenuBar(self.menubar)
 | 
				
			||||||
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
 | 
					        self.statusbar = QtWidgets.QStatusBar(MainWindow)
 | 
				
			||||||
        self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
 | 
					        self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
 | 
				
			||||||
@ -352,6 +354,7 @@ class Ui_MainWindow(object):
 | 
				
			|||||||
        self.menubar.addAction(self.menuFile.menuAction())
 | 
					        self.menubar.addAction(self.menuFile.menuAction())
 | 
				
			||||||
        self.menubar.addAction(self.menuEdit.menuAction())
 | 
					        self.menubar.addAction(self.menuEdit.menuAction())
 | 
				
			||||||
        self.menubar.addAction(self.menuView.menuAction())
 | 
					        self.menubar.addAction(self.menuView.menuAction())
 | 
				
			||||||
 | 
					        self.menubar.addAction(self.menuSAM_model.menuAction())
 | 
				
			||||||
        self.menubar.addAction(self.menuMode.menuAction())
 | 
					        self.menubar.addAction(self.menuMode.menuAction())
 | 
				
			||||||
        self.menubar.addAction(self.menuTools.menuAction())
 | 
					        self.menubar.addAction(self.menuTools.menuAction())
 | 
				
			||||||
        self.menubar.addAction(self.menuAbout.menuAction())
 | 
					        self.menubar.addAction(self.menuAbout.menuAction())
 | 
				
			||||||
@ -391,6 +394,7 @@ class Ui_MainWindow(object):
 | 
				
			|||||||
        self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
 | 
					        self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
 | 
				
			||||||
        self.menuMode.setTitle(_translate("MainWindow", "Mode"))
 | 
					        self.menuMode.setTitle(_translate("MainWindow", "Mode"))
 | 
				
			||||||
        self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
 | 
					        self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
 | 
				
			||||||
 | 
					        self.menuSAM_model.setTitle(_translate("MainWindow", "SAM"))
 | 
				
			||||||
        self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
 | 
					        self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
 | 
				
			||||||
        self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
 | 
					        self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
 | 
				
			||||||
        self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))
 | 
					        self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))
 | 
				
			||||||
 | 
				
			|||||||
@ -202,9 +202,15 @@
 | 
				
			|||||||
    </widget>
 | 
					    </widget>
 | 
				
			||||||
    <addaction name="menuContour_mode"/>
 | 
					    <addaction name="menuContour_mode"/>
 | 
				
			||||||
   </widget>
 | 
					   </widget>
 | 
				
			||||||
 | 
					   <widget class="QMenu" name="menuSAM_model">
 | 
				
			||||||
 | 
					    <property name="title">
 | 
				
			||||||
 | 
					     <string>SAM</string>
 | 
				
			||||||
 | 
					    </property>
 | 
				
			||||||
 | 
					   </widget>
 | 
				
			||||||
   <addaction name="menuFile"/>
 | 
					   <addaction name="menuFile"/>
 | 
				
			||||||
   <addaction name="menuEdit"/>
 | 
					   <addaction name="menuEdit"/>
 | 
				
			||||||
   <addaction name="menuView"/>
 | 
					   <addaction name="menuView"/>
 | 
				
			||||||
 | 
					   <addaction name="menuSAM_model"/>
 | 
				
			||||||
   <addaction name="menuMode"/>
 | 
					   <addaction name="menuMode"/>
 | 
				
			||||||
   <addaction name="menuTools"/>
 | 
					   <addaction name="menuTools"/>
 | 
				
			||||||
   <addaction name="menuAbout"/>
 | 
					   <addaction name="menuAbout"/>
 | 
				
			||||||
 | 
				
			|||||||
@ -34,8 +34,6 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
 | 
				
			|||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        super(MainWindow, self).__init__()
 | 
					        super(MainWindow, self).__init__()
 | 
				
			||||||
        self.setupUi(self)
 | 
					        self.setupUi(self)
 | 
				
			||||||
        self.init_ui()
 | 
					 | 
				
			||||||
        self.init_segment_anything()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.image_root: str = None
 | 
					        self.image_root: str = None
 | 
				
			||||||
        self.label_root:str = None
 | 
					        self.label_root:str = None
 | 
				
			||||||
@ -58,41 +56,51 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
 | 
				
			|||||||
        self.map_mode = MAPMode.LABEL
 | 
					        self.map_mode = MAPMode.LABEL
 | 
				
			||||||
        # 标注目标
 | 
					        # 标注目标
 | 
				
			||||||
        self.current_label:Annotation = None
 | 
					        self.current_label:Annotation = None
 | 
				
			||||||
 | 
					        self.use_segment_anything = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.init_ui()
 | 
				
			||||||
        self.reload_cfg()
 | 
					        self.reload_cfg()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.init_connect()
 | 
					        self.init_connect()
 | 
				
			||||||
        self.reset_action()
 | 
					        self.reset_action()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_segment_anything(self):
 | 
					    def init_segment_anything(self, model_name, reload=False):
 | 
				
			||||||
        if os.path.exists('./segment_any/sam_vit_h_4b8939.pth'):
 | 
					        if model_name == '':
 | 
				
			||||||
            self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_h_4b8939.pth'))
 | 
					 | 
				
			||||||
            self.segany = SegAny('./segment_any/sam_vit_h_4b8939.pth')
 | 
					 | 
				
			||||||
            self.use_segment_anything = True
 | 
					 | 
				
			||||||
        elif os.path.exists('./segment_any/sam_vit_l_0b3195.pth'):
 | 
					 | 
				
			||||||
            self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_l_0b3195.pth'))
 | 
					 | 
				
			||||||
            self.segany = SegAny('./segment_any/sam_vit_l_0b3195.pth')
 | 
					 | 
				
			||||||
            self.use_segment_anything = True
 | 
					 | 
				
			||||||
        elif os.path.exists('./segment_any/sam_vit_b_01ec64.pth'):
 | 
					 | 
				
			||||||
            self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_b_01ec64.pth'))
 | 
					 | 
				
			||||||
            self.segany = SegAny('./segment_any/sam_vit_b_01ec64.pth')
 | 
					 | 
				
			||||||
            self.use_segment_anything = True
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format('https://github.com/facebookresearch/segment-anything#model-checkpoints'))
 | 
					 | 
				
			||||||
            self.use_segment_anything = False
 | 
					            self.use_segment_anything = False
 | 
				
			||||||
 | 
					            for name, action in self.pths_actions.items():
 | 
				
			||||||
 | 
					                action.setChecked(model_name == name)
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        model_path = os.path.join('segment_any', model_name)
 | 
				
			||||||
 | 
					        if not os.path.exists(model_path):
 | 
				
			||||||
 | 
					            QtWidgets.QMessageBox.warning(self, 'Warning',
 | 
				
			||||||
 | 
					                                          'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
 | 
				
			||||||
 | 
					                                              'https://github.com/facebookresearch/segment-anything#model-checkpoints'))
 | 
				
			||||||
 | 
					            for name, action in self.pths_actions.items():
 | 
				
			||||||
 | 
					                action.setChecked(model_name == name)
 | 
				
			||||||
 | 
					            self.use_segment_anything = False
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.use_segment_anything:
 | 
					        self.segany = SegAny(model_path)
 | 
				
			||||||
            if self.segany.device != 'cpu':
 | 
					        self.use_segment_anything = True
 | 
				
			||||||
                self.gpu_resource_thread = GPUResource_Thread()
 | 
					        self.statusbar.showMessage('Use the checkpoint named {}.'.format(model_name), 3000)
 | 
				
			||||||
                self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
 | 
					        for name, action in self.pths_actions.items():
 | 
				
			||||||
                self.gpu_resource_thread.start()
 | 
					            action.setChecked(model_name==name)
 | 
				
			||||||
 | 
					        if not reload:
 | 
				
			||||||
 | 
					            if self.use_segment_anything:
 | 
				
			||||||
 | 
					                if self.segany.device != 'cpu':
 | 
				
			||||||
 | 
					                    self.gpu_resource_thread = GPUResource_Thread()
 | 
				
			||||||
 | 
					                    self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
 | 
				
			||||||
 | 
					                    self.gpu_resource_thread.start()
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    self.labelGPUResource.setText('cpu')
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.labelGPUResource.setText('cpu')
 | 
					                self.labelGPUResource.setText('segment anything unused.')
 | 
				
			||||||
        else:
 | 
					
 | 
				
			||||||
            self.labelGPUResource.setText('segment anything unused.')
 | 
					        if reload and self.current_index is not None:
 | 
				
			||||||
 | 
					            self.show_image(self.current_index)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_ui(self):
 | 
					    def init_ui(self):
 | 
				
			||||||
        #
 | 
					        #q
 | 
				
			||||||
        self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
 | 
					        self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
 | 
					        self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
 | 
				
			||||||
@ -144,6 +152,17 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
 | 
				
			|||||||
        self.statusbar.addPermanentWidget(self.labelData)
 | 
					        self.statusbar.addPermanentWidget(self.labelData)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #
 | 
					        #
 | 
				
			||||||
 | 
					        model_names = sorted([pth for pth in os.listdir('segment_any') if pth.endswith('.pth')])
 | 
				
			||||||
 | 
					        self.pths_actions = {}
 | 
				
			||||||
 | 
					        for model_name in model_names:
 | 
				
			||||||
 | 
					            action = QtWidgets.QAction(self)
 | 
				
			||||||
 | 
					            action.setObjectName("actionZoom_in")
 | 
				
			||||||
 | 
					            action.triggered.connect(functools.partial(self.init_segment_anything, model_name))
 | 
				
			||||||
 | 
					            action.setText("{}".format(model_name))
 | 
				
			||||||
 | 
					            action.setCheckable(True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.pths_actions[model_name] = action
 | 
				
			||||||
 | 
					            self.menuSAM_model.addAction(action)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.toolBar.addSeparator()
 | 
					        self.toolBar.addSeparator()
 | 
				
			||||||
        self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
 | 
					        self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
 | 
				
			||||||
@ -213,6 +232,9 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
 | 
				
			|||||||
        self.cfg['mask_alpha'] = mask_alpha
 | 
					        self.cfg['mask_alpha'] = mask_alpha
 | 
				
			||||||
        self.mask_aplha.setValue(mask_alpha*10)
 | 
					        self.mask_aplha.setValue(mask_alpha*10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        model_name = self.cfg.get('model_name', '')
 | 
				
			||||||
 | 
					        self.init_segment_anything(model_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.categories_dock_widget.update_widget()
 | 
					        self.categories_dock_widget.update_widget()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_saved_state(self, is_saved:bool):
 | 
					    def set_saved_state(self, is_saved:bool):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user