From e89d174bd0e52c6628fcce30305eb5c82eb5db71 Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Mon, 2 Feb 2026 16:39:51 -0500 Subject: [PATCH] websocket integration platform-side headers update comment consistent platform check tweak websocket handling support streaming --- .../alextran/immich/core/HttpClientManager.kt | 14 ++- .../app/alextran/immich/core/Network.g.kt | 82 ++++++++++++++++ .../alextran/immich/core/NetworkApiPlugin.kt | 11 +++ .../alextran/immich/images/RemoteImages.g.kt | 7 +- .../immich/images/RemoteImagesImpl.kt | 11 +-- mobile/ios/Runner/Core/Network.g.swift | 71 ++++++++++++++ mobile/ios/Runner/Core/NetworkApiImpl.swift | 26 +++++ .../ios/Runner/Core/URLSessionManager.swift | 82 +++++++++++++++- mobile/ios/Runner/Images/RemoteImages.g.swift | 7 +- .../ios/Runner/Images/RemoteImagesImpl.swift | 5 +- .../loaders/remote_image_request.dart | 5 +- .../repositories/network.repository.dart | 29 +++++- .../repositories/sync_api.repository.dart | 4 - .../pages/common/headers_settings.page.dart | 8 +- mobile/lib/platform/network_api.g.dart | 94 +++++++++++++++++++ mobile/lib/platform/remote_image_api.g.dart | 8 +- .../widgets/images/remote_image_provider.dart | 7 +- mobile/lib/providers/auth.provider.dart | 2 + mobile/lib/providers/websocket.provider.dart | 10 +- mobile/lib/services/api.service.dart | 10 +- mobile/lib/services/auth.service.dart | 2 +- mobile/pigeon/network_api.dart | 13 +++ mobile/pigeon/remote_image_api.dart | 9 +- mobile/pubspec.lock | 33 +++---- mobile/pubspec.yaml | 11 ++- 25 files changed, 468 insertions(+), 93 deletions(-) diff --git a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/HttpClientManager.kt b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/HttpClientManager.kt index ee92c2120e..f5346aba03 100644 --- a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/HttpClientManager.kt +++ b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/HttpClientManager.kt @@ -5,6 +5,7 @@ import app.alextran.immich.BuildConfig import okhttp3.Cache import okhttp3.ConnectionPool import okhttp3.Dispatcher +import okhttp3.Headers import okhttp3.OkHttpClient import java.io.ByteArrayInputStream import java.io.File @@ -39,6 +40,7 @@ object HttpClientManager { private val keyStore = KeyStore.getInstance("AndroidKeyStore").apply { load(null) } + var headers: Headers = Headers.headersOf("User-Agent", USER_AGENT) val isMtls: Boolean get() = keyStore.containsAlias(CERT_ALIAS) fun initialize(context: Context) { @@ -93,6 +95,12 @@ object HttpClientManager { synchronized(this) { clientChangedListeners.add(listener) } } + fun setRequestHeaders(headerMap: Map) { + val builder = Headers.Builder() + headerMap.forEach { (key, value) -> builder.add(key, value) } + headers = builder.build() + } + private fun build(cacheDir: File): OkHttpClient { val connectionPool = ConnectionPool( maxIdleConnections = KEEP_ALIVE_CONNECTIONS, @@ -109,8 +117,10 @@ object HttpClientManager { HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.socketFactory) return OkHttpClient.Builder() - .addInterceptor { chain -> - chain.proceed(chain.request().newBuilder().header("User-Agent", USER_AGENT).build()) + .addInterceptor { + val builder = it.request().newBuilder() + headers.forEach { (key, value) -> builder.addHeader(key, value) } + it.proceed(builder.build()) } .connectionPool(connectionPool) .dispatcher(Dispatcher().apply { maxRequestsPerHost = MAX_REQUESTS_PER_HOST }) diff --git a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt index 06e77d9b72..f9086554da 100644 --- a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt +++ b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt @@ -145,6 +145,37 @@ data class ClientCertPrompt ( override fun hashCode(): Int = toList().hashCode() } + +/** Generated class from Pigeon that represents data sent in messages. */ +data class WebSocketTaskResult ( + val taskPointer: Long, + val taskProtocol: String? = null +) + { + companion object { + fun fromList(pigeonVar_list: List): WebSocketTaskResult { + val taskPointer = pigeonVar_list[0] as Long + val taskProtocol = pigeonVar_list[1] as String? + return WebSocketTaskResult(taskPointer, taskProtocol) + } + } + fun toList(): List { + return listOf( + taskPointer, + taskProtocol, + ) + } + override fun equals(other: Any?): Boolean { + if (other !is WebSocketTaskResult) { + return false + } + if (this === other) { + return true + } + return NetworkPigeonUtils.deepEquals(toList(), other.toList()) } + + override fun hashCode(): Int = toList().hashCode() +} private open class NetworkPigeonCodec : StandardMessageCodec() { override fun readValueOfType(type: Byte, buffer: ByteBuffer): Any? { return when (type) { @@ -158,6 +189,11 @@ private open class NetworkPigeonCodec : StandardMessageCodec() { ClientCertPrompt.fromList(it) } } + 131.toByte() -> { + return (readValue(buffer) as? List)?.let { + WebSocketTaskResult.fromList(it) + } + } else -> super.readValueOfType(type, buffer) } } @@ -171,6 +207,10 @@ private open class NetworkPigeonCodec : StandardMessageCodec() { stream.write(130) writeValue(stream, value.toList()) } + is WebSocketTaskResult -> { + stream.write(131) + writeValue(stream, value.toList()) + } else -> super.writeValue(stream, value) } } @@ -183,6 +223,9 @@ interface NetworkApi { fun selectCertificate(promptText: ClientCertPrompt, callback: (Result) -> Unit) fun removeCertificate(callback: (Result) -> Unit) fun getClientPointer(): Long + /** iOS only - creates a WebSocket task and waits for connection to be established. */ + fun createWebSocketTask(url: String, protocols: List?, callback: (Result) -> Unit) + fun setRequestHeaders(headers: Map) companion object { /** The codec used by NetworkApi. */ @@ -264,6 +307,45 @@ interface NetworkApi { channel.setMessageHandler(null) } } + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { message, reply -> + val args = message as List + val urlArg = args[0] as String + val protocolsArg = args[1] as List? + api.createWebSocketTask(urlArg, protocolsArg) { result: Result -> + val error = result.exceptionOrNull() + if (error != null) { + reply.reply(NetworkPigeonUtils.wrapError(error)) + } else { + val data = result.getOrNull() + reply.reply(NetworkPigeonUtils.wrapResult(data)) + } + } + } + } else { + channel.setMessageHandler(null) + } + } + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.immich_mobile.NetworkApi.setRequestHeaders$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { message, reply -> + val args = message as List + val headersArg = args[0] as Map + val wrapped: List = try { + api.setRequestHeaders(headersArg) + listOf(null) + } catch (exception: Throwable) { + NetworkPigeonUtils.wrapError(exception) + } + reply.reply(wrapped) + } + } else { + channel.setMessageHandler(null) + } + } } } } diff --git a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/NetworkApiPlugin.kt b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/NetworkApiPlugin.kt index ac71bc5ef6..861e07155b 100644 --- a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/NetworkApiPlugin.kt +++ b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/NetworkApiPlugin.kt @@ -104,6 +104,17 @@ private class NetworkApiImpl(private val context: Context) : NetworkApi { return NativeBuffer.createGlobalRef(client) } + // only used on iOS + override fun createWebSocketTask( + url: String, + protocols: List?, + callback: (Result) -> Unit + ) {} + + override fun setRequestHeaders(headers: Map) { + HttpClientManager.setRequestHeaders(headers) + } + private fun handlePickedFile(uri: Uri) { val callback = pendingCallback ?: return pendingCallback = null diff --git a/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImages.g.kt b/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImages.g.kt index 0e3cf19657..6d6864c8ba 100644 --- a/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImages.g.kt +++ b/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImages.g.kt @@ -47,7 +47,7 @@ private open class RemoteImagesPigeonCodec : StandardMessageCodec() { /** Generated interface from Pigeon that represents a handler of messages from Flutter. */ interface RemoteImageApi { - fun requestImage(url: String, headers: Map, requestId: Long, callback: (Result?>) -> Unit) + fun requestImage(url: String, requestId: Long, callback: (Result?>) -> Unit) fun cancelRequest(requestId: Long) fun clearCache(callback: (Result) -> Unit) @@ -66,9 +66,8 @@ interface RemoteImageApi { channel.setMessageHandler { message, reply -> val args = message as List val urlArg = args[0] as String - val headersArg = args[1] as Map - val requestIdArg = args[2] as Long - api.requestImage(urlArg, headersArg, requestIdArg) { result: Result?> -> + val requestIdArg = args[1] as Long + api.requestImage(urlArg, requestIdArg) { result: Result?> -> val error = result.exceptionOrNull() if (error != null) { reply.reply(RemoteImagesPigeonUtils.wrapError(error)) diff --git a/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImagesImpl.kt b/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImagesImpl.kt index 04a181cd6e..3dfdaaf6cc 100644 --- a/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImagesImpl.kt +++ b/mobile/android/app/src/main/kotlin/app/alextran/immich/images/RemoteImagesImpl.kt @@ -49,7 +49,6 @@ class RemoteImagesImpl(context: Context) : RemoteImageApi { override fun requestImage( url: String, - headers: Map, requestId: Long, callback: (Result?>) -> Unit ) { @@ -58,7 +57,6 @@ class RemoteImagesImpl(context: Context) : RemoteImageApi { ImageFetcherManager.fetch( url, - headers, signal, onSuccess = { buffer -> requestMap.remove(requestId) @@ -119,12 +117,11 @@ private object ImageFetcherManager { fun fetch( url: String, - headers: Map, signal: CancellationSignal, onSuccess: (NativeByteBuffer) -> Unit, onFailure: (Exception) -> Unit, ) { - fetcher.fetch(url, headers, signal, onSuccess, onFailure) + fetcher.fetch(url, signal, onSuccess, onFailure) } fun clearCache(onCleared: (Result) -> Unit) { @@ -151,7 +148,6 @@ private object ImageFetcherManager { private sealed interface ImageFetcher { fun fetch( url: String, - headers: Map, signal: CancellationSignal, onSuccess: (NativeByteBuffer) -> Unit, onFailure: (Exception) -> Unit, @@ -178,7 +174,6 @@ private class CronetImageFetcher(context: Context, cacheDir: File) : ImageFetche override fun fetch( url: String, - headers: Map, signal: CancellationSignal, onSuccess: (NativeByteBuffer) -> Unit, onFailure: (Exception) -> Unit, @@ -193,7 +188,7 @@ private class CronetImageFetcher(context: Context, cacheDir: File) : ImageFetche val callback = FetchCallback(onSuccess, onFailure, ::onComplete) val requestBuilder = engine.newUrlRequestBuilder(url, callback, executor) - headers.forEach { (key, value) -> requestBuilder.addHeader(key, value) } + HttpClientManager.headers.forEach { (key, value) -> requestBuilder.addHeader(key, value) } val request = requestBuilder.build() signal.setOnCancelListener(request::cancel) request.start() @@ -390,7 +385,6 @@ private class OkHttpImageFetcher private constructor( override fun fetch( url: String, - headers: Map, signal: CancellationSignal, onSuccess: (NativeByteBuffer) -> Unit, onFailure: (Exception) -> Unit, @@ -403,7 +397,6 @@ private class OkHttpImageFetcher private constructor( } val requestBuilder = Request.Builder().url(url) - headers.forEach { (key, value) -> requestBuilder.addHeader(key, value) } val call = client.newCall(requestBuilder.build()) signal.setOnCancelListener(call::cancel) diff --git a/mobile/ios/Runner/Core/Network.g.swift b/mobile/ios/Runner/Core/Network.g.swift index 9dd81ee92b..3e6a435510 100644 --- a/mobile/ios/Runner/Core/Network.g.swift +++ b/mobile/ios/Runner/Core/Network.g.swift @@ -176,6 +176,35 @@ struct ClientCertPrompt: Hashable { } } +/// Generated class from Pigeon that represents data sent in messages. +struct WebSocketTaskResult: Hashable { + var taskPointer: Int64 + var taskProtocol: String? = nil + + + // swift-format-ignore: AlwaysUseLowerCamelCase + static func fromList(_ pigeonVar_list: [Any?]) -> WebSocketTaskResult? { + let taskPointer = pigeonVar_list[0] as! Int64 + let taskProtocol: String? = nilOrValue(pigeonVar_list[1]) + + return WebSocketTaskResult( + taskPointer: taskPointer, + taskProtocol: taskProtocol + ) + } + func toList() -> [Any?] { + return [ + taskPointer, + taskProtocol, + ] + } + static func == (lhs: WebSocketTaskResult, rhs: WebSocketTaskResult) -> Bool { + return deepEqualsNetwork(lhs.toList(), rhs.toList()) } + func hash(into hasher: inout Hasher) { + deepHashNetwork(value: toList(), hasher: &hasher) + } +} + private class NetworkPigeonCodecReader: FlutterStandardReader { override func readValue(ofType type: UInt8) -> Any? { switch type { @@ -183,6 +212,8 @@ private class NetworkPigeonCodecReader: FlutterStandardReader { return ClientCertData.fromList(self.readValue() as! [Any?]) case 130: return ClientCertPrompt.fromList(self.readValue() as! [Any?]) + case 131: + return WebSocketTaskResult.fromList(self.readValue() as! [Any?]) default: return super.readValue(ofType: type) } @@ -197,6 +228,9 @@ private class NetworkPigeonCodecWriter: FlutterStandardWriter { } else if let value = value as? ClientCertPrompt { super.writeByte(130) super.writeValue(value.toList()) + } else if let value = value as? WebSocketTaskResult { + super.writeByte(131) + super.writeValue(value.toList()) } else { super.writeValue(value) } @@ -224,6 +258,9 @@ protocol NetworkApi { func selectCertificate(promptText: ClientCertPrompt, completion: @escaping (Result) -> Void) func removeCertificate(completion: @escaping (Result) -> Void) func getClientPointer() throws -> Int64 + /// iOS only - creates a WebSocket task and waits for connection to be established. + func createWebSocketTask(url: String, protocols: [String]?, completion: @escaping (Result) -> Void) + func setRequestHeaders(headers: [String: String]) throws } /// Generated setup class from Pigeon to handle messages through the `binaryMessenger`. @@ -294,5 +331,39 @@ class NetworkApiSetup { } else { getClientPointerChannel.setMessageHandler(nil) } + /// iOS only - creates a WebSocket task and waits for connection to be established. + let createWebSocketTaskChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + createWebSocketTaskChannel.setMessageHandler { message, reply in + let args = message as! [Any?] + let urlArg = args[0] as! String + let protocolsArg: [String]? = nilOrValue(args[1]) + api.createWebSocketTask(url: urlArg, protocols: protocolsArg) { result in + switch result { + case .success(let res): + reply(wrapResult(res)) + case .failure(let error): + reply(wrapError(error)) + } + } + } + } else { + createWebSocketTaskChannel.setMessageHandler(nil) + } + let setRequestHeadersChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.immich_mobile.NetworkApi.setRequestHeaders\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + setRequestHeadersChannel.setMessageHandler { message, reply in + let args = message as! [Any?] + let headersArg = args[0] as! [String: String] + do { + try api.setRequestHeaders(headers: headersArg) + reply(wrapResult(nil)) + } catch { + reply(wrapError(error)) + } + } + } else { + setRequestHeadersChannel.setMessageHandler(nil) + } } } diff --git a/mobile/ios/Runner/Core/NetworkApiImpl.swift b/mobile/ios/Runner/Core/NetworkApiImpl.swift index b9faa89fce..ac314d186a 100644 --- a/mobile/ios/Runner/Core/NetworkApiImpl.swift +++ b/mobile/ios/Runner/Core/NetworkApiImpl.swift @@ -45,6 +45,32 @@ class NetworkApiImpl: NetworkApi { let pointer = URLSessionManager.shared.sessionPointer return Int64(Int(bitPattern: pointer)) } + + func createWebSocketTask( + url: String, + protocols: [String]?, + completion: @escaping (Result) -> Void + ) { + guard let wsUrl = URL(string: url) else { + completion(.failure(WebSocketError.invalidURL(url))) + return + } + + URLSessionManager.shared.createWebSocketTask(url: wsUrl, protocols: protocols) { result in + switch result { + case .success(let (task, proto)): + let pointer = Unmanaged.passUnretained(task).toOpaque() + let address = Int64(Int(bitPattern: pointer)) + completion(.success(WebSocketTaskResult(taskPointer: address, taskProtocol: proto))) + case .failure(let error): + completion(.failure(error)) + } + } + } + + func setRequestHeaders(headers: [String : String]) throws { + URLSessionManager.shared.session.configuration.httpAdditionalHeaders = headers + } } private class CertImporter: NSObject, UIDocumentPickerDelegate { diff --git a/mobile/ios/Runner/Core/URLSessionManager.swift b/mobile/ios/Runner/Core/URLSessionManager.swift index 1e127b312b..06bd42bdb9 100644 --- a/mobile/ios/Runner/Core/URLSessionManager.swift +++ b/mobile/ios/Runner/Core/URLSessionManager.swift @@ -7,6 +7,7 @@ class URLSessionManager: NSObject { static let shared = URLSessionManager() let session: URLSession + let delegate: URLSessionManagerDelegate private let configuration = { let config = URLSessionConfiguration.default @@ -36,12 +37,89 @@ class URLSessionManager: NSObject { } private override init() { - session = URLSession(configuration: configuration, delegate: URLSessionManagerDelegate(), delegateQueue: nil) + delegate = URLSessionManagerDelegate() + session = URLSession(configuration: configuration, delegate: delegate, delegateQueue: nil) super.init() } + + /// Creates a WebSocket task and waits for connection to be established. + func createWebSocketTask( + url: URL, + protocols: [String]?, + completion: @escaping (Result<(URLSessionWebSocketTask, String?), Error>) -> Void + ) { + let task: URLSessionWebSocketTask + if let protocols = protocols, !protocols.isEmpty { + task = session.webSocketTask(with: url, protocols: protocols) + } else { + task = session.webSocketTask(with: url) + } + + delegate.registerWebSocketTask(task) { result in + completion(result) + } + task.resume() + } } -class URLSessionManagerDelegate: NSObject, URLSessionTaskDelegate { +enum WebSocketError: Error { + case connectionFailed(String) + case invalidURL(String) +} + +class URLSessionManagerDelegate: NSObject, URLSessionTaskDelegate, URLSessionWebSocketDelegate { + private var webSocketCompletions: [Int: (Result<(URLSessionWebSocketTask, String?), Error>) -> Void] = [:] + private let lock = { + let lock = UnsafeMutablePointer.allocate(capacity: 1) + lock.initialize(to: os_unfair_lock()) + return lock + }() + + func registerWebSocketTask( + _ task: URLSessionWebSocketTask, + completion: @escaping (Result<(URLSessionWebSocketTask, String?), Error>) -> Void + ) { + os_unfair_lock_lock(lock) + webSocketCompletions[task.taskIdentifier] = completion + os_unfair_lock_unlock(lock) + } + + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol protocol: String? + ) { + os_unfair_lock_lock(lock) + let completion = webSocketCompletions.removeValue(forKey: webSocketTask.taskIdentifier) + os_unfair_lock_unlock(lock) + completion?(.success((webSocketTask, `protocol`))) + } + + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? + ) { + // Close events are handled by CupertinoWebSocket via task.closeCode/closeReason + } + + func urlSession( + _ session: URLSession, + task: URLSessionTask, + didCompleteWithError error: Error? + ) { + guard let webSocketTask = task as? URLSessionWebSocketTask else { return } + + os_unfair_lock_lock(lock) + let completion = webSocketCompletions.removeValue(forKey: webSocketTask.taskIdentifier) + os_unfair_lock_unlock(lock) + + if let error = error { + completion?(.failure(error)) + } + } + func urlSession( _ session: URLSession, didReceive challenge: URLAuthenticationChallenge, diff --git a/mobile/ios/Runner/Images/RemoteImages.g.swift b/mobile/ios/Runner/Images/RemoteImages.g.swift index fc83b09d4b..7781e6f02f 100644 --- a/mobile/ios/Runner/Images/RemoteImages.g.swift +++ b/mobile/ios/Runner/Images/RemoteImages.g.swift @@ -70,7 +70,7 @@ class RemoteImagesPigeonCodec: FlutterStandardMessageCodec, @unchecked Sendable /// Generated protocol from Pigeon that represents a handler of messages from Flutter. protocol RemoteImageApi { - func requestImage(url: String, headers: [String: String], requestId: Int64, completion: @escaping (Result<[String: Int64]?, Error>) -> Void) + func requestImage(url: String, requestId: Int64, completion: @escaping (Result<[String: Int64]?, Error>) -> Void) func cancelRequest(requestId: Int64) throws func clearCache(completion: @escaping (Result) -> Void) } @@ -86,9 +86,8 @@ class RemoteImageApiSetup { requestImageChannel.setMessageHandler { message, reply in let args = message as! [Any?] let urlArg = args[0] as! String - let headersArg = args[1] as! [String: String] - let requestIdArg = args[2] as! Int64 - api.requestImage(url: urlArg, headers: headersArg, requestId: requestIdArg) { result in + let requestIdArg = args[1] as! Int64 + api.requestImage(url: urlArg, requestId: requestIdArg) { result in switch result { case .success(let res): reply(wrapResult(res)) diff --git a/mobile/ios/Runner/Images/RemoteImagesImpl.swift b/mobile/ios/Runner/Images/RemoteImagesImpl.swift index 56e8938521..b886f35905 100644 --- a/mobile/ios/Runner/Images/RemoteImagesImpl.swift +++ b/mobile/ios/Runner/Images/RemoteImagesImpl.swift @@ -33,12 +33,9 @@ class RemoteImageApiImpl: NSObject, RemoteImageApi { kCGImageSourceCreateThumbnailFromImageAlways: true ] as CFDictionary - func requestImage(url: String, headers: [String : String], requestId: Int64, completion: @escaping (Result<[String : Int64]?, any Error>) -> Void) { + func requestImage(url: String, requestId: Int64, completion: @escaping (Result<[String : Int64]?, any Error>) -> Void) { var urlRequest = URLRequest(url: URL(string: url)!) urlRequest.cachePolicy = .returnCacheDataElseLoad - for (key, value) in headers { - urlRequest.setValue(value, forHTTPHeaderField: key) - } let task = URLSessionManager.shared.session.dataTask(with: urlRequest) { data, response, error in Self.handleCompletion(requestId: requestId, data: data, response: response, error: error) diff --git a/mobile/lib/infrastructure/loaders/remote_image_request.dart b/mobile/lib/infrastructure/loaders/remote_image_request.dart index 2da70c3ae1..2544aa088b 100644 --- a/mobile/lib/infrastructure/loaders/remote_image_request.dart +++ b/mobile/lib/infrastructure/loaders/remote_image_request.dart @@ -2,9 +2,8 @@ part of 'image_request.dart'; class RemoteImageRequest extends ImageRequest { final String uri; - final Map headers; - RemoteImageRequest({required this.uri, required this.headers}); + RemoteImageRequest({required this.uri}); @override Future load(ImageDecoderCallback decode, {double scale = 1.0}) async { @@ -12,7 +11,7 @@ class RemoteImageRequest extends ImageRequest { return null; } - final info = await remoteImageApi.requestImage(uri, headers: headers, requestId: requestId); + final info = await remoteImageApi.requestImage(uri, requestId: requestId); final frame = switch (info) { {'pointer': int pointer, 'length': int length} => await _fromEncodedPlatformImage(pointer, length), {'pointer': int pointer, 'width': int width, 'height': int height, 'rowBytes': int rowBytes} => diff --git a/mobile/lib/infrastructure/repositories/network.repository.dart b/mobile/lib/infrastructure/repositories/network.repository.dart index 9029479af8..60c9363c28 100644 --- a/mobile/lib/infrastructure/repositories/network.repository.dart +++ b/mobile/lib/infrastructure/repositories/network.repository.dart @@ -6,18 +6,29 @@ import 'package:http/http.dart' as http; import 'package:immich_mobile/providers/infrastructure/platform.provider.dart'; import 'package:logging/logging.dart'; import 'package:ok_http/ok_http.dart'; +import 'package:web_socket/web_socket.dart'; class NetworkRepository { static final _log = Logger('NetworkRepository'); static http.Client? _client; + static late int _clientPointer; static Future init() async { - final pointer = await networkApi.getClientPointer(); + _clientPointer = await networkApi.getClientPointer(); _client?.close(); if (Platform.isIOS) { - _client = _createIOSClient(pointer); + _client = _createIOSClient(_clientPointer); } else { - _client = _createAndroidClient(pointer); + _client = _createAndroidClient(_clientPointer); + } + } + + // ignore: avoid-unused-parameters + static Future createWebSocket(Uri uri, {Map? headers, Iterable? protocols}) { + if (Platform.isIOS) { + return _createIOSWebSocket(uri, protocols: protocols); + } else { + return _createAndroidWebSocket(uri, protocols: protocols); } } @@ -43,4 +54,16 @@ class NetworkRepository { _log.info('Using shared native OkHttpClient'); return OkHttpClient.fromJniGlobalRef(pointer); } + + static Future _createIOSWebSocket(Uri uri, {Iterable? protocols}) async { + final result = await networkApi.createWebSocketTask(uri.toString(), protocols?.toList()); + final pointer = Pointer.fromAddress(result.taskPointer); + final task = URLSessionWebSocketTask.fromRawPointer(pointer.cast()); + return CupertinoWebSocket.fromConnectedTask(task, protocol: result.taskProtocol ?? ''); + } + + static Future _createAndroidWebSocket(Uri uri, {Iterable? protocols}) { + final pointer = Pointer.fromAddress(_clientPointer); + return OkHttpWebSocket.connectFromJniGlobalRef(pointer, uri, protocols: protocols); + } } diff --git a/mobile/lib/infrastructure/repositories/sync_api.repository.dart b/mobile/lib/infrastructure/repositories/sync_api.repository.dart index 90ff1b1c2e..b2e9cde156 100644 --- a/mobile/lib/infrastructure/repositories/sync_api.repository.dart +++ b/mobile/lib/infrastructure/repositories/sync_api.repository.dart @@ -36,10 +36,6 @@ class SyncApiRepository { final headers = {'Content-Type': 'application/json', 'Accept': 'application/jsonlines+json'}; - final headerParams = {}; - await _api.applyToParams([], headerParams); - headers.addAll(headerParams); - final shouldReset = Store.get(StoreKey.shouldResetSync, false); final request = http.Request('POST', Uri.parse(endpoint)); request.headers.addAll(headers); diff --git a/mobile/lib/pages/common/headers_settings.page.dart b/mobile/lib/pages/common/headers_settings.page.dart index 1cfab355d6..adfcbff1aa 100644 --- a/mobile/lib/pages/common/headers_settings.page.dart +++ b/mobile/lib/pages/common/headers_settings.page.dart @@ -8,6 +8,8 @@ import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:immich_mobile/domain/models/store.model.dart'; import 'package:immich_mobile/entities/store.entity.dart'; import 'package:immich_mobile/generated/intl_keys.g.dart'; +import 'package:immich_mobile/providers/infrastructure/platform.provider.dart'; +import 'package:immich_mobile/services/api.service.dart'; class SettingsHeader { String key = ""; @@ -20,7 +22,6 @@ class HeaderSettingsPage extends HookConsumerWidget { @override Widget build(BuildContext context, WidgetRef ref) { - // final apiService = ref.watch(apiServiceProvider); final headers = useState>([]); final setInitialHeaders = useState(false); @@ -87,7 +88,7 @@ class HeaderSettingsPage extends HookConsumerWidget { ); } - saveHeaders(List headers) { + saveHeaders(List headers) async { final headersMap = {}; for (var header in headers) { final key = header.key.trim(); @@ -98,7 +99,8 @@ class HeaderSettingsPage extends HookConsumerWidget { } var encoded = jsonEncode(headersMap); - Store.put(StoreKey.customHeaders, encoded); + await Store.put(StoreKey.customHeaders, encoded); + await networkApi.setRequestHeaders(ApiService.getRequestHeaders()); } } diff --git a/mobile/lib/platform/network_api.g.dart b/mobile/lib/platform/network_api.g.dart index 92b454c591..3db03a45a3 100644 --- a/mobile/lib/platform/network_api.g.dart +++ b/mobile/lib/platform/network_api.g.dart @@ -112,6 +112,43 @@ class ClientCertPrompt { int get hashCode => Object.hashAll(_toList()); } +class WebSocketTaskResult { + WebSocketTaskResult({required this.taskPointer, this.taskProtocol}); + + int taskPointer; + + String? taskProtocol; + + List _toList() { + return [taskPointer, taskProtocol]; + } + + Object encode() { + return _toList(); + } + + static WebSocketTaskResult decode(Object result) { + result as List; + return WebSocketTaskResult(taskPointer: result[0]! as int, taskProtocol: result[1] as String?); + } + + @override + // ignore: avoid_equals_and_hash_code_on_mutable_classes + bool operator ==(Object other) { + if (other is! WebSocketTaskResult || other.runtimeType != runtimeType) { + return false; + } + if (identical(this, other)) { + return true; + } + return _deepEquals(encode(), other.encode()); + } + + @override + // ignore: avoid_equals_and_hash_code_on_mutable_classes + int get hashCode => Object.hashAll(_toList()); +} + class _PigeonCodec extends StandardMessageCodec { const _PigeonCodec(); @override @@ -125,6 +162,9 @@ class _PigeonCodec extends StandardMessageCodec { } else if (value is ClientCertPrompt) { buffer.putUint8(130); writeValue(buffer, value.encode()); + } else if (value is WebSocketTaskResult) { + buffer.putUint8(131); + writeValue(buffer, value.encode()); } else { super.writeValue(buffer, value); } @@ -137,6 +177,8 @@ class _PigeonCodec extends StandardMessageCodec { return ClientCertData.decode(readValue(buffer)!); case 130: return ClientCertPrompt.decode(readValue(buffer)!); + case 131: + return WebSocketTaskResult.decode(readValue(buffer)!); default: return super.readValueOfType(type, buffer); } @@ -257,4 +299,56 @@ class NetworkApi { return (pigeonVar_replyList[0] as int?)!; } } + + /// iOS only - creates a WebSocket task and waits for connection to be established. + Future createWebSocketTask(String url, List? protocols) async { + final String pigeonVar_channelName = + 'dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$pigeonVar_messageChannelSuffix'; + final BasicMessageChannel pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send([url, protocols]); + final List? pigeonVar_replyList = await pigeonVar_sendFuture as List?; + if (pigeonVar_replyList == null) { + throw _createConnectionError(pigeonVar_channelName); + } else if (pigeonVar_replyList.length > 1) { + throw PlatformException( + code: pigeonVar_replyList[0]! as String, + message: pigeonVar_replyList[1] as String?, + details: pigeonVar_replyList[2], + ); + } else if (pigeonVar_replyList[0] == null) { + throw PlatformException( + code: 'null-error', + message: 'Host platform returned null value for non-null return value.', + ); + } else { + return (pigeonVar_replyList[0] as WebSocketTaskResult?)!; + } + } + + Future setRequestHeaders(Map headers) async { + final String pigeonVar_channelName = + 'dev.flutter.pigeon.immich_mobile.NetworkApi.setRequestHeaders$pigeonVar_messageChannelSuffix'; + final BasicMessageChannel pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send([headers]); + final List? pigeonVar_replyList = await pigeonVar_sendFuture as List?; + if (pigeonVar_replyList == null) { + throw _createConnectionError(pigeonVar_channelName); + } else if (pigeonVar_replyList.length > 1) { + throw PlatformException( + code: pigeonVar_replyList[0]! as String, + message: pigeonVar_replyList[1] as String?, + details: pigeonVar_replyList[2], + ); + } else { + return; + } + } } diff --git a/mobile/lib/platform/remote_image_api.g.dart b/mobile/lib/platform/remote_image_api.g.dart index 410db03ece..f393a89b1c 100644 --- a/mobile/lib/platform/remote_image_api.g.dart +++ b/mobile/lib/platform/remote_image_api.g.dart @@ -49,11 +49,7 @@ class RemoteImageApi { final String pigeonVar_messageChannelSuffix; - Future?> requestImage( - String url, { - required Map headers, - required int requestId, - }) async { + Future?> requestImage(String url, {required int requestId}) async { final String pigeonVar_channelName = 'dev.flutter.pigeon.immich_mobile.RemoteImageApi.requestImage$pigeonVar_messageChannelSuffix'; final BasicMessageChannel pigeonVar_channel = BasicMessageChannel( @@ -61,7 +57,7 @@ class RemoteImageApi { pigeonChannelCodec, binaryMessenger: pigeonVar_binaryMessenger, ); - final Future pigeonVar_sendFuture = pigeonVar_channel.send([url, headers, requestId]); + final Future pigeonVar_sendFuture = pigeonVar_channel.send([url, requestId]); final List? pigeonVar_replyList = await pigeonVar_sendFuture as List?; if (pigeonVar_replyList == null) { throw _createConnectionError(pigeonVar_channelName); diff --git a/mobile/lib/presentation/widgets/images/remote_image_provider.dart b/mobile/lib/presentation/widgets/images/remote_image_provider.dart index 6cb68c1442..8f4358ed68 100644 --- a/mobile/lib/presentation/widgets/images/remote_image_provider.dart +++ b/mobile/lib/presentation/widgets/images/remote_image_provider.dart @@ -6,7 +6,6 @@ import 'package:immich_mobile/domain/services/setting.service.dart'; import 'package:immich_mobile/infrastructure/loaders/image_request.dart'; import 'package:immich_mobile/presentation/widgets/images/image_provider.dart'; import 'package:immich_mobile/presentation/widgets/images/one_frame_multi_image_stream_completer.dart'; -import 'package:immich_mobile/services/api.service.dart'; import 'package:immich_mobile/utils/image_url_builder.dart'; import 'package:openapi/api.dart'; @@ -37,7 +36,7 @@ class RemoteImageProvider extends CancellableImageProvider } Stream _codec(RemoteImageProvider key, ImageDecoderCallback decode) { - final request = this.request = RemoteImageRequest(uri: key.url, headers: ApiService.getRequestHeaders()); + final request = this.request = RemoteImageRequest(uri: key.url); return loadRequest(request, decode); } @@ -88,10 +87,8 @@ class RemoteFullImageProvider extends CancellableImageProvider { Future saveAuthInfo({required String accessToken}) async { await _apiService.setAccessToken(accessToken); + await networkApi.setRequestHeaders(ApiService.getRequestHeaders()); final serverEndpoint = Store.get(StoreKey.serverEndpoint); final customHeaders = Store.tryGet(StoreKey.customHeaders); diff --git a/mobile/lib/providers/websocket.provider.dart b/mobile/lib/providers/websocket.provider.dart index f9473ce440..c31084b0ff 100644 --- a/mobile/lib/providers/websocket.provider.dart +++ b/mobile/lib/providers/websocket.provider.dart @@ -1,18 +1,17 @@ import 'dart:async'; -import 'dart:convert'; import 'package:collection/collection.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:immich_mobile/domain/models/store.model.dart'; import 'package:immich_mobile/entities/asset.entity.dart'; import 'package:immich_mobile/entities/store.entity.dart'; +import 'package:immich_mobile/infrastructure/repositories/network.repository.dart'; import 'package:immich_mobile/models/server_info/server_version.model.dart'; import 'package:immich_mobile/providers/asset.provider.dart'; import 'package:immich_mobile/providers/auth.provider.dart'; import 'package:immich_mobile/providers/background_sync.provider.dart'; import 'package:immich_mobile/providers/db.provider.dart'; import 'package:immich_mobile/providers/server_info.provider.dart'; -import 'package:immich_mobile/services/api.service.dart'; import 'package:immich_mobile/services/sync.service.dart'; import 'package:immich_mobile/utils/debounce.dart'; import 'package:immich_mobile/utils/debug_print.dart'; @@ -99,11 +98,6 @@ class WebsocketNotifier extends StateNotifier { if (authenticationState.isAuthenticated) { try { final endpoint = Uri.parse(Store.get(StoreKey.serverEndpoint)); - final headers = ApiService.getRequestHeaders(); - if (endpoint.userInfo.isNotEmpty) { - headers["Authorization"] = "Basic ${base64.encode(utf8.encode(endpoint.userInfo))}"; - } - dPrint(() => "Attempting to connect to websocket"); // Configure socket transports must be specified Socket socket = io( @@ -111,11 +105,11 @@ class WebsocketNotifier extends StateNotifier { OptionBuilder() .setPath("${endpoint.path}/socket.io") .setTransports(['websocket']) + .setWebSocketConnector(NetworkRepository.createWebSocket) .enableReconnection() .enableForceNew() .enableForceNewConnection() .enableAutoConnect() - .setExtraHeaders(headers) .build(), ); diff --git a/mobile/lib/services/api.service.dart b/mobile/lib/services/api.service.dart index 2c860e68c1..7b1cc0fa44 100644 --- a/mobile/lib/services/api.service.dart +++ b/mobile/lib/services/api.service.dart @@ -129,11 +129,8 @@ class ApiService implements Authentication { Future _getWellKnownEndpoint(String baseUrl) async { try { - var headers = {"Accept": "application/json"}; - headers.addAll(getRequestHeaders()); - final res = await NetworkRepository.client - .get(Uri.parse("$baseUrl/.well-known/immich"), headers: headers) + .get(Uri.parse("$baseUrl/.well-known/immich")) .timeout(const Duration(seconds: 5)); if (res.statusCode == 200) { @@ -197,10 +194,7 @@ class ApiService implements Authentication { @override Future applyToParams(List queryParams, Map headerParams) { - return Future(() { - var headers = ApiService.getRequestHeaders(); - headerParams.addAll(headers); - }); + return Future.value(); } ApiClient get apiClient => _apiClient; diff --git a/mobile/lib/services/auth.service.dart b/mobile/lib/services/auth.service.dart index 0ed9328d3e..c5f3fa6a4a 100644 --- a/mobile/lib/services/auth.service.dart +++ b/mobile/lib/services/auth.service.dart @@ -68,7 +68,7 @@ class AuthService { try { final uri = Uri.parse('$url/users/me'); - final response = await NetworkRepository.client.get(uri, headers: ApiService.getRequestHeaders()); + final response = await NetworkRepository.client.get(uri); if (response.statusCode == 200) { isValid = true; } diff --git a/mobile/pigeon/network_api.dart b/mobile/pigeon/network_api.dart index b9326e2a60..ed20920490 100644 --- a/mobile/pigeon/network_api.dart +++ b/mobile/pigeon/network_api.dart @@ -16,6 +16,13 @@ class ClientCertPrompt { ClientCertPrompt(this.title, this.message, this.cancel, this.confirm); } +class WebSocketTaskResult { + int taskPointer; + String? taskProtocol; + + WebSocketTaskResult(this.taskPointer, this.taskProtocol); +} + @ConfigurePigeon( PigeonOptions( dartOut: 'lib/platform/network_api.g.dart', @@ -40,4 +47,10 @@ abstract class NetworkApi { void removeCertificate(); int getClientPointer(); + + /// iOS only - creates a WebSocket task and waits for connection to be established. + @async + WebSocketTaskResult createWebSocketTask(String url, List? protocols); + + void setRequestHeaders(Map headers); } diff --git a/mobile/pigeon/remote_image_api.dart b/mobile/pigeon/remote_image_api.dart index 749deb828e..50f30456f1 100644 --- a/mobile/pigeon/remote_image_api.dart +++ b/mobile/pigeon/remote_image_api.dart @@ -5,8 +5,7 @@ import 'package:pigeon/pigeon.dart'; dartOut: 'lib/platform/remote_image_api.g.dart', swiftOut: 'ios/Runner/Images/RemoteImages.g.swift', swiftOptions: SwiftOptions(includeErrorClass: false), - kotlinOut: - 'android/app/src/main/kotlin/app/alextran/immich/images/RemoteImages.g.kt', + kotlinOut: 'android/app/src/main/kotlin/app/alextran/immich/images/RemoteImages.g.kt', kotlinOptions: KotlinOptions(package: 'app.alextran.immich.images', includeErrorClass: false), dartOptions: DartOptions(), dartPackageName: 'immich_mobile', @@ -15,11 +14,7 @@ import 'package:pigeon/pigeon.dart'; @HostApi() abstract class RemoteImageApi { @async - Map? requestImage( - String url, { - required Map headers, - required int requestId, - }); + Map? requestImage(String url, {required int requestId}); void cancelRequest(int requestId); diff --git a/mobile/pubspec.lock b/mobile/pubspec.lock index 99296a8ee3..00893823cd 100644 --- a/mobile/pubspec.lock +++ b/mobile/pubspec.lock @@ -333,8 +333,8 @@ packages: dependency: "direct main" description: path: "pkgs/cupertino_http" - ref: "114b2807bdeee641457b5703f411318d722b67b5" - resolved-ref: "114b2807bdeee641457b5703f411318d722b67b5" + ref: "6a28337a5d759bee3d198992e79d9b5c1e80fd3a" + resolved-ref: "6a28337a5d759bee3d198992e79d9b5c1e80fd3a" url: "https://github.com/mertalev/http" source: git version: "3.0.0-wip" @@ -1267,8 +1267,8 @@ packages: dependency: "direct main" description: path: "pkgs/ok_http" - ref: "114b2807bdeee641457b5703f411318d722b67b5" - resolved-ref: "114b2807bdeee641457b5703f411318d722b67b5" + ref: fc43a0bf108c4705a11511f403802528ab1db716 + resolved-ref: fc43a0bf108c4705a11511f403802528ab1db716 url: "https://github.com/mertalev/http" source: git version: "0.1.1-wip" @@ -1727,19 +1727,20 @@ packages: socket_io_client: dependency: "direct main" description: - name: socket_io_client - sha256: ede469f3e4c55e8528b4e023bdedbc20832e8811ab9b61679d1ba3ed5f01f23b - url: "https://pub.dev" - source: hosted - version: "2.0.3+1" + path: "." + ref: e1d813a240b5d5b7e2f141b2b605c5429b7cd006 + resolved-ref: e1d813a240b5d5b7e2f141b2b605c5429b7cd006 + url: "https://github.com/mertalev/socket.io-client-dart" + source: git + version: "3.1.4" socket_io_common: dependency: transitive description: name: socket_io_common - sha256: "2ab92f8ff3ebbd4b353bf4a98bee45cc157e3255464b2f90f66e09c4472047eb" + sha256: "162fbaecbf4bf9a9372a62a341b3550b51dcef2f02f3e5830a297fd48203d45b" url: "https://pub.dev" source: hosted - version: "2.0.3" + version: "3.1.1" source_gen: dependency: transitive description: @@ -2101,21 +2102,21 @@ packages: source: hosted version: "1.1.1" web_socket: - dependency: transitive + dependency: "direct main" description: name: web_socket - sha256: "3c12d96c0c9a4eec095246debcea7b86c0324f22df69893d538fcc6f1b8cce83" + sha256: "34d64019aa8e36bf9842ac014bb5d2f5586ca73df5e4d9bf5c936975cae6982c" url: "https://pub.dev" source: hosted - version: "0.1.6" + version: "1.0.1" web_socket_channel: dependency: transitive description: name: web_socket_channel - sha256: "0b8e2457400d8a859b7b2030786835a28a8e80836ef64402abef392ff4f1d0e5" + sha256: d645757fb0f4773d602444000a8131ff5d48c9e47adfe9772652dd1a4f2d45c8 url: "https://pub.dev" source: hosted - version: "3.0.2" + version: "3.0.3" webdriver: dependency: transitive description: diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index 6663f7c66f..6fa3ca3426 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -75,7 +75,6 @@ dependencies: share_handler: ^0.0.25 share_plus: ^10.1.4 sliver_tools: ^0.2.12 - socket_io_client: ^2.0.3+1 stream_transform: ^2.1.1 thumbhash: 0.1.0+1 timezone: ^0.9.4 @@ -83,16 +82,20 @@ dependencies: uuid: ^4.5.1 wakelock_plus: ^1.3.0 worker_manager: ^7.2.7 - # TODO: upstream these changes + web_socket: ^1.0.1 + socket_io_client: + git: + url: https://github.com/mertalev/socket.io-client-dart + ref: 'e1d813a240b5d5b7e2f141b2b605c5429b7cd006' # https://github.com/rikulo/socket.io-client-dart/pull/435 cupertino_http: git: url: https://github.com/mertalev/http - ref: '114b2807bdeee641457b5703f411318d722b67b5' + ref: '6a28337a5d759bee3d198992e79d9b5c1e80fd3a' # https://github.com/dart-lang/http/pull/1876 path: pkgs/cupertino_http/ ok_http: git: url: https://github.com/mertalev/http - ref: '114b2807bdeee641457b5703f411318d722b67b5' + ref: 'fc43a0bf108c4705a11511f403802528ab1db716' # https://github.com/dart-lang/http/pull/1877 path: pkgs/ok_http/ dev_dependencies: