feat: preload textual model

This commit is contained in:
martabal
2024-09-16 17:53:43 +02:00
parent 4735db8e79
commit 708a53a1eb
17 changed files with 301 additions and 19 deletions

View File

@@ -120,6 +120,10 @@ export interface SystemConfig {
clip: {
enabled: boolean;
modelName: string;
loadTextualModelOnConnection: {
enabled: boolean;
ttl: number;
};
};
duplicateDetection: {
enabled: boolean;
@@ -270,6 +274,10 @@ export const defaults = Object.freeze<SystemConfig>({
clip: {
enabled: true,
modelName: 'ViT-B-32__openai',
loadTextualModelOnConnection: {
enabled: false,
ttl: 300,
},
},
duplicateDetection: {
enabled: true,

View File

@@ -1,6 +1,6 @@
import { ApiProperty } from '@nestjs/swagger';
import { Type } from 'class-transformer';
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
import { ValidateBoolean } from 'src/validation';
export class TaskConfig {
@@ -14,7 +14,20 @@ export class ModelConfig extends TaskConfig {
modelName!: string;
}
export class CLIPConfig extends ModelConfig {}
export class LoadTextualModelOnConnection extends TaskConfig {
@IsNumber()
@Min(0)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'int64' })
ttl!: number;
}
export class CLIPConfig extends ModelConfig {
@Type(() => LoadTextualModelOnConnection)
@ValidateNested()
@IsObject()
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
}
export class DuplicateDetectionConfig extends TaskConfig {
@IsNumber()

View File

@@ -24,13 +24,17 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string };
export interface LoadModelOptions extends ModelOptions {
ttl: number;
}
export type FaceDetectionOptions = ModelOptions & { minScore: number };
type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
export type FacialRecognitionRequest = {
@@ -54,4 +58,5 @@ export interface IMachineLearningRepository {
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
loadTextModel(url: string, config: ModelOptions): Promise<void>;
}

View File

@@ -9,6 +9,7 @@ import {
WebSocketServer,
} from '@nestjs/websockets';
import { Server, Socket } from 'socket.io';
import { SystemConfigCore } from 'src/cores/system-config.core';
import {
ArgsOf,
ClientEventMap,
@@ -19,6 +20,8 @@ import {
ServerEventMap,
} from 'src/interfaces/event.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
import { AuthService } from 'src/services/auth.service';
import { Instrumentation } from 'src/utils/instrumentation';
@@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: EmitHandler<T>[] }>;
@Injectable()
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
private emitHandlers: EmitHandlers = {};
private configCore: SystemConfigCore;
@WebSocketServer()
private server?: Server;
@@ -41,8 +45,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
private moduleRef: ModuleRef,
private eventEmitter: EventEmitter2,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
@Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository,
@Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository,
) {
this.logger.setContext(EventRepository.name);
this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger);
}
afterInit(server: Server) {
@@ -68,6 +75,16 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
queryParams: {},
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
});
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
try {
this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip);
} catch (error) {
this.logger.warn(error);
}
}
}
await client.join(auth.user.id);
if (auth.session) {
await client.join(auth.session.id);

View File

@@ -20,13 +20,9 @@ const errorPrefix = 'Machine learning request';
@Injectable()
export class MachineLearningRepository implements IMachineLearningRepository {
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
const formData = await this.getFormData(payload, config);
const formData = await this.getFormData(config, payload);
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
(error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
},
);
const res = await this.fetchData(url, '/predict', formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
@@ -34,6 +30,25 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return res.json();
}
private async fetchData(url: string, path: string, formData?: FormData): Promise<Response> {
const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => {
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
});
return res;
}
async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) {
try {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } };
const formData = await this.getFormData(request);
const res = await this.fetchData(url, '/load', formData);
if (res.status >= 400) {
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
}
} catch (error) {}
}
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
const request = {
[ModelTask.FACIAL_RECOGNITION]: {
@@ -61,16 +76,17 @@ export class MachineLearningRepository implements IMachineLearningRepository {
return response[ModelTask.SEARCH];
}
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise<FormData> {
const formData = new FormData();
formData.append('entries', JSON.stringify(config));
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
if (payload) {
if ('imagePath' in payload) {
formData.append('image', new Blob([await readFile(payload.imagePath)]));
} else if ('text' in payload) {
formData.append('text', payload.text);
} else {
throw new Error('Invalid input');
}
}
return formData;