import { Injectable } from '@nestjs/common'; import AsyncLock from 'async-lock'; import { FileMigrationProvider, Kysely, Migrator, sql, Transaction } from 'kysely'; import { InjectKysely } from 'nestjs-kysely'; import { readdir } from 'node:fs/promises'; import { join } from 'node:path'; import semver from 'semver'; import { EXTENSION_NAMES, POSTGRES_VERSION_RANGE, VECTOR_EXTENSIONS, VECTOR_INDEX_TABLES, VECTOR_VERSION_RANGE, VECTORCHORD_LIST_SLACK_FACTOR, VECTORCHORD_VERSION_RANGE, VECTORS_VERSION_RANGE, } from 'src/constants'; import { GenerateSql } from 'src/decorators'; import { DatabaseExtension, DatabaseLock, VectorIndex } from 'src/enum'; import { ConfigRepository } from 'src/repositories/config.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { DB } from 'src/schema'; import { ExtensionVersion, VectorExtension, VectorUpdateResult } from 'src/types'; import { vectorIndexQuery } from 'src/utils/database'; import { isValidInteger } from 'src/validation'; export let cachedVectorExtension: VectorExtension | undefined; export async function getVectorExtension(runner: Kysely): Promise { if (cachedVectorExtension) { return cachedVectorExtension; } cachedVectorExtension = new ConfigRepository().getEnv().database.vectorExtension; if (cachedVectorExtension) { return cachedVectorExtension; } const query = `SELECT name FROM pg_available_extensions WHERE name IN (${VECTOR_EXTENSIONS.map((ext) => `'${ext}'`).join(', ')})`; const { rows: availableExtensions } = await sql.raw<{ name: VectorExtension }>(query).execute(runner); const extensionNames = new Set(availableExtensions.map((row) => row.name)); cachedVectorExtension = VECTOR_EXTENSIONS.find((ext) => extensionNames.has(ext)); if (!cachedVectorExtension) { throw new Error(`No vector extension found. Available extensions: ${VECTOR_EXTENSIONS.join(', ')}`); } return cachedVectorExtension; } export const probes: Record = { [VectorIndex.Clip]: 1, [VectorIndex.Face]: 1, }; @Injectable() export class DatabaseRepository { private readonly asyncLock = new AsyncLock(); constructor( @InjectKysely() private db: Kysely, private logger: LoggingRepository, private configRepository: ConfigRepository, ) { this.logger.setContext(DatabaseRepository.name); } async shutdown() { await this.db.destroy(); } getVectorExtension(): Promise { return getVectorExtension(this.db); } @GenerateSql({ params: [[DatabaseExtension.Vectors]] }) async getExtensionVersions(extensions: readonly DatabaseExtension[]): Promise { const { rows } = await sql` SELECT name, default_version as "availableVersion", installed_version as "installedVersion" FROM pg_available_extensions WHERE name in (${sql.join(extensions)}) `.execute(this.db); return rows; } getExtensionVersionRange(extension: VectorExtension): string { switch (extension) { case DatabaseExtension.VectorChord: { return VECTORCHORD_VERSION_RANGE; } case DatabaseExtension.Vectors: { return VECTORS_VERSION_RANGE; } case DatabaseExtension.Vector: { return VECTOR_VERSION_RANGE; } default: { throw new Error(`Unsupported vector extension: '${extension}'`); } } } @GenerateSql() async getPostgresVersion(): Promise { const { rows } = await sql<{ server_version: string }>`SHOW server_version`.execute(this.db); return rows[0].server_version; } getPostgresVersionRange(): string { return POSTGRES_VERSION_RANGE; } async createExtension(extension: DatabaseExtension): Promise { this.logger.log(`Creating ${EXTENSION_NAMES[extension]} extension`); await sql`CREATE EXTENSION IF NOT EXISTS ${sql.raw(extension)} CASCADE`.execute(this.db); if (extension === DatabaseExtension.VectorChord) { const dbName = sql.id(await this.getDatabaseName()); await sql`ALTER DATABASE ${dbName} SET vchordrq.probes = 1`.execute(this.db); await sql`SET vchordrq.probes = 1`.execute(this.db); } } async dropExtension(extension: DatabaseExtension): Promise { this.logger.log(`Dropping ${EXTENSION_NAMES[extension]} extension`); await sql`DROP EXTENSION IF EXISTS ${sql.raw(extension)}`.execute(this.db); } async updateVectorExtension(extension: VectorExtension, targetVersion?: string): Promise { const [{ availableVersion, installedVersion }] = await this.getExtensionVersions([extension]); if (!installedVersion) { throw new Error(`${EXTENSION_NAMES[extension]} extension is not installed`); } if (!availableVersion) { throw new Error(`No available version for ${EXTENSION_NAMES[extension]} extension`); } targetVersion ??= availableVersion; let restartRequired = false; const diff = semver.diff(installedVersion, targetVersion); if (!diff) { return { restartRequired: false }; } await Promise.all([ this.db.schema.dropIndex(VectorIndex.Clip).ifExists().execute(), this.db.schema.dropIndex(VectorIndex.Face).ifExists().execute(), ]); await this.db.transaction().execute(async (tx) => { await this.setSearchPath(tx); await sql`ALTER EXTENSION ${sql.raw(extension)} UPDATE TO ${sql.lit(targetVersion)}`.execute(tx); if (extension === DatabaseExtension.Vectors && (diff === 'major' || diff === 'minor')) { await sql`SELECT pgvectors_upgrade()`.execute(tx); restartRequired = true; } }); if (!restartRequired) { await Promise.all([this.reindexVectors(VectorIndex.Clip), this.reindexVectors(VectorIndex.Face)]); } return { restartRequired }; } async prewarm(index: VectorIndex): Promise { const vectorExtension = await getVectorExtension(this.db); if (vectorExtension !== DatabaseExtension.VectorChord) { return; } this.logger.debug(`Prewarming ${index}`); await sql`SELECT vchordrq_prewarm(${index})`.execute(this.db); } async reindexVectorsIfNeeded(names: VectorIndex[]): Promise { const { rows } = await sql<{ indexdef: string; indexname: string; }>`SELECT indexdef, indexname FROM pg_indexes WHERE indexname = ANY(ARRAY[${sql.join(names)}])`.execute(this.db); const vectorExtension = await getVectorExtension(this.db); const promises = []; for (const indexName of names) { const row = rows.find((index) => index.indexname === indexName); const table = VECTOR_INDEX_TABLES[indexName]; if (!row) { promises.push(this.reindexVectors(indexName)); continue; } switch (vectorExtension) { case DatabaseExtension.Vector: { if (!row.indexdef.toLowerCase().includes('using hnsw')) { promises.push(this.reindexVectors(indexName)); } break; } case DatabaseExtension.Vectors: { if (!row.indexdef.toLowerCase().includes('using vectors')) { promises.push(this.reindexVectors(indexName)); } break; } case DatabaseExtension.VectorChord: { const matches = row.indexdef.match(/(?<=lists = \[)\d+/g); const lists = matches && matches.length > 0 ? Number(matches[0]) : 1; promises.push( this.getRowCount(table).then((count) => { const targetLists = this.targetListCount(count); this.logger.log(`targetLists=${targetLists}, current=${lists} for ${indexName} of ${count} rows`); if ( !row.indexdef.toLowerCase().includes('using vchordrq') || // slack factor is to avoid frequent reindexing if the count is borderline (lists !== targetLists && lists !== this.targetListCount(count * VECTORCHORD_LIST_SLACK_FACTOR)) ) { probes[indexName] = this.targetProbeCount(targetLists); return this.reindexVectors(indexName, { lists: targetLists }); } else { probes[indexName] = this.targetProbeCount(lists); } }), ); break; } } } if (promises.length > 0) { await Promise.all(promises); } } private async reindexVectors(indexName: VectorIndex, { lists }: { lists?: number } = {}): Promise { this.logger.log(`Reindexing ${indexName}`); const table = VECTOR_INDEX_TABLES[indexName]; const vectorExtension = await getVectorExtension(this.db); const { rows } = await sql<{ columnName: string; }>`SELECT column_name as "columnName" FROM information_schema.columns WHERE table_name = ${table}`.execute(this.db); if (rows.length === 0) { this.logger.warn( `Table ${table} does not exist, skipping reindexing. This is only normal if this is a new Immich instance.`, ); return; } const dimSize = await this.getDimensionSize(table); lists ||= this.targetListCount(await this.getRowCount(table)); await this.db.schema.dropIndex(indexName).ifExists().execute(); if (table === 'smart_search') { await this.db.schema.alterTable(table).dropConstraint('dim_size_constraint').ifExists().execute(); } await this.db.transaction().execute(async (tx) => { if (!rows.some((row) => row.columnName === 'embedding')) { this.logger.warn(`Column 'embedding' does not exist in table '${table}', truncating and adding column.`); await sql`TRUNCATE TABLE ${sql.raw(table)}`.execute(tx); await sql`ALTER TABLE ${sql.raw(table)} ADD COLUMN embedding real[] NOT NULL`.execute(tx); } await sql`ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE real[]`.execute(tx); const schema = vectorExtension === DatabaseExtension.Vectors ? 'vectors.' : ''; await sql` ALTER TABLE ${sql.raw(table)} ALTER COLUMN embedding SET DATA TYPE ${sql.raw(schema)}vector(${sql.raw(String(dimSize))})`.execute(tx); await sql.raw(vectorIndexQuery({ vectorExtension, table, indexName, lists })).execute(tx); }); try { await sql`VACUUM ANALYZE ${sql.raw(table)}`.execute(this.db); } catch (error: any) { this.logger.warn(`Failed to vacuum table '${table}'. The DB will temporarily use more disk space: ${error}`); } this.logger.log(`Reindexed ${indexName}`); } private async setSearchPath(tx: Transaction): Promise { await sql`SET search_path TO "$user", public, vectors`.execute(tx); } private async getDatabaseName(): Promise { const { rows } = await sql<{ db: string }>`SELECT current_database() as db`.execute(this.db); return rows[0].db; } async getDimensionSize(table: string, column = 'embedding'): Promise { const { rows } = await sql<{ dimsize: number }>` SELECT atttypmod as dimsize FROM pg_attribute f JOIN pg_class c ON c.oid = f.attrelid WHERE c.relkind = 'r'::char AND f.attnum > 0 AND c.relname = ${table}::text AND f.attname = ${column}::text `.execute(this.db); const dimSize = rows[0]?.dimsize; if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { this.logger.warn(`Could not retrieve dimension size of column '${column}' in table '${table}', assuming 512`); return 512; } return dimSize; } async setDimensionSize(dimSize: number): Promise { if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) { throw new Error(`Invalid CLIP dimension size: ${dimSize}`); } // this is done in two transactions to handle concurrent writes await this.db.transaction().execute(async (trx) => { await sql`delete from ${sql.table('smart_search')}`.execute(trx); await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute(); await sql`alter table ${sql.table('smart_search')} add constraint dim_size_constraint check (array_length(embedding::real[], 1) = ${sql.lit(dimSize)})`.execute( trx, ); }); const vectorExtension = await this.getVectorExtension(); await this.db.transaction().execute(async (trx) => { await sql`drop index if exists clip_index`.execute(trx); await trx.schema .alterTable('smart_search') .alterColumn('embedding', (col) => col.setDataType(sql.raw(`vector(${dimSize})`))) .execute(); await sql .raw(vectorIndexQuery({ vectorExtension, table: 'smart_search', indexName: VectorIndex.Clip })) .execute(trx); await trx.schema.alterTable('smart_search').dropConstraint('dim_size_constraint').ifExists().execute(); }); probes[VectorIndex.Clip] = 1; await sql`vacuum analyze ${sql.table('smart_search')}`.execute(this.db); } async deleteAllSearchEmbeddings(): Promise { await sql`truncate ${sql.table('smart_search')}`.execute(this.db); } private targetListCount(count: number) { if (count < 128_000) { return 1; } else if (count < 2_048_000) { return 1 << (32 - Math.clz32(count / 1000)); } else { return 1 << (33 - Math.clz32(Math.sqrt(count))); } } private targetProbeCount(lists: number) { return Math.ceil(lists / 8); } private async getRowCount(table: keyof DB): Promise { const { count } = await this.db .selectFrom(this.db.dynamic.table(table).as('t')) .select((eb) => eb.fn.countAll().as('count')) .executeTakeFirstOrThrow(); return count; } async runMigrations(): Promise { this.logger.debug('Running migrations'); const migrator = new Migrator({ db: this.db, migrationLockTableName: 'kysely_migrations_lock', allowUnorderedMigrations: this.configRepository.isDev(), migrationTableName: 'kysely_migrations', provider: new FileMigrationProvider({ fs: { readdir }, path: { join }, // eslint-disable-next-line unicorn/prefer-module migrationFolder: join(__dirname, '..', 'schema/migrations'), }), }); const { error, results } = await migrator.migrateToLatest(); for (const result of results ?? []) { if (result.status === 'Success') { this.logger.log(`Migration "${result.migrationName}" succeeded`); } if (result.status === 'Error') { this.logger.warn(`Migration "${result.migrationName}" failed`); } } if (error) { this.logger.error(`Migrations failed: ${error}`); throw error; } this.logger.debug('Finished running migrations'); } async migrateFilePaths(sourceFolder: string, targetFolder: string): Promise { // remove trailing slashes if (sourceFolder.endsWith('/')) { sourceFolder = sourceFolder.slice(0, -1); } if (targetFolder.endsWith('/')) { targetFolder = targetFolder.slice(0, -1); } // escaping regex special characters with a backslash const sourceRegex = '^' + sourceFolder.replaceAll(/[-[\]{}()*+?.,\\^$|#\s]/g, String.raw`\$&`); const source = sql.raw(`'${sourceRegex}'`); const target = sql.lit(targetFolder); await this.db.transaction().execute(async (tx) => { await tx .updateTable('asset') .set((eb) => ({ originalPath: eb.fn('REGEXP_REPLACE', ['originalPath', source, target]), encodedVideoPath: eb.fn('REGEXP_REPLACE', ['encodedVideoPath', source, target]), sidecarPath: eb.fn('REGEXP_REPLACE', ['sidecarPath', source, target]), })) .execute(); await tx .updateTable('asset_file') .set((eb) => ({ path: eb.fn('REGEXP_REPLACE', ['path', source, target]) })) .execute(); await tx .updateTable('person') .set((eb) => ({ thumbnailPath: eb.fn('REGEXP_REPLACE', ['thumbnailPath', source, target]) })) .execute(); await tx .updateTable('user') .set((eb) => ({ profileImagePath: eb.fn('REGEXP_REPLACE', ['profileImagePath', source, target]) })) .execute(); }); } async withLock(lock: DatabaseLock, callback: () => Promise): Promise { let res; await this.asyncLock.acquire(DatabaseLock[lock], async () => { await this.db.connection().execute(async (connection) => { try { await this.acquireLock(lock, connection); res = await callback(); } finally { await this.releaseLock(lock, connection); } }); }); return res as R; } tryLock(lock: DatabaseLock): Promise { return this.db.connection().execute(async (connection) => this.acquireTryLock(lock, connection)); } isBusy(lock: DatabaseLock): boolean { return this.asyncLock.isBusy(DatabaseLock[lock]); } async wait(lock: DatabaseLock): Promise { await this.asyncLock.acquire(DatabaseLock[lock], () => {}); } private async acquireLock(lock: DatabaseLock, connection: Kysely): Promise { await sql`SELECT pg_advisory_lock(${lock})`.execute(connection); } private async acquireTryLock(lock: DatabaseLock, connection: Kysely): Promise { const { rows } = await sql<{ pg_try_advisory_lock: boolean; }>`SELECT pg_try_advisory_lock(${lock})`.execute(connection); return rows[0].pg_try_advisory_lock; } private async releaseLock(lock: DatabaseLock, connection: Kysely): Promise { await sql`SELECT pg_advisory_unlock(${lock})`.execute(connection); } }