diff --git a/machine-learning/immich_ml/models/ocr/detection.py b/machine-learning/immich_ml/models/ocr/detection.py index 235fcc677e..0a9d09b599 100644 --- a/machine-learning/immich_ml/models/ocr/detection.py +++ b/machine-learning/immich_ml/models/ocr/detection.py @@ -1,8 +1,10 @@ from typing import Any +import cv2 import numpy as np +from numpy.typing import NDArray from PIL import Image -from rapidocr.ch_ppocr_det import TextDetector as RapidTextDetector +from rapidocr.ch_ppocr_det.utils import DBPostProcess from rapidocr.inference_engine.base import FileInfo, InferSession from rapidocr.utils import DownloadFile, DownloadFileInput from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType @@ -10,11 +12,10 @@ from rapidocr.utils.typings import ModelType as RapidModelType from immich_ml.config import log from immich_ml.models.base import InferenceModel -from immich_ml.models.transforms import decode_cv2 from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType from immich_ml.sessions.ort import OrtSession -from .schemas import OcrOptions, TextDetectionOutput +from .schemas import TextDetectionOutput class TextDetector(InferenceModel): @@ -24,13 +25,20 @@ class TextDetector(InferenceModel): def __init__(self, model_name: str, **model_kwargs: Any) -> None: super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX) self.max_resolution = 736 - self.min_score = 0.5 - self.score_mode = "fast" + self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32) + self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0) self._empty: TextDetectionOutput = { - "image": np.empty(0, dtype=np.float32), "boxes": np.empty(0, dtype=np.float32), "scores": np.empty(0, dtype=np.float32), } + self.postprocess = DBPostProcess( + thresh=0.3, + box_thresh=model_kwargs.get("minScore", 0.5), + max_candidates=1000, + unclip_ratio=1.6, + use_dilation=True, + score_mode="fast", + ) def _download(self) -> None: model_info = InferSession.get_model_url( @@ -52,35 +60,65 @@ class TextDetector(InferenceModel): def _load(self) -> ModelSession: # TODO: support other runtime sessions - session = OrtSession(self.model_path) - self.model = RapidTextDetector( - OcrOptions( - session=session.session, - limit_side_len=self.max_resolution, - limit_type="min", - box_thresh=self.min_score, - score_mode=self.score_mode, - ) - ) - return session + return OrtSession(self.model_path) - def _predict(self, inputs: bytes | Image.Image) -> TextDetectionOutput: - results = self.model(decode_cv2(inputs)) - if results.boxes is None or results.scores is None or results.img is None: + # partly adapted from RapidOCR + def _predict(self, inputs: Image.Image) -> TextDetectionOutput: + w, h = inputs.size + if w < 32 or h < 32: + return self._empty + out = self.session.run(None, {"x": self._transform(inputs)})[0] + boxes, scores = self.postprocess(out, (h, w)) + if len(boxes) == 0: return self._empty return { - "image": results.img, - "boxes": np.array(results.boxes, dtype=np.float32), - "scores": np.array(results.scores, dtype=np.float32), + "boxes": self.sorted_boxes(boxes), + "scores": np.array(scores, dtype=np.float32), } + # adapted from RapidOCR + def _transform(self, img: Image.Image) -> NDArray[np.float32]: + if img.height < img.width: + ratio = float(self.max_resolution) / img.height + else: + ratio = float(self.max_resolution) / img.width + + resize_h = int(img.height * ratio) + resize_w = int(img.width * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS) + + img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore + img_np -= self.mean + img_np *= self.std_inv + img_np = np.transpose(img_np, (2, 0, 1)) + return np.expand_dims(img_np, axis=0) + + def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]: + if len(dt_boxes) == 0: + return dt_boxes + + # Sort by y, then identify lines, then sort by (line, x) + y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable") + sorted_y = dt_boxes[y_order, 0, 1] + + line_ids = np.empty(len(dt_boxes), dtype=np.int32) + line_ids[0] = 0 + np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:]) + + # Create composite sort key for final ordering + # Shift line_ids by large factor, add x for tie-breaking + sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0] + final_order = np.argsort(sort_key, kind="stable") + sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]] + return sorted_boxes + def configure(self, **kwargs: Any) -> None: if (max_resolution := kwargs.get("maxResolution")) is not None: self.max_resolution = max_resolution - self.model.limit_side_len = max_resolution if (min_score := kwargs.get("minScore")) is not None: - self.min_score = min_score - self.model.postprocess_op.box_thresh = min_score + self.postprocess.box_thresh = min_score if (score_mode := kwargs.get("scoreMode")) is not None: - self.score_mode = score_mode - self.model.postprocess_op.score_mode = score_mode + self.postprocess.score_mode = score_mode diff --git a/machine-learning/immich_ml/models/ocr/recognition.py b/machine-learning/immich_ml/models/ocr/recognition.py index c3d39b0d70..0f91fc4105 100644 --- a/machine-learning/immich_ml/models/ocr/recognition.py +++ b/machine-learning/immich_ml/models/ocr/recognition.py @@ -1,9 +1,8 @@ from typing import Any -import cv2 import numpy as np from numpy.typing import NDArray -from PIL.Image import Image +from PIL import Image from rapidocr.ch_ppocr_rec import TextRecInput from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer from rapidocr.inference_engine.base import FileInfo, InferSession @@ -14,6 +13,7 @@ from rapidocr.utils.vis_res import VisRes from immich_ml.config import log, settings from immich_ml.models.base import InferenceModel +from immich_ml.models.transforms import pil_to_cv2 from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType from immich_ml.sessions.ort import OrtSession @@ -65,17 +65,16 @@ class TextRecognizer(InferenceModel): ) return session - def _predict(self, _: Image, texts: TextDetectionOutput) -> TextRecognitionOutput: - boxes, img, box_scores = texts["boxes"], texts["image"], texts["scores"] + def _predict(self, img: Image.Image, texts: TextDetectionOutput) -> TextRecognitionOutput: + boxes, box_scores = texts["boxes"], texts["scores"] if boxes.shape[0] == 0: return self._empty rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes))) if rec.txts is None: return self._empty - height, width = img.shape[0:2] - boxes[:, :, 0] /= width - boxes[:, :, 1] /= height + boxes[:, :, 0] /= img.width + boxes[:, :, 1] /= img.height text_scores = np.array(rec.scores) valid_text_score_idx = text_scores > self.min_score @@ -87,7 +86,7 @@ class TextRecognizer(InferenceModel): "textScore": text_scores[valid_text_score_idx], } - def get_crop_img_list(self, img: NDArray[np.float32], boxes: NDArray[np.float32]) -> list[NDArray[np.float32]]: + def get_crop_img_list(self, img: Image.Image, boxes: NDArray[np.float32]) -> list[NDArray[np.uint8]]: img_crop_width = np.maximum( np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1) ).astype(np.int32) @@ -98,22 +97,55 @@ class TextRecognizer(InferenceModel): pts_std[:, 1:3, 0] = img_crop_width[:, None] pts_std[:, 2:4, 1] = img_crop_height[:, None] - img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1).tolist() - imgs: list[NDArray[np.float32]] = [] - for box, pts_std, dst_size in zip(list(boxes), list(pts_std), img_crop_sizes): - M = cv2.getPerspectiveTransform(box, pts_std) - dst_img: NDArray[np.float32] = cv2.warpPerspective( - img, - M, - dst_size, - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC, - ) # type: ignore - dst_height, dst_width = dst_img.shape[0:2] + img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1) + all_coeffs = self._get_perspective_transform(pts_std, boxes) + imgs: list[NDArray[np.uint8]] = [] + for coeffs, dst_size in zip(all_coeffs, img_crop_sizes): + dst_img = img.transform( + size=tuple(dst_size), + method=Image.Transform.PERSPECTIVE, + data=tuple(coeffs), + resample=Image.Resampling.BICUBIC, + ) + + dst_width, dst_height = dst_img.size if dst_height * 1.0 / dst_width >= 1.5: - dst_img = np.rot90(dst_img) - imgs.append(dst_img) + dst_img = dst_img.rotate(90, expand=True) + imgs.append(pil_to_cv2(dst_img)) + return imgs + def _get_perspective_transform(self, src: NDArray[np.float32], dst: NDArray[np.float32]) -> NDArray[np.float32]: + N = src.shape[0] + x, y = src[:, :, 0], src[:, :, 1] + u, v = dst[:, :, 0], dst[:, :, 1] + A = np.zeros((N, 8, 9), dtype=np.float32) + + # Fill even rows (0, 2, 4, 6): [x, y, 1, 0, 0, 0, -u*x, -u*y, -u] + A[:, ::2, 0] = x + A[:, ::2, 1] = y + A[:, ::2, 2] = 1 + A[:, ::2, 6] = -u * x + A[:, ::2, 7] = -u * y + A[:, ::2, 8] = -u + + # Fill odd rows (1, 3, 5, 7): [0, 0, 0, x, y, 1, -v*x, -v*y, -v] + A[:, 1::2, 3] = x + A[:, 1::2, 4] = y + A[:, 1::2, 5] = 1 + A[:, 1::2, 6] = -v * x + A[:, 1::2, 7] = -v * y + A[:, 1::2, 8] = -v + + # Solve using SVD for all matrices at once + _, _, Vt = np.linalg.svd(A) + H = Vt[:, -1, :].reshape(N, 3, 3) + H = H / H[:, 2:3, 2:3] + + # Extract the 8 coefficients for each transformation + return np.column_stack( + [H[:, 0, 0], H[:, 0, 1], H[:, 0, 2], H[:, 1, 0], H[:, 1, 1], H[:, 1, 2], H[:, 2, 0], H[:, 2, 1]] + ) # pyright: ignore[reportReturnType] + def configure(self, **kwargs: Any) -> None: self.min_score = kwargs.get("minScore", self.min_score) diff --git a/machine-learning/immich_ml/models/ocr/schemas.py b/machine-learning/immich_ml/models/ocr/schemas.py index 14a7d3cea0..a63c8dd8e5 100644 --- a/machine-learning/immich_ml/models/ocr/schemas.py +++ b/machine-learning/immich_ml/models/ocr/schemas.py @@ -7,7 +7,6 @@ from typing_extensions import TypedDict class TextDetectionOutput(TypedDict): - image: npt.NDArray[np.float32] boxes: npt.NDArray[np.float32] scores: npt.NDArray[np.float32]