diff --git a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt index 06e77d9b72..3e34ffa7ac 100644 --- a/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt +++ b/mobile/android/app/src/main/kotlin/app/alextran/immich/core/Network.g.kt @@ -145,6 +145,37 @@ data class ClientCertPrompt ( override fun hashCode(): Int = toList().hashCode() } + +/** Generated class from Pigeon that represents data sent in messages. */ +data class WebSocketTaskResult ( + val taskPointer: Long, + val taskProtocol: String? = null +) + { + companion object { + fun fromList(pigeonVar_list: List): WebSocketTaskResult { + val taskPointer = pigeonVar_list[0] as Long + val taskProtocol = pigeonVar_list[1] as String? + return WebSocketTaskResult(taskPointer, taskProtocol) + } + } + fun toList(): List { + return listOf( + taskPointer, + taskProtocol, + ) + } + override fun equals(other: Any?): Boolean { + if (other !is WebSocketTaskResult) { + return false + } + if (this === other) { + return true + } + return NetworkPigeonUtils.deepEquals(toList(), other.toList()) } + + override fun hashCode(): Int = toList().hashCode() +} private open class NetworkPigeonCodec : StandardMessageCodec() { override fun readValueOfType(type: Byte, buffer: ByteBuffer): Any? { return when (type) { @@ -158,6 +189,11 @@ private open class NetworkPigeonCodec : StandardMessageCodec() { ClientCertPrompt.fromList(it) } } + 131.toByte() -> { + return (readValue(buffer) as? List)?.let { + WebSocketTaskResult.fromList(it) + } + } else -> super.readValueOfType(type, buffer) } } @@ -171,6 +207,10 @@ private open class NetworkPigeonCodec : StandardMessageCodec() { stream.write(130) writeValue(stream, value.toList()) } + is WebSocketTaskResult -> { + stream.write(131) + writeValue(stream, value.toList()) + } else -> super.writeValue(stream, value) } } @@ -183,6 +223,11 @@ interface NetworkApi { fun selectCertificate(promptText: ClientCertPrompt, callback: (Result) -> Unit) fun removeCertificate(callback: (Result) -> Unit) fun getClientPointer(): Long + /** + * Creates a WebSocket task and waits for connection to be established. + * iOS only - Android should use OkHttpWebSocket.connectWithClient directly. + */ + fun createWebSocketTask(url: String, protocols: List?, callback: (Result) -> Unit) companion object { /** The codec used by NetworkApi. */ @@ -264,6 +309,27 @@ interface NetworkApi { channel.setMessageHandler(null) } } + run { + val channel = BasicMessageChannel(binaryMessenger, "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$separatedMessageChannelSuffix", codec) + if (api != null) { + channel.setMessageHandler { message, reply -> + val args = message as List + val urlArg = args[0] as String + val protocolsArg = args[1] as List? + api.createWebSocketTask(urlArg, protocolsArg) { result: Result -> + val error = result.exceptionOrNull() + if (error != null) { + reply.reply(NetworkPigeonUtils.wrapError(error)) + } else { + val data = result.getOrNull() + reply.reply(NetworkPigeonUtils.wrapResult(data)) + } + } + } + } else { + channel.setMessageHandler(null) + } + } } } } diff --git a/mobile/ios/Runner/Core/Network.g.swift b/mobile/ios/Runner/Core/Network.g.swift index 9dd81ee92b..9018f05f49 100644 --- a/mobile/ios/Runner/Core/Network.g.swift +++ b/mobile/ios/Runner/Core/Network.g.swift @@ -176,6 +176,35 @@ struct ClientCertPrompt: Hashable { } } +/// Generated class from Pigeon that represents data sent in messages. +struct WebSocketTaskResult: Hashable { + var taskPointer: Int64 + var taskProtocol: String? = nil + + + // swift-format-ignore: AlwaysUseLowerCamelCase + static func fromList(_ pigeonVar_list: [Any?]) -> WebSocketTaskResult? { + let taskPointer = pigeonVar_list[0] as! Int64 + let taskProtocol: String? = nilOrValue(pigeonVar_list[1]) + + return WebSocketTaskResult( + taskPointer: taskPointer, + taskProtocol: taskProtocol + ) + } + func toList() -> [Any?] { + return [ + taskPointer, + taskProtocol, + ] + } + static func == (lhs: WebSocketTaskResult, rhs: WebSocketTaskResult) -> Bool { + return deepEqualsNetwork(lhs.toList(), rhs.toList()) } + func hash(into hasher: inout Hasher) { + deepHashNetwork(value: toList(), hasher: &hasher) + } +} + private class NetworkPigeonCodecReader: FlutterStandardReader { override func readValue(ofType type: UInt8) -> Any? { switch type { @@ -183,6 +212,8 @@ private class NetworkPigeonCodecReader: FlutterStandardReader { return ClientCertData.fromList(self.readValue() as! [Any?]) case 130: return ClientCertPrompt.fromList(self.readValue() as! [Any?]) + case 131: + return WebSocketTaskResult.fromList(self.readValue() as! [Any?]) default: return super.readValue(ofType: type) } @@ -197,6 +228,9 @@ private class NetworkPigeonCodecWriter: FlutterStandardWriter { } else if let value = value as? ClientCertPrompt { super.writeByte(130) super.writeValue(value.toList()) + } else if let value = value as? WebSocketTaskResult { + super.writeByte(131) + super.writeValue(value.toList()) } else { super.writeValue(value) } @@ -224,6 +258,9 @@ protocol NetworkApi { func selectCertificate(promptText: ClientCertPrompt, completion: @escaping (Result) -> Void) func removeCertificate(completion: @escaping (Result) -> Void) func getClientPointer() throws -> Int64 + /// Creates a WebSocket task and waits for connection to be established. + /// iOS only - Android should use OkHttpWebSocket.connectWithClient directly. + func createWebSocketTask(url: String, protocols: [String]?, completion: @escaping (Result) -> Void) } /// Generated setup class from Pigeon to handle messages through the `binaryMessenger`. @@ -294,5 +331,25 @@ class NetworkApiSetup { } else { getClientPointerChannel.setMessageHandler(nil) } + /// Creates a WebSocket task and waits for connection to be established. + /// iOS only - Android should use OkHttpWebSocket.connectWithClient directly. + let createWebSocketTaskChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec) + if let api = api { + createWebSocketTaskChannel.setMessageHandler { message, reply in + let args = message as! [Any?] + let urlArg = args[0] as! String + let protocolsArg: [String]? = nilOrValue(args[1]) + api.createWebSocketTask(url: urlArg, protocols: protocolsArg) { result in + switch result { + case .success(let res): + reply(wrapResult(res)) + case .failure(let error): + reply(wrapError(error)) + } + } + } + } else { + createWebSocketTaskChannel.setMessageHandler(nil) + } } } diff --git a/mobile/ios/Runner/Core/NetworkApiImpl.swift b/mobile/ios/Runner/Core/NetworkApiImpl.swift index b9faa89fce..51e746fa08 100644 --- a/mobile/ios/Runner/Core/NetworkApiImpl.swift +++ b/mobile/ios/Runner/Core/NetworkApiImpl.swift @@ -45,6 +45,28 @@ class NetworkApiImpl: NetworkApi { let pointer = URLSessionManager.shared.sessionPointer return Int64(Int(bitPattern: pointer)) } + + func createWebSocketTask( + url: String, + protocols: [String]?, + completion: @escaping (Result) -> Void + ) { + guard let wsUrl = URL(string: url) else { + completion(.failure(WebSocketError.invalidURL(url))) + return + } + + URLSessionManager.shared.createWebSocketTask(url: wsUrl, protocols: protocols) { result in + switch result { + case .success(let (task, proto)): + let pointer = Unmanaged.passUnretained(task).toOpaque() + let address = Int64(Int(bitPattern: pointer)) + completion(.success(WebSocketTaskResult(taskPointer: address, taskProtocol: proto))) + case .failure(let error): + completion(.failure(error)) + } + } + } } private class CertImporter: NSObject, UIDocumentPickerDelegate { diff --git a/mobile/ios/Runner/Core/URLSessionManager.swift b/mobile/ios/Runner/Core/URLSessionManager.swift index 1e127b312b..1c6c489e4f 100644 --- a/mobile/ios/Runner/Core/URLSessionManager.swift +++ b/mobile/ios/Runner/Core/URLSessionManager.swift @@ -7,6 +7,7 @@ class URLSessionManager: NSObject { static let shared = URLSessionManager() let session: URLSession + let delegate: URLSessionManagerDelegate private let configuration = { let config = URLSessionConfiguration.default @@ -36,12 +37,85 @@ class URLSessionManager: NSObject { } private override init() { - session = URLSession(configuration: configuration, delegate: URLSessionManagerDelegate(), delegateQueue: nil) + delegate = URLSessionManagerDelegate() + session = URLSession(configuration: configuration, delegate: delegate, delegateQueue: nil) super.init() } + + /// Creates a WebSocket task and waits for connection to be established. + func createWebSocketTask( + url: URL, + protocols: [String]?, + completion: @escaping (Result<(URLSessionWebSocketTask, String?), Error>) -> Void + ) { + let task: URLSessionWebSocketTask + if let protocols = protocols, !protocols.isEmpty { + task = session.webSocketTask(with: url, protocols: protocols) + } else { + task = session.webSocketTask(with: url) + } + + delegate.registerWebSocketTask(task) { result in + completion(result) + } + task.resume() + } } -class URLSessionManagerDelegate: NSObject, URLSessionTaskDelegate { +enum WebSocketError: Error { + case connectionFailed(String) + case invalidURL(String) +} + +class URLSessionManagerDelegate: NSObject, URLSessionTaskDelegate, URLSessionWebSocketDelegate { + private var webSocketCompletions: [URLSessionWebSocketTask: (Result<(URLSessionWebSocketTask, String?), Error>) -> Void] = [:] + private let lock = NSLock() + + func registerWebSocketTask( + _ task: URLSessionWebSocketTask, + completion: @escaping (Result<(URLSessionWebSocketTask, String?), Error>) -> Void + ) { + lock.lock() + webSocketCompletions[task] = completion + lock.unlock() + } + + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol protocol: String? + ) { + lock.lock() + let completion = webSocketCompletions.removeValue(forKey: webSocketTask) + lock.unlock() + completion?(.success((webSocketTask, `protocol`))) + } + + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? + ) { + // Close events are handled by CupertinoWebSocket via task.closeCode/closeReason + } + + func urlSession( + _ session: URLSession, + task: URLSessionTask, + didCompleteWithError error: Error? + ) { + guard let webSocketTask = task as? URLSessionWebSocketTask else { return } + + lock.lock() + let completion = webSocketCompletions.removeValue(forKey: webSocketTask) + lock.unlock() + + if let error = error { + completion?(.failure(error)) + } + } + func urlSession( _ session: URLSession, didReceive challenge: URLAuthenticationChallenge, diff --git a/mobile/lib/infrastructure/repositories/network.repository.dart b/mobile/lib/infrastructure/repositories/network.repository.dart index 9029479af8..071027d414 100644 --- a/mobile/lib/infrastructure/repositories/network.repository.dart +++ b/mobile/lib/infrastructure/repositories/network.repository.dart @@ -3,21 +3,33 @@ import 'dart:io'; import 'package:cupertino_http/cupertino_http.dart'; import 'package:http/http.dart' as http; +import 'package:immich_mobile/extensions/platform_extensions.dart'; import 'package:immich_mobile/providers/infrastructure/platform.provider.dart'; import 'package:logging/logging.dart'; import 'package:ok_http/ok_http.dart'; +import 'package:web_socket/web_socket.dart'; class NetworkRepository { static final _log = Logger('NetworkRepository'); static http.Client? _client; + static late int _clientPointer; static Future init() async { - final pointer = await networkApi.getClientPointer(); + _clientPointer = await networkApi.getClientPointer(); _client?.close(); if (Platform.isIOS) { - _client = _createIOSClient(pointer); + _client = _createIOSClient(_clientPointer); } else { - _client = _createAndroidClient(pointer); + _client = _createAndroidClient(_clientPointer); + } + } + + // ignore: avoid-unused-parameters + static Future createWebSocket(Uri uri, {Map? headers, Iterable? protocols}) { + if (CurrentPlatform.isIOS) { + return _createIOSWebSocket(uri, protocols: protocols); + } else { + return _createAndroidWebSocket(uri, protocols: protocols); } } @@ -43,4 +55,16 @@ class NetworkRepository { _log.info('Using shared native OkHttpClient'); return OkHttpClient.fromJniGlobalRef(pointer); } + + static Future _createIOSWebSocket(Uri uri, {Iterable? protocols}) async { + final result = await networkApi.createWebSocketTask(uri.toString(), protocols?.toList()); + final pointer = Pointer.fromAddress(result.taskPointer); + final task = URLSessionWebSocketTask.fromRawPointer(pointer.cast()); + return CupertinoWebSocket.fromConnectedTask(task, protocol: result.protocol ?? ''); + } + + static Future _createAndroidWebSocket(Uri uri, {Iterable? protocols}) { + final pointer = Pointer.fromAddress(_clientPointer); + return OkHttpWebSocket.connectFromJniGlobalRef(pointer, uri, protocols: protocols); + } } diff --git a/mobile/lib/platform/network_api.g.dart b/mobile/lib/platform/network_api.g.dart index 92b454c591..6f3e860568 100644 --- a/mobile/lib/platform/network_api.g.dart +++ b/mobile/lib/platform/network_api.g.dart @@ -112,6 +112,43 @@ class ClientCertPrompt { int get hashCode => Object.hashAll(_toList()); } +class WebSocketTaskResult { + WebSocketTaskResult({required this.taskPointer, this.taskProtocol}); + + int taskPointer; + + String? taskProtocol; + + List _toList() { + return [taskPointer, taskProtocol]; + } + + Object encode() { + return _toList(); + } + + static WebSocketTaskResult decode(Object result) { + result as List; + return WebSocketTaskResult(taskPointer: result[0]! as int, taskProtocol: result[1] as String?); + } + + @override + // ignore: avoid_equals_and_hash_code_on_mutable_classes + bool operator ==(Object other) { + if (other is! WebSocketTaskResult || other.runtimeType != runtimeType) { + return false; + } + if (identical(this, other)) { + return true; + } + return _deepEquals(encode(), other.encode()); + } + + @override + // ignore: avoid_equals_and_hash_code_on_mutable_classes + int get hashCode => Object.hashAll(_toList()); +} + class _PigeonCodec extends StandardMessageCodec { const _PigeonCodec(); @override @@ -125,6 +162,9 @@ class _PigeonCodec extends StandardMessageCodec { } else if (value is ClientCertPrompt) { buffer.putUint8(130); writeValue(buffer, value.encode()); + } else if (value is WebSocketTaskResult) { + buffer.putUint8(131); + writeValue(buffer, value.encode()); } else { super.writeValue(buffer, value); } @@ -137,6 +177,8 @@ class _PigeonCodec extends StandardMessageCodec { return ClientCertData.decode(readValue(buffer)!); case 130: return ClientCertPrompt.decode(readValue(buffer)!); + case 131: + return WebSocketTaskResult.decode(readValue(buffer)!); default: return super.readValueOfType(type, buffer); } @@ -257,4 +299,34 @@ class NetworkApi { return (pigeonVar_replyList[0] as int?)!; } } + + /// Creates a WebSocket task and waits for connection to be established. + /// iOS only - Android should use OkHttpWebSocket.connectWithClient directly. + Future createWebSocketTask(String url, List? protocols) async { + final String pigeonVar_channelName = + 'dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$pigeonVar_messageChannelSuffix'; + final BasicMessageChannel pigeonVar_channel = BasicMessageChannel( + pigeonVar_channelName, + pigeonChannelCodec, + binaryMessenger: pigeonVar_binaryMessenger, + ); + final Future pigeonVar_sendFuture = pigeonVar_channel.send([url, protocols]); + final List? pigeonVar_replyList = await pigeonVar_sendFuture as List?; + if (pigeonVar_replyList == null) { + throw _createConnectionError(pigeonVar_channelName); + } else if (pigeonVar_replyList.length > 1) { + throw PlatformException( + code: pigeonVar_replyList[0]! as String, + message: pigeonVar_replyList[1] as String?, + details: pigeonVar_replyList[2], + ); + } else if (pigeonVar_replyList[0] == null) { + throw PlatformException( + code: 'null-error', + message: 'Host platform returned null value for non-null return value.', + ); + } else { + return (pigeonVar_replyList[0] as WebSocketTaskResult?)!; + } + } } diff --git a/mobile/lib/providers/websocket.provider.dart b/mobile/lib/providers/websocket.provider.dart index f9473ce440..d46428c50f 100644 --- a/mobile/lib/providers/websocket.provider.dart +++ b/mobile/lib/providers/websocket.provider.dart @@ -6,6 +6,7 @@ import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:immich_mobile/domain/models/store.model.dart'; import 'package:immich_mobile/entities/asset.entity.dart'; import 'package:immich_mobile/entities/store.entity.dart'; +import 'package:immich_mobile/infrastructure/repositories/network.repository.dart'; import 'package:immich_mobile/models/server_info/server_version.model.dart'; import 'package:immich_mobile/providers/asset.provider.dart'; import 'package:immich_mobile/providers/auth.provider.dart'; @@ -111,6 +112,7 @@ class WebsocketNotifier extends StateNotifier { OptionBuilder() .setPath("${endpoint.path}/socket.io") .setTransports(['websocket']) + .setWebSocketConnector(NetworkRepository.createWebSocket) .enableReconnection() .enableForceNew() .enableForceNewConnection() diff --git a/mobile/pigeon/network_api.dart b/mobile/pigeon/network_api.dart index b9326e2a60..da4a0be3b2 100644 --- a/mobile/pigeon/network_api.dart +++ b/mobile/pigeon/network_api.dart @@ -16,6 +16,13 @@ class ClientCertPrompt { ClientCertPrompt(this.title, this.message, this.cancel, this.confirm); } +class WebSocketTaskResult { + int taskPointer; + String? taskProtocol; + + WebSocketTaskResult(this.taskPointer, this.taskProtocol); +} + @ConfigurePigeon( PigeonOptions( dartOut: 'lib/platform/network_api.g.dart', @@ -40,4 +47,9 @@ abstract class NetworkApi { void removeCertificate(); int getClientPointer(); + + /// Creates a WebSocket task and waits for connection to be established. + /// iOS only - Android should use OkHttpWebSocket.connectWithClient directly. + @async + WebSocketTaskResult createWebSocketTask(String url, List? protocols); }