2023-09-11 11:17:20 +08:00

568 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
# @Author : LG
from PyQt5 import QtWidgets, QtGui, QtCore
from widgets.polygon import Polygon, Vertex
from configs import STATUSMode, CLICKMode, DRAWMode, CONTOURMode
from PIL import Image
import numpy as np
import cv2
import time # 拖动鼠标描点
class AnnotationScene(QtWidgets.QGraphicsScene):
def __init__(self, mainwindow):
super(AnnotationScene, self).__init__()
self.mainwindow = mainwindow
self.image_item:QtWidgets.QGraphicsPixmapItem = None
self.image_data = None
self.current_graph:Polygon = None
self.mode = STATUSMode.VIEW
self.click = CLICKMode.POSITIVE
self.draw_mode = DRAWMode.SEGMENTANYTHING # 默认使用segment anything进行快速标注
self.contour_mode = CONTOURMode.SAVE_EXTERNAL # 默认SAM只保留外轮廓
self.click_points = [] # SAM point prompt
self.click_points_mode = [] # SAM point prompt
self.masks:np.ndarray = None
self.mask_alpha = 0.5
self.top_layer = 1
self.guide_line_x:QtWidgets.QGraphicsLineItem = None
self.guide_line_y:QtWidgets.QGraphicsLineItem = None
# 拖动鼠标描点
self.last_draw_time = time.time()
self.draw_interval = 0.15
self.pressd = False
def load_image(self, image_path:str):
if self.mainwindow.image_changed:
self.clear()
self.image_data = np.array(Image.open(image_path))
self.image_item = QtWidgets.QGraphicsPixmapItem()
self.image_item.setZValue(0)
self.addItem(self.image_item)
self.mask_item = QtWidgets.QGraphicsPixmapItem()
self.mask_item.setZValue(1)
self.addItem(self.mask_item)
self.image_item.setPixmap(QtGui.QPixmap(image_path))
self.setSceneRect(self.image_item.boundingRect())
self.change_mode_to_view()
# for polygon in self.mainwindow.polygons:
# self.addItem(polygon)
def change_mode_to_create(self):
if self.image_item is None:
return
self.mode = STATUSMode.CREATE
self.image_item.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.CrossCursor))
self.mainwindow.actionPrev.setEnabled(False)
self.mainwindow.actionNext.setEnabled(False)
self.mainwindow.actionSegment_anything.setEnabled(False)
self.mainwindow.actionPolygon.setEnabled(False)
self.mainwindow.actionBackspace.setEnabled(True)
self.mainwindow.actionFinish.setEnabled(True)
self.mainwindow.actionCancel.setEnabled(True)
self.mainwindow.actionTo_top.setEnabled(False)
self.mainwindow.actionTo_bottom.setEnabled(False)
self.mainwindow.actionEdit.setEnabled(False)
self.mainwindow.actionDelete.setEnabled(False)
self.mainwindow.actionSave.setEnabled(False)
self.mainwindow.set_labels_visible(False)
self.mainwindow.annos_dock_widget.setEnabled(False)
def change_mode_to_view(self):
self.mode = STATUSMode.VIEW
self.image_item.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.ArrowCursor))
self.mainwindow.actionPrev.setEnabled(True)
self.mainwindow.actionNext.setEnabled(True)
self.mainwindow.actionSegment_anything.setEnabled(self.mainwindow.use_segment_anything and self.mainwindow.can_be_annotated)
if self.mainwindow.use_segment_anything and self.mainwindow.segany.image is None:
self.mainwindow.actionSegment_anything.setEnabled(False)
self.mainwindow.actionPolygon.setEnabled(self.mainwindow.can_be_annotated)
self.mainwindow.actionBackspace.setEnabled(False)
self.mainwindow.actionFinish.setEnabled(False)
self.mainwindow.actionCancel.setEnabled(False)
self.mainwindow.actionTo_top.setEnabled(False)
self.mainwindow.actionTo_bottom.setEnabled(False)
self.mainwindow.actionEdit.setEnabled(False)
self.mainwindow.actionDelete.setEnabled(False)
self.mainwindow.actionSave.setEnabled(self.mainwindow.can_be_annotated)
self.mainwindow.set_labels_visible(True)
self.mainwindow.annos_dock_widget.setEnabled(True)
def change_mode_to_edit(self):
self.mode = STATUSMode.EDIT
self.image_item.setCursor(QtGui.QCursor(QtCore.Qt.CursorShape.CrossCursor))
self.mainwindow.actionPrev.setEnabled(False)
self.mainwindow.actionNext.setEnabled(False)
self.mainwindow.actionSegment_anything.setEnabled(False)
self.mainwindow.actionPolygon.setEnabled(False)
self.mainwindow.actionBackspace.setEnabled(False)
self.mainwindow.actionFinish.setEnabled(False)
self.mainwindow.actionCancel.setEnabled(False)
self.mainwindow.actionTo_top.setEnabled(True)
self.mainwindow.actionTo_bottom.setEnabled(True)
self.mainwindow.actionEdit.setEnabled(True)
self.mainwindow.actionDelete.setEnabled(True)
self.mainwindow.actionSave.setEnabled(True)
def change_click_to_positive(self):
self.click = CLICKMode.POSITIVE
def change_click_to_negative(self):
self.click = CLICKMode.NEGATIVE
def change_contour_mode_to_save_all(self):
self.contour_mode = CONTOURMode.SAVE_ALL
def change_contour_mode_to_save_max_only(self):
self.contour_mode = CONTOURMode.SAVE_MAX_ONLY
def change_contour_mode_to_save_external(self):
self.contour_mode = CONTOURMode.SAVE_EXTERNAL
def start_segment_anything(self):
self.draw_mode = DRAWMode.SEGMENTANYTHING
self.start_draw()
def start_draw_polygon(self):
self.draw_mode = DRAWMode.POLYGON
self.start_draw()
def start_draw(self):
# 只有view模式时才能切换create模式
if self.mode != STATUSMode.VIEW:
return
# 否则,切换到绘图模式
self.change_mode_to_create()
# 绘图模式
if self.mode == STATUSMode.CREATE:
self.current_graph = Polygon()
self.addItem(self.current_graph)
def finish_draw(self):
if self.current_graph is None:
return
self.change_mode_to_view()
category = self.mainwindow.current_category
group = self.mainwindow.current_group
is_crowd = False
note = ''
if self.draw_mode == DRAWMode.SEGMENTANYTHING:
# mask to polygon
# --------------
if self.masks is not None:
masks = self.masks
masks = masks.astype('uint8') * 255
h, w = masks.shape[-2:]
masks = masks.reshape(h, w)
if self.contour_mode == CONTOURMode.SAVE_ALL:
# 当保留所有轮廓时,检测所有轮廓,并建立二层等级关系
contours, hierarchy = cv2.findContours(masks, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_KCOS)
else:
# 当只保留外轮廓或单个mask时只检测外轮廓
contours, hierarchy = cv2.findContours(masks, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)
if self.contour_mode == CONTOURMode.SAVE_MAX_ONLY:
contour = contours[0]
for cont in contours:
if len(cont) > len(contour):
contour = cont
contours = [contour]
for index, contour in enumerate(contours):
if self.current_graph is None:
self.current_graph = Polygon()
self.addItem(self.current_graph)
if len(contour) < 3:
continue
for point in contour:
x, y = point[0]
self.current_graph.addPoint(QtCore.QPointF(x, y))
if self.contour_mode == CONTOURMode.SAVE_ALL and hierarchy[0][index][3] != -1:
# 保存所有轮廓,且当前轮廓为子轮廓,则自轮廓类别设置为背景
category = '__background__'
group = 0
else:
category = self.mainwindow.current_category
group = self.mainwindow.current_group
self.current_graph.set_drawed(category,
group,
is_crowd,
note,
QtGui.QColor(self.mainwindow.category_color_dict[category]),
self.top_layer)
# 添加新polygon
self.mainwindow.polygons.append(self.current_graph)
# 设置为最高图层
self.current_graph.setZValue(len(self.mainwindow.polygons))
for vertex in self.current_graph.vertexs:
vertex.setZValue(len(self.mainwindow.polygons))
self.current_graph = None
self.mainwindow.current_group += 1
elif self.draw_mode == DRAWMode.POLYGON:
if len(self.current_graph.points) < 1:
return
# 移除鼠标移动点
# self.current_graph.removePoint(len(self.current_graph.points) - 1)
# 单点,删除
if len(self.current_graph.points) < 2:
self.current_graph.delete()
self.removeItem(self.current_graph)
self.change_mode_to_view()
return
# 两点,默认矩形
if len(self.current_graph.points) == 2:
first_point = self.current_graph.points[0]
last_point = self.current_graph.points[-1]
self.current_graph.removePoint(len(self.current_graph.points) - 1)
self.current_graph.addPoint(QtCore.QPointF(first_point.x(), last_point.y()))
self.current_graph.addPoint(last_point)
self.current_graph.addPoint(QtCore.QPointF(last_point.x(), first_point.y()))
# 设置polygon 属性
self.current_graph.set_drawed(category,
group,
is_crowd,
note,
QtGui.QColor(self.mainwindow.category_color_dict[category]),
self.top_layer)
self.mainwindow.current_group += 1
# 添加新polygon
self.mainwindow.polygons.append(self.current_graph)
# 设置为最高图层
self.current_graph.setZValue(len(self.mainwindow.polygons))
for vertex in self.current_graph.vertexs:
vertex.setZValue(len(self.mainwindow.polygons))
# 选择类别
# self.mainwindow.category_choice_widget.load_cfg()
# self.mainwindow.category_choice_widget.show()
self.mainwindow.annos_dock_widget.update_listwidget()
self.current_graph = None
self.change_mode_to_view()
# mask清空
self.click_points.clear()
self.click_points_mode.clear()
self.update_mask()
def cancel_draw(self):
if self.current_graph is None:
return
self.current_graph.delete() # 清除所有路径
self.removeItem(self.current_graph)
self.current_graph = None
self.change_mode_to_view()
self.click_points.clear()
self.click_points_mode.clear()
self.update_mask()
def delete_selected_graph(self):
deleted_layer = None
for item in self.selectedItems():
if isinstance(item, Polygon) and (item in self.mainwindow.polygons):
self.mainwindow.polygons.remove(item)
item.delete()
self.removeItem(item)
deleted_layer = item.zValue()
del item
elif isinstance(item, Vertex):
polygon = item.polygon
index = polygon.vertexs.index(item)
item.polygon.removePoint(index)
# 如果剩余顶点少于三个,删除多边形
if len(polygon.vertexs) < 3:
self.mainwindow.polygons.remove(polygon)
polygon.delete()
self.removeItem(polygon)
deleted_layer = polygon.zValue()
del polygon
if deleted_layer is not None:
for p in self.mainwindow.polygons:
if p.zValue() > deleted_layer:
p.setZValue(p.zValue() - 1)
self.mainwindow.annos_dock_widget.update_listwidget()
def edit_polygon(self):
selectd_items = self.selectedItems()
if len(selectd_items) < 1:
return
item = selectd_items[0]
if not item:
return
self.mainwindow.category_edit_widget.polygon = item
self.mainwindow.category_edit_widget.load_cfg()
self.mainwindow.category_edit_widget.show()
def move_polygon_to_top(self):
selectd_items = self.selectedItems()
if len(selectd_items) < 1:
return
current_polygon = selectd_items[0]
max_layer = len(self.mainwindow.polygons)
current_layer = current_polygon.zValue()
for p in self.mainwindow.polygons:
if p.zValue() > current_layer:
p.setZValue(p.zValue() - 1)
current_polygon.setZValue(max_layer)
for vertex in current_polygon.vertexs:
vertex.setZValue(max_layer)
self.mainwindow.set_saved_state(False)
def move_polygon_to_bottom(self):
selectd_items = self.selectedItems()
if len(selectd_items) < 1:
return
current_polygon = selectd_items[0]
if current_polygon is not None:
current_layer = current_polygon.zValue()
for p in self.mainwindow.polygons:
if p.zValue() < current_layer:
p.setZValue(p.zValue() + 1)
current_polygon.setZValue(1)
for vertex in current_polygon.vertexs:
vertex.setZValue(1)
self.mainwindow.set_saved_state(False)
def mousePressEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
if self.mode == STATUSMode.CREATE:
sceneX, sceneY = event.scenePos().x(), event.scenePos().y()
sceneX = 0 if sceneX < 0 else sceneX
sceneX = self.width()-1 if sceneX > self.width()-1 else sceneX
sceneY = 0 if sceneY < 0 else sceneY
sceneY = self.height()-1 if sceneY > self.height()-1 else sceneY
if event.button() == QtCore.Qt.MouseButton.LeftButton:
if self.draw_mode == DRAWMode.SEGMENTANYTHING:
self.click_points.append([sceneX, sceneY])
self.click_points_mode.append(1)
elif self.draw_mode == DRAWMode.POLYGON:
# 移除随鼠标移动的点
self.current_graph.removePoint(len(self.current_graph.points) - 1)
# 添加当前点
self.current_graph.addPoint(QtCore.QPointF(sceneX, sceneY))
# 添加随鼠标移动的点
self.current_graph.addPoint(QtCore.QPointF(sceneX, sceneY))
else:
raise ValueError('The draw mode named {} not supported.')
if event.button() == QtCore.Qt.MouseButton.RightButton:
if self.draw_mode == DRAWMode.SEGMENTANYTHING:
self.click_points.append([sceneX, sceneY])
self.click_points_mode.append(0)
elif self.draw_mode == DRAWMode.POLYGON:
pass
else:
raise ValueError('The draw mode named {} not supported.')
if self.draw_mode == DRAWMode.SEGMENTANYTHING:
self.update_mask()
# 拖动鼠标描点
self.last_draw_time = time.time()
self.pressd = True
super(AnnotationScene, self).mousePressEvent(event)
# 拖动鼠标描点
def mouseReleaseEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
self.pressd = False
super(AnnotationScene, self).mouseReleaseEvent(event)
def mouseMoveEvent(self, event: 'QtWidgets.QGraphicsSceneMouseEvent'):
# 拖动鼠标描点
if self.pressd: # 拖动鼠标
current_time = time.time()
if self.last_draw_time is not None and current_time - self.last_draw_time < self.draw_interval:
return # 时间小于给定值不画点
self.last_draw_time = current_time
sceneX, sceneY = event.scenePos().x(), event.scenePos().y()
sceneX = 0 if sceneX < 0 else sceneX
sceneX = self.width()-1 if sceneX > self.width()-1 else sceneX
sceneY = 0 if sceneY < 0 else sceneY
sceneY = self.height()-1 if sceneY > self.height()-1 else sceneY
if self.current_graph is not None:
if self.draw_mode == DRAWMode.POLYGON:
# 移除随鼠标移动的点
self.current_graph.removePoint(len(self.current_graph.points) - 1)
# 添加当前点
self.current_graph.addPoint(QtCore.QPointF(sceneX, sceneY))
# 添加随鼠标移动的点
self.current_graph.addPoint(QtCore.QPointF(sceneX, sceneY))
# 辅助线
if self.guide_line_x is not None and self.guide_line_y is not None:
if self.guide_line_x in self.items():
self.removeItem(self.guide_line_x)
if self.guide_line_y in self.items():
self.removeItem(self.guide_line_y)
self.guide_line_x = None
self.guide_line_y = None
pos = event.scenePos()
if pos.x() < 0: pos.setX(0)
if pos.x() > self.width()-1: pos.setX(self.width()-1)
if pos.y() < 0: pos.setY(0)
if pos.y() > self.height()-1: pos.setY(self.height()-1)
# 限制在图片范围内
if self.mode == STATUSMode.CREATE:
if self.draw_mode == DRAWMode.POLYGON:
# 随鼠标位置实时更新多边形
self.current_graph.movePoint(len(self.current_graph.points)-1, pos)
# 辅助线
if self.guide_line_x is None and self.width()>0 and self.height()>0:
self.guide_line_x = QtWidgets.QGraphicsLineItem(QtCore.QLineF(pos.x(), 0, pos.x(), self.height()))
self.guide_line_x.setZValue(1)
self.addItem(self.guide_line_x)
if self.guide_line_y is None and self.width()>0 and self.height()>0:
self.guide_line_y = QtWidgets.QGraphicsLineItem(QtCore.QLineF(0, pos.y(), self.width(), pos.y()))
self.guide_line_y.setZValue(1)
self.addItem(self.guide_line_y)
# 状态栏,显示当前坐标
if self.image_data is not None:
x, y = round(pos.x()), round(pos.y())
self.mainwindow.labelCoord.setText('xy: ({:>4d},{:>4d})'.format(x, y))
data = self.image_data[y][x]
if self.image_data.ndim == 2:
self.mainwindow.labelData.setText('pix: [{:^3d}]'.format(data))
elif self.image_data.ndim == 3:
if len(data) == 3:
self.mainwindow.labelData.setText('rgb: [{:>3d},{:>3d},{:>3d}]'.format(data[0], data[1], data[2]))
else:
self.mainwindow.labelData.setText('pix: [{}]'.format(data))
super(AnnotationScene, self).mouseMoveEvent(event)
def update_mask(self):
if not self.mainwindow.use_segment_anything:
return
if self.image_data is None:
return
if not (self.image_data.ndim == 3 and self.image_data.shape[-1] == 3):
return
if len(self.click_points) > 0 and len(self.click_points_mode) > 0:
masks = self.mainwindow.segany.predict_with_point_prompt(self.click_points, self.click_points_mode)
self.masks = masks
color = np.array([0, 0, 255])
h, w = masks.shape[-2:]
mask_image = masks.reshape(h, w, 1) * color.reshape(1, 1, -1)
mask_image = mask_image.astype("uint8")
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
# 这里通过调整原始图像的权重self.mask_alpha来调整mask的明显程度。
mask_image = cv2.addWeighted(self.image_data, self.mask_alpha, mask_image, 1, 0)
mask_image = QtGui.QImage(mask_image[:], mask_image.shape[1], mask_image.shape[0], mask_image.shape[1] * 3,
QtGui.QImage.Format_RGB888)
mask_pixmap = QtGui.QPixmap(mask_image)
self.mask_item.setPixmap(mask_pixmap)
else:
mask_image = np.zeros(self.image_data.shape, dtype=np.uint8)
mask_image = cv2.addWeighted(self.image_data, 1, mask_image, 0, 0)
mask_image = QtGui.QImage(mask_image[:], mask_image.shape[1], mask_image.shape[0], mask_image.shape[1] * 3,
QtGui.QImage.Format_RGB888)
mask_pixmap = QtGui.QPixmap(mask_image)
self.mask_item.setPixmap(mask_pixmap)
def backspace(self):
if self.mode != STATUSMode.CREATE:
return
# 返回上一步操作
if self.draw_mode == DRAWMode.SEGMENTANYTHING:
if len(self.click_points) > 0:
self.click_points.pop()
if len(self.click_points_mode) > 0:
self.click_points_mode.pop()
self.update_mask()
elif self.draw_mode == DRAWMode.POLYGON:
if len(self.current_graph.points) < 2:
return
# 移除随鼠标移动的点
self.current_graph.removePoint(len(self.current_graph.points) - 2)
class AnnotationView(QtWidgets.QGraphicsView):
def __init__(self, parent=None):
super(AnnotationView, self).__init__(parent)
self.setMouseTracking(True)
self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarPolicy.ScrollBarAlwaysOn)
self.setDragMode(QtWidgets.QGraphicsView.DragMode.ScrollHandDrag)
self.factor = 1.2
def wheelEvent(self, event: QtGui.QWheelEvent):
angel = event.angleDelta()
angelX, angelY = angel.x(), angel.y()
point = event.pos() # 当前鼠标位置
if angelY > 0:
self.zoom(self.factor, point)
else:
self.zoom(1 / self.factor, point)
def zoom_in(self):
self.zoom(self.factor)
def zoom_out(self):
self.zoom(1/self.factor)
def zoomfit(self):
self.fitInView(0, 0, self.scene().width(), self.scene().height(), QtCore.Qt.AspectRatioMode.KeepAspectRatio)
def zoom(self, factor, point=None):
mouse_old = self.mapToScene(point) if point is not None else None
# 缩放比例
pix_widget = self.transform().scale(factor, factor).mapRect(QtCore.QRectF(0, 0, 1, 1)).width()
if pix_widget > 30 and factor > 1: return
if pix_widget < 0.01 and factor < 1: return
self.scale(factor, factor)
if point is not None:
mouse_now = self.mapToScene(point)
center_now = self.mapToScene(self.viewport().width() // 2, self.viewport().height() // 2)
center_new = mouse_old - mouse_now + center_now
self.centerOn(center_new)