websocket integration

This commit is contained in:
mertalev
2026-02-02 16:39:51 -05:00
parent aff7e7e199
commit 577da2bc85
8 changed files with 334 additions and 5 deletions

View File

@@ -145,6 +145,37 @@ data class ClientCertPrompt (
override fun hashCode(): Int = toList().hashCode()
}
/** Generated class from Pigeon that represents data sent in messages. */
data class WebSocketTaskResult (
val taskPointer: Long,
val taskProtocol: String? = null
)
{
companion object {
fun fromList(pigeonVar_list: List<Any?>): WebSocketTaskResult {
val taskPointer = pigeonVar_list[0] as Long
val taskProtocol = pigeonVar_list[1] as String?
return WebSocketTaskResult(taskPointer, taskProtocol)
}
}
fun toList(): List<Any?> {
return listOf(
taskPointer,
taskProtocol,
)
}
override fun equals(other: Any?): Boolean {
if (other !is WebSocketTaskResult) {
return false
}
if (this === other) {
return true
}
return NetworkPigeonUtils.deepEquals(toList(), other.toList()) }
override fun hashCode(): Int = toList().hashCode()
}
private open class NetworkPigeonCodec : StandardMessageCodec() {
override fun readValueOfType(type: Byte, buffer: ByteBuffer): Any? {
return when (type) {
@@ -158,6 +189,11 @@ private open class NetworkPigeonCodec : StandardMessageCodec() {
ClientCertPrompt.fromList(it)
}
}
131.toByte() -> {
return (readValue(buffer) as? List<Any?>)?.let {
WebSocketTaskResult.fromList(it)
}
}
else -> super.readValueOfType(type, buffer)
}
}
@@ -171,6 +207,10 @@ private open class NetworkPigeonCodec : StandardMessageCodec() {
stream.write(130)
writeValue(stream, value.toList())
}
is WebSocketTaskResult -> {
stream.write(131)
writeValue(stream, value.toList())
}
else -> super.writeValue(stream, value)
}
}
@@ -183,6 +223,11 @@ interface NetworkApi {
fun selectCertificate(promptText: ClientCertPrompt, callback: (Result<ClientCertData>) -> Unit)
fun removeCertificate(callback: (Result<Unit>) -> Unit)
fun getClientPointer(): Long
/**
* Creates a WebSocket task and waits for connection to be established.
* iOS only - Android should use OkHttpWebSocket.connectWithClient directly.
*/
fun createWebSocketTask(url: String, protocols: List<String>?, callback: (Result<WebSocketTaskResult>) -> Unit)
companion object {
/** The codec used by NetworkApi. */
@@ -264,6 +309,27 @@ interface NetworkApi {
channel.setMessageHandler(null)
}
}
run {
val channel = BasicMessageChannel<Any?>(binaryMessenger, "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$separatedMessageChannelSuffix", codec)
if (api != null) {
channel.setMessageHandler { message, reply ->
val args = message as List<Any?>
val urlArg = args[0] as String
val protocolsArg = args[1] as List<String>?
api.createWebSocketTask(urlArg, protocolsArg) { result: Result<WebSocketTaskResult> ->
val error = result.exceptionOrNull()
if (error != null) {
reply.reply(NetworkPigeonUtils.wrapError(error))
} else {
val data = result.getOrNull()
reply.reply(NetworkPigeonUtils.wrapResult(data))
}
}
}
} else {
channel.setMessageHandler(null)
}
}
}
}
}

View File

