mirror of
https://github.com/immich-app/immich.git
synced 2026-02-13 04:17:56 +03:00
refactor
This commit is contained in:
@@ -226,9 +226,9 @@ async def load(model: InferenceModel) -> InferenceModel:
|
||||
except FileNotFoundError as e:
|
||||
if model.model_format == ModelFormat.ONNX:
|
||||
raise e
|
||||
log.exception(e)
|
||||
log.warning(
|
||||
f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it."
|
||||
f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it.",
|
||||
exc_info=e,
|
||||
)
|
||||
model.model_format = ModelFormat.ONNX
|
||||
model.load()
|
||||
|
||||
@@ -8,9 +8,8 @@ from typing import Any, ClassVar
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import ann.ann
|
||||
import rknn.rknnpool
|
||||
import app.sessions.rknn as rknn
|
||||
from app.sessions.ort import OrtSession
|
||||
from app.sessions.rknn import RknnSession
|
||||
|
||||
from ..config import clean_name, log, settings
|
||||
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
|
||||
@@ -34,6 +33,7 @@ class InferenceModel(ABC):
|
||||
self.model_name = clean_name(model_name)
|
||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
|
||||
self.model_format = model_format if model_format is not None else self._model_format_default
|
||||
self.model_path_prefix = rknn.model_prefix if self.model_format == ModelFormat.RKNN else None
|
||||
if session is not None:
|
||||
self.session = session
|
||||
|
||||
@@ -116,7 +116,7 @@ class InferenceModel(ABC):
|
||||
case ".onnx":
|
||||
session = OrtSession(model_path)
|
||||
case ".rknn":
|
||||
session = RknnSession(model_path)
|
||||
session = rknn.RknnSession(model_path)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||
return session
|
||||
@@ -127,6 +127,8 @@ class InferenceModel(ABC):
|
||||
|
||||
@property
|
||||
def model_path(self) -> Path:
|
||||
if self.model_path_prefix:
|
||||
return self.model_dir / self.model_path_prefix / f"model.{self.model_format}"
|
||||
return self.model_dir / f"model.{self.model_format}"
|
||||
|
||||
@property
|
||||
@@ -164,7 +166,7 @@ class InferenceModel(ABC):
|
||||
|
||||
@property
|
||||
def _model_format_default(self) -> ModelFormat:
|
||||
if rknn.rknnpool.is_available and settings.rknn:
|
||||
if rknn.is_available:
|
||||
return ModelFormat.RKNN
|
||||
elif ann.ann.is_available and settings.ann:
|
||||
return ModelFormat.ARMNN
|
||||
|
||||
@@ -6,15 +6,17 @@ from typing import Any, NamedTuple
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from app.config import log, settings
|
||||
from app.schemas import SessionNode
|
||||
from rknn.rknnpool import RknnPoolExecutor, soc_name
|
||||
|
||||
from ..config import log, settings
|
||||
from .rknnpool import RknnPoolExecutor, is_available, soc_name
|
||||
|
||||
is_available = is_available and settings.rknn
|
||||
model_prefix = Path("rknpu") / soc_name if is_available else None
|
||||
|
||||
|
||||
def runInference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -38,17 +40,13 @@ input_output_mapping: dict[str, dict[str, Any]] = {
|
||||
|
||||
|
||||
class RknnSession:
|
||||
def __init__(self, model_path: Path | str):
|
||||
self.model_path = Path(str(model_path).replace("model", soc_name))
|
||||
self.model_type = "detection" if "detection" in self.model_path.as_posix() else "recognition"
|
||||
def __init__(self, model_path: Path) -> None:
|
||||
self.model_type = "detection" if "detection" in model_path.parts else "recognition"
|
||||
self.tpe = settings.rknn_threads
|
||||
|
||||
log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.")
|
||||
self.rknnpool = RknnPoolExecutor(rknnModel=self.model_path.as_posix(), tpes=self.tpe, func=runInference)
|
||||
log.info(f"Loaded RKNN model from {self.model_path} with {self.tpe} threads.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.rknnpool.release()
|
||||
log.info(f"Loading RKNN model from {model_path} with {self.tpe} threads.")
|
||||
self.rknnpool = RknnPoolExecutor(model_path=model_path.as_posix(), tpes=self.tpe, func=run_inference)
|
||||
log.info(f"Loaded RKNN model from {model_path} with {self.tpe} threads.")
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()]
|
||||
85
machine-learning/app/sessions/rknn/rknnpool.py
Normal file
85
machine-learning/app/sessions/rknn/rknnpool.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# This code is from leafqycc/rknn-multi-threaded
|
||||
# Following Apache License 2.0
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from app.config import log
|
||||
|
||||
supported_socs = ["rk3566", "rk3588"]
|
||||
coremask_supported_socs = ["rk3576", "rk3588"]
|
||||
|
||||
|
||||
def get_soc(device_tree_path: Path | str) -> str | None:
|
||||
try:
|
||||
with Path(device_tree_path).open() as f:
|
||||
device_compatible_str = f.read()
|
||||
for soc in supported_socs:
|
||||
if soc in device_compatible_str:
|
||||
return soc
|
||||
log.warning("Device is not supported for RKNN")
|
||||
except OSError as e:
|
||||
log.warning("Could not read /proc/device-tree/compatible. Reason: %s", e.msg)
|
||||
return None
|
||||
|
||||
|
||||
soc_name = None
|
||||
is_available = False
|
||||
try:
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
soc_name = get_soc("/proc/device-tree/compatible")
|
||||
is_available = soc_name is not None
|
||||
except ImportError:
|
||||
log.debug("RKNN is not available")
|
||||
|
||||
|
||||
def init_rknn(model_path: str) -> RKNNLite:
|
||||
if not is_available:
|
||||
raise RuntimeError("rknn is not available!")
|
||||
rknn_lite = RKNNLite()
|
||||
ret = rknn_lite.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError("Load RKNN rknnModel failed")
|
||||
|
||||
if soc_name in coremask_supported_socs:
|
||||
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
|
||||
else:
|
||||
ret = rknn_lite.init_runtime() # Please do not set this parameter on other platforms.
|
||||
|
||||
if ret != 0:
|
||||
raise RuntimeError("Init runtime environment failed")
|
||||
|
||||
return rknn_lite
|
||||
|
||||
|
||||
class RknnPoolExecutor:
|
||||
def __init__(self, model_path: str, tpes: int, func):
|
||||
self.tpes = tpes
|
||||
self.queue = Queue()
|
||||
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
||||
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
||||
self.func = func
|
||||
self.num = 0
|
||||
|
||||
def put(self, inputs) -> None:
|
||||
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
||||
self.num += 1
|
||||
|
||||
def get(self) -> list[list[NDArray[np.float32]], bool]:
|
||||
if self.queue.empty():
|
||||
return None
|
||||
fut = self.queue.get()
|
||||
return fut.result()
|
||||
|
||||
def release(self) -> None:
|
||||
self.pool.shutdown()
|
||||
for rknn_lite in self.rknn_pool:
|
||||
rknn_lite.release()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.release()
|
||||
Reference in New Issue
Block a user