SAM-HQ支持,界面更新,添加模型选择
This commit is contained in:
parent
db139d2c02
commit
f7edb986d4
@ -1,813 +1,12 @@
|
|||||||
{
|
{
|
||||||
"info": {
|
"info": {
|
||||||
"description": "ISAT",
|
"description": "ISAT",
|
||||||
"folder": "C:/Users/lg/PycharmProjects/ISAT_with_segment_anything/example/images",
|
"folder": "/home/super/PycharmProjects/ISAT_with_segment_anything/example/images",
|
||||||
"name": "000000000144.jpg",
|
"name": "000000000144.jpg",
|
||||||
"width": 640,
|
"width": 640,
|
||||||
"height": 480,
|
"height": 480,
|
||||||
"depth": 3,
|
"depth": 3,
|
||||||
"note": ""
|
"note": ""
|
||||||
},
|
},
|
||||||
"objects": [
|
"objects": []
|
||||||
{
|
|
||||||
"category": "fence",
|
|
||||||
"group": "1",
|
|
||||||
"segmentation": [
|
|
||||||
[
|
|
||||||
20,
|
|
||||||
239
|
|
||||||
],
|
|
||||||
[
|
|
||||||
17,
|
|
||||||
240
|
|
||||||
],
|
|
||||||
[
|
|
||||||
15,
|
|
||||||
241
|
|
||||||
],
|
|
||||||
[
|
|
||||||
13,
|
|
||||||
242
|
|
||||||
],
|
|
||||||
[
|
|
||||||
11,
|
|
||||||
244
|
|
||||||
],
|
|
||||||
[
|
|
||||||
10,
|
|
||||||
244
|
|
||||||
],
|
|
||||||
[
|
|
||||||
9,
|
|
||||||
245
|
|
||||||
],
|
|
||||||
[
|
|
||||||
6,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
5,
|
|
||||||
251
|
|
||||||
],
|
|
||||||
[
|
|
||||||
4,
|
|
||||||
255
|
|
||||||
],
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
261
|
|
||||||
],
|
|
||||||
[
|
|
||||||
0,
|
|
||||||
280
|
|
||||||
],
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
277
|
|
||||||
],
|
|
||||||
[
|
|
||||||
6,
|
|
||||||
268
|
|
||||||
],
|
|
||||||
[
|
|
||||||
9,
|
|
||||||
259
|
|
||||||
],
|
|
||||||
[
|
|
||||||
13,
|
|
||||||
254
|
|
||||||
],
|
|
||||||
[
|
|
||||||
13,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
15,
|
|
||||||
250
|
|
||||||
],
|
|
||||||
[
|
|
||||||
18,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
20,
|
|
||||||
247
|
|
||||||
],
|
|
||||||
[
|
|
||||||
24,
|
|
||||||
247
|
|
||||||
],
|
|
||||||
[
|
|
||||||
26,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
30,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
33,
|
|
||||||
256
|
|
||||||
],
|
|
||||||
[
|
|
||||||
35,
|
|
||||||
260
|
|
||||||
],
|
|
||||||
[
|
|
||||||
38,
|
|
||||||
266
|
|
||||||
],
|
|
||||||
[
|
|
||||||
41,
|
|
||||||
273
|
|
||||||
],
|
|
||||||
[
|
|
||||||
43,
|
|
||||||
280
|
|
||||||
],
|
|
||||||
[
|
|
||||||
50,
|
|
||||||
294
|
|
||||||
],
|
|
||||||
[
|
|
||||||
53,
|
|
||||||
297
|
|
||||||
],
|
|
||||||
[
|
|
||||||
54,
|
|
||||||
297
|
|
||||||
],
|
|
||||||
[
|
|
||||||
59,
|
|
||||||
301
|
|
||||||
],
|
|
||||||
[
|
|
||||||
61,
|
|
||||||
303
|
|
||||||
],
|
|
||||||
[
|
|
||||||
62,
|
|
||||||
316
|
|
||||||
],
|
|
||||||
[
|
|
||||||
63,
|
|
||||||
341
|
|
||||||
],
|
|
||||||
[
|
|
||||||
63,
|
|
||||||
382
|
|
||||||
],
|
|
||||||
[
|
|
||||||
64,
|
|
||||||
405
|
|
||||||
],
|
|
||||||
[
|
|
||||||
66,
|
|
||||||
414
|
|
||||||
],
|
|
||||||
[
|
|
||||||
71,
|
|
||||||
413
|
|
||||||
],
|
|
||||||
[
|
|
||||||
72,
|
|
||||||
402
|
|
||||||
],
|
|
||||||
[
|
|
||||||
72,
|
|
||||||
399
|
|
||||||
],
|
|
||||||
[
|
|
||||||
71,
|
|
||||||
359
|
|
||||||
],
|
|
||||||
[
|
|
||||||
69,
|
|
||||||
324
|
|
||||||
],
|
|
||||||
[
|
|
||||||
70,
|
|
||||||
300
|
|
||||||
],
|
|
||||||
[
|
|
||||||
73,
|
|
||||||
299
|
|
||||||
],
|
|
||||||
[
|
|
||||||
78,
|
|
||||||
294
|
|
||||||
],
|
|
||||||
[
|
|
||||||
80,
|
|
||||||
291
|
|
||||||
],
|
|
||||||
[
|
|
||||||
81,
|
|
||||||
289
|
|
||||||
],
|
|
||||||
[
|
|
||||||
83,
|
|
||||||
283
|
|
||||||
],
|
|
||||||
[
|
|
||||||
85,
|
|
||||||
278
|
|
||||||
],
|
|
||||||
[
|
|
||||||
86,
|
|
||||||
272
|
|
||||||
],
|
|
||||||
[
|
|
||||||
91,
|
|
||||||
258
|
|
||||||
],
|
|
||||||
[
|
|
||||||
94,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
98,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
105,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
107,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
111,
|
|
||||||
253
|
|
||||||
],
|
|
||||||
[
|
|
||||||
116,
|
|
||||||
261
|
|
||||||
],
|
|
||||||
[
|
|
||||||
118,
|
|
||||||
265
|
|
||||||
],
|
|
||||||
[
|
|
||||||
121,
|
|
||||||
271
|
|
||||||
],
|
|
||||||
[
|
|
||||||
128,
|
|
||||||
288
|
|
||||||
],
|
|
||||||
[
|
|
||||||
130,
|
|
||||||
291
|
|
||||||
],
|
|
||||||
[
|
|
||||||
135,
|
|
||||||
296
|
|
||||||
],
|
|
||||||
[
|
|
||||||
138,
|
|
||||||
298
|
|
||||||
],
|
|
||||||
[
|
|
||||||
140,
|
|
||||||
299
|
|
||||||
],
|
|
||||||
[
|
|
||||||
143,
|
|
||||||
300
|
|
||||||
],
|
|
||||||
[
|
|
||||||
148,
|
|
||||||
300
|
|
||||||
],
|
|
||||||
[
|
|
||||||
151,
|
|
||||||
299
|
|
||||||
],
|
|
||||||
[
|
|
||||||
153,
|
|
||||||
298
|
|
||||||
],
|
|
||||||
[
|
|
||||||
156,
|
|
||||||
296
|
|
||||||
],
|
|
||||||
[
|
|
||||||
159,
|
|
||||||
293
|
|
||||||
],
|
|
||||||
[
|
|
||||||
161,
|
|
||||||
290
|
|
||||||
],
|
|
||||||
[
|
|
||||||
164,
|
|
||||||
284
|
|
||||||
],
|
|
||||||
[
|
|
||||||
179,
|
|
||||||
255
|
|
||||||
],
|
|
||||||
[
|
|
||||||
180,
|
|
||||||
253
|
|
||||||
],
|
|
||||||
[
|
|
||||||
180,
|
|
||||||
251
|
|
||||||
],
|
|
||||||
[
|
|
||||||
181,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
183,
|
|
||||||
247
|
|
||||||
],
|
|
||||||
[
|
|
||||||
186,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
187,
|
|
||||||
259
|
|
||||||
],
|
|
||||||
[
|
|
||||||
188,
|
|
||||||
282
|
|
||||||
],
|
|
||||||
[
|
|
||||||
188,
|
|
||||||
303
|
|
||||||
],
|
|
||||||
[
|
|
||||||
190,
|
|
||||||
320
|
|
||||||
],
|
|
||||||
[
|
|
||||||
190,
|
|
||||||
343
|
|
||||||
],
|
|
||||||
[
|
|
||||||
191,
|
|
||||||
355
|
|
||||||
],
|
|
||||||
[
|
|
||||||
192,
|
|
||||||
357
|
|
||||||
],
|
|
||||||
[
|
|
||||||
196,
|
|
||||||
358
|
|
||||||
],
|
|
||||||
[
|
|
||||||
198,
|
|
||||||
357
|
|
||||||
],
|
|
||||||
[
|
|
||||||
200,
|
|
||||||
353
|
|
||||||
],
|
|
||||||
[
|
|
||||||
200,
|
|
||||||
341
|
|
||||||
],
|
|
||||||
[
|
|
||||||
198,
|
|
||||||
321
|
|
||||||
],
|
|
||||||
[
|
|
||||||
197,
|
|
||||||
300
|
|
||||||
],
|
|
||||||
[
|
|
||||||
196,
|
|
||||||
279
|
|
||||||
],
|
|
||||||
[
|
|
||||||
195,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
195,
|
|
||||||
250
|
|
||||||
],
|
|
||||||
[
|
|
||||||
196,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
200,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
200,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
202,
|
|
||||||
251
|
|
||||||
],
|
|
||||||
[
|
|
||||||
207,
|
|
||||||
260
|
|
||||||
],
|
|
||||||
[
|
|
||||||
208,
|
|
||||||
262
|
|
||||||
],
|
|
||||||
[
|
|
||||||
208,
|
|
||||||
264
|
|
||||||
],
|
|
||||||
[
|
|
||||||
210,
|
|
||||||
266
|
|
||||||
],
|
|
||||||
[
|
|
||||||
212,
|
|
||||||
273
|
|
||||||
],
|
|
||||||
[
|
|
||||||
215,
|
|
||||||
279
|
|
||||||
],
|
|
||||||
[
|
|
||||||
217,
|
|
||||||
281
|
|
||||||
],
|
|
||||||
[
|
|
||||||
218,
|
|
||||||
283
|
|
||||||
],
|
|
||||||
[
|
|
||||||
219,
|
|
||||||
287
|
|
||||||
],
|
|
||||||
[
|
|
||||||
221,
|
|
||||||
290
|
|
||||||
],
|
|
||||||
[
|
|
||||||
225,
|
|
||||||
294
|
|
||||||
],
|
|
||||||
[
|
|
||||||
227,
|
|
||||||
295
|
|
||||||
],
|
|
||||||
[
|
|
||||||
230,
|
|
||||||
296
|
|
||||||
],
|
|
||||||
[
|
|
||||||
235,
|
|
||||||
296
|
|
||||||
],
|
|
||||||
[
|
|
||||||
239,
|
|
||||||
295
|
|
||||||
],
|
|
||||||
[
|
|
||||||
241,
|
|
||||||
294
|
|
||||||
],
|
|
||||||
[
|
|
||||||
245,
|
|
||||||
290
|
|
||||||
],
|
|
||||||
[
|
|
||||||
246,
|
|
||||||
288
|
|
||||||
],
|
|
||||||
[
|
|
||||||
247,
|
|
||||||
286
|
|
||||||
],
|
|
||||||
[
|
|
||||||
248,
|
|
||||||
283
|
|
||||||
],
|
|
||||||
[
|
|
||||||
248,
|
|
||||||
279
|
|
||||||
],
|
|
||||||
[
|
|
||||||
245,
|
|
||||||
281
|
|
||||||
],
|
|
||||||
[
|
|
||||||
244,
|
|
||||||
283
|
|
||||||
],
|
|
||||||
[
|
|
||||||
244,
|
|
||||||
285
|
|
||||||
],
|
|
||||||
[
|
|
||||||
243,
|
|
||||||
287
|
|
||||||
],
|
|
||||||
[
|
|
||||||
238,
|
|
||||||
292
|
|
||||||
],
|
|
||||||
[
|
|
||||||
230,
|
|
||||||
292
|
|
||||||
],
|
|
||||||
[
|
|
||||||
228,
|
|
||||||
291
|
|
||||||
],
|
|
||||||
[
|
|
||||||
224,
|
|
||||||
287
|
|
||||||
],
|
|
||||||
[
|
|
||||||
222,
|
|
||||||
284
|
|
||||||
],
|
|
||||||
[
|
|
||||||
219,
|
|
||||||
278
|
|
||||||
],
|
|
||||||
[
|
|
||||||
214,
|
|
||||||
268
|
|
||||||
],
|
|
||||||
[
|
|
||||||
214,
|
|
||||||
266
|
|
||||||
],
|
|
||||||
[
|
|
||||||
207,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
205,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
201,
|
|
||||||
245
|
|
||||||
],
|
|
||||||
[
|
|
||||||
198,
|
|
||||||
243
|
|
||||||
],
|
|
||||||
[
|
|
||||||
194,
|
|
||||||
242
|
|
||||||
],
|
|
||||||
[
|
|
||||||
188,
|
|
||||||
242
|
|
||||||
],
|
|
||||||
[
|
|
||||||
183,
|
|
||||||
243
|
|
||||||
],
|
|
||||||
[
|
|
||||||
180,
|
|
||||||
244
|
|
||||||
],
|
|
||||||
[
|
|
||||||
176,
|
|
||||||
246
|
|
||||||
],
|
|
||||||
[
|
|
||||||
174,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
172,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
170,
|
|
||||||
256
|
|
||||||
],
|
|
||||||
[
|
|
||||||
169,
|
|
||||||
260
|
|
||||||
],
|
|
||||||
[
|
|
||||||
168,
|
|
||||||
262
|
|
||||||
],
|
|
||||||
[
|
|
||||||
166,
|
|
||||||
264
|
|
||||||
],
|
|
||||||
[
|
|
||||||
165,
|
|
||||||
266
|
|
||||||
],
|
|
||||||
[
|
|
||||||
164,
|
|
||||||
268
|
|
||||||
],
|
|
||||||
[
|
|
||||||
161,
|
|
||||||
276
|
|
||||||
],
|
|
||||||
[
|
|
||||||
156,
|
|
||||||
286
|
|
||||||
],
|
|
||||||
[
|
|
||||||
153,
|
|
||||||
290
|
|
||||||
],
|
|
||||||
[
|
|
||||||
152,
|
|
||||||
291
|
|
||||||
],
|
|
||||||
[
|
|
||||||
151,
|
|
||||||
291
|
|
||||||
],
|
|
||||||
[
|
|
||||||
149,
|
|
||||||
293
|
|
||||||
],
|
|
||||||
[
|
|
||||||
147,
|
|
||||||
293
|
|
||||||
],
|
|
||||||
[
|
|
||||||
144,
|
|
||||||
293
|
|
||||||
],
|
|
||||||
[
|
|
||||||
142,
|
|
||||||
293
|
|
||||||
],
|
|
||||||
[
|
|
||||||
140,
|
|
||||||
292
|
|
||||||
],
|
|
||||||
[
|
|
||||||
138,
|
|
||||||
290
|
|
||||||
],
|
|
||||||
[
|
|
||||||
136,
|
|
||||||
287
|
|
||||||
],
|
|
||||||
[
|
|
||||||
133,
|
|
||||||
281
|
|
||||||
],
|
|
||||||
[
|
|
||||||
125,
|
|
||||||
263
|
|
||||||
],
|
|
||||||
[
|
|
||||||
123,
|
|
||||||
259
|
|
||||||
],
|
|
||||||
[
|
|
||||||
120,
|
|
||||||
253
|
|
||||||
],
|
|
||||||
[
|
|
||||||
117,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
112,
|
|
||||||
243
|
|
||||||
],
|
|
||||||
[
|
|
||||||
110,
|
|
||||||
242
|
|
||||||
],
|
|
||||||
[
|
|
||||||
107,
|
|
||||||
241
|
|
||||||
],
|
|
||||||
[
|
|
||||||
98,
|
|
||||||
241
|
|
||||||
],
|
|
||||||
[
|
|
||||||
96,
|
|
||||||
242
|
|
||||||
],
|
|
||||||
[
|
|
||||||
94,
|
|
||||||
243
|
|
||||||
],
|
|
||||||
[
|
|
||||||
89,
|
|
||||||
248
|
|
||||||
],
|
|
||||||
[
|
|
||||||
85,
|
|
||||||
256
|
|
||||||
],
|
|
||||||
[
|
|
||||||
83,
|
|
||||||
261
|
|
||||||
],
|
|
||||||
[
|
|
||||||
81,
|
|
||||||
269
|
|
||||||
],
|
|
||||||
[
|
|
||||||
78,
|
|
||||||
274
|
|
||||||
],
|
|
||||||
[
|
|
||||||
76,
|
|
||||||
282
|
|
||||||
],
|
|
||||||
[
|
|
||||||
74,
|
|
||||||
287
|
|
||||||
],
|
|
||||||
[
|
|
||||||
73,
|
|
||||||
289
|
|
||||||
],
|
|
||||||
[
|
|
||||||
67,
|
|
||||||
295
|
|
||||||
],
|
|
||||||
[
|
|
||||||
63,
|
|
||||||
295
|
|
||||||
],
|
|
||||||
[
|
|
||||||
61,
|
|
||||||
294
|
|
||||||
],
|
|
||||||
[
|
|
||||||
55,
|
|
||||||
288
|
|
||||||
],
|
|
||||||
[
|
|
||||||
53,
|
|
||||||
284
|
|
||||||
],
|
|
||||||
[
|
|
||||||
50,
|
|
||||||
278
|
|
||||||
],
|
|
||||||
[
|
|
||||||
49,
|
|
||||||
273
|
|
||||||
],
|
|
||||||
[
|
|
||||||
41,
|
|
||||||
254
|
|
||||||
],
|
|
||||||
[
|
|
||||||
40,
|
|
||||||
252
|
|
||||||
],
|
|
||||||
[
|
|
||||||
38,
|
|
||||||
249
|
|
||||||
],
|
|
||||||
[
|
|
||||||
35,
|
|
||||||
246
|
|
||||||
],
|
|
||||||
[
|
|
||||||
35,
|
|
||||||
245
|
|
||||||
],
|
|
||||||
[
|
|
||||||
33,
|
|
||||||
243
|
|
||||||
],
|
|
||||||
[
|
|
||||||
30,
|
|
||||||
241
|
|
||||||
],
|
|
||||||
[
|
|
||||||
28,
|
|
||||||
240
|
|
||||||
],
|
|
||||||
[
|
|
||||||
24,
|
|
||||||
239
|
|
||||||
]
|
|
||||||
],
|
|
||||||
"area": 4485.0,
|
|
||||||
"layer": 1.0,
|
|
||||||
"bbox": [
|
|
||||||
0.0,
|
|
||||||
239.0,
|
|
||||||
248.0,
|
|
||||||
414.0
|
|
||||||
],
|
|
||||||
"iscrowd": 0,
|
|
||||||
"note": ""
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
contour_mode: external
|
contour_mode: all
|
||||||
label:
|
label:
|
||||||
- color: '#000000'
|
- color: '#000000'
|
||||||
name: __background__
|
name: __background__
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
contour_mode: external
|
contour_mode: all
|
||||||
label:
|
label:
|
||||||
- color: '#000000'
|
- color: '#000000'
|
||||||
name: __background__
|
name: __background__
|
||||||
|
2
main.py
2
main.py
@ -12,5 +12,5 @@ if __name__ == '__main__':
|
|||||||
app = QtWidgets.QApplication([''])
|
app = QtWidgets.QApplication([''])
|
||||||
mainwindow = MainWindow()
|
mainwindow = MainWindow()
|
||||||
mainwindow.show()
|
mainwindow.show()
|
||||||
sys.exit(app.exec_())
|
sys.exit(app.exec())
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ class SegAny:
|
|||||||
self.model_type = "vit_h"
|
self.model_type = "vit_h"
|
||||||
else:
|
else:
|
||||||
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
|
raise ValueError('The checkpoint named {} is not supported.'.format(checkpoint))
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
sam = sam_model_registry[self.model_type](checkpoint=checkpoint)
|
||||||
|
@ -11,5 +11,6 @@ from .build_sam import (
|
|||||||
build_sam_vit_b,
|
build_sam_vit_b,
|
||||||
sam_model_registry,
|
sam_model_registry,
|
||||||
)
|
)
|
||||||
|
from .build_sam_baseline import sam_model_registry_baseline
|
||||||
from .predictor import SamPredictor
|
from .predictor import SamPredictor
|
||||||
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
||||||
|
@ -134,7 +134,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
self.output_mode = output_mode
|
self.output_mode = output_mode
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Generates masks for the given image.
|
Generates masks for the given image.
|
||||||
|
|
||||||
@ -160,7 +160,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Generate masks
|
# Generate masks
|
||||||
mask_data = self._generate_masks(image)
|
mask_data = self._generate_masks(image, multimask_output)
|
||||||
|
|
||||||
# Filter small disconnected regions and holes in masks
|
# Filter small disconnected regions and holes in masks
|
||||||
if self.min_mask_region_area > 0:
|
if self.min_mask_region_area > 0:
|
||||||
@ -194,7 +194,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
|
|
||||||
return curr_anns
|
return curr_anns
|
||||||
|
|
||||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData:
|
||||||
orig_size = image.shape[:2]
|
orig_size = image.shape[:2]
|
||||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||||
@ -203,7 +203,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
# Iterate over image crops
|
# Iterate over image crops
|
||||||
data = MaskData()
|
data = MaskData()
|
||||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output)
|
||||||
data.cat(crop_data)
|
data.cat(crop_data)
|
||||||
|
|
||||||
# Remove duplicate masks between crops
|
# Remove duplicate masks between crops
|
||||||
@ -228,6 +228,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
crop_box: List[int],
|
crop_box: List[int],
|
||||||
crop_layer_idx: int,
|
crop_layer_idx: int,
|
||||||
orig_size: Tuple[int, ...],
|
orig_size: Tuple[int, ...],
|
||||||
|
multimask_output: bool = True,
|
||||||
) -> MaskData:
|
) -> MaskData:
|
||||||
# Crop the image and calculate embeddings
|
# Crop the image and calculate embeddings
|
||||||
x0, y0, x1, y1 = crop_box
|
x0, y0, x1, y1 = crop_box
|
||||||
@ -242,7 +243,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
# Generate masks for this crop in batches
|
# Generate masks for this crop in batches
|
||||||
data = MaskData()
|
data = MaskData()
|
||||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||||
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output)
|
||||||
data.cat(batch_data)
|
data.cat(batch_data)
|
||||||
del batch_data
|
del batch_data
|
||||||
self.predictor.reset_image()
|
self.predictor.reset_image()
|
||||||
@ -269,6 +270,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
im_size: Tuple[int, ...],
|
im_size: Tuple[int, ...],
|
||||||
crop_box: List[int],
|
crop_box: List[int],
|
||||||
orig_size: Tuple[int, ...],
|
orig_size: Tuple[int, ...],
|
||||||
|
multimask_output: bool = True,
|
||||||
) -> MaskData:
|
) -> MaskData:
|
||||||
orig_h, orig_w = orig_size
|
orig_h, orig_w = orig_size
|
||||||
|
|
||||||
@ -279,7 +281,7 @@ class SamAutomaticMaskGenerator:
|
|||||||
masks, iou_preds, _ = self.predictor.predict_torch(
|
masks, iou_preds, _ = self.predictor.predict_torch(
|
||||||
in_points[:, None, :],
|
in_points[:, None, :],
|
||||||
in_labels[:, None],
|
in_labels[:, None],
|
||||||
multimask_output=True,
|
multimask_output=multimask_output,
|
||||||
return_logits=True,
|
return_logits=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
|
||||||
|
|
||||||
|
|
||||||
def build_sam_vit_h(checkpoint=None):
|
def build_sam_vit_h(checkpoint=None):
|
||||||
@ -84,7 +84,7 @@ def _build_sam(
|
|||||||
input_image_size=(image_size, image_size),
|
input_image_size=(image_size, image_size),
|
||||||
mask_in_chans=16,
|
mask_in_chans=16,
|
||||||
),
|
),
|
||||||
mask_decoder=MaskDecoder(
|
mask_decoder=MaskDecoderHQ(
|
||||||
num_multimask_outputs=3,
|
num_multimask_outputs=3,
|
||||||
transformer=TwoWayTransformer(
|
transformer=TwoWayTransformer(
|
||||||
depth=2,
|
depth=2,
|
||||||
@ -95,13 +95,19 @@ def _build_sam(
|
|||||||
transformer_dim=prompt_embed_dim,
|
transformer_dim=prompt_embed_dim,
|
||||||
iou_head_depth=3,
|
iou_head_depth=3,
|
||||||
iou_head_hidden_dim=256,
|
iou_head_hidden_dim=256,
|
||||||
|
vit_dim=encoder_embed_dim,
|
||||||
),
|
),
|
||||||
pixel_mean=[123.675, 116.28, 103.53],
|
pixel_mean=[123.675, 116.28, 103.53],
|
||||||
pixel_std=[58.395, 57.12, 57.375],
|
pixel_std=[58.395, 57.12, 57.375],
|
||||||
)
|
)
|
||||||
sam.eval()
|
# sam.eval()
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
with open(checkpoint, "rb") as f:
|
with open(checkpoint, "rb") as f:
|
||||||
state_dict = torch.load(f)
|
state_dict = torch.load(f)
|
||||||
sam.load_state_dict(state_dict)
|
info = sam.load_state_dict(state_dict, strict=False)
|
||||||
|
print(info)
|
||||||
|
for n, p in sam.named_parameters():
|
||||||
|
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
return sam
|
return sam
|
||||||
|
107
segment_anything/build_sam_baseline.py
Normal file
107
segment_anything/build_sam_baseline.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
||||||
|
|
||||||
|
|
||||||
|
def build_sam_vit_h(checkpoint=None):
|
||||||
|
return _build_sam(
|
||||||
|
encoder_embed_dim=1280,
|
||||||
|
encoder_depth=32,
|
||||||
|
encoder_num_heads=16,
|
||||||
|
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
build_sam = build_sam_vit_h
|
||||||
|
|
||||||
|
|
||||||
|
def build_sam_vit_l(checkpoint=None):
|
||||||
|
return _build_sam(
|
||||||
|
encoder_embed_dim=1024,
|
||||||
|
encoder_depth=24,
|
||||||
|
encoder_num_heads=16,
|
||||||
|
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_sam_vit_b(checkpoint=None):
|
||||||
|
return _build_sam(
|
||||||
|
encoder_embed_dim=768,
|
||||||
|
encoder_depth=12,
|
||||||
|
encoder_num_heads=12,
|
||||||
|
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
sam_model_registry_baseline = {
|
||||||
|
"default": build_sam_vit_h,
|
||||||
|
"vit_h": build_sam_vit_h,
|
||||||
|
"vit_l": build_sam_vit_l,
|
||||||
|
"vit_b": build_sam_vit_b,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sam(
|
||||||
|
encoder_embed_dim,
|
||||||
|
encoder_depth,
|
||||||
|
encoder_num_heads,
|
||||||
|
encoder_global_attn_indexes,
|
||||||
|
checkpoint=None,
|
||||||
|
):
|
||||||
|
prompt_embed_dim = 256
|
||||||
|
image_size = 1024
|
||||||
|
vit_patch_size = 16
|
||||||
|
image_embedding_size = image_size // vit_patch_size
|
||||||
|
sam = Sam(
|
||||||
|
image_encoder=ImageEncoderViT(
|
||||||
|
depth=encoder_depth,
|
||||||
|
embed_dim=encoder_embed_dim,
|
||||||
|
img_size=image_size,
|
||||||
|
mlp_ratio=4,
|
||||||
|
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||||
|
num_heads=encoder_num_heads,
|
||||||
|
patch_size=vit_patch_size,
|
||||||
|
qkv_bias=True,
|
||||||
|
use_rel_pos=True,
|
||||||
|
global_attn_indexes=encoder_global_attn_indexes,
|
||||||
|
window_size=14,
|
||||||
|
out_chans=prompt_embed_dim,
|
||||||
|
),
|
||||||
|
prompt_encoder=PromptEncoder(
|
||||||
|
embed_dim=prompt_embed_dim,
|
||||||
|
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||||
|
input_image_size=(image_size, image_size),
|
||||||
|
mask_in_chans=16,
|
||||||
|
),
|
||||||
|
mask_decoder=MaskDecoder(
|
||||||
|
num_multimask_outputs=3,
|
||||||
|
transformer=TwoWayTransformer(
|
||||||
|
depth=2,
|
||||||
|
embedding_dim=prompt_embed_dim,
|
||||||
|
mlp_dim=2048,
|
||||||
|
num_heads=8,
|
||||||
|
),
|
||||||
|
transformer_dim=prompt_embed_dim,
|
||||||
|
iou_head_depth=3,
|
||||||
|
iou_head_hidden_dim=256,
|
||||||
|
),
|
||||||
|
pixel_mean=[123.675, 116.28, 103.53],
|
||||||
|
pixel_std=[58.395, 57.12, 57.375],
|
||||||
|
)
|
||||||
|
sam.eval()
|
||||||
|
if checkpoint is not None:
|
||||||
|
with open(checkpoint, "rb") as f:
|
||||||
|
state_dict = torch.load(f)
|
||||||
|
sam.load_state_dict(state_dict)
|
||||||
|
return sam
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
from .sam import Sam
|
from .sam import Sam
|
||||||
from .image_encoder import ImageEncoderViT
|
from .image_encoder import ImageEncoderViT
|
||||||
|
from .mask_decoder_hq import MaskDecoderHQ
|
||||||
from .mask_decoder import MaskDecoder
|
from .mask_decoder import MaskDecoder
|
||||||
from .prompt_encoder import PromptEncoder
|
from .prompt_encoder import PromptEncoder
|
||||||
from .transformer import TwoWayTransformer
|
from .transformer import TwoWayTransformer
|
||||||
|
@ -108,12 +108,15 @@ class ImageEncoderViT(nn.Module):
|
|||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
x = x + self.pos_embed
|
x = x + self.pos_embed
|
||||||
|
|
||||||
|
interm_embeddings=[]
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
x = blk(x)
|
x = blk(x)
|
||||||
|
if blk.window_size == 0:
|
||||||
|
interm_embeddings.append(x)
|
||||||
|
|
||||||
x = self.neck(x.permute(0, 3, 1, 2))
|
x = self.neck(x.permute(0, 3, 1, 2))
|
||||||
|
|
||||||
return x
|
return x, interm_embeddings
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
|
@ -75,6 +75,8 @@ class MaskDecoder(nn.Module):
|
|||||||
sparse_prompt_embeddings: torch.Tensor,
|
sparse_prompt_embeddings: torch.Tensor,
|
||||||
dense_prompt_embeddings: torch.Tensor,
|
dense_prompt_embeddings: torch.Tensor,
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
|
hq_token_only: bool,
|
||||||
|
interm_embeddings: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Predict masks given image and prompt embeddings.
|
Predict masks given image and prompt embeddings.
|
||||||
|
232
segment_anything/modeling/mask_decoder_hq.py
Normal file
232
segment_anything/modeling/mask_decoder_hq.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# Modified by HQ-SAM team
|
||||||
|
# All rights reserved.
|
||||||
|
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
|
from .common import LayerNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
class MaskDecoderHQ(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
transformer_dim: int,
|
||||||
|
transformer: nn.Module,
|
||||||
|
num_multimask_outputs: int = 3,
|
||||||
|
activation: Type[nn.Module] = nn.GELU,
|
||||||
|
iou_head_depth: int = 3,
|
||||||
|
iou_head_hidden_dim: int = 256,
|
||||||
|
vit_dim: int = 1024,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Predicts masks given an image and prompt embeddings, using a
|
||||||
|
transformer architecture.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
transformer_dim (int): the channel dimension of the transformer
|
||||||
|
transformer (nn.Module): the transformer used to predict masks
|
||||||
|
num_multimask_outputs (int): the number of masks to predict
|
||||||
|
when disambiguating masks
|
||||||
|
activation (nn.Module): the type of activation to use when
|
||||||
|
upscaling masks
|
||||||
|
iou_head_depth (int): the depth of the MLP used to predict
|
||||||
|
mask quality
|
||||||
|
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||||
|
used to predict mask quality
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.transformer_dim = transformer_dim
|
||||||
|
self.transformer = transformer
|
||||||
|
|
||||||
|
self.num_multimask_outputs = num_multimask_outputs
|
||||||
|
|
||||||
|
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||||
|
self.num_mask_tokens = num_multimask_outputs + 1
|
||||||
|
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||||
|
|
||||||
|
self.output_upscaling = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||||
|
LayerNorm2d(transformer_dim // 4),
|
||||||
|
activation(),
|
||||||
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||||
|
activation(),
|
||||||
|
)
|
||||||
|
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
||||||
|
for i in range(self.num_mask_tokens)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.iou_prediction_head = MLP(
|
||||||
|
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
||||||
|
)
|
||||||
|
|
||||||
|
# HQ-SAM parameters
|
||||||
|
self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
|
||||||
|
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token
|
||||||
|
self.num_mask_tokens = self.num_mask_tokens + 1
|
||||||
|
|
||||||
|
# three conv fusion layers for obtaining HQ-Feature
|
||||||
|
self.compress_vit_feat = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
|
||||||
|
LayerNorm2d(transformer_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
|
||||||
|
|
||||||
|
self.embedding_encoder = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||||
|
LayerNorm2d(transformer_dim // 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||||
|
)
|
||||||
|
self.embedding_maskfeature = nn.Sequential(
|
||||||
|
nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
|
||||||
|
LayerNorm2d(transformer_dim // 4),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_embeddings: torch.Tensor,
|
||||||
|
image_pe: torch.Tensor,
|
||||||
|
sparse_prompt_embeddings: torch.Tensor,
|
||||||
|
dense_prompt_embeddings: torch.Tensor,
|
||||||
|
multimask_output: bool,
|
||||||
|
hq_token_only: bool,
|
||||||
|
interm_embeddings: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Predict masks given image and prompt embeddings.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
|
||||||
|
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||||
|
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||||
|
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||||
|
multimask_output (bool): Whether to return multiple masks or a single
|
||||||
|
mask.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: batched predicted masks
|
||||||
|
torch.Tensor: batched predictions of mask quality
|
||||||
|
"""
|
||||||
|
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
|
||||||
|
hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
|
||||||
|
|
||||||
|
masks, iou_pred = self.predict_masks(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
image_pe=image_pe,
|
||||||
|
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||||
|
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||||
|
hq_features=hq_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select the correct mask or masks for output
|
||||||
|
if multimask_output:
|
||||||
|
# mask with highest score
|
||||||
|
mask_slice = slice(1,self.num_mask_tokens-1)
|
||||||
|
iou_pred = iou_pred[:, mask_slice]
|
||||||
|
iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
|
||||||
|
iou_pred = iou_pred.unsqueeze(1)
|
||||||
|
masks_multi = masks[:, mask_slice, :, :]
|
||||||
|
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# singale mask output, default
|
||||||
|
mask_slice = slice(0, 1)
|
||||||
|
iou_pred = iou_pred[:,mask_slice]
|
||||||
|
masks_sam = masks[:,mask_slice]
|
||||||
|
|
||||||
|
masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]
|
||||||
|
if hq_token_only:
|
||||||
|
masks = masks_hq
|
||||||
|
else:
|
||||||
|
masks = masks_sam + masks_hq
|
||||||
|
# Prepare output
|
||||||
|
return masks, iou_pred
|
||||||
|
|
||||||
|
def predict_masks(
|
||||||
|
self,
|
||||||
|
image_embeddings: torch.Tensor,
|
||||||
|
image_pe: torch.Tensor,
|
||||||
|
sparse_prompt_embeddings: torch.Tensor,
|
||||||
|
dense_prompt_embeddings: torch.Tensor,
|
||||||
|
hq_features: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Predicts masks. See 'forward' for more details."""
|
||||||
|
# Concatenate output tokens
|
||||||
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
|
||||||
|
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||||
|
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||||
|
|
||||||
|
# Expand per-image data in batch direction to be per-mask
|
||||||
|
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||||
|
src = src + dense_prompt_embeddings
|
||||||
|
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||||
|
b, c, h, w = src.shape
|
||||||
|
|
||||||
|
# Run the transformer
|
||||||
|
hs, src = self.transformer(src, pos_src, tokens)
|
||||||
|
iou_token_out = hs[:, 0, :]
|
||||||
|
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||||
|
|
||||||
|
# Upscale mask embeddings and predict masks using the mask tokens
|
||||||
|
src = src.transpose(1, 2).view(b, c, h, w)
|
||||||
|
|
||||||
|
upscaled_embedding_sam = self.output_upscaling(src)
|
||||||
|
upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
|
||||||
|
|
||||||
|
hyper_in_list: List[torch.Tensor] = []
|
||||||
|
for i in range(self.num_mask_tokens):
|
||||||
|
if i < self.num_mask_tokens - 1:
|
||||||
|
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||||
|
else:
|
||||||
|
hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
|
||||||
|
|
||||||
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||||
|
b, c, h, w = upscaled_embedding_sam.shape
|
||||||
|
|
||||||
|
masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
|
||||||
|
masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
|
||||||
|
masks = torch.cat([masks_sam,masks_sam_hq],dim=1)
|
||||||
|
# Generate mask quality predictions
|
||||||
|
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||||
|
|
||||||
|
return masks, iou_pred
|
||||||
|
|
||||||
|
|
||||||
|
# Lightly adapted from
|
||||||
|
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
sigmoid_output: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
h = [hidden_dim] * (num_layers - 1)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||||
|
)
|
||||||
|
self.sigmoid_output = sigmoid_output
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||||
|
if self.sigmoid_output:
|
||||||
|
x = F.sigmoid(x)
|
||||||
|
return x
|
@ -50,11 +50,11 @@ class Sam(nn.Module):
|
|||||||
def device(self) -> Any:
|
def device(self) -> Any:
|
||||||
return self.pixel_mean.device
|
return self.pixel_mean.device
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
batched_input: List[Dict[str, Any]],
|
batched_input: List[Dict[str, Any]],
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
|
hq_token_only: bool =False,
|
||||||
) -> List[Dict[str, torch.Tensor]]:
|
) -> List[Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Predicts masks end-to-end from provided images and prompts.
|
Predicts masks end-to-end from provided images and prompts.
|
||||||
@ -95,10 +95,11 @@ class Sam(nn.Module):
|
|||||||
to subsequent iterations of prediction.
|
to subsequent iterations of prediction.
|
||||||
"""
|
"""
|
||||||
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
||||||
image_embeddings = self.image_encoder(input_images)
|
image_embeddings, interm_embeddings = self.image_encoder(input_images)
|
||||||
|
interm_embeddings = interm_embeddings[0] # early layer
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
|
||||||
if "point_coords" in image_record:
|
if "point_coords" in image_record:
|
||||||
points = (image_record["point_coords"], image_record["point_labels"])
|
points = (image_record["point_coords"], image_record["point_labels"])
|
||||||
else:
|
else:
|
||||||
@ -114,6 +115,8 @@ class Sam(nn.Module):
|
|||||||
sparse_prompt_embeddings=sparse_embeddings,
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
dense_prompt_embeddings=dense_embeddings,
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
multimask_output=multimask_output,
|
multimask_output=multimask_output,
|
||||||
|
hq_token_only=hq_token_only,
|
||||||
|
interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
|
||||||
)
|
)
|
||||||
masks = self.postprocess_masks(
|
masks = self.postprocess_masks(
|
||||||
low_res_masks,
|
low_res_masks,
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from segment_anything.modeling import Sam
|
from .modeling import Sam
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@ -49,10 +49,12 @@ class SamPredictor:
|
|||||||
"RGB",
|
"RGB",
|
||||||
"BGR",
|
"BGR",
|
||||||
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
if image_format != self.model.image_format:
|
if image_format != self.model.image_format:
|
||||||
image = image[..., ::-1]
|
image = image[..., ::-1]
|
||||||
|
|
||||||
# Transform the image to the form expected by the model
|
# Transform the image to the form expected by the model
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
input_image = self.transform.apply_image(image)
|
input_image = self.transform.apply_image(image)
|
||||||
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
||||||
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
||||||
@ -86,7 +88,7 @@ class SamPredictor:
|
|||||||
self.original_size = original_image_size
|
self.original_size = original_image_size
|
||||||
self.input_size = tuple(transformed_image.shape[-2:])
|
self.input_size = tuple(transformed_image.shape[-2:])
|
||||||
input_image = self.model.preprocess(transformed_image)
|
input_image = self.model.preprocess(transformed_image)
|
||||||
self.features = self.model.image_encoder(input_image)
|
self.features, self.interm_features = self.model.image_encoder(input_image)
|
||||||
self.is_image_set = True
|
self.is_image_set = True
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
@ -97,6 +99,7 @@ class SamPredictor:
|
|||||||
mask_input: Optional[np.ndarray] = None,
|
mask_input: Optional[np.ndarray] = None,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
return_logits: bool = False,
|
return_logits: bool = False,
|
||||||
|
hq_token_only: bool =False,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Predict masks for the given input prompts, using the currently set image.
|
Predict masks for the given input prompts, using the currently set image.
|
||||||
@ -158,6 +161,7 @@ class SamPredictor:
|
|||||||
mask_input_torch,
|
mask_input_torch,
|
||||||
multimask_output,
|
multimask_output,
|
||||||
return_logits=return_logits,
|
return_logits=return_logits,
|
||||||
|
hq_token_only=hq_token_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
masks_np = masks[0].detach().cpu().numpy()
|
masks_np = masks[0].detach().cpu().numpy()
|
||||||
@ -174,6 +178,7 @@ class SamPredictor:
|
|||||||
mask_input: Optional[torch.Tensor] = None,
|
mask_input: Optional[torch.Tensor] = None,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
return_logits: bool = False,
|
return_logits: bool = False,
|
||||||
|
hq_token_only: bool =False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Predict masks for the given input prompts, using the currently set image.
|
Predict masks for the given input prompts, using the currently set image.
|
||||||
@ -232,6 +237,8 @@ class SamPredictor:
|
|||||||
sparse_prompt_embeddings=sparse_embeddings,
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
dense_prompt_embeddings=dense_embeddings,
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
multimask_output=multimask_output,
|
multimask_output=multimask_output,
|
||||||
|
hq_token_only=hq_token_only,
|
||||||
|
interm_embeddings=self.interm_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Upscale the masks to the original image resolution
|
# Upscale the masks to the original image resolution
|
||||||
|
@ -25,7 +25,8 @@ class SamOnnxModel(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Sam,
|
model: Sam,
|
||||||
return_single_mask: bool,
|
hq_token_only: bool = False,
|
||||||
|
multimask_output: bool = False,
|
||||||
use_stability_score: bool = False,
|
use_stability_score: bool = False,
|
||||||
return_extra_metrics: bool = False,
|
return_extra_metrics: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -33,7 +34,8 @@ class SamOnnxModel(nn.Module):
|
|||||||
self.mask_decoder = model.mask_decoder
|
self.mask_decoder = model.mask_decoder
|
||||||
self.model = model
|
self.model = model
|
||||||
self.img_size = model.image_encoder.img_size
|
self.img_size = model.image_encoder.img_size
|
||||||
self.return_single_mask = return_single_mask
|
self.hq_token_only = hq_token_only
|
||||||
|
self.multimask_output = multimask_output
|
||||||
self.use_stability_score = use_stability_score
|
self.use_stability_score = use_stability_score
|
||||||
self.stability_score_offset = 1.0
|
self.stability_score_offset = 1.0
|
||||||
self.return_extra_metrics = return_extra_metrics
|
self.return_extra_metrics = return_extra_metrics
|
||||||
@ -89,25 +91,12 @@ class SamOnnxModel(nn.Module):
|
|||||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
||||||
return masks
|
return masks
|
||||||
|
|
||||||
def select_masks(
|
|
||||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Determine if we should return the multiclick mask or not from the number of points.
|
|
||||||
# The reweighting is used to avoid control flow.
|
|
||||||
score_reweight = torch.tensor(
|
|
||||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
|
||||||
).to(iou_preds.device)
|
|
||||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
|
||||||
best_idx = torch.argmax(score, dim=1)
|
|
||||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
|
||||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
|
||||||
|
|
||||||
return masks, iou_preds
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_embeddings: torch.Tensor,
|
image_embeddings: torch.Tensor,
|
||||||
|
interm_embeddings: torch.Tensor,
|
||||||
point_coords: torch.Tensor,
|
point_coords: torch.Tensor,
|
||||||
point_labels: torch.Tensor,
|
point_labels: torch.Tensor,
|
||||||
mask_input: torch.Tensor,
|
mask_input: torch.Tensor,
|
||||||
@ -117,11 +106,15 @@ class SamOnnxModel(nn.Module):
|
|||||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
sparse_embedding = self._embed_points(point_coords, point_labels)
|
||||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
||||||
|
|
||||||
|
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
|
||||||
|
hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
|
||||||
|
|
||||||
masks, scores = self.model.mask_decoder.predict_masks(
|
masks, scores = self.model.mask_decoder.predict_masks(
|
||||||
image_embeddings=image_embeddings,
|
image_embeddings=image_embeddings,
|
||||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||||
sparse_prompt_embeddings=sparse_embedding,
|
sparse_prompt_embeddings=sparse_embedding,
|
||||||
dense_prompt_embeddings=dense_embedding,
|
dense_prompt_embeddings=dense_embedding,
|
||||||
|
hq_features=hq_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_stability_score:
|
if self.use_stability_score:
|
||||||
@ -129,8 +122,26 @@ class SamOnnxModel(nn.Module):
|
|||||||
masks, self.model.mask_threshold, self.stability_score_offset
|
masks, self.model.mask_threshold, self.stability_score_offset
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.return_single_mask:
|
if self.multimask_output:
|
||||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
# mask with highest score
|
||||||
|
mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
|
||||||
|
scores = scores[:, mask_slice]
|
||||||
|
scores, max_iou_idx = torch.max(scores,dim=1)
|
||||||
|
scores = scores.unsqueeze(1)
|
||||||
|
masks_multi = masks[:, mask_slice, :, :]
|
||||||
|
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# singale mask output, default
|
||||||
|
mask_slice = slice(0, 1)
|
||||||
|
scores = scores[:,mask_slice]
|
||||||
|
masks_sam = masks[:,mask_slice]
|
||||||
|
|
||||||
|
masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
|
||||||
|
|
||||||
|
if self.hq_token_only:
|
||||||
|
masks = masks_hq
|
||||||
|
else:
|
||||||
|
masks = masks_sam + masks_hq
|
||||||
|
|
||||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
||||||
|
|
||||||
|
@ -88,6 +88,8 @@ class Ui_MainWindow(object):
|
|||||||
self.menuMode.setObjectName("menuMode")
|
self.menuMode.setObjectName("menuMode")
|
||||||
self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
|
self.menuContour_mode = QtWidgets.QMenu(self.menuMode)
|
||||||
self.menuContour_mode.setObjectName("menuContour_mode")
|
self.menuContour_mode.setObjectName("menuContour_mode")
|
||||||
|
self.menuSAM_model = QtWidgets.QMenu(self.menubar)
|
||||||
|
self.menuSAM_model.setObjectName("menuSAM_model")
|
||||||
MainWindow.setMenuBar(self.menubar)
|
MainWindow.setMenuBar(self.menubar)
|
||||||
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
self.statusbar = QtWidgets.QStatusBar(MainWindow)
|
||||||
self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
|
self.statusbar.setLayoutDirection(QtCore.Qt.LeftToRight)
|
||||||
@ -352,6 +354,7 @@ class Ui_MainWindow(object):
|
|||||||
self.menubar.addAction(self.menuFile.menuAction())
|
self.menubar.addAction(self.menuFile.menuAction())
|
||||||
self.menubar.addAction(self.menuEdit.menuAction())
|
self.menubar.addAction(self.menuEdit.menuAction())
|
||||||
self.menubar.addAction(self.menuView.menuAction())
|
self.menubar.addAction(self.menuView.menuAction())
|
||||||
|
self.menubar.addAction(self.menuSAM_model.menuAction())
|
||||||
self.menubar.addAction(self.menuMode.menuAction())
|
self.menubar.addAction(self.menuMode.menuAction())
|
||||||
self.menubar.addAction(self.menuTools.menuAction())
|
self.menubar.addAction(self.menuTools.menuAction())
|
||||||
self.menubar.addAction(self.menuAbout.menuAction())
|
self.menubar.addAction(self.menuAbout.menuAction())
|
||||||
@ -391,6 +394,7 @@ class Ui_MainWindow(object):
|
|||||||
self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
|
self.menuEdit.setTitle(_translate("MainWindow", "Edit"))
|
||||||
self.menuMode.setTitle(_translate("MainWindow", "Mode"))
|
self.menuMode.setTitle(_translate("MainWindow", "Mode"))
|
||||||
self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
|
self.menuContour_mode.setTitle(_translate("MainWindow", "Contour mode"))
|
||||||
|
self.menuSAM_model.setTitle(_translate("MainWindow", "SAM"))
|
||||||
self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
|
self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
|
||||||
self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
|
self.info_dock.setWindowTitle(_translate("MainWindow", "Info"))
|
||||||
self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))
|
self.annos_dock.setWindowTitle(_translate("MainWindow", "Annos"))
|
||||||
|
@ -202,9 +202,15 @@
|
|||||||
</widget>
|
</widget>
|
||||||
<addaction name="menuContour_mode"/>
|
<addaction name="menuContour_mode"/>
|
||||||
</widget>
|
</widget>
|
||||||
|
<widget class="QMenu" name="menuSAM_model">
|
||||||
|
<property name="title">
|
||||||
|
<string>SAM</string>
|
||||||
|
</property>
|
||||||
|
</widget>
|
||||||
<addaction name="menuFile"/>
|
<addaction name="menuFile"/>
|
||||||
<addaction name="menuEdit"/>
|
<addaction name="menuEdit"/>
|
||||||
<addaction name="menuView"/>
|
<addaction name="menuView"/>
|
||||||
|
<addaction name="menuSAM_model"/>
|
||||||
<addaction name="menuMode"/>
|
<addaction name="menuMode"/>
|
||||||
<addaction name="menuTools"/>
|
<addaction name="menuTools"/>
|
||||||
<addaction name="menuAbout"/>
|
<addaction name="menuAbout"/>
|
||||||
|
@ -34,8 +34,6 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(MainWindow, self).__init__()
|
super(MainWindow, self).__init__()
|
||||||
self.setupUi(self)
|
self.setupUi(self)
|
||||||
self.init_ui()
|
|
||||||
self.init_segment_anything()
|
|
||||||
|
|
||||||
self.image_root: str = None
|
self.image_root: str = None
|
||||||
self.label_root:str = None
|
self.label_root:str = None
|
||||||
@ -58,41 +56,51 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
|||||||
self.map_mode = MAPMode.LABEL
|
self.map_mode = MAPMode.LABEL
|
||||||
# 标注目标
|
# 标注目标
|
||||||
self.current_label:Annotation = None
|
self.current_label:Annotation = None
|
||||||
|
self.use_segment_anything = False
|
||||||
|
|
||||||
|
self.init_ui()
|
||||||
self.reload_cfg()
|
self.reload_cfg()
|
||||||
|
|
||||||
self.init_connect()
|
self.init_connect()
|
||||||
self.reset_action()
|
self.reset_action()
|
||||||
|
|
||||||
def init_segment_anything(self):
|
def init_segment_anything(self, model_name, reload=False):
|
||||||
if os.path.exists('./segment_any/sam_vit_h_4b8939.pth'):
|
if model_name == '':
|
||||||
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_h_4b8939.pth'))
|
|
||||||
self.segany = SegAny('./segment_any/sam_vit_h_4b8939.pth')
|
|
||||||
self.use_segment_anything = True
|
|
||||||
elif os.path.exists('./segment_any/sam_vit_l_0b3195.pth'):
|
|
||||||
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_l_0b3195.pth'))
|
|
||||||
self.segany = SegAny('./segment_any/sam_vit_l_0b3195.pth')
|
|
||||||
self.use_segment_anything = True
|
|
||||||
elif os.path.exists('./segment_any/sam_vit_b_01ec64.pth'):
|
|
||||||
self.statusbar.showMessage('Find the checkpoint named {}.'.format('sam_vit_b_01ec64.pth'))
|
|
||||||
self.segany = SegAny('./segment_any/sam_vit_b_01ec64.pth')
|
|
||||||
self.use_segment_anything = True
|
|
||||||
else:
|
|
||||||
QtWidgets.QMessageBox.warning(self, 'Warning', 'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format('https://github.com/facebookresearch/segment-anything#model-checkpoints'))
|
|
||||||
self.use_segment_anything = False
|
self.use_segment_anything = False
|
||||||
|
for name, action in self.pths_actions.items():
|
||||||
|
action.setChecked(model_name == name)
|
||||||
|
return
|
||||||
|
model_path = os.path.join('segment_any', model_name)
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
QtWidgets.QMessageBox.warning(self, 'Warning',
|
||||||
|
'The checkpoint of [Segment anything] not existed. If you want use quick annotate, please download from {}'.format(
|
||||||
|
'https://github.com/facebookresearch/segment-anything#model-checkpoints'))
|
||||||
|
for name, action in self.pths_actions.items():
|
||||||
|
action.setChecked(model_name == name)
|
||||||
|
self.use_segment_anything = False
|
||||||
|
return
|
||||||
|
|
||||||
if self.use_segment_anything:
|
self.segany = SegAny(model_path)
|
||||||
if self.segany.device != 'cpu':
|
self.use_segment_anything = True
|
||||||
self.gpu_resource_thread = GPUResource_Thread()
|
self.statusbar.showMessage('Use the checkpoint named {}.'.format(model_name), 3000)
|
||||||
self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
|
for name, action in self.pths_actions.items():
|
||||||
self.gpu_resource_thread.start()
|
action.setChecked(model_name==name)
|
||||||
|
if not reload:
|
||||||
|
if self.use_segment_anything:
|
||||||
|
if self.segany.device != 'cpu':
|
||||||
|
self.gpu_resource_thread = GPUResource_Thread()
|
||||||
|
self.gpu_resource_thread.message.connect(self.labelGPUResource.setText)
|
||||||
|
self.gpu_resource_thread.start()
|
||||||
|
else:
|
||||||
|
self.labelGPUResource.setText('cpu')
|
||||||
else:
|
else:
|
||||||
self.labelGPUResource.setText('cpu')
|
self.labelGPUResource.setText('segment anything unused.')
|
||||||
else:
|
|
||||||
self.labelGPUResource.setText('segment anything unused.')
|
if reload and self.current_index is not None:
|
||||||
|
self.show_image(self.current_index)
|
||||||
|
|
||||||
def init_ui(self):
|
def init_ui(self):
|
||||||
#
|
#q
|
||||||
self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
|
self.setting_dialog = SettingDialog(parent=self, mainwindow=self)
|
||||||
|
|
||||||
self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
|
self.categories_dock_widget = CategoriesDockWidget(mainwindow=self)
|
||||||
@ -144,6 +152,17 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
|||||||
self.statusbar.addPermanentWidget(self.labelData)
|
self.statusbar.addPermanentWidget(self.labelData)
|
||||||
|
|
||||||
#
|
#
|
||||||
|
model_names = sorted([pth for pth in os.listdir('segment_any') if pth.endswith('.pth')])
|
||||||
|
self.pths_actions = {}
|
||||||
|
for model_name in model_names:
|
||||||
|
action = QtWidgets.QAction(self)
|
||||||
|
action.setObjectName("actionZoom_in")
|
||||||
|
action.triggered.connect(functools.partial(self.init_segment_anything, model_name))
|
||||||
|
action.setText("{}".format(model_name))
|
||||||
|
action.setCheckable(True)
|
||||||
|
|
||||||
|
self.pths_actions[model_name] = action
|
||||||
|
self.menuSAM_model.addAction(action)
|
||||||
|
|
||||||
self.toolBar.addSeparator()
|
self.toolBar.addSeparator()
|
||||||
self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
|
self.mask_aplha = QtWidgets.QSlider(QtCore.Qt.Orientation.Horizontal, self)
|
||||||
@ -213,6 +232,9 @@ class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
|
|||||||
self.cfg['mask_alpha'] = mask_alpha
|
self.cfg['mask_alpha'] = mask_alpha
|
||||||
self.mask_aplha.setValue(mask_alpha*10)
|
self.mask_aplha.setValue(mask_alpha*10)
|
||||||
|
|
||||||
|
model_name = self.cfg.get('model_name', '')
|
||||||
|
self.init_segment_anything(model_name)
|
||||||
|
|
||||||
self.categories_dock_widget.update_widget()
|
self.categories_dock_widget.update_widget()
|
||||||
|
|
||||||
def set_saved_state(self, is_saved:bool):
|
def set_saved_state(self, is_saved:bool):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user