refactor(server): use kysely (#12857)

This commit is contained in:
Mert
2025-01-09 11:15:41 -05:00
committed by GitHub
parent 1489d69f81
commit 2e12c46980
59 changed files with 2891 additions and 3289 deletions

View File

@@ -153,7 +153,6 @@ export class AlbumRepository implements IAlbumRepository {
await this.repository.delete({ ownerId: userId });
}
@GenerateSql({ params: [DummyValue.UUID] })
async removeAsset(assetId: string): Promise<void> {
// Using dataSource, because there is no direct access to albums_assets_assets.
await this.dataSource
@@ -164,7 +163,6 @@ export class AlbumRepository implements IAlbumRepository {
.execute();
}
@GenerateSql({ params: [DummyValue.UUID, [DummyValue.UUID]] })
@Chunked({ paramIndex: 1 })
async removeAssetIds(albumId: string, assetIds: string[]): Promise<void> {
if (assetIds.length === 0) {
@@ -207,7 +205,6 @@ export class AlbumRepository implements IAlbumRepository {
return new Set(results.map(({ assetId }) => assetId));
}
@GenerateSql({ params: [DummyValue.UUID, [DummyValue.UUID]] })
async addAssetIds(albumId: string, assetIds: string[]): Promise<void> {
await this.addAssets(this.dataSource.manager, albumId, assetIds);
}
@@ -272,7 +269,6 @@ export class AlbumRepository implements IAlbumRepository {
*
* @returns Amount of updated album thumbnails or undefined when unknown
*/
@GenerateSql()
async updateThumbnails(): Promise<number | undefined> {
// Subquery for getting a new thumbnail.

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
import { PostgresJSDialect } from 'kysely-postgres-js';
import { ImmichTelemetry } from 'src/enum';
import { clearEnvCache, ConfigRepository } from 'src/repositories/config.repository';
@@ -79,14 +80,20 @@ describe('getEnv', () => {
it('should use defaults', () => {
const { database } = getEnv();
expect(database).toEqual({
config: expect.objectContaining({
type: 'postgres',
host: 'database',
port: 5432,
database: 'immich',
username: 'postgres',
password: 'postgres',
}),
config: {
kysely: {
dialect: expect.any(PostgresJSDialect),
log: ['error'],
},
typeorm: expect.objectContaining({
type: 'postgres',
host: 'database',
port: 5432,
database: 'immich',
username: 'postgres',
password: 'postgres',
}),
},
skipMigrations: false,
vectorExtension: 'vectors',
});

View File

@@ -2,8 +2,10 @@ import { Inject, Injectable, Optional } from '@nestjs/common';
import { plainToInstance } from 'class-transformer';
import { validateSync } from 'class-validator';
import { Request, Response } from 'express';
import { PostgresJSDialect } from 'kysely-postgres-js';
import { CLS_ID } from 'nestjs-cls';
import { join, resolve } from 'node:path';
import postgres from 'postgres';
import { citiesFile, excludePaths, IWorker } from 'src/constants';
import { Telemetry } from 'src/decorators';
import { EnvDto } from 'src/dtos/env.dto';
@@ -96,6 +98,33 @@ const getEnv = (): EnvData => {
}
}
const driverOptions = {
max: 10,
types: {
date: {
to: 1184,
from: [1082, 1114, 1184],
serialize: (x: Date | string) => (x instanceof Date ? x.toISOString() : x),
parse: (x: string) => new Date(x),
},
bigint: {
to: 20,
from: [20],
parse: (value: string) => Number.parseInt(value),
serialize: (value: number) => value.toString(),
},
},
};
const parts = {
connectionType: 'parts',
host: dto.DB_HOSTNAME || 'database',
port: dto.DB_PORT || 5432,
username: dto.DB_USERNAME || 'postgres',
password: dto.DB_PASSWORD || 'postgres',
database: dto.DB_DATABASE_NAME || 'immich',
} as const;
return {
host: dto.IMMICH_HOST,
port: dto.IMMICH_PORT || 2283,
@@ -150,24 +179,23 @@ const getEnv = (): EnvData => {
database: {
config: {
type: 'postgres',
entities: [`${folders.dist}/entities` + '/*.entity.{js,ts}'],
migrations: [`${folders.dist}/migrations` + '/*.{js,ts}'],
subscribers: [`${folders.dist}/subscribers` + '/*.{js,ts}'],
migrationsRun: false,
synchronize: false,
connectTimeoutMS: 10_000, // 10 seconds
parseInt8: true,
...(databaseUrl
? { connectionType: 'url', url: databaseUrl }
: {
connectionType: 'parts',
host: dto.DB_HOSTNAME || 'database',
port: dto.DB_PORT || 5432,
username: dto.DB_USERNAME || 'postgres',
password: dto.DB_PASSWORD || 'postgres',
database: dto.DB_DATABASE_NAME || 'immich',
}),
typeorm: {
type: 'postgres',
entities: [`${folders.dist}/entities` + '/*.entity.{js,ts}'],
migrations: [`${folders.dist}/migrations` + '/*.{js,ts}'],
subscribers: [`${folders.dist}/subscribers` + '/*.{js,ts}'],
migrationsRun: false,
synchronize: false,
connectTimeoutMS: 10_000, // 10 seconds
parseInt8: true,
...(databaseUrl ? { connectionType: 'url', url: databaseUrl } : parts),
},
kysely: {
dialect: new PostgresJSDialect({
postgres: databaseUrl ? postgres(databaseUrl, driverOptions) : postgres({ ...parts, ...driverOptions }),
}),
log: ['error'] as const,
},
},
skipMigrations: dto.DB_SKIP_MIGRATIONS ?? false,

View File

@@ -1,8 +1,10 @@
import { Inject, Injectable } from '@nestjs/common';
import { InjectDataSource } from '@nestjs/typeorm';
import AsyncLock from 'async-lock';
import { sql } from 'kysely';
import semver from 'semver';
import { POSTGRES_VERSION_RANGE, VECTOR_VERSION_RANGE, VECTORS_VERSION_RANGE } from 'src/constants';
import { DB } from 'src/db';
import { IConfigRepository } from 'src/interfaces/config.interface';
import {
DatabaseExtension,
@@ -15,13 +17,14 @@ import {
VectorUpdateResult,
} from 'src/interfaces/database.interface';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import { UPSERT_COLUMNS } from 'src/utils/database';
import { isValidInteger } from 'src/validation';
import { DataSource, EntityManager, QueryRunner } from 'typeorm';
import { DataSource, EntityManager, EntityMetadata, QueryRunner } from 'typeorm';
@Injectable()
export class DatabaseRepository implements IDatabaseRepository {
private vectorExtension: VectorExtension;
readonly asyncLock = new AsyncLock();
private readonly asyncLock = new AsyncLock();
constructor(
@InjectDataSource() private dataSource: DataSource,
@@ -32,6 +35,13 @@ export class DatabaseRepository implements IDatabaseRepository {
this.logger.setContext(DatabaseRepository.name);
}
init() {
for (const metadata of this.dataSource.entityMetadatas) {
const table = metadata.tableName as keyof DB;
UPSERT_COLUMNS[table] = this.getUpsertColumns(metadata);
}
}
async reconnect() {
try {
if (this.dataSource.isInitialized) {
@@ -249,4 +259,10 @@ export class DatabaseRepository implements IDatabaseRepository {
private async releaseLock(lock: DatabaseLock, queryRunner: QueryRunner): Promise<void> {
return queryRunner.query('SELECT pg_advisory_unlock($1)', [lock]);
}
private getUpsertColumns(metadata: EntityMetadata) {
return Object.fromEntries(
metadata.ownColumns.map((column) => [column.propertyName, sql<string>`excluded.${sql.ref(column.propertyName)}`]),
) as any;
}
}

View File

@@ -64,7 +64,6 @@ export class MemoryRepository implements IMemoryRepository {
return new Set(results.map(({ assetId }) => assetId));
}
@GenerateSql({ params: [DummyValue.UUID, [DummyValue.UUID]] })
async addAssetIds(id: string, assetIds: string[]): Promise<void> {
await this.dataSource
.createQueryBuilder()
@@ -74,7 +73,6 @@ export class MemoryRepository implements IMemoryRepository {
.execute();
}
@GenerateSql({ params: [DummyValue.UUID, [DummyValue.UUID]] })
@Chunked({ paramIndex: 1 })
async removeAssetIds(id: string, assetIds: string[]): Promise<void> {
await this.dataSource

View File

@@ -1,22 +1,17 @@
import { Inject, Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Kysely, OrderByDirectionExpression, sql } from 'kysely';
import { InjectKysely } from 'nestjs-kysely';
import { randomUUID } from 'node:crypto';
import { DB } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators';
import { AssetFaceEntity } from 'src/entities/asset-face.entity';
import { AssetEntity } from 'src/entities/asset.entity';
import { ExifEntity } from 'src/entities/exif.entity';
import { AssetEntity, searchAssetBuilder } from 'src/entities/asset.entity';
import { GeodataPlacesEntity } from 'src/entities/geodata-places.entity';
import { SmartSearchEntity } from 'src/entities/smart-search.entity';
import { AssetType, PaginationMode } from 'src/enum';
import { IConfigRepository } from 'src/interfaces/config.interface';
import { DatabaseExtension, VectorExtension } from 'src/interfaces/database.interface';
import { AssetType } from 'src/enum';
import { ILoggerRepository } from 'src/interfaces/logger.interface';
import {
AssetDuplicateResult,
AssetDuplicateSearch,
AssetSearchOptions,
FaceEmbeddingSearch,
FaceSearchResult,
GetCameraMakesOptions,
GetCameraModelsOptions,
GetCitiesOptions,
@@ -25,40 +20,17 @@ import {
SearchPaginationOptions,
SmartSearchOptions,
} from 'src/interfaces/search.interface';
import { asVector, searchAssetBuilder } from 'src/utils/database';
import { Paginated, PaginationResult, paginatedBuilder } from 'src/utils/pagination';
import { anyUuid, asUuid, asVector } from 'src/utils/database';
import { Paginated } from 'src/utils/pagination';
import { isValidInteger } from 'src/validation';
import { Repository } from 'typeorm';
@Injectable()
export class SearchRepository implements ISearchRepository {
private vectorExtension: VectorExtension;
private faceColumns: string[];
private assetsByCityQuery: string;
constructor(
@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>,
@InjectRepository(ExifEntity) private exifRepository: Repository<ExifEntity>,
@InjectRepository(AssetFaceEntity) private assetFaceRepository: Repository<AssetFaceEntity>,
@InjectRepository(SmartSearchEntity) private smartSearchRepository: Repository<SmartSearchEntity>,
@InjectRepository(GeodataPlacesEntity) private geodataPlacesRepository: Repository<GeodataPlacesEntity>,
@Inject(ILoggerRepository) private logger: ILoggerRepository,
@Inject(IConfigRepository) configRepository: IConfigRepository,
@InjectKysely() private db: Kysely<DB>,
) {
this.vectorExtension = configRepository.getEnv().database.vectorExtension;
this.logger.setContext(SearchRepository.name);
this.faceColumns = this.assetFaceRepository.manager.connection
.getMetadata(AssetFaceEntity)
.ownColumns.map((column) => column.propertyName)
.filter((propertyName) => propertyName !== 'embedding');
this.assetsByCityQuery =
assetsByCityCte +
this.assetRepository
.createQueryBuilder('asset')
.innerJoinAndSelect('asset.exifInfo', 'exif')
.withDeleted()
.getQuery() +
' INNER JOIN cte ON asset.id = cte."assetId" ORDER BY exif.city';
}
@GenerateSql({
@@ -74,14 +46,15 @@ export class SearchRepository implements ISearchRepository {
],
})
async searchMetadata(pagination: SearchPaginationOptions, options: AssetSearchOptions): Paginated<AssetEntity> {
let builder = this.assetRepository.createQueryBuilder('asset');
builder = searchAssetBuilder(builder, options).orderBy('asset.fileCreatedAt', options.orderDirection ?? 'DESC');
return paginatedBuilder<AssetEntity>(builder, {
mode: PaginationMode.SKIP_TAKE,
skip: (pagination.page - 1) * pagination.size,
take: pagination.size,
});
const orderDirection = (options.orderDirection?.toLowerCase() || 'desc') as OrderByDirectionExpression;
const items = await searchAssetBuilder(this.db, options)
.orderBy('assets.fileCreatedAt', orderDirection)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute();
const hasNextPage = items.length > pagination.size;
items.splice(pagination.size);
return { items: items as any as AssetEntity[], hasNextPage };
}
@GenerateSql({
@@ -96,21 +69,12 @@ export class SearchRepository implements ISearchRepository {
},
],
})
async searchRandom(size: number, options: AssetSearchOptions): Promise<AssetEntity[]> {
const builder1 = searchAssetBuilder(this.assetRepository.createQueryBuilder('asset'), options);
const builder2 = builder1.clone();
searchRandom(size: number, options: AssetSearchOptions): Promise<AssetEntity[]> {
const uuid = randomUUID();
builder1.andWhere('asset.id > :uuid', { uuid }).orderBy('asset.id').take(size);
builder2.andWhere('asset.id < :uuid', { uuid }).orderBy('asset.id').take(size);
const [assets1, assets2] = await Promise.all([builder1.getMany(), builder2.getMany()]);
const missingCount = size - assets1.length;
for (let i = 0; i < missingCount && i < assets2.length; i++) {
assets1.push(assets2[i]);
}
return assets1;
const builder = searchAssetBuilder(this.db, options);
const lessThan = builder.where('assets.id', '<', uuid).orderBy('assets.id').limit(size);
const greaterThan = builder.where('assets.id', '>', uuid).orderBy('assets.id').limit(size);
return sql`${lessThan} union all ${greaterThan}`.execute(this.db) as any as Promise<AssetEntity[]>;
}
@GenerateSql({
@@ -126,76 +90,58 @@ export class SearchRepository implements ISearchRepository {
},
],
})
async searchSmart(
pagination: SearchPaginationOptions,
{ embedding, userIds, ...options }: SmartSearchOptions,
): Paginated<AssetEntity> {
let results: PaginationResult<AssetEntity> = { items: [], hasNextPage: false };
async searchSmart(pagination: SearchPaginationOptions, options: SmartSearchOptions): Paginated<AssetEntity> {
if (!isValidInteger(pagination.size, { min: 1, max: 1000 })) {
throw new Error(`Invalid value for 'size': ${pagination.size}`);
}
await this.assetRepository.manager.transaction(async (manager) => {
let builder = manager.createQueryBuilder(AssetEntity, 'asset');
builder = searchAssetBuilder(builder, options);
builder
.innerJoin('asset.smartSearch', 'search')
.andWhere('asset.ownerId IN (:...userIds )')
.orderBy('search.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });
const items = (await searchAssetBuilder(this.db, options)
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.orderBy(sql`smart_search.embedding <=> ${asVector(options.embedding)}`)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute()) as any as AssetEntity[];
const runtimeConfig = this.getRuntimeConfig(pagination.size);
if (runtimeConfig) {
await manager.query(runtimeConfig);
}
results = await paginatedBuilder<AssetEntity>(builder, {
mode: PaginationMode.LIMIT_OFFSET,
skip: (pagination.page - 1) * pagination.size,
take: pagination.size,
});
});
return results;
const hasNextPage = items.length > pagination.size;
items.splice(pagination.size);
return { items, hasNextPage };
}
@GenerateSql({
params: [
{
assetId: DummyValue.UUID,
embedding: Array.from({ length: 512 }, Math.random),
maxDistance: 0.6,
type: AssetType.IMAGE,
userIds: [DummyValue.UUID],
},
],
})
searchDuplicates({
assetId,
embedding,
maxDistance,
type,
userIds,
}: AssetDuplicateSearch): Promise<AssetDuplicateResult[]> {
const cte = this.assetRepository.createQueryBuilder('asset');
cte
.select('search.assetId', 'assetId')
.addSelect('asset.duplicateId', 'duplicateId')
.addSelect(`search.embedding <=> :embedding`, 'distance')
.innerJoin('asset.smartSearch', 'search')
.where('asset.ownerId IN (:...userIds )')
.andWhere('asset.id != :assetId')
.andWhere('asset.isVisible = :isVisible')
.andWhere('asset.type = :type')
.orderBy('search.embedding <=> :embedding')
.limit(64)
.setParameters({ assetId, embedding: asVector(embedding), isVisible: true, type, userIds });
const builder = this.assetRepository.manager
.createQueryBuilder()
.addCommonTableExpression(cte, 'cte')
.from('cte', 'res')
.select('res.*');
if (maxDistance) {
builder.where('res.distance <= :maxDistance', { maxDistance });
}
return builder.getRawMany() as Promise<AssetDuplicateResult[]>;
searchDuplicates({ assetId, embedding, maxDistance, type, userIds }: AssetDuplicateSearch) {
const vector = asVector(embedding);
return this.db
.with('cte', (qb) =>
qb
.selectFrom('assets')
.select([
'assets.id as assetId',
'assets.duplicateId',
sql<number>`smart_search.embedding <=> ${vector}`.as('distance'),
])
.innerJoin('smart_search', 'assets.id', 'smart_search.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.where('assets.isVisible', '=', true)
.where('assets.type', '=', type)
.where('assets.id', '!=', asUuid(assetId))
.orderBy(sql`smart_search.embedding <=> ${vector}`)
.limit(64),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance as number)
.execute();
}
@GenerateSql({
@@ -208,120 +154,131 @@ export class SearchRepository implements ISearchRepository {
},
],
})
async searchFaces({
userIds,
embedding,
numResults,
maxDistance,
hasPerson,
}: FaceEmbeddingSearch): Promise<FaceSearchResult[]> {
if (!isValidInteger(numResults, { min: 1 })) {
searchFaces({ userIds, embedding, numResults, maxDistance, hasPerson }: FaceEmbeddingSearch) {
if (!isValidInteger(numResults, { min: 1, max: 1000 })) {
throw new Error(`Invalid value for 'numResults': ${numResults}`);
}
// setting this too low messes with prefilter recall
numResults = Math.max(numResults, 64);
let results: Array<AssetFaceEntity & { distance: number }> = [];
await this.assetRepository.manager.transaction(async (manager) => {
const cte = manager
.createQueryBuilder(AssetFaceEntity, 'faces')
.select('search.embedding <=> :embedding', 'distance')
.innerJoin('faces.asset', 'asset')
.innerJoin('faces.faceSearch', 'search')
.where('asset.ownerId IN (:...userIds )')
.orderBy('search.embedding <=> :embedding')
.setParameters({ userIds, embedding: asVector(embedding) });
cte.limit(numResults);
if (hasPerson) {
cte.andWhere('faces."personId" IS NOT NULL');
}
for (const col of this.faceColumns) {
cte.addSelect(`faces.${col}`, col);
}
const runtimeConfig = this.getRuntimeConfig(numResults);
if (runtimeConfig) {
await manager.query(runtimeConfig);
}
results = await manager
.createQueryBuilder()
.select('res.*')
.addCommonTableExpression(cte, 'cte')
.from('cte', 'res')
.where('res.distance <= :maxDistance', { maxDistance })
.orderBy('res.distance')
.getRawMany();
});
return results.map((row) => ({
face: this.assetFaceRepository.create(row),
distance: row.distance,
}));
const vector = asVector(embedding);
return this.db
.with('cte', (qb) =>
qb
.selectFrom('asset_faces')
.select([
'asset_faces.id',
'asset_faces.personId',
sql<number>`face_search.embedding <=> ${vector}`.as('distance'),
])
.innerJoin('assets', 'assets.id', 'asset_faces.assetId')
.innerJoin('face_search', 'face_search.faceId', 'asset_faces.id')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.deletedAt', 'is', null)
.$if(!!hasPerson, (qb) => qb.where('asset_faces.personId', 'is not', null))
.orderBy(sql`face_search.embedding <=> ${vector}`)
.limit(numResults),
)
.selectFrom('cte')
.selectAll()
.where('cte.distance', '<=', maxDistance)
.execute();
}
@GenerateSql({ params: [DummyValue.STRING] })
async searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]> {
return await this.geodataPlacesRepository
.createQueryBuilder('geoplaces')
.where(`f_unaccent(name) %>> f_unaccent(:placeName)`)
.orWhere(`f_unaccent("admin2Name") %>> f_unaccent(:placeName)`)
.orWhere(`f_unaccent("admin1Name") %>> f_unaccent(:placeName)`)
.orWhere(`f_unaccent("alternateNames") %>> f_unaccent(:placeName)`)
searchPlaces(placeName: string): Promise<GeodataPlacesEntity[]> {
return this.db
.selectFrom('geodata_places')
.selectAll()
.where(
() =>
// kysely doesn't support trigram %>> or <->>> operators
sql`
f_unaccent(name) %>> f_unaccent(${placeName}) or
f_unaccent("admin2Name") %>> f_unaccent(${placeName}) or
f_unaccent("admin1Name") %>> f_unaccent(${placeName}) or
f_unaccent("alternateNames") %>> f_unaccent(${placeName})
`,
)
.orderBy(
`
COALESCE(f_unaccent(name) <->>> f_unaccent(:placeName), 0.1) +
COALESCE(f_unaccent("admin2Name") <->>> f_unaccent(:placeName), 0.1) +
COALESCE(f_unaccent("admin1Name") <->>> f_unaccent(:placeName), 0.1) +
COALESCE(f_unaccent("alternateNames") <->>> f_unaccent(:placeName), 0.1)
sql`
coalesce(f_unaccent(name) <->>> f_unaccent(${placeName}), 0.1) +
coalesce(f_unaccent("admin2Name") <->>> f_unaccent(${placeName}), 0.1) +
coalesce(f_unaccent("admin1Name") <->>> f_unaccent(${placeName}), 0.1) +
coalesce(f_unaccent("alternateNames") <->>> f_unaccent(${placeName}), 0.1)
`,
)
.setParameters({ placeName })
.limit(20)
.getMany();
.execute() as Promise<GeodataPlacesEntity[]>;
}
@GenerateSql({ params: [[DummyValue.UUID]] })
async getAssetsByCity(userIds: string[]): Promise<AssetEntity[]> {
const parameters = [userIds, true, false, AssetType.IMAGE];
const rawRes = await this.assetRepository.query(this.assetsByCityQuery, parameters);
getAssetsByCity(userIds: string[]): Promise<AssetEntity[]> {
return this.db
.withRecursive('cte', (qb) => {
const base = qb
.selectFrom('exif')
.select(['city', 'assetId'])
.innerJoin('assets', 'assets.id', 'exif.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.isVisible', '=', true)
.where('assets.isArchived', '=', false)
.where('assets.type', '=', 'IMAGE')
.where('assets.deletedAt', 'is', null)
.orderBy('city')
.limit(1);
const items: AssetEntity[] = [];
for (const res of rawRes) {
const item = { exifInfo: {} as Record<string, any> } as Record<string, any>;
for (const [key, value] of Object.entries(res)) {
if (key.startsWith('exif_')) {
item.exifInfo[key.replace('exif_', '')] = value;
} else {
item[key.replace('asset_', '')] = value;
}
}
items.push(item as AssetEntity);
}
const recursive = qb
.selectFrom('cte')
.select(['l.city', 'l.assetId'])
.innerJoinLateral(
(qb) =>
qb
.selectFrom('exif')
.select(['city', 'assetId'])
.innerJoin('assets', 'assets.id', 'exif.assetId')
.where('assets.ownerId', '=', anyUuid(userIds))
.where('assets.isVisible', '=', true)
.where('assets.isArchived', '=', false)
.where('assets.type', '=', 'IMAGE')
.where('assets.deletedAt', 'is', null)
.whereRef('exif.city', '>', 'cte.city')
.orderBy('city')
.limit(1)
.as('l'),
(join) => join.onTrue(),
);
return items;
return sql<{ city: string; assetId: string }>`(${base} union all ${recursive})`;
})
.selectFrom('assets')
.innerJoin('exif', 'assets.id', 'exif.assetId')
.innerJoin('cte', 'assets.id', 'cte.assetId')
.selectAll('assets')
.select((eb) => eb.fn('to_jsonb', [eb.table('exif')]).as('exifInfo'))
.orderBy('exif.city')
.execute() as any as Promise<AssetEntity[]>;
}
async upsert(assetId: string, embedding: number[]): Promise<void> {
await this.smartSearchRepository.upsert(
{ assetId, embedding: () => asVector(embedding, true) },
{ conflictPaths: ['assetId'] },
);
const vector = asVector(embedding);
await this.db
.insertInto('smart_search')
.values({ assetId: asUuid(assetId), embedding: vector } as any)
.onConflict((oc) => oc.column('assetId').doUpdateSet({ embedding: vector } as any))
.execute();
}
async getDimensionSize(): Promise<number> {
const res = await this.smartSearchRepository.manager.query(`
SELECT atttypmod as dimsize
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid
WHERE c.relkind = 'r'::char
AND f.attnum > 0
AND c.relname = 'smart_search'
AND f.attname = 'embedding'`);
const { rows } = await sql<{ dimsize: number }>`
select atttypmod as dimsize
from pg_attribute f
join pg_class c ON c.oid = f.attrelid
where c.relkind = 'r'::char
and f.attnum > 0
and c.relname = 'smart_search'
and f.attname = 'embedding'
`.execute(this.db);
const dimSize = res[0]['dimsize'];
const dimSize = rows[0]['dimsize'];
if (!isValidInteger(dimSize, { min: 1, max: 2 ** 16 })) {
throw new Error(`Could not retrieve CLIP dimension size`);
}
@@ -333,146 +290,71 @@ export class SearchRepository implements ISearchRepository {
throw new Error(`Invalid CLIP dimension size: ${dimSize}`);
}
return this.smartSearchRepository.manager.transaction(async (manager) => {
await manager.clear(SmartSearchEntity);
await manager.query(`ALTER TABLE smart_search ALTER COLUMN embedding SET DATA TYPE vector(${dimSize})`);
await manager.query(`REINDEX INDEX clip_index`);
return this.db.transaction().execute(async (trx) => {
await sql`truncate ${sql.table('smart_search')}`.execute(trx);
await trx.schema
.alterTable('smart_search')
.alterColumn('embedding', (col) => col.setDataType(sql.lit(`vector(${dimSize})`)))
.execute();
await sql`reindex index clip_index`.execute(trx);
});
}
async deleteAllSearchEmbeddings(): Promise<void> {
return this.smartSearchRepository.clear();
await sql`truncate ${sql.table('smart_search')}`.execute(this.db);
}
@GenerateSql({ params: [[DummyValue.UUID]] })
async getCountries(userIds: string[]): Promise<string[]> {
const query = this.exifRepository
.createQueryBuilder('exif')
.innerJoin('exif.asset', 'asset')
.where('asset.ownerId IN (:...userIds )', { userIds })
.andWhere(`exif.country != ''`)
.andWhere('exif.country IS NOT NULL')
.select('exif.country', 'country')
.distinctOn(['exif.country']);
const results = await query.getRawMany<{ country: string }>();
return results.map(({ country }) => country);
const res = await this.getExifField('country', userIds).execute();
return res.map((row) => row.country!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
async getStates(userIds: string[], { country }: GetStatesOptions): Promise<string[]> {
const query = this.exifRepository
.createQueryBuilder('exif')
.innerJoin('exif.asset', 'asset')
.where('asset.ownerId IN (:...userIds )', { userIds })
.andWhere(`exif.state != ''`)
.andWhere('exif.state IS NOT NULL')
.select('exif.state', 'state')
.distinctOn(['exif.state']);
const res = await this.getExifField('state', userIds)
.$if(!!country, (qb) => qb.where('country', '=', country!))
.execute();
if (country) {
query.andWhere('exif.country = :country', { country });
}
const result = await query.getRawMany<{ state: string }>();
return result.map(({ state }) => state);
return res.map((row) => row.state!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING, DummyValue.STRING] })
async getCities(userIds: string[], { country, state }: GetCitiesOptions): Promise<string[]> {
const query = this.exifRepository
.createQueryBuilder('exif')
.innerJoin('exif.asset', 'asset')
.where('asset.ownerId IN (:...userIds )', { userIds })
.andWhere(`exif.city != ''`)
.andWhere('exif.city IS NOT NULL')
.select('exif.city', 'city')
.distinctOn(['exif.city']);
const res = await this.getExifField('city', userIds)
.$if(!!country, (qb) => qb.where('country', '=', country!))
.$if(!!state, (qb) => qb.where('state', '=', state!))
.execute();
if (country) {
query.andWhere('exif.country = :country', { country });
}
if (state) {
query.andWhere('exif.state = :state', { state });
}
const results = await query.getRawMany<{ city: string }>();
return results.map(({ city }) => city);
return res.map((row) => row.city!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
async getCameraMakes(userIds: string[], { model }: GetCameraMakesOptions): Promise<string[]> {
const query = this.exifRepository
.createQueryBuilder('exif')
.innerJoin('exif.asset', 'asset')
.where('asset.ownerId IN (:...userIds )', { userIds })
.andWhere(`exif.make != ''`)
.andWhere('exif.make IS NOT NULL')
.select('exif.make', 'make')
.distinctOn(['exif.make']);
const res = await this.getExifField('make', userIds)
.$if(!!model, (qb) => qb.where('model', '=', model!))
.execute();
if (model) {
query.andWhere('exif.model = :model', { model });
}
const results = await query.getRawMany<{ make: string }>();
return results.map(({ make }) => make);
return res.map((row) => row.make!);
}
@GenerateSql({ params: [[DummyValue.UUID], DummyValue.STRING] })
async getCameraModels(userIds: string[], { make }: GetCameraModelsOptions): Promise<string[]> {
const query = this.exifRepository
.createQueryBuilder('exif')
.innerJoin('exif.asset', 'asset')
.where('asset.ownerId IN (:...userIds )', { userIds })
.andWhere(`exif.model != ''`)
.andWhere('exif.model IS NOT NULL')
.select('exif.model', 'model')
.distinctOn(['exif.model']);
const res = await this.getExifField('model', userIds)
.$if(!!make, (qb) => qb.where('make', '=', make!))
.execute();
if (make) {
query.andWhere('exif.make = :make', { make });
}
const results = await query.getRawMany<{ model: string }>();
return results.map(({ model }) => model);
return res.map((row) => row.model!);
}
private getRuntimeConfig(numResults?: number): string | undefined {
if (this.vectorExtension === DatabaseExtension.VECTOR) {
return 'SET LOCAL hnsw.ef_search = 1000;'; // mitigate post-filter recall
}
if (numResults && numResults !== 100) {
return `SET LOCAL vectors.hnsw_ef_search = ${Math.max(numResults, 100)};`;
}
private getExifField<K extends 'city' | 'state' | 'country' | 'make' | 'model'>(field: K, userIds: string[]) {
return this.db
.selectFrom('exif')
.select(field)
.distinctOn(field)
.innerJoin('assets', 'assets.id', 'exif.assetId')
.where('ownerId', '=', anyUuid(userIds))
.where('isVisible', '=', true)
.where('deletedAt', 'is', null)
.where(field, 'is not', null);
}
}
// the performance difference between this and the normal way is too huge to ignore, e.g. 3s vs 4ms
const assetsByCityCte = `
WITH RECURSIVE cte AS (
(
SELECT city, "assetId"
FROM exif
INNER JOIN assets ON exif."assetId" = assets.id
WHERE "ownerId" = ANY($1::uuid[]) AND "isVisible" = $2 AND "isArchived" = $3 AND type = $4
ORDER BY city
LIMIT 1
)
UNION ALL
SELECT l.city, l."assetId"
FROM cte c
, LATERAL (
SELECT city, "assetId"
FROM exif
INNER JOIN assets ON exif."assetId" = assets.id
WHERE city > c.city AND "ownerId" = ANY($1::uuid[]) AND "isVisible" = $2 AND "isArchived" = $3 AND type = $4
ORDER BY city
LIMIT 1
) l
)
`;

View File

@@ -108,7 +108,6 @@ export class TagRepository implements ITagRepository {
return new Set(results.map(({ assetId }) => assetId));
}
@GenerateSql({ params: [DummyValue.UUID, [DummyValue.UUID]] })
async addAssetIds(tagId: string, assetIds: string[]): Promise<void> {
if (assetIds.length === 0) {
return;
@@ -122,7 +121,6 @@ export class TagRepository implements ITagRepository {
.execute();
}
@GenerateSql({ params: [DummyValue.UUID, [DummyValue.UUID]] })
@Chunked({ paramIndex: 1 })
async removeAssetIds(tagId: string, assetIds: string[]): Promise<void> {
if (assetIds.length === 0) {
@@ -140,7 +138,6 @@ export class TagRepository implements ITagRepository {
.execute();
}
@GenerateSql({ params: [[{ assetId: DummyValue.UUID, tagId: DummyValue.UUID }]] })
@Chunked()
async upsertAssetIds(items: AssetTagItem[]): Promise<AssetTagItem[]> {
if (items.length === 0) {

View File

@@ -1,48 +1,47 @@
import { InjectRepository } from '@nestjs/typeorm';
import { Kysely } from 'kysely';
import { InjectKysely } from 'nestjs-kysely';
import { DB } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators';
import { AssetEntity } from 'src/entities/asset.entity';
import { AssetEntity, withExif } from 'src/entities/asset.entity';
import { IViewRepository } from 'src/interfaces/view.interface';
import { Brackets, Repository } from 'typeorm';
import { asUuid } from 'src/utils/database';
export class ViewRepository implements IViewRepository {
constructor(@InjectRepository(AssetEntity) private assetRepository: Repository<AssetEntity>) {}
constructor(@InjectKysely() private db: Kysely<DB>) {}
@GenerateSql({ params: [DummyValue.UUID] })
async getUniqueOriginalPaths(userId: string): Promise<string[]> {
const results = await this.assetRepository
.createQueryBuilder('asset')
.where({
isVisible: true,
isArchived: false,
ownerId: userId,
})
.select("DISTINCT substring(asset.originalPath FROM '^(.*/)[^/]*$')", 'directoryPath')
.getRawMany();
const results = await this.db
.selectFrom('assets')
.select((eb) => eb.fn<string>('substring', ['assets.originalPath', eb.val('^(.*/)[^/]*$')]).as('directoryPath'))
.distinct()
.where('ownerId', '=', asUuid(userId))
.where('isVisible', '=', true)
.where('isArchived', '=', false)
.where('deletedAt', 'is', null)
.execute();
return results.map((row: { directoryPath: string }) => row.directoryPath.replaceAll(/^\/|\/$/g, ''));
return results.map((row) => row.directoryPath.replaceAll(/^\/|\/$/g, ''));
}
@GenerateSql({ params: [DummyValue.UUID, DummyValue.STRING] })
async getAssetsByOriginalPath(userId: string, partialPath: string): Promise<AssetEntity[]> {
const normalizedPath = partialPath.replaceAll(/^\/|\/$/g, '');
const assets = await this.assetRepository
.createQueryBuilder('asset')
.where({
isVisible: true,
isArchived: false,
ownerId: userId,
})
.leftJoinAndSelect('asset.exifInfo', 'exifInfo')
.andWhere(
new Brackets((qb) => {
qb.where('asset.originalPath LIKE :likePath', { likePath: `%${normalizedPath}/%` }).andWhere(
'asset.originalPath NOT LIKE :notLikePath',
{ notLikePath: `%${normalizedPath}/%/%` },
);
}),
)
.orderBy(String.raw`regexp_replace(asset.originalPath, '.*/(.+)', '\1')`, 'ASC')
.getMany();
return assets;
return this.db
.selectFrom('assets')
.selectAll('assets')
.$call(withExif)
.where('ownerId', '=', asUuid(userId))
.where('isVisible', '=', true)
.where('isArchived', '=', false)
.where('deletedAt', 'is', null)
.where('originalPath', 'like', `%${normalizedPath}/%`)
.where('originalPath', 'not like', `%${normalizedPath}/%/%`)
.orderBy(
(eb) => eb.fn('regexp_replace', ['assets.originalPath', eb.val('.*/(.+)'), eb.val(String.raw`\1`)]),
'asc',
)
.execute() as any as Promise<AssetEntity[]>;
}
}