fix(ml): batch size setting (#26524)

This commit is contained in:
Mert
2026-03-05 12:01:47 -05:00
committed by GitHub
parent 09fabb36b6
commit 35a521c6ec
6 changed files with 89 additions and 8 deletions

View File

@@ -29,7 +29,7 @@ class FaceRecognizer(InferenceModel):
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs)
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
max_batch_size = settings.max_batch_size and settings.max_batch_size.facial_recognition
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
def _load(self) -> ModelSession:

View File

@@ -22,7 +22,7 @@ class TextDetector(InferenceModel):
depends = []
identity = (ModelType.DETECTION, ModelTask.OCR)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
def __init__(self, model_name: str, min_score: float = 0.5, **model_kwargs: Any) -> None:
super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX)
self.max_resolution = 736
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
@@ -33,7 +33,7 @@ class TextDetector(InferenceModel):
}
self.postprocess = DBPostProcess(
thresh=0.3,
box_thresh=model_kwargs.get("minScore", 0.5),
box_thresh=model_kwargs.get("minScore", min_score),
max_candidates=1000,
unclip_ratio=1.6,
use_dilation=True,

View File

@@ -24,9 +24,9 @@ class TextRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.OCR)]
identity = (ModelType.RECOGNITION, ModelTask.OCR)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None:
self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH
self.min_score = model_kwargs.get("minScore", 0.9)
self.min_score = model_kwargs.get("minScore", min_score)
self._empty: TextRecognitionOutput = {
"box": np.empty(0, dtype=np.float32),
"boxScore": np.empty(0, dtype=np.float32),
@@ -57,10 +57,11 @@ class TextRecognizer(InferenceModel):
def _load(self) -> ModelSession:
# TODO: support other runtimes
session = OrtSession(self.model_path)
max_batch_size = settings.max_batch_size and settings.max_batch_size.ocr
self.model = RapidTextRecognizer(
OcrOptions(
session=session.session,
rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6,
rec_batch_num=max_batch_size if max_batch_size else 6,
rec_img_shape=(3, 48, 320),
lang_type=self.language,
)