Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Verified Commit 2350cf23 authored by Romain Hunault's avatar Romain Hunault 🚴🏻
Browse files

fix(workspace): handle transient OIDC discovery failures

Classify discovery HTTP responses and retry only transient failures.\nAdd tests covering transient retry, not-configured handling, and IOException retry exhaustion.
parent 800ead1c
Loading
Loading
Loading
Loading
Loading
+77 −15
Original line number Diff line number Diff line
@@ -19,9 +19,11 @@ package at.bitfire.davdroid.workspace

import at.bitfire.davdroid.log.Logger
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
import kotlinx.coroutines.withContext
import okhttp3.OkHttpClient
import okhttp3.Request
import java.io.IOException

/**
 * Discovers the OIDC issuer for a Murena Workspace by following the unauthenticated redirect
@@ -38,6 +40,8 @@ object MurenaOidcDiscovery {

    private const val OIDC_LOGIN_PATH = "/apps/oidc_login/oidc"
    private const val KEYCLOAK_OIDC_PATH = "/protocol/openid-connect"
    private const val MAX_TRANSIENT_RETRIES = 2
    private const val RETRY_DELAY_MS = 300L

    /**
     * Outcome of an OIDC discovery attempt.
@@ -75,25 +79,83 @@ object MurenaOidcDiscovery {
            .followSslRedirects(false)
            .build()

        var attempt = 0
        var lastFailure: Exception? = null
        while (attempt <= MAX_TRANSIENT_RETRIES) {
            try {
            noRedirectClient.newCall(request).execute().use { response ->
                val location = response.header("Location")
                val (outcome, statusCode) = noRedirectClient.newCall(request).execute().use { response ->
                    classifyResponse(response.code, response.header("Location")) to response.code
                }

                when (outcome) {
                    is ResponseOutcome.Discovered ->
                        return@withContext Result.Discovered(descriptor.copy(oidcIssuer = outcome.issuer))
                    ResponseOutcome.NotConfigured ->
                        return@withContext Result.NotConfigured
                    is ResponseOutcome.Failed -> {
                        val shouldRetry = outcome.transient && attempt < MAX_TRANSIENT_RETRIES
                        if (shouldRetry) {
                            attempt += 1
                            lastFailure = IOException("OIDC discovery failed for ${descriptor.workspaceDomain} with HTTP $statusCode")
                            Logger.log.info("OIDC discovery transient HTTP failure for ${descriptor.workspaceDomain} (code=$statusCode), retry $attempt/$MAX_TRANSIENT_RETRIES")
                            delay(RETRY_DELAY_MS * attempt)
                            continue
                        }
                        val cause = IOException("OIDC discovery failed for ${descriptor.workspaceDomain} with HTTP $statusCode")
                        return@withContext Result.Failed(cause)
                    }
                }
            } catch (e: kotlinx.coroutines.CancellationException) {
                throw e
            } catch (e: Exception) {
                val isTransient = e is IOException
                val shouldRetry = isTransient && attempt < MAX_TRANSIENT_RETRIES
                if (shouldRetry) {
                    attempt += 1
                    lastFailure = e
                    Logger.log.info("OIDC discovery transient error for ${descriptor.workspaceDomain}: $e, retry $attempt/$MAX_TRANSIENT_RETRIES")
                    delay(RETRY_DELAY_MS * attempt)
                    continue
                }
                Logger.log.warning("OIDC discovery failed for ${descriptor.workspaceDomain}: $e")
                return@withContext Result.Failed(e)
            }
        }

        Result.Failed(lastFailure ?: IOException("OIDC discovery failed for ${descriptor.workspaceDomain} after retries"))
    }

    internal sealed class ResponseOutcome {
        data class Discovered(val issuer: String) : ResponseOutcome()
        object NotConfigured : ResponseOutcome()
        data class Failed(val transient: Boolean) : ResponseOutcome()
    }

    internal fun classifyResponse(statusCode: Int, location: String?): ResponseOutcome {
        if (statusCode in 300..399) {
            if (location == null) {
                    Logger.log.info("OIDC discovery: $url did not redirect, assuming OIDC not configured")
                    Result.NotConfigured
                } else {
                Logger.log.info("OIDC discovery: redirect status without Location header (code=$statusCode)")
                return ResponseOutcome.Failed(transient = false)
            }

            val issuer = extractIssuer(location)
                    if (issuer != null) {
                        Result.Discovered(descriptor.copy(oidcIssuer = issuer))
            return if (issuer != null) {
                ResponseOutcome.Discovered(issuer)
            } else {
                        Result.NotConfigured
                ResponseOutcome.NotConfigured
            }
        }

        if (statusCode == 401 || statusCode == 403 || statusCode == 404) {
            return ResponseOutcome.NotConfigured
        }
        } catch (e: Exception) {
            Logger.log.warning("OIDC discovery failed for ${descriptor.workspaceDomain}: $e")
            Result.Failed(e)

        if (statusCode == 408 || statusCode == 429 || statusCode in 500..599) {
            return ResponseOutcome.Failed(transient = true)
        }

        Logger.log.info("OIDC discovery: unexpected HTTP status $statusCode")
        return ResponseOutcome.Failed(transient = false)
    }

    /**
+82 −0
Original line number Diff line number Diff line
package at.bitfire.davdroid.workspace

import kotlinx.coroutines.runBlocking
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.Request
import okhttp3.Response
import okhttp3.ResponseBody.Companion.toResponseBody
import org.junit.Assert.assertEquals
import org.junit.Assert.assertNull
import org.junit.Assert.assertSame
import org.junit.Assert.assertTrue
import org.junit.Test
import java.io.IOException
import java.util.concurrent.atomic.AtomicInteger

class MurenaOidcDiscoveryTest {

@@ -75,4 +84,77 @@ class MurenaOidcDiscoveryTest {
            MurenaOidcDiscovery.Result.NotConfigured === MurenaOidcDiscovery.Result.NotConfigured
        )
    }

    // --- discover HTTP classification + retry ---

    @Test
    fun `discover retries transient HTTP failure then discovers issuer`() = runBlocking {
        val calls = AtomicInteger(0)
        val client = OkHttpClient.Builder()
            .addInterceptor { chain ->
                val n = calls.incrementAndGet()
                if (n == 1) {
                    response(chain.request(), 503)
                } else {
                    response(
                        chain.request(),
                        302,
                        "https://accounts.murena.io/auth/realms/murena/protocol/openid-connect/auth?client_id=murena.io"
                    )
                }
            }
            .build()

        val result = MurenaOidcDiscovery.discover(descriptor, client)
        assertTrue(result is MurenaOidcDiscovery.Result.Discovered)
        assertEquals(2, calls.get())

        val discovered = result as MurenaOidcDiscovery.Result.Discovered
        assertEquals("https://accounts.murena.io/auth/realms/murena", discovered.descriptor.oidcIssuer)
    }

    @Test
    fun `discover does not retry non-transient not-configured HTTP status`() = runBlocking {
        val calls = AtomicInteger(0)
        val client = OkHttpClient.Builder()
            .addInterceptor { chain ->
                calls.incrementAndGet()
                response(chain.request(), 404)
            }
            .build()

        val result = MurenaOidcDiscovery.discover(descriptor, client)
        assertSame(MurenaOidcDiscovery.Result.NotConfigured, result)
        assertEquals(1, calls.get())
    }

    @Test
    fun `discover retries transient IO failures and eventually returns Failed`() = runBlocking {
        val calls = AtomicInteger(0)
        val client = OkHttpClient.Builder()
            .addInterceptor { _ ->
                calls.incrementAndGet()
                throw IOException("timeout")
            }
            .build()

        val result = MurenaOidcDiscovery.discover(descriptor, client)
        assertTrue(result is MurenaOidcDiscovery.Result.Failed)
        assertEquals(3, calls.get())
    }

    private fun response(request: Request, code: Int, location: String? = null): Response {
        val builder = Response.Builder()
            .request(request)
            .protocol(Protocol.HTTP_1_1)
            .code(code)
            .message("HTTP $code")
            .body("".toResponseBody())

        if (location != null) {
            builder.header("Location", location)
        }

        return builder.build()
    }
}