feat: vectorchord (#18042)

* wip

auto-detect available extensions

auto-recovery, fix reindexing check

use original image for ml

* set probes

* update image for sql checker

update images for gha

* cascade

* fix new instance

* accurate dummy vector

* simplify dummy

* preexisiting pg docs

* handle different db name

* maybe fix sql generation

* revert refreshfaces sql change

* redundant switch

* outdated message

* update docker compose files

* Update docs/docs/administration/postgres-standalone.md

Co-authored-by: Daniel Dietzler <36593685+danieldietzler@users.noreply.github.com>

* tighten range

* avoid always printing "vector reindexing complete"

* remove nesting

* use new images

* add vchord to unit tests

* debug e2e image

* mention 1.107.2 in startup error

* support new vchord versions

---------

Co-authored-by: Daniel Dietzler <36593685+danieldietzler@users.noreply.github.com>
This commit is contained in:
Mert
2025-05-20 09:36:43 -04:00
committed by GitHub
parent fe71894308
commit 0d773af6c3
35 changed files with 572 additions and 444 deletions

View File

@@ -89,7 +89,7 @@ describe('getEnv', () => {
password: 'postgres',
},
skipMigrations: false,
vectorExtension: 'vectors',
vectorExtension: undefined,
});
});

View File

@@ -58,7 +58,7 @@ export interface EnvData {
database: {
config: DatabaseConnectionParams;
skipMigrations: boolean;
vectorExtension: VectorExtension;
vectorExtension?: VectorExtension;
};
licensePublicKey: {
@@ -196,6 +196,22 @@ const getEnv = (): EnvData => {
ssl: dto.DB_SSL_MODE || undefined,
};
let vectorExtension: VectorExtension | undefined;
switch (dto.DB_VECTOR_EXTENSION) {
case 'pgvector': {
vectorExtension = DatabaseExtension.VECTOR;
break;
}
case 'pgvecto.rs': {
vectorExtension = DatabaseExtension.VECTORS;
break;
}
case 'vectorchord': {
vectorExtension = DatabaseExtension.VECTORCHORD;
break;
}
}
return {
host: dto.IMMICH_HOST,
port: dto.IMMICH_PORT || 2283,
@@ -251,7 +267,7 @@ const getEnv = (): EnvData => {
database: {
config: databaseConnection,
skipMigrations: dto.DB_SKIP_MIGRATIONS ?? false,
vectorExtension: dto.DB_VECTOR_EXTENSION === 'pgvector' ? DatabaseExtension.VECTOR : DatabaseExtension.VECTORS,
vectorExtension,
},
licensePublicKey: isProd ? productionKeys : stagingKeys,

View File

@@ -5,7 +5,16 @@ import { InjectKysely } from 'nestjs-kysely';
import { readdir } from 'node:fs/promises';
import { join, resolve } from 'node:path';
import semver from 'semver';
import { EXTENSION_NAMES, POSTGRES_VERSION_RANGE, VECTOR_VERSION_RANGE, VECTORS_VERSION_RANGE } from 'src/constants';
import {
EXTENSION_NAMES,
POSTGRES_VERSION_RANGE,
VECTOR_EXTENSIONS,
VECTOR_INDEX_TABLES,
VECTOR_VERSION_RANGE,
VECTORCHORD_LIST_SLACK_FACTOR,
VECTORCHORD_VERSION_RANGE,
VECTORS_VERSION_RANGE,
} from 'src/constants';
import { DB } from 'src/db';
import { GenerateSql } from 'src/decorators';
import { DatabaseExtension, DatabaseLock, VectorIndex } from 'src/enum';
@@ -14,11 +23,42 @@ import { LoggingRepository } from 'src/repositories/logging.repository';
import { ExtensionVersion, VectorExtension, VectorUpdateResult } from 'src/types';
import { vectorIndexQuery } from 'src/utils/database';
import { isValidInteger } from 'src/validation';
import { DataSource } from 'typeorm';
import { DataSource, QueryRunner } from 'typeorm';
export let cachedVectorExtension: VectorExtension | undefined;
export async function getVectorExtension(runner: Kysely<DB> | QueryRunner): Promise<VectorExtension> {
if (cachedVectorExtension) {
return cachedVectorExtension;
}
cachedVectorExtension = new ConfigRepository().getEnv().database.vectorExtension;
if (cachedVectorExtension) {
return cachedVectorExtension;
}
let availableExtensions: { name: VectorExtension }[];
const query = `SELECT name FROM pg_available_extensions WHERE name IN (${VECTOR_EXTENSIONS.map((ext) => `'${ext}'`).join(', ')})`;
if (runner instanceof Kysely) {
const { rows } = await sql.raw<{ name: VectorExtension }>(query).execute(runner);
availableExtensions = rows;
} else {
availableExtensions = (await runner.query(query)) as { name: VectorExtension }[];
}
const extensionNames = new Set(availableExtensions.map((row) => row.name));
cachedVectorExtension = VECTOR_EXTENSIONS.find((ext) => extensionNames.has(ext));
if (!cachedVectorExtension) {
throw new Error(`No vector extension found. Available extensions: ${VECTOR_EXTENSIONS.join(', ')}`);
}
return cachedVectorExtension;
}
export const probes: Record<VectorIndex, number> = {
[VectorIndex.CLIP]: 1,
[VectorIndex.FACE]: 1,
};
@Injectable()
export class DatabaseRepository {
private vectorExtension: VectorExtension;
private readonly asyncLock = new AsyncLock();
constructor(
@@ -26,7 +66,6 @@ export class DatabaseRepository {
private logger: LoggingRepository,
private configRepository: ConfigRepository,
) {
this.vectorExtension = configRepository.getEnv().database.vectorExtension;
this.logger.setContext(DatabaseRepository.name);
}
@@ -34,6 +73,10 @@ export class DatabaseRepository {
await this.db.destroy();
}
getVectorExtension(): Promise<VectorExtension> {
return getVectorExtension(this.db);
}
@GenerateSql({ params: [DatabaseExtension.VECTORS] })
async getExtensionVersion(extension: DatabaseExtension): Promise<ExtensionVersion> {
const { rows } = await sql<ExtensionVersion>`
@@ -45,7 +88,20 @@ export class DatabaseRepository {
}
getExtensionVersionRange(extension: VectorExtension): string {
return extension === DatabaseExtension.VECTORS ? VECTORS_VERSION_RANGE : VECTOR_VERSION_RANGE;
switch (extension) {
case DatabaseExtension.VECTORCHORD: {
return VECTORCHORD_VERSION_RANGE;
}
case DatabaseExtension.VECTORS: {
return VECTORS_VERSION_RANGE;
}
case DatabaseExtension.VECTOR: {
return VECTOR_VERSION_RANGE;
}
default: {
throw new Error(`Unsupported vector extension: '${extension}'`);
}
}
}
@GenerateSql()
@@ -59,7 +115,14 @@ export class DatabaseRepository {
}
async createExtension(extension: DatabaseExtension): Promise<void> {
await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(extension)}`.execute(this.db);
await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(extension)} CASCADE`.execute(this.db);
if (extension === DatabaseExtension.VECTORCHORD) {
const dbName = sql.table(await this.getDatabaseName());
await sql`ALTER DATABASE ${dbName} SET vchordrq.prewarm_dim = '512,640,768,1024,1152,1536'`.execute(this.db);
await sql`SET vchordrq.prewarm_dim = '512,640,768,1024,1152,1536'`.execute(this.db);
await sql`ALTER DATABASE ${dbName} SET vchordrq.probes = 1`.execute(this.db);
await sql`SET vchordrq.probes = 1`.execute(this.db);
}
}
async updateVectorExtension(extension: VectorExtension, targetVersion?: string): Promise<VectorUpdateResult> {
@@ -78,120 +141,201 @@ export class DatabaseRepository {
await this.db.transaction().execute(async (tx) => {
await this.setSearchPath(tx);
if (isVectors && installedVersion === '0.1.1') {
await this.setExtVersion(tx, DatabaseExtension.VECTORS, '0.1.11');
}
const isSchemaUpgrade = semver.satisfies(installedVersion, '0.1.1 || 0.1.11');
if (isSchemaUpgrade && isVectors) {
await this.updateVectorsSchema(tx);
}
await sql`ALTER EXTENSION ${sql.raw(extension)} UPDATE TO ${sql.lit(targetVersion)}`.execute(tx);
const diff = semver.diff(installedVersion, targetVersion);
if (isVectors && diff && ['minor', 'major'].includes(diff)) {
if (isVectors && (diff === 'major' || diff === 'minor')) {
await sql`SELECT pgvectors_upgrade()`.execute(tx);
restartRequired = true;
} else {
await this.reindex(VectorIndex.CLIP);
await this.reindex(VectorIndex.FACE);
} else if (diff) {
await Promise.all([this.reindexVectors(VectorIndex.CLIP), this.reindexVectors(VectorIndex.FACE)]);
}
});
return { restartRequired };
}
async reindex(index: VectorIndex): Promise<void> {
try {
await sql`REINDEX INDEX ${sql.raw(index)}`.execute(this.db);
} catch (error) {
if (this.vectorExtension !== DatabaseExtension.VECTORS) {
throw error;
}
this.logger.warn(`Could not reindex index ${index}. Attempting to auto-fix.`);
async prewarm(index: VectorIndex): Promise<void> {
const vectorExtension = await getVectorExtension(this.db);
if (vectorExtension !== DatabaseExtension.VECTORCHORD) {
return;
}
this.logger.debug(`Prewarming ${index}`);
await sql`SELECT vchordrq_prewarm(${index})`.execute(this.db);
}
const table = await this.getIndexTable(index);
const dimSize = await this.getDimSize(table);
await this.db.transaction().execute(async (tx) => {
await this.setSearchPath(tx);
await sql`DROP INDEX IF EXISTS ${sql.raw(index)}`.execute(tx);
await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE real[]`.execute(tx);
await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE vector(${sql.raw(String(dimSize))})`.execute(
tx,
);
await sql.raw(vectorIndexQuery({ vectorExtension: this.vectorExtension, table, indexName: index })).execute(tx);
});
async reindexVectorsIfNeeded(names: VectorIndex[]): Promise<void> {
const { rows } = await sql<{
indexdef: string;
indexname: string;
}>`SELECT indexdef, indexname FROM pg_indexes WHERE indexname = ANY(ARRAY[${sql.join(names)}])`.execute(this.db);
const vectorExtension = await getVectorExtension(this.db);
const promises = [];
for (const indexName of names) {
const row = rows.find((index) => index.indexname === indexName);
const table = VECTOR_INDEX_TABLES[indexName];
if (!row) {
promises.push(this.reindexVectors(indexName));
continue;
}
switch (vectorExtension) {
case DatabaseExtension.VECTOR: {
if (!row.indexdef.toLowerCase().includes('using hnsw')) {
promises.push(this.reindexVectors(indexName));
}
break;
}
case DatabaseExtension.VECTORS: {
if (!row.indexdef.toLowerCase().includes('using vectors')) {
promises.push(this.reindexVectors(indexName));
}
break;
}
case DatabaseExtension.VECTORCHORD: {
const matches = row.indexdef.match(/(?<=lists = \[)\d+/g);
const lists = matches && matches.length > 0 ? Number(matches[0]) : 1;
promises.push(
this.db
.selectFrom(this.db.dynamic.table(table).as('t'))
.select((eb) => eb.fn.countAll<number>().as('count'))
.executeTakeFirstOrThrow()
.then(({ count }) => {
const targetLists = this.targetListCount(count);
this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`);
if (
!row.indexdef.toLowerCase().includes('using vchordrq') ||
// slack factor is to avoid frequent reindexing if the count is borderline
(lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR))
) {
probes[indexName] = this.targetProbeCount(targetLists);
return this.reindexVectors(indexName, { lists: targetLists });
} else {
probes[indexName] = this.targetProbeCount(lists);
}
}),
);
break;
}
}
}
if (promises.length > 0) {
await Promise.all(promises);
}
}
@GenerateSql({ params: [VectorIndex.CLIP] })
async shouldReindex(name: VectorIndex): Promise<boolean> {
if (this.vectorExtension !== DatabaseExtension.VECTORS) {
return false;
private async reindexVectors(indexName: VectorIndex, { lists }: { lists?: number } = {}): Promise<void> {
this.logger.log(`Reindexing ${indexName}`);
const table = VECTOR_INDEX_TABLES[indexName];
const vectorExtension = await getVectorExtension(this.db);
const { rows } = await sql<{
columnName: string;
}>`SELECT column_name as "columnName" FROM information_schema.columns WHERE table_name = ${table}`.execute(this.db);
if (rows.length === 0) {
this.logger.warn(
`Table ${table} does not exist, skipping reindexing. This is only normal if this is a new Immich instance.`,
);
return;
}
try {
const { rows } = await sql<{
idx_status: string;
}>`SELECT idx_status FROM pg_vector_index_stat WHERE indexname = ${name}`.execute(this.db);
return rows[0]?.idx_status === 'UPGRADE';
} catch (error) {
const message: string = (error as any).message;
if (message.includes('index is not existing')) {
return true;
} else if (message.includes('relation "pg_vector_index_stat" does not exist')) {
return false;
const dimSize = await this.getDimensionSize(table);
await this.db.transaction().execute(async (tx) => {
await sql`DROP INDEX IF EXISTS ${sql.raw(indexName)}`.execute(tx);
if (!rows.some((row) => row.columnName === 'embedding')) {
this.logger.warn(`Column 'embedding' does not exist in table '${table}', truncating and adding column.`);
await sql`TRUNCATE TABLE ${sql.raw(table)}`.execute(tx);
await sql`ALTER TABLE ${sql.raw(table)} ADD COLUMN embedding real[] NOT NULL`.execute(tx);
}
throw error;
}
await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE real[]`.execute(tx);
const schema = vectorExtension === DatabaseExtension.VECTORS ? 'vectors.' : '';
await sql`
ALTER TABLE ${sql.raw(table)}
ALTER COLUMN embedding
SET DATA TYPE ${sql.raw(schema)}vector(${sql.raw(String(dimSize))})`.execute(tx);
await sql.raw(vectorIndexQuery({ vectorExtension, table, indexName, lists })).execute(tx);
});
this.logger.log(`Reindexed ${indexName}`);
}
private async setSearchPath(tx: Transaction<DB>): Promise<void> {
await sql`SET search_path TO "$user", public, vectors`.execute(tx);
}
private async setExtVersion(tx: Transaction<DB>, extName: DatabaseExtension, version: string): Promise<void> {
await sql`UPDATE pg_catalog.pg_extension SET extversion = ${version} WHERE extname = ${extName}`.execute(tx);
private async getDatabaseName(): Promise<string> {
const { rows } = await sql<{ db: string }>`SELECT current_database() as db`.execute(this.db);
return rows[0].db;
}
private async getIndexTable(index: VectorIndex): Promise<string> {
const { rows } = await sql<{
relname: string | null;
}>`SELECT relname FROM pg_stat_all_indexes WHERE indexrelname = ${index}`.execute(this.db);
const table = rows[0]?.relname;
if (!table) {
throw new Error(`Could not find table for index ${index}`);
}
return table;
}
private async updateVectorsSchema(tx: Transaction<DB>): Promise<void> {
const extension = DatabaseExtension.VECTORS;
await sql`CREATE SCHEMA IF NOT EXISTS ${extension}`.execute(tx);
await sql`UPDATE pg_catalog.pg_extension SET extrelocatable = true WHERE extname = ${extension}`.execute(tx);
await sql`ALTER EXTENSION vectors SET SCHEMA vectors`.execute(tx);
await sql`UPDATE pg_catalog.pg_extension SET extrelocatable = false WHERE extname = ${extension}`.execute(tx);
}
private async getDimSize(table: string, column = 'embedding'): Promise<number> {
async getDimensionSize(table: string, column = 'embedding'): Promise<number> {
const { rows } = await sql<{ dimsize: number }>`
SELECT atttypmod as dimsize
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid
WHERE c.relkind = 'r'::char
AND f.attnum > 0
AND c.relname = ${table}
AND f.attname = '${column}'
AND c.relname = ${table}::text
AND f.attname = ${column}::text
`.execute(this.db);
const dimSize = rows[0]?.dimsize;
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve dimension size`);
this.logger.warn(`Could not retrieve dimension size of column '${column}' in table '${table}', assuming 512`);
return 512;
}
return dimSize;
}
async setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
// this is done in two transactions to handle concurrent writes
await this.db.transaction().execute(async (trx) => {
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
trx,
);
});
const vectorExtension = await this.getVectorExtension();
await this.db.transaction().execute(async (trx) => {
await sql`drop index if exists clip_index`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
.execute();
await sql
.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.CLIP }))
.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
});
probes[VectorIndex.CLIP] = 1;
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
}
async deleteAllSearchEmbeddings(): Promise<void> {
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
private targetListCount(count: number) {
if (count < 128_000) {
return 1;
} else if (count < 2_048_000) {
return 1 << (32 - Math.clz32(count / 1000));
} else {
return 1 << (33 - Math.clz32(Math.sqrt(count)));
}
}
private targetProbeCount(lists: number) {
return Math.ceil(lists / 8);
}
async runMigrations(options?: { transaction?: 'all' | 'none' | 'each' }): Promise<void> {
const { database } = this.configRepository.getEnv();

View File

@@ -398,6 +398,7 @@ export class PersonRepository {
return results.map(({ id }) => id);
}
@GenerateSql({ params: [[], [], [{ faceId: DummyValue.UUID, embedding: DummyValue.VECTOR }]] })
async refreshFaces(
facesToAdd: (Insertable<AssetFaces> & { assetId: string })[],
faceIdsToRemove: string[],

View File

@@ -5,9 +5,9 @@ import { randomUUID } from 'node:crypto';
import { DB, Exif } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators';
import { MapAsset } from 'src/dtos/asset-response.dto';
import { AssetStatus, AssetType, AssetVisibility } from 'src/enum';
import { ConfigRepository } from 'src/repositories/config.repository';
import { anyUuid, asUuid, searchAssetBuilder, vectorIndexQuery } from 'src/utils/database';
import { AssetStatus, AssetType, AssetVisibility, VectorIndex } from 'src/enum';
import { probes } from 'src/repositories/database.repository';
import { anyUuid, asUuid, searchAssetBuilder } from 'src/utils/database';
import { paginationHelper } from 'src/utils/pagination';
import { isValidInteger } from 'src/validation';
@@ -168,10 +168,7 @@ export interface GetCameraMakesOptions {
@Injectable()
export class SearchRepository {
constructor(
@InjectKysely() private db: Kysely<DB>,
private configRepository: ConfigRepository,
) {}
constructor(@InjectKysely() private db: Kysely<DB>) {}
@GenerateSql({
params: [
@@ -236,19 +233,21 @@ export class SearchRepository {
},
],
})
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions) {
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
throw new Error(`Invalid value for 'size': ${pagination.size}`);
}
const items = await searchAssetBuilder(this.db, options)
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute();
return paginationHelper(items, pagination.size);
return this.db.transaction().execute(async (trx) => {
await sql`set local vchordrq.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
const items = await searchAssetBuilder(trx, options)
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.orderBy(sql`smart_search.embedding <=> ${options.embedding}`)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute();
return paginationHelper(items, pagination.size);
});
}
@GenerateSql({
@@ -263,29 +262,32 @@ export class SearchRepository {
],
})
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
return this.db
.with('cte', (qb) =>
qb
.selectFrom('assets')
.select([
'assets.id as assetId',
'assets.duplicateId',
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
])
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
.where('assets.type', '=', type)
.where('assets.id', '!=', asUuid(assetId))
.where('assets.stackId', 'is', null)
.orderBy(sql`smart_search.embedding <=> ${embedding}`)
.limit(64),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance as number)
.execute();
return this.db.transaction().execute(async (trx) => {
await sql`set local vchordrq.probes = ${sql.lit(probes[VectorIndex.CLIP])}`.execute(trx);
return await trx
.with('cte', (qb) =>
qb
.selectFrom('assets')
.select([
'assets.id as assetId',
'assets.duplicateId',
sql<number>`smart_search.embedding <=> ${embedding}`.as('distance'),
])
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.where('assets.visibility', '!=', AssetVisibility.HIDDEN)
.where('assets.type', '=', type)
.where('assets.id', '!=', asUuid(assetId))
.where('assets.stackId', 'is', null)
.orderBy('distance')
.limit(64),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance as number)
.execute();
});
}
@GenerateSql({
@@ -303,31 +305,36 @@ export class SearchRepository {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
return this.db
.with('cte', (qb) =>
qb
.selectFrom('asset_faces')
.select([
'asset_faces.id',
'asset_faces.personId',
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
])
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
.leftJoin('person', 'person.id', 'asset_faces.personId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
.$if(!!minBirthDate, (qb) =>
qb.where((eb) => eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)])),
)
.orderBy(sql`face_search.embedding <=> ${embedding}`)
.limit(numResults),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance)
.execute();
return this.db.transaction().execute(async (trx) => {
await sql`set local vchordrq.probes = ${sql.lit(probes[VectorIndex.FACE])}`.execute(trx);
return await trx
.with('cte', (qb) =>
qb
.selectFrom('asset_faces')
.select([
'asset_faces.id',
'asset_faces.personId',
sql<number>`face_search.embedding <=> ${embedding}`.as('distance'),
])
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
.leftJoin('person', 'person.id', 'asset_faces.personId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
.$if(!!minBirthDate, (qb) =>
qb.where((eb) =>
eb.or([eb('person.birthDate', 'is', null), eb('person.birthDate', '<=', minBirthDate!)]),
),
)
.orderBy('distance')
.limit(numResults),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance)
.execute();
});
}
@GenerateSql({ params: [DummyValue.STRING] })
@@ -416,56 +423,6 @@ export class SearchRepository {
.execute();
}
async getDimensionSize(): Promise<number> {
const { rows } = await sql<{ dimsize: number }>`
select atttypmod as dimsize
from pg_attribute f
join pg_class c ON c.oid = f.attrelid
where c.relkind = 'r'::char
and f.attnum > 0
and c.relname = 'smart_search'
and f.attname = 'embedding'
`.execute(this.db);
const dimSize = rows[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve CLIP dimension size`);
}
return dimSize;
}
async setDimensionSize(dimSize: number): Promise<void> {
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
// this is done in two transactions to handle concurrent writes
await this.db.transaction().execute(async (trx) => {
await sql`delete from ${sql.table('smart_search')}`.execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute(
trx,
);
});
const vectorExtension = this.configRepository.getEnv().database.vectorExtension;
await this.db.transaction().execute(async (trx) => {
await sql`drop index if exists clip_index`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`)))
.execute();
await sql.raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: 'clip_index' })).execute(trx);
await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute();
});
await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db);
}
async deleteAllSearchEmbeddings(): Promise<void> {
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
async getCountries(userIds: string[]): Promise<string[]> {
const res = await this.getExifField('country', userIds).execute();
return res.map((row) => row.country!);