mirror of
https://github.com/immich-app/immich.git
synced 2026-03-04 09:57:33 +03:00
feat(ml): composable ml (#9973)
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
import { Injectable } from '@nestjs/common';
|
||||
import { readFile } from 'node:fs/promises';
|
||||
import { CLIPConfig, ModelConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
|
||||
import { CLIPConfig } from 'src/dtos/model-config.dto';
|
||||
import {
|
||||
CLIPMode,
|
||||
DetectFaceResult,
|
||||
ClipTextualResponse,
|
||||
ClipVisualResponse,
|
||||
FaceDetectionOptions,
|
||||
FacialRecognitionResponse,
|
||||
IMachineLearningRepository,
|
||||
MachineLearningRequest,
|
||||
ModelPayload,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
TextModelInput,
|
||||
VisionModelInput,
|
||||
} from 'src/interfaces/machine-learning.interface';
|
||||
import { Instrumentation } from 'src/utils/instrumentation';
|
||||
|
||||
@@ -16,8 +19,8 @@ const errorPrefix = 'Machine learning request';
|
||||
@Instrumentation()
|
||||
@Injectable()
|
||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
private async predict<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
|
||||
const formData = await this.getFormData(input, config);
|
||||
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||
const formData = await this.getFormData(payload, config);
|
||||
|
||||
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
||||
(error: Error | any) => {
|
||||
@@ -26,50 +29,46 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
);
|
||||
|
||||
if (res.status >= 400) {
|
||||
const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
|
||||
throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
|
||||
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
|
||||
return this.predict<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
|
||||
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||
const request = {
|
||||
[ModelTask.FACIAL_RECOGNITION]: {
|
||||
[ModelType.DETECTION]: { modelName, minScore },
|
||||
[ModelType.RECOGNITION]: { modelName },
|
||||
},
|
||||
};
|
||||
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
|
||||
return {
|
||||
imageHeight: response.imageHeight,
|
||||
imageWidth: response.imageWidth,
|
||||
faces: response[ModelTask.FACIAL_RECOGNITION],
|
||||
};
|
||||
}
|
||||
|
||||
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.predict<number[]>(url, input, {
|
||||
...config,
|
||||
modelType: ModelType.CLIP,
|
||||
mode: CLIPMode.VISION,
|
||||
} as CLIPConfig);
|
||||
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
|
||||
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
|
||||
return this.predict<number[]>(url, input, {
|
||||
...config,
|
||||
modelType: ModelType.CLIP,
|
||||
mode: CLIPMode.TEXT,
|
||||
} as CLIPConfig);
|
||||
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
private async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
|
||||
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
|
||||
const formData = new FormData();
|
||||
const { enabled, modelName, modelType, ...options } = config;
|
||||
if (!enabled) {
|
||||
throw new Error(`${modelType} is not enabled`);
|
||||
}
|
||||
formData.append('entries', JSON.stringify(config));
|
||||
|
||||
formData.append('modelName', modelName);
|
||||
if (modelType) {
|
||||
formData.append('modelType', modelType);
|
||||
}
|
||||
if (options) {
|
||||
formData.append('options', JSON.stringify(options));
|
||||
}
|
||||
if ('imagePath' in input) {
|
||||
formData.append('image', new Blob([await readFile(input.imagePath)]));
|
||||
} else if ('text' in input) {
|
||||
formData.append('text', input.text);
|
||||
if ('imagePath' in payload) {
|
||||
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
||||
} else if ('text' in payload) {
|
||||
formData.append('text', payload.text);
|
||||
} else {
|
||||
throw new Error('Invalid input');
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user