websocket integration

platform-side headers

update comment

consistent platform check

tweak websocket handling

support streaming
This commit is contained in:
mertalev
2026-02-02 16:39:51 -05:00
parent 3228dcb70b
commit e89d174bd0
25 changed files with 468 additions and 93 deletions

View File

@@ -5,6 +5,7 @@ import app.alextran.immich.BuildConfig
import okhttp3.Cache
import okhttp3.ConnectionPool
import okhttp3.Dispatcher
import okhttp3.Headers
import okhttp3.OkHttpClient
import java.io.ByteArrayInputStream
import java.io.File
@@ -39,6 +40,7 @@ object HttpClientManager {
private val keyStore = KeyStore.getInstance("AndroidKeyStore").apply { load(null) }
var headers: Headers = Headers.headersOf("User-Agent", USER_AGENT)
val isMtls: Boolean get() = keyStore.containsAlias(CERT_ALIAS)
fun initialize(context: Context) {
@@ -93,6 +95,12 @@ object HttpClientManager {
synchronized(this) { clientChangedListeners.add(listener) }
}
fun setRequestHeaders(headerMap: Map<String, String>) {
val builder = Headers.Builder()
headerMap.forEach { (key, value) -> builder.add(key, value) }
headers = builder.build()
}
private fun build(cacheDir: File): OkHttpClient {
val connectionPool = ConnectionPool(
maxIdleConnections = KEEP_ALIVE_CONNECTIONS,
@@ -109,8 +117,10 @@ object HttpClientManager {
HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.socketFactory)
return OkHttpClient.Builder()
.addInterceptor { chain ->
chain.proceed(chain.request().newBuilder().header("User-Agent", USER_AGENT).build())
.addInterceptor {
val builder = it.request().newBuilder()
headers.forEach { (key, value) -> builder.addHeader(key, value) }
it.proceed(builder.build())
}
.connectionPool(connectionPool)
.dispatcher(Dispatcher().apply { maxRequestsPerHost = MAX_REQUESTS_PER_HOST })

View File

@@ -145,6 +145,37 @@ data class ClientCertPrompt (
override fun hashCode(): Int = toList().hashCode()
}
/** Generated class from Pigeon that represents data sent in messages. */
data class WebSocketTaskResult (
val taskPointer: Long,
val taskProtocol: String? = null
)
{
companion object {
fun fromList(pigeonVar_list: List<Any?>): WebSocketTaskResult {
val taskPointer = pigeonVar_list[0] as Long
val taskProtocol = pigeonVar_list[1] as String?
return WebSocketTaskResult(taskPointer, taskProtocol)
}
}
fun toList(): List<Any?> {
return listOf(
taskPointer,
taskProtocol,
)
}
override fun equals(other: Any?): Boolean {
if (other !is WebSocketTaskResult) {
return false
}
if (this === other) {
return true
}
return NetworkPigeonUtils.deepEquals(toList(), other.toList()) }
override fun hashCode(): Int = toList().hashCode()
}
private open class NetworkPigeonCodec : StandardMessageCodec() {
override fun readValueOfType(type: Byte, buffer: ByteBuffer): Any? {
return when (type) {
@@ -158,6 +189,11 @@ private open class NetworkPigeonCodec : StandardMessageCodec() {
ClientCertPrompt.fromList(it)
}
}
131.toByte() -> {
return (readValue(buffer) as? List<Any?>)?.let {
WebSocketTaskResult.fromList(it)
}
}
else -> super.readValueOfType(type, buffer)
}
}
@@ -171,6 +207,10 @@ private open class NetworkPigeonCodec : StandardMessageCodec() {
stream.write(130)
writeValue(stream, value.toList())
}
is WebSocketTaskResult -> {
stream.write(131)
writeValue(stream, value.toList())
}
else -> super.writeValue(stream, value)
}
}
@@ -183,6 +223,9 @@ interface NetworkApi {
fun selectCertificate(promptText: ClientCertPrompt, callback: (Result<ClientCertData>) -> Unit)
fun removeCertificate(callback: (Result<Unit>) -> Unit)
fun getClientPointer(): Long
/** iOS only - creates a WebSocket task and waits for connection to be established. */
fun createWebSocketTask(url: String, protocols: List<String>?, callback: (Result<WebSocketTaskResult>) -> Unit)
fun setRequestHeaders(headers: Map<String, String>)
companion object {
/** The codec used by NetworkApi. */
@@ -264,6 +307,45 @@ interface NetworkApi {
channel.setMessageHandler(null)
}
}
run {
val channel = BasicMessageChannel<Any?>(binaryMessenger, "dev.flutter.pigeon.immich_mobile.NetworkApi.createWebSocketTask$separatedMessageChannelSuffix", codec)
if (api != null) {
channel.setMessageHandler { message, reply ->
val args = message as List<Any?>
val urlArg = args[0] as String
val protocolsArg = args[1] as List<String>?
api.createWebSocketTask(urlArg, protocolsArg) { result: Result<WebSocketTaskResult> ->
val error = result.exceptionOrNull()
if (error != null) {
reply.reply(NetworkPigeonUtils.wrapError(error))
} else {
val data = result.getOrNull()
reply.reply(NetworkPigeonUtils.wrapResult(data))
}
}
}
} else {
channel.setMessageHandler(null)
}
}
run {
val channel = BasicMessageChannel<Any?>(binaryMessenger, "dev.flutter.pigeon.immich_mobile.NetworkApi.setRequestHeaders$separatedMessageChannelSuffix", codec)
if (api != null) {
channel.setMessageHandler { message, reply ->
val args = message as List<Any?>
val headersArg = args[0] as Map<String, String>
val wrapped: List<Any?> = try {
api.setRequestHeaders(headersArg)
listOf(null)
} catch (exception: Throwable) {
NetworkPigeonUtils.wrapError(exception)
}
reply.reply(wrapped)
}
} else {
channel.setMessageHandler(null)
}
}
}
}
}

View File

@@ -104,6 +104,17 @@ private class NetworkApiImpl(private val context: Context) : NetworkApi {
return NativeBuffer.createGlobalRef(client)
}
// only used on iOS
override fun createWebSocketTask(
url: String,
protocols: List<String>?,
callback: (Result<WebSocketTaskResult>) -> Unit
) {}
override fun setRequestHeaders(headers: Map<String, String>) {
HttpClientManager.setRequestHeaders(headers)
}
private fun handlePickedFile(uri: Uri) {
val callback = pendingCallback ?: return
pendingCallback = null

View File

@@ -47,7 +47,7 @@ private open class RemoteImagesPigeonCodec : StandardMessageCodec() {
/** Generated interface from Pigeon that represents a handler of messages from Flutter. */
interface RemoteImageApi {
fun requestImage(url: String, headers: Map<String, String>, requestId: Long, callback: (Result<Map<String, Long>?>) -> Unit)
fun requestImage(url: String, requestId: Long, callback: (Result<Map<String, Long>?>) -> Unit)
fun cancelRequest(requestId: Long)
fun clearCache(callback: (Result<Long>) -> Unit)
@@ -66,9 +66,8 @@ interface RemoteImageApi {
channel.setMessageHandler { message, reply ->
val args = message as List<Any?>
val urlArg = args[0] as String
val headersArg = args[1] as Map<String, String>
val requestIdArg = args[2] as Long
api.requestImage(urlArg, headersArg, requestIdArg) { result: Result<Map<String, Long>?> ->
val requestIdArg = args[1] as Long
api.requestImage(urlArg, requestIdArg) { result: Result<Map<String, Long>?> ->
val error = result.exceptionOrNull()
if (error != null) {
reply.reply(RemoteImagesPigeonUtils.wrapError(error))

View File

@@ -49,7 +49,6 @@ class RemoteImagesImpl(context: Context) : RemoteImageApi {
override fun requestImage(
url: String,
headers: Map<String, String>,
requestId: Long,
callback: (Result<Map<String, Long>?>) -> Unit
) {
@@ -58,7 +57,6 @@ class RemoteImagesImpl(context: Context) : RemoteImageApi {
ImageFetcherManager.fetch(
url,
headers,
signal,
onSuccess = { buffer ->
requestMap.remove(requestId)
@@ -119,12 +117,11 @@ private object ImageFetcherManager {
fun fetch(
url: String,
headers: Map<String, String>,
signal: CancellationSignal,
onSuccess: (NativeByteBuffer) -> Unit,
onFailure: (Exception) -> Unit,
) {
fetcher.fetch(url, headers, signal, onSuccess, onFailure)
fetcher.fetch(url, signal, onSuccess, onFailure)
}
fun clearCache(onCleared: (Result<Long>) -> Unit) {
@@ -151,7 +148,6 @@ private object ImageFetcherManager {
private sealed interface ImageFetcher {
fun fetch(
url: String,
headers: Map<String, String>,
signal: CancellationSignal,
onSuccess: (NativeByteBuffer) -> Unit,
onFailure: (Exception) -> Unit,
@@ -178,7 +174,6 @@ private class CronetImageFetcher(context: Context, cacheDir: File) : ImageFetche
override fun fetch(
url: String,
headers: Map<String, String>,
signal: CancellationSignal,
onSuccess: (NativeByteBuffer) -> Unit,
onFailure: (Exception) -> Unit,
@@ -193,7 +188,7 @@ private class CronetImageFetcher(context: Context, cacheDir: File) : ImageFetche
val callback = FetchCallback(onSuccess, onFailure, ::onComplete)
val requestBuilder = engine.newUrlRequestBuilder(url, callback, executor)
headers.forEach { (key, value) -> requestBuilder.addHeader(key, value) }
HttpClientManager.headers.forEach { (key, value) -> requestBuilder.addHeader(key, value) }
val request = requestBuilder.build()
signal.setOnCancelListener(request::cancel)
request.start()
@@ -390,7 +385,6 @@ private class OkHttpImageFetcher private constructor(
override fun fetch(
url: String,
headers: Map<String, String>,
signal: CancellationSignal,
onSuccess: (NativeByteBuffer) -> Unit,
onFailure: (Exception) -> Unit,
@@ -403,7 +397,6 @@ private class OkHttpImageFetcher private constructor(
}
val requestBuilder = Request.Builder().url(url)
headers.forEach { (key, value) -> requestBuilder.addHeader(key, value) }
val call = client.newCall(requestBuilder.build())
signal.setOnCancelListener(call::cancel)