mirror of
https://github.com/immich-app/immich.git
synced 2026-03-22 17:59:48 +03:00
fix(ml): batch size setting (#26524)
This commit is contained in:
@@ -29,7 +29,7 @@ class FaceRecognizer(InferenceModel):
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
|
||||
max_batch_size = settings.max_batch_size and settings.max_batch_size.facial_recognition
|
||||
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
|
||||
@@ -22,7 +22,7 @@ class TextDetector(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.DETECTION, ModelTask.OCR)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
def __init__(self, model_name: str, min_score: float = 0.5, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX)
|
||||
self.max_resolution = 736
|
||||
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
||||
@@ -33,7 +33,7 @@ class TextDetector(InferenceModel):
|
||||
}
|
||||
self.postprocess = DBPostProcess(
|
||||
thresh=0.3,
|
||||
box_thresh=model_kwargs.get("minScore", 0.5),
|
||||
box_thresh=model_kwargs.get("minScore", min_score),
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.6,
|
||||
use_dilation=True,
|
||||
|
||||
@@ -24,9 +24,9 @@ class TextRecognizer(InferenceModel):
|
||||
depends = [(ModelType.DETECTION, ModelTask.OCR)]
|
||||
identity = (ModelType.RECOGNITION, ModelTask.OCR)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None:
|
||||
self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH
|
||||
self.min_score = model_kwargs.get("minScore", 0.9)
|
||||
self.min_score = model_kwargs.get("minScore", min_score)
|
||||
self._empty: TextRecognitionOutput = {
|
||||
"box": np.empty(0, dtype=np.float32),
|
||||
"boxScore": np.empty(0, dtype=np.float32),
|
||||
@@ -57,10 +57,11 @@ class TextRecognizer(InferenceModel):
|
||||
def _load(self) -> ModelSession:
|
||||
# TODO: support other runtimes
|
||||
session = OrtSession(self.model_path)
|
||||
max_batch_size = settings.max_batch_size and settings.max_batch_size.ocr
|
||||
self.model = RapidTextRecognizer(
|
||||
OcrOptions(
|
||||
session=session.session,
|
||||
rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6,
|
||||
rec_batch_num=max_batch_size if max_batch_size else 6,
|
||||
rec_img_shape=(3, 48, 320),
|
||||
lang_type=self.language,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user