refactor(server): add config events for clip (#11575)

use config events for clip, add tests

formatting
This commit is contained in:
Mert
2024-08-04 17:00:36 -04:00
committed by GitHub
parent 3f4b783889
commit 4ed75f2ac9
5 changed files with 288 additions and 53 deletions

View File

@@ -21,7 +21,6 @@ import {
} from 'src/interfaces/search.interface';
import { asVector, searchAssetBuilder } from 'src/utils/database';
import { Instrumentation } from 'src/utils/instrumentation';
import { getCLIPModelInfo } from 'src/utils/misc';
import { Paginated, PaginationMode, PaginationResult, paginatedBuilder } from 'src/utils/pagination';
import { isValidInteger } from 'src/validation';
import { Repository, SelectQueryBuilder } from 'typeorm';
@@ -55,17 +54,6 @@ export class SearchRepository implements ISearchRepository {
' INNER JOIN cte ON asset.id = cte."assetId" ORDER BY exif.city';
}
async init(modelName: string): Promise<void> {
const { dimSize } = getCLIPModelInfo(modelName);
const curDimSize = await this.getDimSize();
this.logger.verbose(`Current database CLIP dimension size is ${curDimSize}`);
if (dimSize != curDimSize) {
this.logger.log(`Dimension size of model ${modelName} is ${dimSize}, but database expects ${curDimSize}.`);
await this.updateDimSize(dimSize);
}
}
@GenerateSql({
params: [
{ page: 1, size: 100 },
@@ -300,32 +288,7 @@ export class SearchRepository implements ISearchRepository {
);
}
private async updateDimSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
const curDimSize = await this.getDimSize();
if (curDimSize === dimSize) {
return;
}
this.logger.log(`Updating database CLIP dimension size to ${dimSize}.`);
await this.smartSearchRepository.manager.transaction(async (manager) => {
await manager.clear(SmartSearchEntity);
await manager.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`);
await manager.query(`REINDEX INDEX clip_index`);
});
this.logger.log(`Successfully updated database CLIP dimension size from ${curDimSize} to ${dimSize}.`);
}
deleteAllSearchEmbeddings(): Promise<void> {
return this.smartSearchRepository.clear();
}
private async getDimSize(): Promise<number> {
async getDimensionSize(): Promise<number> {
const res = await this.smartSearchRepository.manager.query(`
SELECT atttypmod as dimsize
FROM pg_attribute f
@@ -342,6 +305,22 @@ export class SearchRepository implements ISearchRepository {
return dimSize;
}
setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
return this.smartSearchRepository.manager.transaction(async (manager) => {
await manager.clear(SmartSearchEntity);
await manager.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`);
await manager.query(`REINDEX INDEX clip_index`);
});
}
async deleteAllSearchEmbeddings(): Promise<void> {
return this.smartSearchRepository.clear();
}
private getRuntimeConfig(numResults?: number): string {
if (getVectorExtension() === DatabaseExtension.VECTOR) {
return 'SET LOCAL hnsw.ef_search = 1000;'; // mitigate post-filter recall