diff --git a/e2e/docker-compose.yml b/e2e/docker-compose.yml index 8ae5762a1b..07a1c7e859 100644 --- a/e2e/docker-compose.yml +++ b/e2e/docker-compose.yml @@ -11,6 +11,7 @@ services: immich-server: container_name: immich-e2e-server image: immich-server:latest + shm_size: 128mb build: context: ../ dockerfile: server/Dockerfile diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3dfb4375bd..fb12e7a898 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -409,9 +409,12 @@ importers: '@react-email/render': specifier: ^1.1.2 version: 1.4.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@socket.io/redis-adapter': - specifier: ^8.3.0 - version: 8.3.0(socket.io-adapter@2.5.6) + '@socket.io/postgres-adapter': + specifier: ^0.5.0 + version: 0.5.0(socket.io-adapter@2.5.6) + '@types/pg': + specifier: ^8.16.0 + version: 8.16.0 ajv: specifier: ^8.17.1 version: 8.18.0 @@ -565,6 +568,9 @@ importers: socket.io: specifier: ^4.8.1 version: 4.8.3 + socket.io-adapter: + specifier: ^2.5.6 + version: 2.5.6 tailwindcss-preset-email: specifier: ^1.4.0 version: 1.4.1(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2)) @@ -3383,6 +3389,10 @@ packages: '@microsoft/tsdoc@0.16.0': resolution: {integrity: sha512-xgAyonlVVS+q7Vc7qLW0UrJU7rSFcETRWsqdXZtjzRU8dF+6CkozTK4V4y1LwOX7j8r/vHphjDeMeGI4tNGeGA==} + '@msgpack/msgpack@2.8.0': + resolution: {integrity: sha512-h9u4u/jiIRKbq25PM+zymTyW6bhTzELvOoUd+AvYriWOAKpLGnIamaET3pnHYoI5iYphAHBI4ayx0MehR+VVPQ==} + engines: {node: '>= 10'} + '@msgpackr-extract/msgpackr-extract-darwin-arm64@3.0.3': resolution: {integrity: sha512-QZHtlVgbAdy2zAqNA9Gu1UpIuI8Xvsd1v8ic6B2pZmeFnFcMWiPLfWXh7TVw4eGEZ/C9TH281KwhVoeQUKbyjw==} cpu: [arm64] @@ -4291,9 +4301,9 @@ packages: '@socket.io/component-emitter@3.1.2': resolution: {integrity: sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==} - '@socket.io/redis-adapter@8.3.0': - resolution: {integrity: sha512-ly0cra+48hDmChxmIpnESKrc94LjRL80TEmZVscuQ/WWkRP81nNj8W8cCGMqbI4L6NCuAaPRSzZF1a9GlAxxnA==} - engines: {node: '>=10.0.0'} + '@socket.io/postgres-adapter@0.5.0': + resolution: {integrity: sha512-s1vFsatB4lS429ZbeAi8ju+mZMgtgdSmi9UsZsdcEG++vVtX5z10yDEt4TV8saePscvvGjs6uXvJfMCxz8+M2Q==} + engines: {node: '>=12.0.0'} peerDependencies: socket.io-adapter: ^2.5.4 @@ -9247,9 +9257,6 @@ packages: not@0.1.0: resolution: {integrity: sha512-5PDmaAsVfnWUgTUbJ3ERwn7u79Z0dYxN9ErxCpVJJqe2RK0PJ3z+iFUxuqjwtlDDegXvtWoxD/3Fzxox7tFGWA==} - notepack.io@3.0.1: - resolution: {integrity: sha512-TKC/8zH5pXIAMVQio2TvVDTtPRX+DJPHDqjRbxogtFiByHyzKmy96RA0JtCQJ+WouyyL4A10xomQzgbUT+1jCg==} - npm-run-path@4.0.1: resolution: {integrity: sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==} engines: {node: '>=8'} @@ -11554,10 +11561,6 @@ packages: engines: {node: '>=0.8.0'} hasBin: true - uid2@1.0.0: - resolution: {integrity: sha512-+I6aJUv63YAcY9n4mQreLUt0d4lvwkkopDNmpomkAUz0fAkEMV9pRWxN0EjhW1YfRhcuyHg2v3mwddCDW1+LFQ==} - engines: {node: '>= 4.0.0'} - uid@2.0.2: resolution: {integrity: sha512-u3xV3X7uzvi5b1MncmZo3i2Aw222Zk1keqLA1YkHldREkAhAqi65wuPfe7lHx8H/Wzy+8CE7S7uS3jekIM5s8g==} engines: {node: '>=8'} @@ -15315,6 +15318,8 @@ snapshots: '@microsoft/tsdoc@0.16.0': {} + '@msgpack/msgpack@2.8.0': {} + '@msgpackr-extract/msgpackr-extract-darwin-arm64@3.0.3': optional: true @@ -16193,13 +16198,15 @@ snapshots: '@socket.io/component-emitter@3.1.2': {} - '@socket.io/redis-adapter@8.3.0(socket.io-adapter@2.5.6)': + '@socket.io/postgres-adapter@0.5.0(socket.io-adapter@2.5.6)': dependencies: + '@msgpack/msgpack': 2.8.0 + '@types/pg': 8.16.0 debug: 4.3.7 - notepack.io: 3.0.1 + pg: 8.18.0 socket.io-adapter: 2.5.6 - uid2: 1.0.0 transitivePeerDependencies: + - pg-native - supports-color '@sphinxxxx/color-conversion@2.2.2': {} @@ -22099,8 +22106,6 @@ snapshots: not@0.1.0: {} - notepack.io@3.0.1: {} - npm-run-path@4.0.1: dependencies: path-key: 3.1.1 @@ -24828,8 +24833,6 @@ snapshots: uglify-js@3.19.3: optional: true - uid2@1.0.0: {} - uid@2.0.2: dependencies: '@lukeed/csprng': 1.1.0 diff --git a/server/package.json b/server/package.json index 5578f242ae..ce6e6a3620 100644 --- a/server/package.json +++ b/server/package.json @@ -57,7 +57,8 @@ "@opentelemetry/semantic-conventions": "^1.34.0", "@react-email/components": "^0.5.0", "@react-email/render": "^1.1.2", - "@socket.io/redis-adapter": "^8.3.0", + "@socket.io/postgres-adapter": "^0.5.0", + "@types/pg": "^8.16.0", "ajv": "^8.17.1", "archiver": "^7.0.0", "async-lock": "^1.4.0", @@ -109,6 +110,7 @@ "sharp": "^0.34.5", "sirv": "^3.0.0", "socket.io": "^4.8.1", + "socket.io-adapter": "^2.5.6", "tailwindcss-preset-email": "^1.4.0", "thumbhash": "^0.1.1", "transformation-matrix": "^3.1.0", diff --git a/server/src/app.common.ts b/server/src/app.common.ts index 934c13343f..c9c0156b00 100644 --- a/server/src/app.common.ts +++ b/server/src/app.common.ts @@ -5,8 +5,9 @@ import cookieParser from 'cookie-parser'; import { existsSync } from 'node:fs'; import sirv from 'sirv'; import { excludePaths, serverVersion } from 'src/constants'; +import { SocketIoAdapter } from 'src/enum'; import { MaintenanceWorkerService } from 'src/maintenance/maintenance-worker.service'; -import { WebSocketAdapter } from 'src/middleware/websocket.adapter'; +import { createWebSocketAdapter } from 'src/middleware/websocket.adapter'; import { ConfigRepository } from 'src/repositories/config.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { bootstrapTelemetry } from 'src/repositories/telemetry.repository'; @@ -25,6 +26,7 @@ export async function configureExpress( { permitSwaggerWrite = true, ssr, + socketIoAdapter, }: { /** * Whether to allow swagger module to write to the specs.json @@ -36,6 +38,10 @@ export async function configureExpress( * Service to use for server-side rendering */ ssr: typeof ApiService | typeof MaintenanceWorkerService; + /** + * Override the Socket.IO adapter. If not specified, uses the adapter from config. + */ + socketIoAdapter?: SocketIoAdapter; }, ) { const configRepository = app.get(ConfigRepository); @@ -55,7 +61,7 @@ export async function configureExpress( } app.setGlobalPrefix('api', { exclude: excludePaths }); - app.useWebSocketAdapter(new WebSocketAdapter(app)); + app.useWebSocketAdapter(await createWebSocketAdapter(app, socketIoAdapter)); useSwagger(app, { write: configRepository.isDev() && permitSwaggerWrite }); diff --git a/server/src/controllers/index.ts b/server/src/controllers/index.ts index dc3754ce24..716a86c22f 100644 --- a/server/src/controllers/index.ts +++ b/server/src/controllers/index.ts @@ -10,6 +10,7 @@ import { DatabaseBackupController } from 'src/controllers/database-backup.contro import { DownloadController } from 'src/controllers/download.controller'; import { DuplicateController } from 'src/controllers/duplicate.controller'; import { FaceController } from 'src/controllers/face.controller'; +import { InternalController } from 'src/controllers/internal.controller'; import { JobController } from 'src/controllers/job.controller'; import { LibraryController } from 'src/controllers/library.controller'; import { MaintenanceController } from 'src/controllers/maintenance.controller'; @@ -51,6 +52,7 @@ export const controllers = [ DownloadController, DuplicateController, FaceController, + InternalController, JobController, LibraryController, MaintenanceController, diff --git a/server/src/controllers/internal.controller.ts b/server/src/controllers/internal.controller.ts new file mode 100644 index 0000000000..4a97012d46 --- /dev/null +++ b/server/src/controllers/internal.controller.ts @@ -0,0 +1,22 @@ +import { Body, Controller, NotFoundException, Post, Req } from '@nestjs/common'; +import { ApiExcludeController } from '@nestjs/swagger'; +import { Request } from 'express'; +import { AppRestartEvent, EventRepository } from 'src/repositories/event.repository'; + +const LOCALHOST_ADDRESSES = new Set(['127.0.0.1', '::1', '::ffff:127.0.0.1']); + +@ApiExcludeController() +@Controller('internal') +export class InternalController { + constructor(private eventRepository: EventRepository) {} + + @Post('restart') + async restart(@Req() req: Request, @Body() dto: AppRestartEvent): Promise { + const remoteAddress = req.socket.remoteAddress; + if (!remoteAddress || !LOCALHOST_ADDRESSES.has(remoteAddress)) { + throw new NotFoundException(); + } + + await this.eventRepository.emit('AppRestart', dto); + } +} diff --git a/server/src/dtos/env.dto.ts b/server/src/dtos/env.dto.ts index b04366c273..f8b41d2c3c 100644 --- a/server/src/dtos/env.dto.ts +++ b/server/src/dtos/env.dto.ts @@ -1,6 +1,6 @@ import { Transform, Type } from 'class-transformer'; import { IsEnum, IsInt, IsString, Matches } from 'class-validator'; -import { ImmichEnvironment, LogFormat, LogLevel } from 'src/enum'; +import { ImmichEnvironment, LogFormat, LogLevel, SocketIoAdapter } from 'src/enum'; import { IsIPRange, Optional, ValidateBoolean } from 'src/validation'; // TODO import from sql-tools once the swagger plugin supports external enums @@ -149,6 +149,11 @@ export class EnvDto { @Optional() IMMICH_WORKERS_EXCLUDE?: string; + @IsEnum(SocketIoAdapter) + @Optional() + @Transform(({ value }) => (value ? String(value).toLowerCase().trim() : value)) + IMMICH_SOCKETIO_ADAPTER?: SocketIoAdapter; + @IsString() @Optional() DB_DATABASE_NAME?: string; diff --git a/server/src/enum.ts b/server/src/enum.ts index 2aa9bd2aa6..8e3fca71fe 100644 --- a/server/src/enum.ts +++ b/server/src/enum.ts @@ -518,6 +518,11 @@ export enum ImmichTelemetry { Job = 'job', } +export enum SocketIoAdapter { + BroadcastChannel = 'broadcastchannel', + Postgres = 'postgres', +} + export enum ExifOrientation { Horizontal = 1, MirrorHorizontal = 2, diff --git a/server/src/main.ts b/server/src/main.ts index f2491f07bc..9e11e4c681 100644 --- a/server/src/main.ts +++ b/server/src/main.ts @@ -1,6 +1,5 @@ import { Kysely, sql } from 'kysely'; import { CommandFactory } from 'nest-commander'; -import { ChildProcess, fork } from 'node:child_process'; import { dirname, join } from 'node:path'; import { Worker } from 'node:worker_threads'; import { PostgresError } from 'postgres'; @@ -18,7 +17,7 @@ class Workers { /** * Currently running workers */ - workers: Partial Promise | void }>> = {}; + workers: Partial Promise | void }>> = {}; /** * Fail-safe in case anything dies during restart @@ -101,25 +100,23 @@ class Workers { const basePath = dirname(__filename); const workerFile = join(basePath, 'workers', `${name}.js`); - let anyWorker: Worker | ChildProcess; - let kill: (signal?: NodeJS.Signals) => Promise | void; + const inspectArg = process.execArgv.find((arg) => arg.startsWith('--inspect')); + const workerData: { inspectorPort?: number } = {}; - if (name === ImmichWorker.Api) { - const worker = fork(workerFile, [], { - execArgv: process.execArgv.map((arg) => (arg.startsWith('--inspect') ? '--inspect=0.0.0.0:9231' : arg)), - }); - - kill = (signal) => void worker.kill(signal); - anyWorker = worker; - } else { - const worker = new Worker(workerFile); - - kill = async () => void (await worker.terminate()); - anyWorker = worker; + if (inspectArg) { + const inspectorPorts: Record = { + [ImmichWorker.Api]: 9230, + [ImmichWorker.Microservices]: 9231, + [ImmichWorker.Maintenance]: 9232, + }; + workerData.inspectorPort = inspectorPorts[name]; } - anyWorker.on('error', (error) => this.onError(name, error)); - anyWorker.on('exit', (exitCode) => this.onExit(name, exitCode)); + const worker = new Worker(workerFile, { workerData }); + const kill = async () => void (await worker.terminate()); + + worker.on('error', (error) => this.onError(name, error)); + worker.on('exit', (exitCode) => this.onExit(name, exitCode)); this.workers[name] = { kill }; } @@ -152,8 +149,8 @@ class Workers { console.error(`${name} worker exited with code ${exitCode}`); if (this.workers[ImmichWorker.Api] && name !== ImmichWorker.Api) { - console.error('Killing api process'); - void this.workers[ImmichWorker.Api].kill('SIGTERM'); + console.error('Terminating api worker'); + void this.workers[ImmichWorker.Api].kill(); } } diff --git a/server/src/maintenance/maintenance-worker.controller.ts b/server/src/maintenance/maintenance-worker.controller.ts index 162fa27257..c854971573 100644 --- a/server/src/maintenance/maintenance-worker.controller.ts +++ b/server/src/maintenance/maintenance-worker.controller.ts @@ -4,6 +4,7 @@ import { Delete, Get, Next, + NotFoundException, Param, Post, Req, @@ -25,12 +26,15 @@ import { ImmichCookie } from 'src/enum'; import { MaintenanceRoute } from 'src/maintenance/maintenance-auth.guard'; import { MaintenanceWorkerService } from 'src/maintenance/maintenance-worker.service'; import { GetLoginDetails } from 'src/middleware/auth.guard'; +import { AppRestartEvent } from 'src/repositories/event.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { LoginDetails } from 'src/services/auth.service'; import { sendFile } from 'src/utils/file'; import { respondWithCookie } from 'src/utils/response'; import { FilenameParamDto } from 'src/validation'; +const LOCALHOST_ADDRESSES = new Set(['127.0.0.1', '::1', '::ffff:127.0.0.1']); + import type { DatabaseBackupController as _DatabaseBackupController } from 'src/controllers/database-backup.controller'; import type { ServerController as _ServerController } from 'src/controllers/server.controller'; import { DatabaseBackupDeleteDto, DatabaseBackupListResponseDto } from 'src/dtos/database-backup.dto'; @@ -131,4 +135,14 @@ export class MaintenanceWorkerController { setMaintenanceMode(@Body() dto: SetMaintenanceModeDto): void { void this.service.setAction(dto); } + + @Post('internal/restart') + internalRestart(@Req() req: Request, @Body() dto: AppRestartEvent): void { + const remoteAddress = req.socket.remoteAddress; + if (!remoteAddress || !LOCALHOST_ADDRESSES.has(remoteAddress)) { + throw new NotFoundException(); + } + + this.service.handleInternalRestart(dto); + } } diff --git a/server/src/maintenance/maintenance-worker.service.ts b/server/src/maintenance/maintenance-worker.service.ts index 9ceb3caa43..d73eb8bf4e 100644 --- a/server/src/maintenance/maintenance-worker.service.ts +++ b/server/src/maintenance/maintenance-worker.service.ts @@ -19,6 +19,7 @@ import { MaintenanceWebsocketRepository } from 'src/maintenance/maintenance-webs import { AppRepository } from 'src/repositories/app.repository'; import { ConfigRepository } from 'src/repositories/config.repository'; import { DatabaseRepository } from 'src/repositories/database.repository'; +import { AppRestartEvent } from 'src/repositories/event.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { ProcessRepository } from 'src/repositories/process.repository'; import { StorageRepository } from 'src/repositories/storage.repository'; @@ -290,6 +291,9 @@ export class MaintenanceWorkerService { const lock = await this.databaseRepository.tryLock(DatabaseLock.MaintenanceOperation); if (!lock) { + // Another maintenance worker has the lock - poll until maintenance mode ends + this.logger.log('Another worker has the maintenance lock, polling for maintenance mode changes...'); + await this.pollForMaintenanceEnd(); return; } @@ -351,4 +355,25 @@ export class MaintenanceWorkerService { this.maintenanceWebsocketRepository.serverSend('AppRestart', state); this.appRepository.exitApp(); } + + handleInternalRestart(state: AppRestartEvent): void { + this.maintenanceWebsocketRepository.clientBroadcast('AppRestartV1', state); + this.maintenanceWebsocketRepository.serverSend('AppRestart', state); + this.appRepository.exitApp(); + } + + private async pollForMaintenanceEnd(): Promise { + const pollIntervalMs = 5000; + + while (true) { + await new Promise((resolve) => setTimeout(resolve, pollIntervalMs)); + + const state = await this.systemMetadataRepository.get(SystemMetadataKey.MaintenanceMode); + if (!state?.isMaintenanceMode) { + this.logger.log('Maintenance mode ended, restarting...'); + this.appRepository.exitApp(); + return; + } + } + } } diff --git a/server/src/middleware/broadcast-channel.adapter.ts b/server/src/middleware/broadcast-channel.adapter.ts new file mode 100644 index 0000000000..af55448fb2 --- /dev/null +++ b/server/src/middleware/broadcast-channel.adapter.ts @@ -0,0 +1,80 @@ +import { + ClusterAdapterWithHeartbeat, + type ClusterAdapterOptions, + type ClusterMessage, + type ClusterResponse, + type ServerId, +} from 'socket.io-adapter'; + +const BC_CHANNEL_NAME = 'immich:socketio'; + +interface BroadcastChannelPayload { + type: 'message' | 'response'; + sourceUid: string; + targetUid?: string; + data: unknown; +} + +/** + * Socket.IO adapter using Node.js BroadcastChannel + * + * Relays messages between worker_threads within a single OS process. + * Zero external dependencies. Does NOT work across containers — use + * the Postgres adapter for multi-replica deployments. + */ +class BroadcastChannelAdapter extends ClusterAdapterWithHeartbeat { + private readonly channel: BroadcastChannel; + + constructor(nsp: any, opts?: Partial) { + super(nsp, opts ?? {}); + + this.channel = new BroadcastChannel(BC_CHANNEL_NAME); + this.channel.addEventListener('message', (event: MessageEvent) => { + const msg = event.data; + if (msg.sourceUid === this.uid) { + return; + } + if (msg.type === 'message') { + this.onMessage(msg.data as ClusterMessage); + } else if (msg.type === 'response' && msg.targetUid === this.uid) { + this.onResponse(msg.data as ClusterResponse); + } + }); + + this.init(); + } + + override doPublish(message: ClusterMessage): Promise { + this.channel.postMessage({ + type: 'message', + sourceUid: this.uid, + data: message, + }); + return Promise.resolve(''); + } + + override doPublishResponse(requesterUid: ServerId, response: ClusterResponse): Promise { + this.channel.postMessage({ + type: 'response', + sourceUid: this.uid, + targetUid: requesterUid, + data: response, + }); + return Promise.resolve(); + } + + override close(): void { + super.close(); + this.channel.close(); + } +} + +export function createBroadcastChannelAdapter(opts?: Partial) { + const options: Partial = { + ...opts, + }; + + return function (nsp: any) { + return new BroadcastChannelAdapter(nsp, options); + }; +} diff --git a/server/src/middleware/websocket.adapter.ts b/server/src/middleware/websocket.adapter.ts index 64bb1f9ea5..bf7ae60865 100644 --- a/server/src/middleware/websocket.adapter.ts +++ b/server/src/middleware/websocket.adapter.ts @@ -1,21 +1,103 @@ -import { INestApplicationContext } from '@nestjs/common'; +import { INestApplication, Logger } from '@nestjs/common'; import { IoAdapter } from '@nestjs/platform-socket.io'; -import { createAdapter } from '@socket.io/redis-adapter'; -import { Redis } from 'ioredis'; -import { ServerOptions } from 'socket.io'; +import { Pool, PoolConfig } from 'pg'; +import type { ServerOptions } from 'socket.io'; +import { SocketIoAdapter } from 'src/enum'; +import { createBroadcastChannelAdapter } from 'src/middleware/broadcast-channel.adapter'; import { ConfigRepository } from 'src/repositories/config.repository'; +import { asPostgresConnectionConfig } from 'src/utils/database'; -export class WebSocketAdapter extends IoAdapter { - constructor(private app: INestApplicationContext) { +export type Ssl = 'require' | 'allow' | 'prefer' | 'verify-full' | boolean | object; + +export function asPgPoolSsl(ssl?: Ssl): PoolConfig['ssl'] { + if (ssl === undefined || ssl === false || ssl === 'allow') { + return false; + } + + if (ssl === true || ssl === 'prefer' || ssl === 'require') { + return { rejectUnauthorized: false }; + } + + if (ssl === 'verify-full') { + return { rejectUnauthorized: true }; + } + + return ssl; +} + +class BroadcastChannelSocketAdapter extends IoAdapter { + private adapterConstructor: ReturnType; + + constructor(app: INestApplication) { super(app); + this.adapterConstructor = createBroadcastChannelAdapter(); } createIOServer(port: number, options?: ServerOptions): any { - const { redis } = this.app.get(ConfigRepository).getEnv(); const server = super.createIOServer(port, options); - const pubClient = new Redis(redis); - const subClient = pubClient.duplicate(); - server.adapter(createAdapter(pubClient, subClient)); + server.adapter(this.adapterConstructor); return server; } } + +class PostgresSocketAdapter extends IoAdapter { + private adapterConstructor: any; + + constructor(app: INestApplication, adapterConstructor: any) { + super(app); + this.adapterConstructor = adapterConstructor; + } + + createIOServer(port: number, options?: ServerOptions): any { + const server = super.createIOServer(port, options); + server.adapter(this.adapterConstructor); + return server; + } +} + +export async function createWebSocketAdapter( + app: INestApplication, + adapterOverride?: SocketIoAdapter, +): Promise { + const logger = new Logger('WebSocketAdapter'); + const config = new ConfigRepository(); + const { database, socketIo } = config.getEnv(); + const adapter = adapterOverride ?? socketIo.adapter; + + switch (adapter) { + case SocketIoAdapter.Postgres: { + logger.log('Using Postgres Socket.IO adapter'); + const { createAdapter } = await import('@socket.io/postgres-adapter'); + const config = asPostgresConnectionConfig(database.config); + const pool = new Pool({ + host: config.host, + port: config.port, + user: config.username, + password: config.password, + database: config.database, + ssl: asPgPoolSsl(config.ssl), + max: 2, + }); + + await pool.query(` + CREATE TABLE IF NOT EXISTS socket_io_attachments ( + id bigserial UNIQUE, + created_at timestamptz DEFAULT NOW(), + payload bytea + ); + `); + + pool.on('error', (error) => { + logger.error(' Postgres pool error', error); + }); + + const adapterConstructor = createAdapter(pool); + return new PostgresSocketAdapter(app, adapterConstructor); + } + + case SocketIoAdapter.BroadcastChannel: { + logger.log('Using BroadcastChannel Socket.IO adapter'); + return new BroadcastChannelSocketAdapter(app); + } + } +} diff --git a/server/src/repositories/app.repository.ts b/server/src/repositories/app.repository.ts index 96e413232f..3e934153f5 100644 --- a/server/src/repositories/app.repository.ts +++ b/server/src/repositories/app.repository.ts @@ -1,7 +1,4 @@ import { Injectable } from '@nestjs/common'; -import { createAdapter } from '@socket.io/redis-adapter'; -import Redis from 'ioredis'; -import { Server as SocketIO } from 'socket.io'; import { ExitCode } from 'src/enum'; import { ConfigRepository } from 'src/repositories/config.repository'; import { AppRestartEvent } from 'src/repositories/event.repository'; @@ -24,24 +21,17 @@ export class AppRepository { } async sendOneShotAppRestart(state: AppRestartEvent): Promise { - const server = new SocketIO(); - const { redis } = new ConfigRepository().getEnv(); - const pubClient = new Redis({ ...redis, lazyConnect: true }); - const subClient = pubClient.duplicate(); + const { port } = new ConfigRepository().getEnv(); + const url = `http://127.0.0.1:${port}/api/internal/restart`; - await Promise.all([pubClient.connect(), subClient.connect()]); - - server.adapter(createAdapter(pubClient, subClient)); - - // => corresponds to notification.service.ts#onAppRestart - server.emit('AppRestartV1', state, async () => { - const responses = await server.serverSideEmitWithAck('AppRestart', state); - if (responses.some((response) => response !== 'ok')) { - throw new Error("One or more node(s) returned a non-'ok' response to our restart request!"); - } - - pubClient.disconnect(); - subClient.disconnect(); + const response = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(state), }); + + if (!response.ok) { + throw new Error(`Failed to trigger app restart: ${response.status} ${response.statusText}`); + } } } diff --git a/server/src/repositories/config.repository.ts b/server/src/repositories/config.repository.ts index 7e8082a582..60847811f1 100644 --- a/server/src/repositories/config.repository.ts +++ b/server/src/repositories/config.repository.ts @@ -21,6 +21,7 @@ import { LogFormat, LogLevel, QueueName, + SocketIoAdapter, } from 'src/enum'; import { VectorExtension } from 'src/types'; import { setDifference } from 'src/utils/set'; @@ -117,6 +118,10 @@ export interface EnvData { }; }; + socketIo: { + adapter: SocketIoAdapter; + }; + noColor: boolean; nodeVersion?: string; } @@ -347,6 +352,10 @@ const getEnv = (): EnvData => { }, }, + socketIo: { + adapter: dto.IMMICH_SOCKETIO_ADAPTER ?? SocketIoAdapter.Postgres, + }, + noColor: !!dto.NO_COLOR, }; }; diff --git a/server/src/utils/maintenance.ts b/server/src/utils/maintenance.ts index 47abb0ab89..44b98f3f75 100644 --- a/server/src/utils/maintenance.ts +++ b/server/src/utils/maintenance.ts @@ -1,60 +1,11 @@ -import { createAdapter } from '@socket.io/redis-adapter'; -import Redis from 'ioredis'; import { SignJWT } from 'jose'; import { randomBytes } from 'node:crypto'; import { join } from 'node:path'; -import { Server as SocketIO } from 'socket.io'; import { StorageCore } from 'src/cores/storage.core'; import { MaintenanceAuthDto, MaintenanceDetectInstallResponseDto } from 'src/dtos/maintenance.dto'; import { StorageFolder } from 'src/enum'; -import { ConfigRepository } from 'src/repositories/config.repository'; -import { AppRestartEvent } from 'src/repositories/event.repository'; import { StorageRepository } from 'src/repositories/storage.repository'; -export function sendOneShotAppRestart(state: AppRestartEvent): void { - const server = new SocketIO(); - const { redis } = new ConfigRepository().getEnv(); - const pubClient = new Redis(redis); - const subClient = pubClient.duplicate(); - server.adapter(createAdapter(pubClient, subClient)); - - /** - * Keep trying until we manage to stop Immich - * - * Sometimes there appear to be communication - * issues between to the other servers. - * - * This issue only occurs with this method. - */ - async function tryTerminate() { - while (true) { - try { - const responses = await server.serverSideEmitWithAck('AppRestart', state); - if (responses.length > 0) { - return; - } - } catch (error) { - console.error(error); - console.error('Encountered an error while telling Immich to stop.'); - } - - console.info( - "\nIt doesn't appear that Immich stopped, trying again in a moment.\nIf Immich is already not running, you can ignore this error.", - ); - - await new Promise((r) => setTimeout(r, 1e3)); - } - } - - // => corresponds to notification.service.ts#onAppRestart - server.emit('AppRestartV1', state, () => { - void tryTerminate().finally(() => { - pubClient.disconnect(); - subClient.disconnect(); - }); - }); -} - export async function createMaintenanceLoginUrl( baseUrl: string, auth: MaintenanceAuthDto, diff --git a/server/src/workers/api.ts b/server/src/workers/api.ts index 99c08c0fa7..3f095f40de 100644 --- a/server/src/workers/api.ts +++ b/server/src/workers/api.ts @@ -1,14 +1,21 @@ import { NestFactory } from '@nestjs/core'; import { NestExpressApplication } from '@nestjs/platform-express'; +import inspector from 'node:inspector'; +import { isMainThread, workerData } from 'node:worker_threads'; import { configureExpress, configureTelemetry } from 'src/app.common'; import { ApiModule } from 'src/app.module'; import { AppRepository } from 'src/repositories/app.repository'; import { ApiService } from 'src/services/api.service'; import { isStartUpError } from 'src/utils/misc'; -async function bootstrap() { +export async function bootstrap() { process.title = 'immich-api'; + const { inspectorPort } = workerData ?? {}; + if (inspectorPort) { + inspector.open(inspectorPort, '0.0.0.0', false); + } + configureTelemetry(); const app = await NestFactory.create(ApiModule, { bufferLogs: true }); @@ -19,10 +26,12 @@ async function bootstrap() { }); } -bootstrap().catch((error) => { - if (!isStartUpError(error)) { - console.error(error); - } - // eslint-disable-next-line unicorn/no-process-exit - process.exit(1); -}); +if (!isMainThread || process.send) { + bootstrap().catch((error) => { + if (!isStartUpError(error)) { + console.error(error); + } + + process.exit(1); + }); +} diff --git a/server/src/workers/maintenance.ts b/server/src/workers/maintenance.ts index 035ec600af..5bc466f837 100644 --- a/server/src/workers/maintenance.ts +++ b/server/src/workers/maintenance.ts @@ -1,13 +1,22 @@ import { NestFactory } from '@nestjs/core'; import { NestExpressApplication } from '@nestjs/platform-express'; +import inspector from 'node:inspector'; +import { isMainThread, workerData } from 'node:worker_threads'; import { configureExpress, configureTelemetry } from 'src/app.common'; import { MaintenanceModule } from 'src/app.module'; +import { SocketIoAdapter } from 'src/enum'; import { MaintenanceWorkerService } from 'src/maintenance/maintenance-worker.service'; import { AppRepository } from 'src/repositories/app.repository'; import { isStartUpError } from 'src/utils/misc'; -async function bootstrap() { +export async function bootstrap() { process.title = 'immich-maintenance'; + + const { inspectorPort } = workerData ?? {}; + if (inspectorPort) { + inspector.open(inspectorPort, '0.0.0.0', false); + } + configureTelemetry(); const app = await NestFactory.create(MaintenanceModule, { bufferLogs: true }); @@ -16,13 +25,18 @@ async function bootstrap() { void configureExpress(app, { permitSwaggerWrite: false, ssr: MaintenanceWorkerService, + // Use BroadcastChannel instead of Postgres adapter to avoid crash when + // pg_terminate_backend() kills all database connections during restore + socketIoAdapter: SocketIoAdapter.BroadcastChannel, }); } -bootstrap().catch((error) => { - if (!isStartUpError(error)) { - console.error(error); - } - // eslint-disable-next-line unicorn/no-process-exit - process.exit(1); -}); +if (!isMainThread) { + bootstrap().catch((error) => { + if (!isStartUpError(error)) { + console.error(error); + } + + process.exit(1); + }); +} diff --git a/server/src/workers/microservices.ts b/server/src/workers/microservices.ts index 8f06b4b0b1..19b82a8d5e 100644 --- a/server/src/workers/microservices.ts +++ b/server/src/workers/microservices.ts @@ -1,8 +1,9 @@ import { NestFactory } from '@nestjs/core'; -import { isMainThread } from 'node:worker_threads'; +import inspector from 'node:inspector'; +import { isMainThread, workerData } from 'node:worker_threads'; import { MicroservicesModule } from 'src/app.module'; import { serverVersion } from 'src/constants'; -import { WebSocketAdapter } from 'src/middleware/websocket.adapter'; +import { createWebSocketAdapter } from 'src/middleware/websocket.adapter'; import { AppRepository } from 'src/repositories/app.repository'; import { ConfigRepository } from 'src/repositories/config.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; @@ -10,6 +11,11 @@ import { bootstrapTelemetry } from 'src/repositories/telemetry.repository'; import { isStartUpError } from 'src/utils/misc'; export async function bootstrap() { + const { inspectorPort } = workerData ?? {}; + if (inspectorPort) { + inspector.open(inspectorPort, '0.0.0.0', false); + } + const { telemetry } = new ConfigRepository().getEnv(); if (telemetry.metrics.size > 0) { bootstrapTelemetry(telemetry.microservicesPort); @@ -24,7 +30,7 @@ export async function bootstrap() { logger.setContext('Bootstrap'); app.useLogger(logger); - app.useWebSocketAdapter(new WebSocketAdapter(app)); + app.useWebSocketAdapter(await createWebSocketAdapter(app)); await (host ? app.listen(0, host) : app.listen(0)); diff --git a/server/test/medium/specs/middleware/broadcast-channel.adapter.spec.ts b/server/test/medium/specs/middleware/broadcast-channel.adapter.spec.ts new file mode 100644 index 0000000000..0512dda827 --- /dev/null +++ b/server/test/medium/specs/middleware/broadcast-channel.adapter.spec.ts @@ -0,0 +1,276 @@ +import { ClusterMessage, ClusterResponse } from 'socket.io-adapter'; +import { createBroadcastChannelAdapter } from 'src/middleware/broadcast-channel.adapter'; +import { vi } from 'vitest'; + +const createMockNamespace = () => ({ + name: '/', + sockets: new Map(), + adapter: null, + server: { + encoder: { + encode: vi.fn().mockReturnValue([]), + }, + _opts: {}, + sockets: { + sockets: new Map(), + }, + }, +}); + +describe('BroadcastChannelAdapter', () => { + describe('createBroadcastChannelAdapter', () => { + it('should return a factory function', () => { + const factory = createBroadcastChannelAdapter(); + expect(typeof factory).toBe('function'); + }); + + it('should create adapter instance when factory is called', () => { + const mockNamespace = createMockNamespace(); + const factory = createBroadcastChannelAdapter(); + const adapter = factory(mockNamespace); + + expect(adapter).toBeDefined(); + expect(adapter.doPublish).toBeDefined(); + expect(adapter.doPublishResponse).toBeDefined(); + + adapter.close(); + }); + }); + + describe('BroadcastChannelAdapter message passing', () => { + it('should actually send and receive messages between two adapters', async () => { + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + + const namespace1 = createMockNamespace(); + const namespace2 = createMockNamespace(); + + const adapter1 = factory1(namespace1); + const adapter2 = factory2(namespace2); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedMessages: ClusterMessage[] = []; + const messageReceived = new Promise((resolve) => { + const originalOnMessage = adapter2.onMessage.bind(adapter2); + adapter2.onMessage = (message: ClusterMessage) => { + receivedMessages.push(message); + resolve(); + return originalOnMessage(message); + }; + }); + + const testMessage = { + type: 2, + data: { + opts: { rooms: new Set(['room1']) }, + rooms: ['room1'], + }, + nsp: '/', + }; + + void adapter1.doPublish(testMessage as any); + + await Promise.race([messageReceived, new Promise((resolve) => setTimeout(resolve, 500))]); + + expect(receivedMessages.length).toBeGreaterThan(0); + + adapter1.close(); + adapter2.close(); + }); + + it('should send ConfigUpdate-style event and receive it on another adapter', async () => { + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + + const namespace1 = createMockNamespace(); + const namespace2 = createMockNamespace(); + + const adapter1 = factory1(namespace1); + const adapter2 = factory2(namespace2); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedMessages: ClusterMessage[] = []; + const messageReceived = new Promise((resolve) => { + const originalOnMessage = adapter2.onMessage.bind(adapter2); + adapter2.onMessage = (message: ClusterMessage) => { + receivedMessages.push(message); + if ((message as any)?.data?.event === 'ConfigUpdate') { + resolve(); + } + return originalOnMessage(message); + }; + }); + + const configUpdateMessage = { + type: 2, + data: { + event: 'ConfigUpdate', + payload: { newConfig: { ffmpeg: { crf: 23 } }, oldConfig: { ffmpeg: { crf: 20 } } }, + opts: { rooms: new Set() }, + rooms: [], + }, + nsp: '/', + }; + + void adapter1.doPublish(configUpdateMessage as any); + + await Promise.race([messageReceived, new Promise((resolve) => setTimeout(resolve, 500))]); + + const configMessages = receivedMessages.filter((m) => (m as any)?.data?.event === 'ConfigUpdate'); + expect(configMessages.length).toBeGreaterThan(0); + expect((configMessages[0] as any).data.payload.newConfig.ffmpeg.crf).toBe(23); + + adapter1.close(); + adapter2.close(); + }); + + it('should send AppRestart-style event and receive it on another adapter', async () => { + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + + const namespace1 = createMockNamespace(); + const namespace2 = createMockNamespace(); + + const adapter1 = factory1(namespace1); + const adapter2 = factory2(namespace2); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedMessages: ClusterMessage[] = []; + const messageReceived = new Promise((resolve) => { + const originalOnMessage = adapter2.onMessage.bind(adapter2); + adapter2.onMessage = (message: ClusterMessage) => { + receivedMessages.push(message); + if ((message as any)?.data?.event === 'AppRestart') { + resolve(); + } + return originalOnMessage(message); + }; + }); + + const appRestartMessage = { + type: 2, + data: { + event: 'AppRestart', + payload: { isMaintenanceMode: true }, + opts: { rooms: new Set() }, + rooms: [], + }, + nsp: '/', + }; + + void adapter1.doPublish(appRestartMessage as any); + + await Promise.race([messageReceived, new Promise((resolve) => setTimeout(resolve, 500))]); + + const restartMessages = receivedMessages.filter((m) => (m as any)?.data?.event === 'AppRestart'); + expect(restartMessages.length).toBeGreaterThan(0); + expect((restartMessages[0] as any).data.payload.isMaintenanceMode).toBe(true); + + adapter1.close(); + adapter2.close(); + }); + + it('should not receive its own messages (echo prevention)', async () => { + const factory = createBroadcastChannelAdapter(); + const namespace = createMockNamespace(); + const adapter = factory(namespace); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedOwnMessages: ClusterMessage[] = []; + const uniqueMarker = `test-${Date.now()}-${Math.random()}`; + + const originalOnMessage = adapter.onMessage.bind(adapter); + adapter.onMessage = (message: ClusterMessage) => { + if ((message as any)?.data?.marker === uniqueMarker) { + receivedOwnMessages.push(message); + } + return originalOnMessage(message); + }; + + const testMessage = { + type: 2, + data: { + marker: uniqueMarker, + opts: { rooms: new Set() }, + rooms: [], + }, + nsp: '/', + }; + + void adapter.doPublish(testMessage as any); + + await new Promise((resolve) => setTimeout(resolve, 200)); + + expect(receivedOwnMessages.length).toBe(0); + + adapter.close(); + }); + + it('should send and receive response messages between adapters', async () => { + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + + const namespace1 = createMockNamespace(); + const namespace2 = createMockNamespace(); + + const adapter1 = factory1(namespace1); + const adapter2 = factory2(namespace2); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedResponses: ClusterResponse[] = []; + const responseReceived = new Promise((resolve) => { + const originalOnResponse = adapter1.onResponse.bind(adapter1); + adapter1.onResponse = (response: ClusterResponse) => { + receivedResponses.push(response); + resolve(); + return originalOnResponse(response); + }; + }); + + const responseMessage = { + type: 3, + data: { result: 'success', count: 42 }, + }; + + void adapter2.doPublishResponse((adapter1 as any).uid, responseMessage as any); + + await Promise.race([responseReceived, new Promise((resolve) => setTimeout(resolve, 500))]); + + expect(receivedResponses.length).toBeGreaterThan(0); + + adapter1.close(); + adapter2.close(); + }); + }); + + describe('BroadcastChannelAdapter lifecycle', () => { + it('should close cleanly without errors', () => { + const factory = createBroadcastChannelAdapter(); + const namespace = createMockNamespace(); + const adapter = factory(namespace); + + expect(() => adapter.close()).not.toThrow(); + }); + + it('should handle multiple adapters closing in sequence', () => { + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + const factory3 = createBroadcastChannelAdapter(); + + const adapter1 = factory1(createMockNamespace()); + const adapter2 = factory2(createMockNamespace()); + const adapter3 = factory3(createMockNamespace()); + + expect(() => { + adapter1.close(); + adapter2.close(); + adapter3.close(); + }).not.toThrow(); + }); + }); +}); diff --git a/server/test/medium/specs/middleware/websocket-integration.spec.ts b/server/test/medium/specs/middleware/websocket-integration.spec.ts new file mode 100644 index 0000000000..ca10fbad54 --- /dev/null +++ b/server/test/medium/specs/middleware/websocket-integration.spec.ts @@ -0,0 +1,159 @@ +import { Server } from 'socket.io'; +import { createBroadcastChannelAdapter } from 'src/middleware/broadcast-channel.adapter'; +import { EventRepository } from 'src/repositories/event.repository'; +import { LoggingRepository } from 'src/repositories/logging.repository'; +import { WebsocketRepository } from 'src/repositories/websocket.repository'; +import { automock } from 'test/utils'; +import { vi } from 'vitest'; + +describe('WebSocket Integration - serverSend with adapters', () => { + describe('BroadcastChannel adapter', () => { + it('should broadcast ConfigUpdate event through BroadcastChannel adapter', async () => { + const createMockNamespace = () => ({ + name: '/', + sockets: new Map(), + adapter: null, + server: { + encoder: { encode: vi.fn().mockReturnValue([]) }, + _opts: {}, + sockets: { sockets: new Map() }, + }, + }); + + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + + const namespace1 = createMockNamespace(); + const namespace2 = createMockNamespace(); + + const adapter1 = factory1(namespace1); + const adapter2 = factory2(namespace2); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedMessages: any[] = []; + vi.spyOn(adapter2, 'onMessage').mockImplementation((message: any) => { + receivedMessages.push(message); + }); + + const configUpdatePayload = { + type: 5, + data: { + event: 'ConfigUpdate', + args: [{ newConfig: { ffmpeg: { crf: 23 } }, oldConfig: { ffmpeg: { crf: 20 } } }], + }, + nsp: '/', + }; + + void adapter1.doPublish(configUpdatePayload as any); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const configMessages = receivedMessages.filter((m) => m?.data?.event === 'ConfigUpdate'); + expect(configMessages.length).toBeGreaterThan(0); + + adapter1.close(); + adapter2.close(); + }); + + it('should broadcast AppRestart event through BroadcastChannel adapter', async () => { + const createMockNamespace = () => ({ + name: '/', + sockets: new Map(), + adapter: null, + server: { + encoder: { encode: vi.fn().mockReturnValue([]) }, + _opts: {}, + sockets: { sockets: new Map() }, + }, + }); + + const factory1 = createBroadcastChannelAdapter(); + const factory2 = createBroadcastChannelAdapter(); + + const namespace1 = createMockNamespace(); + const namespace2 = createMockNamespace(); + + const adapter1 = factory1(namespace1); + const adapter2 = factory2(namespace2); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const receivedMessages: any[] = []; + vi.spyOn(adapter2, 'onMessage').mockImplementation((message: any) => { + receivedMessages.push(message); + }); + + const appRestartPayload = { + type: 5, + data: { + event: 'AppRestart', + args: [{ isMaintenanceMode: true }], + }, + nsp: '/', + }; + + void adapter1.doPublish(appRestartPayload as any); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const restartMessages = receivedMessages.filter((m) => m?.data?.event === 'AppRestart'); + expect(restartMessages.length).toBeGreaterThan(0); + + adapter1.close(); + adapter2.close(); + }); + }); + + describe('WebsocketRepository with adapter', () => { + it('should call serverSideEmit when serverSend is called', () => { + const mockServer = { + serverSideEmit: vi.fn(), + on: vi.fn(), + } as unknown as Server; + + const eventRepository = automock(EventRepository, { + args: [undefined, undefined, { setContext: () => {} }], + }); + const loggingRepository = automock(LoggingRepository, { + args: [undefined, { getEnv: () => ({ noColor: false }) }], + strict: false, + }); + + const websocketRepository = new WebsocketRepository(eventRepository, loggingRepository); + (websocketRepository as any).server = mockServer; + + websocketRepository.serverSend('ConfigUpdate', { + newConfig: { ffmpeg: { crf: 23 } } as any, + oldConfig: { ffmpeg: { crf: 20 } } as any, + }); + + expect(mockServer.serverSideEmit).toHaveBeenCalledWith('ConfigUpdate', { + newConfig: { ffmpeg: { crf: 23 } }, + oldConfig: { ffmpeg: { crf: 20 } }, + }); + }); + + it('should call serverSideEmit for AppRestart event', () => { + const mockServer = { + serverSideEmit: vi.fn(), + on: vi.fn(), + } as unknown as Server; + + const eventRepository = automock(EventRepository, { + args: [undefined, undefined, { setContext: () => {} }], + }); + const loggingRepository = automock(LoggingRepository, { + args: [undefined, { getEnv: () => ({ noColor: false }) }], + strict: false, + }); + + const websocketRepository = new WebsocketRepository(eventRepository, loggingRepository); + (websocketRepository as any).server = mockServer; + + websocketRepository.serverSend('AppRestart', { isMaintenanceMode: true }); + + expect(mockServer.serverSideEmit).toHaveBeenCalledWith('AppRestart', { isMaintenanceMode: true }); + }); + }); +}); diff --git a/server/test/medium/specs/middleware/websocket.adapter.spec.ts b/server/test/medium/specs/middleware/websocket.adapter.spec.ts new file mode 100644 index 0000000000..5b7f894d26 --- /dev/null +++ b/server/test/medium/specs/middleware/websocket.adapter.spec.ts @@ -0,0 +1,70 @@ +import { INestApplication } from '@nestjs/common'; +import { IoAdapter } from '@nestjs/platform-socket.io'; +import { SocketIoAdapter } from 'src/enum'; +import { asPgPoolSsl, createWebSocketAdapter } from 'src/middleware/websocket.adapter'; +import { Mocked, vi } from 'vitest'; + +describe('asPgPoolSsl', () => { + it('should return false for undefined ssl', () => { + expect(asPgPoolSsl()).toBe(false); + }); + + it('should return false for ssl = false', () => { + expect(asPgPoolSsl(false)).toBe(false); + }); + + it('should return false for ssl = "allow"', () => { + expect(asPgPoolSsl('allow')).toBe(false); + }); + + it('should return { rejectUnauthorized: false } for ssl = true', () => { + expect(asPgPoolSsl(true)).toEqual({ rejectUnauthorized: false }); + }); + + it('should return { rejectUnauthorized: false } for ssl = "prefer"', () => { + expect(asPgPoolSsl('prefer')).toEqual({ rejectUnauthorized: false }); + }); + + it('should return { rejectUnauthorized: false } for ssl = "require"', () => { + expect(asPgPoolSsl('require')).toEqual({ rejectUnauthorized: false }); + }); + + it('should return { rejectUnauthorized: true } for ssl = "verify-full"', () => { + expect(asPgPoolSsl('verify-full')).toEqual({ rejectUnauthorized: true }); + }); + + it('should pass through object ssl config unchanged', () => { + const sslConfig = { ca: 'certificate', rejectUnauthorized: true }; + expect(asPgPoolSsl(sslConfig)).toBe(sslConfig); + }); +}); + +describe('createWebSocketAdapter', () => { + let mockApp: Mocked; + + beforeEach(() => { + vi.clearAllMocks(); + + mockApp = { + getHttpServer: vi.fn().mockReturnValue({}), + } as unknown as Mocked; + }); + + describe('BroadcastChannel adapter', () => { + it('should create BroadcastChannel adapter when configured', async () => { + const adapter = await createWebSocketAdapter(mockApp, SocketIoAdapter.BroadcastChannel); + + expect(adapter).toBeDefined(); + expect(adapter).toBeInstanceOf(IoAdapter); + }); + }); + + describe('Postgres adapter', () => { + it('should create Postgres adapter when configured', async () => { + const adapter = await createWebSocketAdapter(mockApp, SocketIoAdapter.Postgres); + + expect(adapter).toBeDefined(); + expect(adapter).toBeInstanceOf(IoAdapter); + }); + }); +}); diff --git a/server/test/repositories/config.repository.mock.ts b/server/test/repositories/config.repository.mock.ts index 62e498372e..1094828719 100644 --- a/server/test/repositories/config.repository.mock.ts +++ b/server/test/repositories/config.repository.mock.ts @@ -1,4 +1,4 @@ -import { DatabaseExtension, ImmichEnvironment, ImmichWorker, LogFormat } from 'src/enum'; +import { DatabaseExtension, ImmichEnvironment, ImmichWorker, LogFormat, SocketIoAdapter } from 'src/enum'; import { ConfigRepository, EnvData } from 'src/repositories/config.repository'; import { RepositoryInterface } from 'src/types'; import { Mocked, vitest } from 'vitest'; @@ -99,6 +99,10 @@ const envData: EnvData = { }, }, + socketIo: { + adapter: SocketIoAdapter.Postgres, + }, + noColor: false, };