mirror of
https://github.com/immich-app/immich.git
synced 2026-02-14 12:58:17 +03:00
feat: preload textual model
This commit is contained in:
@@ -120,6 +120,10 @@ export interface SystemConfig {
|
||||
clip: {
|
||||
enabled: boolean;
|
||||
modelName: string;
|
||||
loadTextualModelOnConnection: {
|
||||
enabled: boolean;
|
||||
ttl: number;
|
||||
};
|
||||
};
|
||||
duplicateDetection: {
|
||||
enabled: boolean;
|
||||
@@ -270,6 +274,10 @@ export const defaults = Object.freeze<SystemConfig>({
|
||||
clip: {
|
||||
enabled: true,
|
||||
modelName: 'ViT-B-32__openai',
|
||||
loadTextualModelOnConnection: {
|
||||
enabled: false,
|
||||
ttl: 300,
|
||||
},
|
||||
},
|
||||
duplicateDetection: {
|
||||
enabled: true,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { ApiProperty } from '@nestjs/swagger';
|
||||
import { Type } from 'class-transformer';
|
||||
import { IsNotEmpty, IsNumber, IsString, Max, Min } from 'class-validator';
|
||||
import { IsNotEmpty, IsNumber, IsObject, IsString, Max, Min, ValidateNested } from 'class-validator';
|
||||
import { ValidateBoolean } from 'src/validation';
|
||||
|
||||
export class TaskConfig {
|
||||
@@ -14,7 +14,20 @@ export class ModelConfig extends TaskConfig {
|
||||
modelName!: string;
|
||||
}
|
||||
|
||||
export class CLIPConfig extends ModelConfig {}
|
||||
export class LoadTextualModelOnConnection extends TaskConfig {
|
||||
@IsNumber()
|
||||
@Min(0)
|
||||
@Type(() => Number)
|
||||
@ApiProperty({ type: 'number', format: 'int64' })
|
||||
ttl!: number;
|
||||
}
|
||||
|
||||
export class CLIPConfig extends ModelConfig {
|
||||
@Type(() => LoadTextualModelOnConnection)
|
||||
@ValidateNested()
|
||||
@IsObject()
|
||||
loadTextualModelOnConnection!: LoadTextualModelOnConnection;
|
||||
}
|
||||
|
||||
export class DuplicateDetectionConfig extends TaskConfig {
|
||||
@IsNumber()
|
||||
|
||||
@@ -24,13 +24,17 @@ export type ModelPayload = { imagePath: string } | { text: string };
|
||||
|
||||
type ModelOptions = { modelName: string };
|
||||
|
||||
export interface LoadModelOptions extends ModelOptions {
|
||||
ttl: number;
|
||||
}
|
||||
|
||||
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
||||
|
||||
type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||
export type ClipVisualResponse = { [ModelTask.SEARCH]: number[] } & VisualResponse;
|
||||
|
||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions } };
|
||||
export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: ModelOptions | LoadModelOptions } };
|
||||
export type ClipTextualResponse = { [ModelTask.SEARCH]: number[] };
|
||||
|
||||
export type FacialRecognitionRequest = {
|
||||
@@ -54,4 +58,5 @@ export interface IMachineLearningRepository {
|
||||
encodeImage(url: string, imagePath: string, config: ModelOptions): Promise<number[]>;
|
||||
encodeText(url: string, text: string, config: ModelOptions): Promise<number[]>;
|
||||
detectFaces(url: string, imagePath: string, config: FaceDetectionOptions): Promise<DetectedFaces>;
|
||||
loadTextModel(url: string, config: ModelOptions): Promise<void>;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
WebSocketServer,
|
||||
} from '@nestjs/websockets';
|
||||
import { Server, Socket } from 'socket.io';
|
||||
import { SystemConfigCore } from 'src/cores/system-config.core';
|
||||
import {
|
||||
ArgsOf,
|
||||
ClientEventMap,
|
||||
@@ -19,6 +20,8 @@ import {
|
||||
ServerEventMap,
|
||||
} from 'src/interfaces/event.interface';
|
||||
import { ILoggerRepository } from 'src/interfaces/logger.interface';
|
||||
import { IMachineLearningRepository } from 'src/interfaces/machine-learning.interface';
|
||||
import { ISystemMetadataRepository } from 'src/interfaces/system-metadata.interface';
|
||||
import { AuthService } from 'src/services/auth.service';
|
||||
import { Instrumentation } from 'src/utils/instrumentation';
|
||||
|
||||
@@ -33,6 +36,7 @@ type EmitHandlers = Partial<{ [T in EmitEvent]: EmitHandler<T>[] }>;
|
||||
@Injectable()
|
||||
export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect, OnGatewayInit, IEventRepository {
|
||||
private emitHandlers: EmitHandlers = {};
|
||||
private configCore: SystemConfigCore;
|
||||
|
||||
@WebSocketServer()
|
||||
private server?: Server;
|
||||
@@ -41,8 +45,11 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
||||
private moduleRef: ModuleRef,
|
||||
private eventEmitter: EventEmitter2,
|
||||
@Inject(ILoggerRepository) private logger: ILoggerRepository,
|
||||
@Inject(IMachineLearningRepository) private machineLearningRepository: IMachineLearningRepository,
|
||||
@Inject(ISystemMetadataRepository) systemMetadataRepository: ISystemMetadataRepository,
|
||||
) {
|
||||
this.logger.setContext(EventRepository.name);
|
||||
this.configCore = SystemConfigCore.create(systemMetadataRepository, this.logger);
|
||||
}
|
||||
|
||||
afterInit(server: Server) {
|
||||
@@ -68,6 +75,16 @@ export class EventRepository implements OnGatewayConnection, OnGatewayDisconnect
|
||||
queryParams: {},
|
||||
metadata: { adminRoute: false, sharedLinkRoute: false, uri: '/api/socket.io' },
|
||||
});
|
||||
if ('background' in client.handshake.query && client.handshake.query.background === 'false') {
|
||||
const { machineLearning } = await this.configCore.getConfig({ withCache: true });
|
||||
if (machineLearning.clip.loadTextualModelOnConnection.enabled) {
|
||||
try {
|
||||
this.machineLearningRepository.loadTextModel(machineLearning.url, machineLearning.clip);
|
||||
} catch (error) {
|
||||
this.logger.warn(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
await client.join(auth.user.id);
|
||||
if (auth.session) {
|
||||
await client.join(auth.session.id);
|
||||
|
||||
@@ -20,13 +20,9 @@ const errorPrefix = 'Machine learning request';
|
||||
@Injectable()
|
||||
export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
private async predict<T>(url: string, payload: ModelPayload, config: MachineLearningRequest): Promise<T> {
|
||||
const formData = await this.getFormData(payload, config);
|
||||
const formData = await this.getFormData(config, payload);
|
||||
|
||||
const res = await fetch(new URL('/predict', url), { method: 'POST', body: formData }).catch(
|
||||
(error: Error | any) => {
|
||||
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
|
||||
},
|
||||
);
|
||||
const res = await this.fetchData(url, '/predict', formData);
|
||||
|
||||
if (res.status >= 400) {
|
||||
throw new Error(`${errorPrefix} '${JSON.stringify(config)}' failed with status ${res.status}: ${res.statusText}`);
|
||||
@@ -34,6 +30,25 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
return res.json();
|
||||
}
|
||||
|
||||
private async fetchData(url: string, path: string, formData?: FormData): Promise<Response> {
|
||||
const res = await fetch(new URL(path, url), { method: 'POST', body: formData }).catch((error: Error | any) => {
|
||||
throw new Error(`${errorPrefix} to "${url}" failed with ${error?.cause || error}`);
|
||||
});
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
async loadTextModel(url: string, { modelName, loadTextualModelOnConnection: { ttl } }: CLIPConfig) {
|
||||
try {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, ttl } } };
|
||||
const formData = await this.getFormData(request);
|
||||
const res = await this.fetchData(url, '/load', formData);
|
||||
if (res.status >= 400) {
|
||||
throw new Error(`${errorPrefix} Loadings textual model failed with status ${res.status}: ${res.statusText}`);
|
||||
}
|
||||
} catch (error) {}
|
||||
}
|
||||
|
||||
async detectFaces(url: string, imagePath: string, { modelName, minScore }: FaceDetectionOptions) {
|
||||
const request = {
|
||||
[ModelTask.FACIAL_RECOGNITION]: {
|
||||
@@ -61,16 +76,17 @@ export class MachineLearningRepository implements IMachineLearningRepository {
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
|
||||
private async getFormData(config: MachineLearningRequest, payload?: ModelPayload): Promise<FormData> {
|
||||
const formData = new FormData();
|
||||
formData.append('entries', JSON.stringify(config));
|
||||
|
||||
if ('imagePath' in payload) {
|
||||
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
||||
} else if ('text' in payload) {
|
||||
formData.append('text', payload.text);
|
||||
} else {
|
||||
throw new Error('Invalid input');
|
||||
if (payload) {
|
||||
if ('imagePath' in payload) {
|
||||
formData.append('image', new Blob([await readFile(payload.imagePath)]));
|
||||
} else if ('text' in payload) {
|
||||
formData.append('text', payload.text);
|
||||
} else {
|
||||
throw new Error('Invalid input');
|
||||
}
|
||||
}
|
||||
|
||||
return formData;
|
||||
|
||||
Reference in New Issue
Block a user