Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 07198260 authored by Abhishek Aggarwal's avatar Abhishek Aggarwal
Browse files

fix: prevent stale session refresh from overwriting logout state

parent 185567a2
Loading
Loading
Loading
Loading
+49 −17
Original line number Diff line number Diff line
@@ -54,6 +54,7 @@ class SessionStateHolder @Inject constructor(
    private val refreshMutex = Mutex()
    private val faultyTokenReportMutex = Mutex()
    private var inFlightRefresh: InFlightRefresh? = null
    private var refreshEpoch: Long = 0

    private var faultyTokenReportVersion: Long = 0
    private var latestFaultyTokenReport: FaultyTokenReportPayload? = null
@@ -103,7 +104,10 @@ class SessionStateHolder @Inject constructor(
            val currentRefresh = inFlightRefresh
            when {
                currentRefresh == null -> {
                    val newRefreshEpoch = refreshEpoch + 1
                    refreshEpoch = newRefreshEpoch
                    val newRefresh = InFlightRefresh(
                        epoch = newRefreshEpoch,
                        storesToReset = normalizedStoresToReset,
                        completion = CompletableDeferred(),
                    )
@@ -121,8 +125,10 @@ class SessionStateHolder @Inject constructor(
    }

    private suspend fun executeRefresh(refresh: InFlightRefresh) {
        val refreshFailure = normalizeRefreshFailure(runRefresh(refresh))
        publishRefreshFailureIfNeeded(refreshFailure)
        val refreshResult = runRefresh(refresh)
        val refreshFailure = normalizeRefreshFailure(refreshResult.exceptionOrNull())

        publishRefreshState(refresh, refreshResult.getOrNull(), refreshFailure)

        try {
            refreshMutex.withLock {
@@ -137,14 +143,34 @@ class SessionStateHolder @Inject constructor(
        rethrowRefreshFailure(refreshFailure)
    }

    private suspend fun runRefresh(refresh: InFlightRefresh): Throwable? {
    private suspend fun runRefresh(refresh: InFlightRefresh): Result<AuthRefreshOutcome> {
        return runCatching {
            val authRefreshOutcome = authRefreshRepository.refreshSessions(refresh.storesToReset)
            publishRefreshOutcome(authRefreshOutcome)
        }.exceptionOrNull()
            authRefreshRepository.refreshSessions(refresh.storesToReset)
        }
    }

    private suspend fun publishRefreshState(
        refresh: InFlightRefresh,
        authRefreshOutcome: AuthRefreshOutcome?,
        refreshFailure: Throwable?,
    ) {
        refreshMutex.lock()
        try {
            if (refresh.epoch != refreshEpoch) {
                return
            }

            when {
                authRefreshOutcome != null -> publishRefreshOutcomeLocked(authRefreshOutcome)
                refreshFailure is CancellationException -> Unit
                refreshFailure != null -> publishRefreshFailureLocked(refreshFailure)
            }
        } finally {
            refreshMutex.unlock()
        }
    }

    private suspend fun publishRefreshOutcome(authRefreshOutcome: AuthRefreshOutcome) {
    private suspend fun publishRefreshOutcomeLocked(authRefreshOutcome: AuthRefreshOutcome) {
        replaceFaultyTokenReport(authRefreshOutcome.faultyTokenReport)
        val authRefreshSnapshot = authRefreshOutcome.snapshot
        _activeSessions.value = authRefreshSnapshot.activeSessions
@@ -168,18 +194,12 @@ class SessionStateHolder @Inject constructor(
        }
    }

    private fun publishRefreshFailureIfNeeded(refreshFailure: Throwable?) {
        when (refreshFailure) {
            null -> Unit
            is CancellationException -> Unit
            else -> {
    private fun publishRefreshFailureLocked(refreshFailure: Throwable) {
        if (_authRefreshState.value !is AuthRefreshState.Completed) {
            _authRefreshState.value = AuthRefreshState.Failed
        }
        Timber.e(refreshFailure, "Auth refresh failed before publishing a snapshot")
    }
        }
    }

    private fun completeRefresh(refresh: InFlightRefresh, refreshFailure: Throwable?) {
        when (refreshFailure) {
@@ -217,6 +237,8 @@ class SessionStateHolder @Inject constructor(
    }

    override suspend fun clearLoadedSessions() {
        invalidateRefreshAndClearCurrent()

        clearFaultyTokenReport()
        _activeSessions.value = emptyList()
        _authRefreshState.value = AuthRefreshState.Pending
@@ -225,12 +247,21 @@ class SessionStateHolder @Inject constructor(

    override suspend fun markLoggedOut() {
        val loggedOutSnapshot = AuthRefreshSnapshot(emptyList())
        invalidateRefreshAndClearCurrent()

        clearFaultyTokenReport()
        _activeSessions.value = emptyList()
        _authRefreshState.value = AuthRefreshState.Completed(loggedOutSnapshot)
        _authRefreshSnapshot.value = loggedOutSnapshot
    }

    private suspend fun invalidateRefreshAndClearCurrent() {
        refreshMutex.withLock {
            refreshEpoch += 1
            inFlightRefresh = null
        }
    }

    private suspend fun replaceFaultyTokenReport(payload: FaultyTokenReportPayload?) {
        faultyTokenReportMutex.withLock {
            faultyTokenReportVersion += 1
@@ -271,6 +302,7 @@ class SessionStateHolder @Inject constructor(
    }

    private data class InFlightRefresh(
        val epoch: Long,
        val storesToReset: List<AuthStore>,
        val completion: CompletableDeferred<Unit>,
    )
+69 −0
Original line number Diff line number Diff line
@@ -122,6 +122,75 @@ class SessionStateHolderTest {
        assertThat(faultyTokenReporter.reportedPayloads).isEmpty()
    }

    @Test
    fun `markLoggedOut while refresh is in flight does not republish stale sessions`() = runTest {
        val authRefreshOutcome = AuthRefreshOutcome(
            snapshot = AuthRefreshSnapshot(
                entries = listOf(
                    AuthRefreshEntry(
                        store = AuthStore.PLAY_STORE,
                        result = AuthResult.Success(
                            AuthSession.PlayStoreSession(loginMode = PlayStoreLoginMode.GOOGLE)
                        ),
                    ),
                ),
            ),
        )
        val authRefreshRepository = BlockingAuthRefreshRepository(authRefreshOutcome)
        val sessionStateHolder = SessionStateHolder(
            authRefreshRepository = authRefreshRepository,
            faultyTokenReporter = FaultyTokenReporter { },
        )

        val refresh = async {
            sessionStateHolder.refreshSessions(listOf(AuthStore.PLAY_STORE))
        }
        authRefreshRepository.started.await()

        sessionStateHolder.markLoggedOut()
        authRefreshRepository.release.complete(Unit)
        refresh.await()

        assertThat(sessionStateHolder.activeSessions.value).isEmpty()
        assertThat(sessionStateHolder.authRefreshState.value)
            .isEqualTo(AuthRefreshState.Completed(AuthRefreshSnapshot(emptyList())))
        assertThat(sessionStateHolder.authRefreshSnapshot.value?.entries).isEmpty()
    }

    @Test
    fun `clearLoadedSessions while refresh is in flight does not republish stale sessions`() = runTest {
        val authRefreshOutcome = AuthRefreshOutcome(
            snapshot = AuthRefreshSnapshot(
                entries = listOf(
                    AuthRefreshEntry(
                        store = AuthStore.PLAY_STORE,
                        result = AuthResult.Success(
                            AuthSession.PlayStoreSession(loginMode = PlayStoreLoginMode.GOOGLE)
                        ),
                    ),
                ),
            ),
        )
        val authRefreshRepository = BlockingAuthRefreshRepository(authRefreshOutcome)
        val sessionStateHolder = SessionStateHolder(
            authRefreshRepository = authRefreshRepository,
            faultyTokenReporter = FaultyTokenReporter { },
        )

        val refresh = async {
            sessionStateHolder.refreshSessions(listOf(AuthStore.PLAY_STORE))
        }
        authRefreshRepository.started.await()

        sessionStateHolder.clearLoadedSessions()
        authRefreshRepository.release.complete(Unit)
        refresh.await()

        assertThat(sessionStateHolder.activeSessions.value).isEmpty()
        assertThat(sessionStateHolder.authRefreshState.value).isEqualTo(AuthRefreshState.Pending)
        assertThat(sessionStateHolder.authRefreshSnapshot.value).isNull()
    }

    @Test
    fun `reportFaultyTokenIfNeeded forwards latest domain payload once`() = runTest {
        val faultyTokenReporter = RecordingFaultyTokenReporter()