Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit e4180dba authored by Fahim M. Choudhury's avatar Fahim M. Choudhury
Browse files

Merge branch '3743-improve-oauth-token-handling' into 'main'

feat: implement per-account token refresh synchronization

See merge request !184
parents 15a8ea4b e3837343
Loading
Loading
Loading
Loading
Loading
+154 −45
Original line number Diff line number Diff line
@@ -23,13 +23,12 @@ import android.content.Context
import at.bitfire.davdroid.BuildConfig
import at.bitfire.davdroid.network.HttpClient.HttpClientEntryPoint
import dagger.hilt.android.EntryPointAccessors
import kotlinx.coroutines.runBlocking
import net.openid.appauth.AuthState
import net.openid.appauth.AuthorizationException
import net.openid.appauth.ClientAuthentication
import org.jetbrains.annotations.Blocking
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionException
import java.util.concurrent.ConcurrentHashMap
import java.util.logging.Level
import java.util.logging.Logger

@@ -38,90 +37,200 @@ import java.util.logging.Logger
 */
object OidcTokenRefresher {
    private val logger: Logger = Logger.getGlobal()
    private val accountLocks = ConcurrentHashMap<String, Any>()
    private val ongoingRefreshOperations =
        ConcurrentHashMap<String, CompletableFuture<AuthState?>>()

    @JvmStatic
    @Throws(AuthorizationException::class)
    /**
     * Refreshes the current AuthState and updates it. Uses the current one if it's still valid,
     * or requests a new one if necessary.
     * Refreshes the AuthState for a given account if necessary.
     *
     * It will:
     * 1. Invoke the AppAuth library's authorization service to refresh tokens.
     * 2. Update AccountManager on successful refresh or log failures.
     * This method handles concurrent requests for the same account by ensuring that only one
     * token refresh operation is executed at a time. If a refresh is already in progress for an
     * account, subsequent requests for the same account will wait for the ongoing operation to
     * complete and receive its result. This prevents redundant network requests and potential
     * race conditions. Refresh operations for different accounts can run in parallel.
     *
     * **Threading:** This method uses [runBlocking] and therefore must **not** be
     * called from the Main/UI thread. It is annotated with `@Blocking` to signal
     * blocking behavior.
     * If the current tokens are still valid and do not need a refresh, the existing AuthState is
     * returned immediately without a network request.
     *
     * This method is synchronized / thread-safe so that it can be called for
     * multiple HTTP requests at the same time.
     *
     * Returns an updated AuthState if token refresh is successful;
     * Throws [AuthorizationException.TokenRequestErrors.INVALID_GRANT] for invalid grant or null otherwise.
     * @param context The application context.
     * @param account The account for which to refresh the token. If null, perform refresh for a
     * newly set up account.
     * @param getClientAuth A function that provides the [ClientAuthentication].
     * @param readAuthState A function that reads the current [AuthState].
     * @param writeAuthState A function that persists the updated [AuthState] to storage after a
     * successful refresh.
     * @return The updated [AuthState] if the refresh was successful or not needed, or `null` if an
     * error occurred.
     * @throws AuthorizationException.TokenRequestErrors.INVALID_GRANT if the refresh token is
     * invalid and re-authentication is required.
     */