@@ -176,6 +176,35 @@ struct ClientCertPrompt: Hashable {
}
}
/// Generated class from Pigeon that represents data sent in messages.
struct WebSocketTaskResult: Hashable {
var taskPointer: Int64
var taskProtocol: String? = nil
// swift-format-ignore: AlwaysUseLowerCamelCase
static func fromList(_ pigeonVar_list: [Any?]) -> WebSocketTaskResult? {
let taskPointer = pigeonVar_list[0] as! Int64
let taskProtocol: String? = nilOrValue(pigeonVar_list[1])
return WebSocketTaskResult(
taskPointer: taskPointer,
taskProtocol: taskProtocol
)
}
func toList() -> [Any?] {
return [
taskPointer,
taskProtocol,
]
}
static func == (lhs: WebSocketTaskResult, rhs: WebSocketTaskResult) -> Bool {
return deepEqualsNetwork(lhs.toList(), rhs.toList()) }
func hash(into hasher: inout Hasher) {
deepHashNetwork(value: toList(), hasher: &hasher)
}
}
private class NetworkPigeonCodecReader: FlutterStandardReader {
override func readValue(ofType type: UInt8) -> Any? {
switch type {
@@ -183,6 +212,8 @@ private class NetworkPigeonCodecReader: FlutterStandardReader {
return ClientCertData.fromList(self.readValue() as! [Any?])
case 130:
return ClientCertPrompt.fromList(self.readValue() as! [Any?])
case 131:
return WebSocketTaskResult.fromList(self.readValue() as! [Any?])
default:
return super.readValue(ofType: type)
}
@@ -197,6 +228,9 @@ private class NetworkPigeonCodecWriter: FlutterStandardWriter {
} else if let value = value as? ClientCertPrompt {
super.writeByte(130)
super.writeValue(value.toList())
} else if let value = value as? WebSocketTaskResult {
super.writeByte(131)
super.writeValue(value.toList())
} else {
super.writeValue(value)
}
@@ -224,6 +258,9 @@ protocol NetworkApi {
func selectCertificate(promptText: ClientCertPrompt, completion: @escaping (Result<ClientCertData, Error>) -> Void)
func removeCertificate(completion: @escaping (Result<Void, Error>) -> Void)
func getClientPointer() throws -> Int64
/// Creates a WebSocket task and waits for connection to be established.
/// iOS only - Android should use OkHttpWebSocket.connectWithClient directly.
func createWebSocketTask(url: String, protocols: [String]?, completion: @escaping (Result<WebSocketTaskResult, Error>) -> Void)
}
/// Generated setup class from Pigeon to handle messages through the `binaryMessenger`.
@@ -294,5 +331,25 @@ class NetworkApiSetup {
} else {
getClientPointerChannel.setMessageHandler(nil)
}
/// Creates a WebSocket task and waits for connection to be established.
/// iOS only - Android should use OkHttpWebSocket.connectWithClient directly.
let createWebSocketTaskChannel = FlutterBasicMessageChannel(name: "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask\(channelSuffix)", binaryMessenger: binaryMessenger, codec: codec)
if let api = api {
createWebSocketTaskChannel.setMessageHandler { message, reply in
let args = message as! [Any?]
let urlArg = args[0] as! String
let protocolsArg: [String]? = nilOrValue(args[1])
api.createWebSocketTask(url: urlArg, protocols: protocolsArg) { result in
switch result {
case .success(let res):
reply(wrapResult(res))
case .failure(let error):
reply(wrapError(error))
}
}
}
} else {
createWebSocketTaskChannel.setMessageHandler(nil)
}
}
}

View File

@@ -45,6 +45,28 @@ class NetworkApiImpl: NetworkApi {
let pointer = URLSessionManager.shared.sessionPointer
return Int64(Int(bitPattern: pointer))
}
func createWebSocketTask(
url: String,
protocols: [String]?,
completion: @escaping (Result<WebSocketTaskResult, any Error>) -> Void
) {
guard let wsUrl = URL(string: url) else {
completion(.failure(WebSocketError.invalidURL(url)))
return
}
URLSessionManager.shared.createWebSocketTask(url: wsUrl, protocols: protocols) { result in
switch result {
case .success(let (task, proto)):
let pointer = Unmanaged.passUnretained(task).toOpaque()
let address = Int64(Int(bitPattern: pointer))
completion(.success(WebSocketTaskResult(taskPointer: address, taskProtocol: proto)))
case .failure(let error):
completion(.failure(error))
}
}
}
}
private class CertImporter: NSObject, UIDocumentPickerDelegate {

View File

@@ -7,6 +7,7 @@ class URLSessionManager: NSObject {
static let shared = URLSessionManager()
let session: URLSession
let delegate: URLSessionManagerDelegate
private let configuration = {
let config = URLSessionConfiguration.default
@@ -36,12 +37,85 @@ class URLSessionManager: NSObject {
}
private override init() {
session = URLSession(configuration: configuration, delegate: URLSessionManagerDelegate(), delegateQueue: nil)
delegate = URLSessionManagerDelegate()
session = URLSession(configuration: configuration, delegate: delegate, delegateQueue: nil)
super.init()
}
/// Creates a WebSocket task and waits for connection to be established.
func createWebSocketTask(
url: URL,
protocols: [String]?,
completion: @escaping (Result<(URLSessionWebSocketTask, String?), Error>) -> Void
) {
let task: URLSessionWebSocketTask
if let protocols = protocols, !protocols.isEmpty {
task = session.webSocketTask(with: url, protocols: protocols)
} else {
task = session.webSocketTask(with: url)
}
delegate.registerWebSocketTask(task) { result in
completion(result)
}
task.resume()
}
}
class URLSessionManagerDelegate: NSObject, URLSessionTaskDelegate {
enum WebSocketError: Error {
case connectionFailed(String)
case invalidURL(String)
}
class URLSessionManagerDelegate: NSObject, URLSessionTaskDelegate, URLSessionWebSocketDelegate {
private var webSocketCompletions: [URLSessionWebSocketTask: (Result<(URLSessionWebSocketTask, String?), Error>) -> Void] = [:]
private let lock = NSLock()
func registerWebSocketTask(
_ task: URLSessionWebSocketTask,
completion: @escaping (Result<(URLSessionWebSocketTask, String?), Error>) -> Void
) {
lock.lock()
webSocketCompletions[task] = completion
lock.unlock()
}
func urlSession(
_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
didOpenWithProtocol protocol: String?
) {
lock.lock()
let completion = webSocketCompletions.removeValue(forKey: webSocketTask)
lock.unlock()
completion?(.success((webSocketTask, `protocol`)))
}
func urlSession(
_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
reason: Data?
) {
// Close events are handled by CupertinoWebSocket via task.closeCode/closeReason
}
func urlSession(
_ session: URLSession,
task: URLSessionTask,
didCompleteWithError error: Error?
) {
guard let webSocketTask = task as? URLSessionWebSocketTask else { return }
lock.lock()
let completion = webSocketCompletions.removeValue(forKey: webSocketTask)
lock.unlock()
if let error = error {
completion?(.failure(error))
}
}
func urlSession(
_ session: URLSession,
didReceive challenge: URLAuthenticationChallenge,

View File

@@ -3,21 +3,33 @@ import 'dart:io';
import 'package:cupertino_http/cupertino_http.dart';
import 'package:http/http.dart' as http;
import 'package:immich_mobile/extensions/platform_extensions.dart';
import 'package:immich_mobile/providers/infrastructure/platform.provider.dart';
import 'package:logging/logging.dart';
import 'package:ok_http/ok_http.dart';
import 'package:web_socket/web_socket.dart';
class NetworkRepository {
static final _log = Logger('NetworkRepository');
static http.Client? _client;
static late int _clientPointer;
static Future<void> init() async {
final pointer = await networkApi.getClientPointer();
_clientPointer = await networkApi.getClientPointer();
_client?.close();
if (Platform.isIOS) {
_client = _createIOSClient(pointer);
_client = _createIOSClient(_clientPointer);
} else {
_client = _createAndroidClient(pointer);
_client = _createAndroidClient(_clientPointer);
}
}
// ignore: avoid-unused-parameters
static Future<WebSocket> createWebSocket(Uri uri, {Map<String, String>? headers, Iterable<String>? protocols}) {
if (CurrentPlatform.isIOS) {
return _createIOSWebSocket(uri, protocols: protocols);
} else {
return _createAndroidWebSocket(uri, protocols: protocols);
}
}
@@ -43,4 +55,16 @@ class NetworkRepository {
_log.info('Using shared native OkHttpClient');
return OkHttpClient.fromJniGlobalRef(pointer);
}
static Future<WebSocket> _createIOSWebSocket(Uri uri, {Iterable<String>? protocols}) async {
final result = await networkApi.createWebSocketTask(uri.toString(), protocols?.toList());
final pointer = Pointer.fromAddress(result.taskPointer);
final task = URLSessionWebSocketTask.fromRawPointer(pointer.cast());
return CupertinoWebSocket.fromConnectedTask(task, protocol: result.protocol ?? '');
}
static Future<WebSocket> _createAndroidWebSocket(Uri uri, {Iterable<String>? protocols}) {
final pointer = Pointer<Void>.fromAddress(_clientPointer);
return OkHttpWebSocket.connectFromJniGlobalRef(pointer, uri, protocols: protocols);
}
}

View File

@@ -112,6 +112,43 @@ class ClientCertPrompt {
int get hashCode => Object.hashAll(_toList());
}
class WebSocketTaskResult {
WebSocketTaskResult({required this.taskPointer, this.taskProtocol});
int taskPointer;
String? taskProtocol;
List<Object?> _toList() {
return <Object?>[taskPointer, taskProtocol];
}
Object encode() {
return _toList();
}
static WebSocketTaskResult decode(Object result) {
result as List<Object?>;
return WebSocketTaskResult(taskPointer: result[0]! as int, taskProtocol: result[1] as String?);
}
@override
// ignore: avoid_equals_and_hash_code_on_mutable_classes
bool operator ==(Object other) {
if (other is! WebSocketTaskResult || other.runtimeType != runtimeType) {
return false;
}
if (identical(this, other)) {
return true;
}
return _deepEquals(encode(), other.encode());
}
@override
// ignore: avoid_equals_and_hash_code_on_mutable_classes
int get hashCode => Object.hashAll(_toList());
}
class _PigeonCodec extends StandardMessageCodec {
const _PigeonCodec();
@override
@@ -125,6 +162,9 @@ class _PigeonCodec extends StandardMessageCodec {
} else if (value is ClientCertPrompt) {
buffer.putUint8(130);
writeValue(buffer, value.encode());
} else if (value is WebSocketTaskResult) {
buffer.putUint8(131);
writeValue(buffer, value.encode());
} else {
super.writeValue(buffer, value);
}
@@ -137,6 +177,8 @@ class _PigeonCodec extends StandardMessageCodec {
return ClientCertData.decode(readValue(buffer)!);
case 130:
return ClientCertPrompt.decode(readValue(buffer)!);
case 131:
return WebSocketTaskResult.decode(readValue(buffer)!);
default:
return super.readValueOfType(type, buffer);
}
@@ -257,4 +299,34 @@ class NetworkApi {
return (pigeonVar_replyList[0] as int?)!;
}
}
/// Creates a WebSocket task and waits for connection to be established.
/// iOS only - Android should use OkHttpWebSocket.connectWithClient directly.
Future<WebSocketTaskResult> createWebSocketTask(String url, List<String>? protocols) async {
final String pigeonVar_channelName =
'dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$pigeonVar_messageChannelSuffix';
final BasicMessageChannel<Object?> pigeonVar_channel = BasicMessageChannel<Object?>(
pigeonVar_channelName,
pigeonChannelCodec,
binaryMessenger: pigeonVar_binaryMessenger,
);
final Future<Object?> pigeonVar_sendFuture = pigeonVar_channel.send(<Object?>[url, protocols]);
final List<Object?>? pigeonVar_replyList = await pigeonVar_sendFuture as List<Object?>?;
if (pigeonVar_replyList == null) {
throw _createConnectionError(pigeonVar_channelName);
} else if (pigeonVar_replyList.length > 1) {
throw PlatformException(
code: pigeonVar_replyList[0]! as String,
message: pigeonVar_replyList[1] as String?,
details: pigeonVar_replyList[2],
);
} else if (pigeonVar_replyList[0] == null) {
throw PlatformException(
code: 'null-error',
message: 'Host platform returned null value for non-null return value.',
);
} else {
return (pigeonVar_replyList[0] as WebSocketTaskResult?)!;
}
}
}

View File

@@ -6,6 +6,7 @@ import 'package:hooks_riverpod/hooks_riverpod.dart';
import 'package:immich_mobile/domain/models/store.model.dart';
import 'package:immich_mobile/entities/asset.entity.dart';
import 'package:immich_mobile/entities/store.entity.dart';
import 'package:immich_mobile/infrastructure/repositories/network.repository.dart';
import 'package:immich_mobile/models/server_info/server_version.model.dart';
import 'package:immich_mobile/providers/asset.provider.dart';
import 'package:immich_mobile/providers/auth.provider.dart';
@@ -111,6 +112,7 @@ class WebsocketNotifier extends StateNotifier<WebsocketState> {
OptionBuilder()
.setPath("${endpoint.path}/socket.io")
.setTransports(['websocket'])
.setWebSocketConnector(NetworkRepository.createWebSocket)
.enableReconnection()
.enableForceNew()
.enableForceNewConnection()

View File

@@ -16,6 +16,13 @@ class ClientCertPrompt {
ClientCertPrompt(this.title, this.message, this.cancel, this.confirm);
}
class WebSocketTaskResult {
int taskPointer;
String? taskProtocol;
WebSocketTaskResult(this.taskPointer, this.taskProtocol);
}
@ConfigurePigeon(
PigeonOptions(
dartOut: 'lib/platform/network_api.g.dart',
@@ -40,4 +47,9 @@ abstract class NetworkApi {
void removeCertificate();
int getClientPointer();
/// Creates a WebSocket task and waits for connection to be established.
/// iOS only - Android should use OkHttpWebSocket.connectWithClient directly.
@async
WebSocketTaskResult createWebSocketTask(String url, List<String>? protocols);
}