feat(ml): composable ml (#9973)

* modularize model classes

* various fixes

* expose port

* change response

* round coordinates

* simplify preload

* update server

* simplify interface

simplify

* update tests

* composable endpoint

* cleanup

fixes

remove unnecessary interface

support text input, cleanup

* ew camelcase

* update server

server fixes

fix typing

* ml fixes

update locustfile

fixes

* cleaner response

* better repo response

* update tests

formatting and typing

rename

* undo compose change

* linting

fix type

actually fix typing

* stricter typing

fix detection-only response

no need for defaultdict

* update spec file

update api

linting

* update e2e

* unnecessary dimension

* remove commented code

* remove duplicate code

* remove unused imports

* add batch dim
This commit is contained in:
Mert
2024-06-06 23:09:47 -04:00
committed by GitHub
parent 7a46f80ddc
commit 2b1b43a7e4
39 changed files with 982 additions and 999 deletions

View File

@@ -1,13 +1,16 @@
import { Injectable } from '@nestjs/common';
import { readFile } from 'node:fs/promises';
import { CLIPConfig, ModelConfig, RecognitionConfig } from 'src/dtos/model-config.dto';
import { CLIPConfig } from 'src/dtos/model-config.dto';
import {
CLIPMode,
DetectFaceResult,
ClipTextualResponse,
ClipVisualResponse,
FaceDetectionOptions,
FacialRecognitionResponse,
IMachineLearningRepository,
MachineLearningRequest,
ModelPayload,
ModelTask,
ModelType,
TextModelInput,
VisionModelInput,
} from 'src/interfaces/machine-learning.interface';
import { Instrumentation } from 'src/utils/instrumentation';
@@ -16,8 +19,8 @@ const errorPrefix = 'Machine learning request';
@Instrumentation()
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<T> {
const formData = await this.getFormData(input, config);
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
@@ -26,50 +29,46 @@ export class MachineLearningRepository implements IMachineLearningRepository {
);
if (res.status >= 400) {
const modelType = config.modelType ? ` for ${config.modelType.replace('-', ' ')}` : '';
throw new Error(`${errorPrefix}${modelType} failed with status ${res.status}: ${res.statusText}`);
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
}
return res.json();
}
detectFaces(url: string, input: VisionModelInput, config: RecognitionConfig): Promise<DetectFaceResult[]> {
return this.predict<DetectFaceResult[]>(url, input, { ...config, modelType: ModelType.FACIAL_RECOGNITION });
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
[ModelType.DETECTION]: { modelName, minScore },
[ModelType.RECOGNITION]: { modelName },
},
};
const response = await this.predict<FacialRecognitionResponse>(url, { imagePath }, request);
return {
imageHeight: response.imageHeight,
imageWidth: response.imageWidth,
faces: response[ModelTask.FACIAL_RECOGNITION],
};
}
encodeImage(url: string, input: VisionModelInput, config: CLIPConfig): Promise<number[]> {
return this.predict<number[]>(url, input, {
...config,
modelType: ModelType.CLIP,
mode: CLIPMode.VISION,
} as CLIPConfig);
async encodeImage(url: string, imagePath: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: { modelName } } };
const response = await this.predict<ClipVisualResponse>(url, { imagePath }, request);
return response[ModelTask.SEARCH];
}
encodeText(url: string, input: TextModelInput, config: CLIPConfig): Promise<number[]> {
return this.predict<number[]>(url, input, {
...config,
modelType: ModelType.CLIP,
mode: CLIPMode.TEXT,
} as CLIPConfig);
async encodeText(url: string, text: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
const response = await this.predict<ClipTextualResponse>(url, { text }, request);
return response[ModelTask.SEARCH];
}
private async getFormData(input: TextModelInput | VisionModelInput, config: ModelConfig): Promise<FormData> {
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
const formData = new FormData();
const { enabled, modelName, modelType, ...options } = config;
if (!enabled) {
throw new Error(`${modelType} is not enabled`);
}
formData.append('entries', JSON.stringify(config));
formData.append('modelName', modelName);
if (modelType) {
formData.append('modelType', modelType);
}
if (options) {
formData.append('options', JSON.stringify(options));
}
if ('imagePath' in input) {
formData.append('image', new Blob([await readFile(input.imagePath)]));
} else if ('text' in input) {
formData.append('text', input.text);
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}