    @JvmStatic
    @Blocking
    @Throws(AuthorizationException::class)
    fun refreshAuthState(
        context: Context,
        account: Account?,
        getClientAuth: () -> ClientAuthentication?,
        readAuthState: () -> AuthState?,
        writeAuthState: (AuthState) -> Unit
    ): AuthState? = synchronized(javaClass) {
        val authState = readAuthState() ?: return null
        // Use cached authState if possible
        if (authState.isAuthorized && authState.accessToken != null && !authState.needsTokenRefresh) {
    ): AuthState? {
        if (account == null) {
            return synchronized(this) {
                performTokenRefresh(context, null, getClientAuth, readAuthState, writeAuthState)
            }
        }

        val accountKey = generateAccountKey(account.name, account.type)
        val accountLock = accountLocks.computeIfAbsent(accountKey) { Any() }

        // Variables to communicate work outside the synchronized block
        var existingOperation: CompletableFuture<AuthState?>? = null
        var starterOperation: CompletableFuture<AuthState?>? = null
        var capturedAuthState: AuthState? = null

        // Short synchronized section: decide what to do
        synchronized(accountLock) {
            if (BuildConfig.DEBUG) {
                logger.finest("$account is using cached AuthState: ${authState.jsonSerializeString()}")
                logger.finest("[Thread-${Thread.currentThread().id}] Entered sync for $account")
            }
            return authState

            // If someone already started a refresh, capture that future and wait outside.
            existingOperation = ongoingRefreshOperations[accountKey]
            if (existingOperation != null) {
                if (BuildConfig.DEBUG) {
                    logger.info("[Thread-${Thread.currentThread().id}] Observed ongoing refresh for $account")
                }
                return@synchronized
            }

        // Check for AuthorizationException
            // No ongoing operation. Read auth state and decide if refresh is needed.
            val authState = readAuthState() ?: return@synchronized
            capturedAuthState = authState

            val authorizationException = authState.authorizationException
            if (authorizationException != null && isInvalidGrant(authorizationException)) {
                throw AuthorizationException.TokenRequestErrors.INVALID_GRANT
            }

        logger.info("$account is requesting fresh access token")
            if (authState.isAuthorized && authState.accessToken != null && !authState.needsTokenRefresh) {
                if (BuildConfig.DEBUG) {
                    logger.finest("[Thread-${Thread.currentThread().id}] Using cached AuthState for $account")
                }
                // no refresh needed. return current authState.
                starterOperation = null
                return@synchronized
            }

            // Need a refresh. Create and register a future, then perform refresh outside the lock.
            val newOperation = CompletableFuture<AuthState?>()
            ongoingRefreshOperations[accountKey] = newOperation
            starterOperation = newOperation
            if (BuildConfig.DEBUG) {
                logger.finest("[Thread-${Thread.currentThread().id}] Registered new refresh operation for $account")
            }
        }

        // If there was an existing operation, wait for it (outside synchronized).
        existingOperation?.let {
            return try {
                it.join()
            } catch (e: CompletionException) {
                logger.log(
                    Level.SEVERE,
                    "[Thread-${Thread.currentThread().id}] Ongoing refresh failed for $account",
                    e
                )
                null
            }
        }

        // If starterOp is null here it means we returned cached state inside synchronized.
        val refreshOperation = starterOperation ?: return capturedAuthState

        // Perform the actual refresh outside lock. Complete the future and clean up.
        try {
            val result = performTokenRefresh(
                context,
                account,
                getClientAuth,
                { capturedAuthState }, // pass the captured one to avoid re-reading under race
                writeAuthState
            )
            refreshOperation.complete(result)
            return result
        } catch (e: Throwable) {
            refreshOperation.completeExceptionally(e)
            throw e
        } finally {
            // remove only if the same future is still registered
            ongoingRefreshOperations.remove(accountKey, refreshOperation)
            if (BuildConfig.DEBUG) {
            logger.finest("AuthState before update = ${authState.jsonSerializeString()}")
                logger.finest("[Thread-${Thread.currentThread().id}] Cleanup done for $account")
            }
        }
    }

    private fun generateAccountKey(accountName: String, accountType: String) =
        "$accountName|$accountType"

    private fun performTokenRefresh(
        context: Context,
        account: Account?,
        getClientAuth: () -> ClientAuthentication?,
        readAuthState: () -> AuthState?,
        writeAuthState: (AuthState) -> Unit
    ): AuthState? {
        val authState = readAuthState() ?: return null
        val clientAuth = getClientAuth() ?: return null

        logger.info("[Thread-${Thread.currentThread().id}] $account is requesting fresh access token")

        val authService =
            EntryPointAccessors.fromApplication(context, HttpClientEntryPoint::class.java)
                .authorizationService()
        val authStateFuture = CompletableFuture<AuthState>()
        val authStateFuture = CompletableFuture<AuthState?>()

        return@synchronized try {
        try {
            authState.performActionWithFreshTokens(
                authService, clientAuth
                authService,
                clientAuth
            ) { accessToken, _, exception ->
                writeAuthState(authState)
                when {
                    accessToken != null -> {
                        logger.info("Token refreshed for $account")
                        if (BuildConfig.DEBUG) {
                            logger.finest("Updated authState = ${authState.jsonSerializeString()}")
                        }
                        // Persist only on success.
                        writeAuthState(authState)
                        logger.info("[Thread-${Thread.currentThread().id}] Token refreshed for $account")
                        authStateFuture.complete(authState)
                    }

                    exception != null -> {
                        logger.warning("[Thread-${Thread.currentThread().id}] Token refresh failed for $account: $exception")
                        authStateFuture.completeExceptionally(exception)
                    }
                    else -> {
                        // Unexpected: neither token nor exception. treat as failure.
                        authStateFuture.completeExceptionally(IllegalStateException("No token, no exception"))
                    }
                }
            }
            authStateFuture.join()
            return authStateFuture.join()
        } catch (e: CompletionException) {
            logger.log(Level.SEVERE, "Couldn't obtain access token", e)
            null
            logger.log(
                Level.SEVERE,
                "[Thread-${Thread.currentThread().id}] Couldn't obtain access token for $account",
                e
            )
            return null
        } finally {
            authService.dispose()
        }
    }

    // Checks whether the given AuthorizationException indicates an invalid grant (requires re-login).
    fun removeAccountLock(accountName: String, accountType: String) {
        val accountKey = generateAccountKey(accountName, accountType)
        accountLocks.remove(accountKey)
        ongoingRefreshOperations.remove(accountKey)
    }

    private fun isInvalidGrant(ex: AuthorizationException?): Boolean {
        val invalidGrant = AuthorizationException.TokenRequestErrors.INVALID_GRANT
        return ex?.code == invalidGrant.code && ex.error == invalidGrant.error
+6 −0
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ import at.bitfire.davdroid.settings.AccountSettings
import at.bitfire.davdroid.syncadapter.AccountUtils
import at.bitfire.davdroid.ui.signout.OpenIdEndSessionActivity
import at.bitfire.davdroid.util.AuthStatePrefUtils
import com.nextcloud.android.sso.OidcTokenRefresher
import com.owncloud.android.lib.common.OwnCloudClientManagerFactory

class AccountRemovedReceiver : BroadcastReceiver() {
@@ -39,6 +40,11 @@ class AccountRemovedReceiver : BroadcastReceiver() {
        val ownCloudClientManager = OwnCloudClientManagerFactory.getDefaultSingleton()
        ownCloudClientManager.removeClientForByName(accountName)

        // Clean up token refresh locks and ongoing operations
        intent.extras?.getString(AccountManager.KEY_ACCOUNT_TYPE)?.let { accountType ->
            OidcTokenRefresher.removeAccountLock(accountName, accountType)
        }

        clearOidcSession(
            intent = intent,
            context = context,