This commit is contained in:
mertalev
2025-03-11 18:35:21 -04:00
parent f5e44f12e1
commit ec0fa4d52b
22 changed files with 132 additions and 105 deletions

View File

@@ -8,9 +8,8 @@ from typing import Any, ClassVar
from huggingface_hub import snapshot_download
import ann.ann
import rknn.rknnpool
import app.sessions.rknn as rknn
from app.sessions.ort import OrtSession
from app.sessions.rknn import RknnSession
from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
@@ -34,6 +33,7 @@ class InferenceModel(ABC):
self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
self.model_format = model_format if model_format is not None else self._model_format_default
self.model_path_prefix = rknn.model_prefix if self.model_format == ModelFormat.RKNN else None
if session is not None:
self.session = session
@@ -116,7 +116,7 @@ class InferenceModel(ABC):
case ".onnx":
session = OrtSession(model_path)
case ".rknn":
session = RknnSession(model_path)
session = rknn.RknnSession(model_path)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
@@ -127,6 +127,8 @@ class InferenceModel(ABC):
@property
def model_path(self) -> Path:
if self.model_path_prefix:
return self.model_dir / self.model_path_prefix / f"model.{self.model_format}"
return self.model_dir / f"model.{self.model_format}"
@property
@@ -164,7 +166,7 @@ class InferenceModel(ABC):
@property
def _model_format_default(self) -> ModelFormat:
if rknn.rknnpool.is_available and settings.rknn:
if rknn.is_available:
return ModelFormat.RKNN
elif ann.ann.is_available and settings.ann:
return ModelFormat.ARMNN