Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Unverified Commit ef1d90f7 authored by Ricki Hirner's avatar Ricki Hirner Committed by GitHub
Browse files

Cache SSLSocketFactories to allow okhttp HTTPS connection reuse (#1942)

* Reuse CustomCertManager

- Update bitfire-cert4android to 75cc6913fd
- Refactor HttpClientBuilder to use Optional for customTrustManager and customHostnameVerifier
- Add CustomCertManagerModule for dependency injection

* Implement connection security manager for HTTP client

- Introduce `ConnectionSecurityManager` and `ConnectionSecurityContext` classes
- Refactor `HttpClientBuilder` to use the new security manager for SSL context setup

* [WIP] Cache SSLContext by certificate alias

- Add context cache using Guava CacheBuilder
- Cache SSLContext in getContext method

* Update comments in HttpClientBuilder.kt for clarity

* Update ConnectionSecurityManager to use SSLSocketFactory caching

* Refactor socket factory caching logic for better clarity

* Add tests

* Refactor socket factory cache to store only SSLSocketFactory

* Minor changes
- Change socketFactoryCache to use LinkedHashMap instead of ConcurrentHashMap
- Update cache key handling to use String? instead of Optional<String>

* Add tests for caching

* Add logging

* Indenting

* Minor simplification

* Fix tests
parent 5efcbfc5
Loading
Loading
Loading
Loading
+64 −0
Original line number Diff line number Diff line
/*
 * Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
 */

package at.bitfire.davdroid.di

import android.content.Context
import at.bitfire.cert4android.CustomCertManager
import at.bitfire.cert4android.CustomCertStore
import at.bitfire.cert4android.SettingsProvider
import at.bitfire.davdroid.BuildConfig
import at.bitfire.davdroid.settings.Settings
import at.bitfire.davdroid.settings.SettingsManager
import at.bitfire.davdroid.ui.ForegroundTracker
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import okhttp3.internal.tls.OkHostnameVerifier
import java.util.Optional
import javax.inject.Singleton

@Module
@InstallIn(SingletonComponent::class)
/**
 * cert4android integration module
 */
class CustomCertManagerModule {

    @Provides
    @Singleton
    fun customCertManager(
        @ApplicationContext context: Context,
        settings: SettingsManager
    ): Optional<CustomCertManager> =
        if (BuildConfig.allowCustomCerts)
            Optional.of(CustomCertManager(
                certStore = CustomCertStore.getInstance(context),
                settings = object : SettingsProvider {

                    override val appInForeground: Boolean
                        get() = ForegroundTracker.inForeground.value

                    override val trustSystemCerts: Boolean
                        get() = !settings.getBoolean(Settings.DISTRUST_SYSTEM_CERTIFICATES)

                }
            ))
        else
            Optional.empty()

    @Provides
    @Singleton
    fun customHostnameVerifier(
        customCertManager: Optional<CustomCertManager>
    ): Optional<CustomCertManager.HostnameVerifier> =
        if (BuildConfig.allowCustomCerts && customCertManager.isPresent) {
            val hostnameVerifier = customCertManager.get().HostnameVerifier(OkHostnameVerifier)
            Optional.of(hostnameVerifier)
        } else
            Optional.empty()

}
 No newline at end of file
+24 −0
Original line number Diff line number Diff line
/*
 * Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
 */

package at.bitfire.davdroid.network

import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.X509TrustManager

/**
 * Holds information that shall be used to create TLS connections.
 *
 * @param sslSocketFactory  the socket factory that shall be used
 * @param trustManager      the trust manager that shall be used
 * @param hostnameVerifier  the hostname verifier that shall be used
 * @param disableHttp2      whether HTTP/2 shall be disabled
 */
class ConnectionSecurityContext(
    val sslSocketFactory: SSLSocketFactory?,
    val trustManager: X509TrustManager?,
    val hostnameVerifier: HostnameVerifier?,
    val disableHttp2: Boolean
)
 No newline at end of file
+108 −0
Original line number Diff line number Diff line
/*
 * Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
 */

package at.bitfire.davdroid.network

import androidx.annotation.VisibleForTesting
import at.bitfire.cert4android.CustomCertManager
import java.lang.ref.SoftReference
import java.security.KeyStore
import java.util.Optional
import java.util.logging.Logger
import javax.inject.Inject
import javax.inject.Singleton
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import kotlin.jvm.optionals.getOrNull

/**
 * Caching provider for [ConnectionSecurityContext].
 */
@Singleton
class ConnectionSecurityManager @Inject constructor(
    private val customHostnameVerifier: Optional<CustomCertManager.HostnameVerifier>,
    private val customTrustManager: Optional<CustomCertManager>,
    private val keyManagerFactory: ClientCertKeyManager.Factory,
    private val logger: Logger
) {

    /**
     * Maps client certificate aliases (or `null` if no client authentication is used) to their SSLSocketFactory.
     * Uses soft references for the values so that they can be garbage-collected when not used anymore.
     *
     * Not thread-safe, access must be synchronized by caller.
     */
    private val socketFactoryCache: MutableMap<String?, SoftReference<SSLSocketFactory>> =
        LinkedHashMap(2)    // usually not more than: one for no client certificates + one for a certain certificate alias

    /**
     * The default TrustManager to use for connections. If [customTrustManager] provides a value, that value is
     * used. Otherwise, the platform's default trust manager is used.
     */
    private val trustManager by lazy { customTrustManager.getOrNull() ?: defaultTrustManager() }

    /**
     * Provides the [ConnectionSecurityContext] for a given [certificateAlias].
     *
     * Uses [socketFactoryCache] to cache the entries (per [certificateAlias]).
     *
     * @param certificateAlias  alias of the client certificate that shall be used for authentication (`null` for none)
     * @return the connection security context
     */
    fun getContext(certificateAlias: String?): ConnectionSecurityContext {
        /* We only need a custom socket factory for
           - client certificates and/or
           - when cert4android is active (= there's a custom trustManager). */
        val socketFactory = if (certificateAlias != null || customTrustManager.isPresent)
            getSocketFactory(certificateAlias)
        else
            null

        return ConnectionSecurityContext(
            sslSocketFactory = socketFactory,
            trustManager = if (socketFactory != null) trustManager else null,   // when there's a customTrustManager, there's always a socketFactory, too
            hostnameVerifier = customHostnameVerifier.getOrNull(),
            disableHttp2 = certificateAlias != null
        )
    }

    @VisibleForTesting
    internal fun getSocketFactory(certificateAlias: String?): SSLSocketFactory = synchronized(socketFactoryCache) {
        // look up cache first
        val cachedFactory = socketFactoryCache[certificateAlias]?.get()
        if (cachedFactory != null) {
            logger.fine("Using cached SSLSocketFactory (certificateAlias=$certificateAlias)")
            return cachedFactory
        } else
            logger.fine("Creating new SSLSocketFactory (certificateAlias=$certificateAlias)")
        // no cached value, calculate and store into cache

        // when a client certificate alias is given, create and use the respective ClientKeyManager
        val clientKeyManager = certificateAlias?.let { keyManagerFactory.create(it) }

        // create SSLContext that provides the SSLSocketFactory
        val sslContext = SSLContext.getInstance("TLS").apply {
            init(
                /* km = */ clientKeyManager?.let { arrayOf(it) },
                /* tm = */ arrayOf(trustManager),
                /* random = */ null /* default RNG */
            )
        }

        // cache reference and return socket factory
        return sslContext.socketFactory.also { socketFactory ->
            socketFactoryCache[certificateAlias] = SoftReference(socketFactory)
        }
    }

    @VisibleForTesting
    internal fun defaultTrustManager(): X509TrustManager {
        val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
        factory.init(null as KeyStore?)
        return factory.trustManagers.filterIsInstance<X509TrustManager>().first()
    }

}
 No newline at end of file
+16 −73
Original line number Diff line number Diff line
@@ -5,22 +5,16 @@
package at.bitfire.davdroid.network

import android.accounts.Account
import android.content.Context
import androidx.annotation.WorkerThread
import at.bitfire.cert4android.CustomCertManager
import at.bitfire.cert4android.CustomCertStore
import at.bitfire.dav4jvm.okhttp.BasicDigestAuthHandler
import at.bitfire.dav4jvm.okhttp.UrlUtils
import at.bitfire.davdroid.BuildConfig
import at.bitfire.davdroid.di.IoDispatcher
import at.bitfire.davdroid.settings.AccountSettings
import at.bitfire.davdroid.settings.Credentials
import at.bitfire.davdroid.settings.Settings
import at.bitfire.davdroid.settings.SettingsManager
import at.bitfire.davdroid.ui.ForegroundTracker
import com.google.common.net.HttpHeaders
import com.google.errorprone.annotations.MustBeClosed
import dagger.hilt.android.qualifiers.ApplicationContext
import io.ktor.client.HttpClient
import io.ktor.client.engine.okhttp.OkHttp
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
@@ -35,20 +29,13 @@ import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.brotli.BrotliInterceptor
import okhttp3.internal.tls.OkHostnameVerifier
import okhttp3.logging.HttpLoggingInterceptor
import java.net.InetSocketAddress
import java.net.Proxy
import java.security.KeyStore
import java.util.concurrent.TimeUnit
import java.util.logging.Level
import java.util.logging.Logger
import javax.inject.Inject
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.KeyManager
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager

/**
 * Builder for the HTTP client.
@@ -60,10 +47,9 @@ import javax.net.ssl.X509TrustManager
 */
class HttpClientBuilder @Inject constructor(
    private val accountSettingsFactory: AccountSettings.Factory,
    @ApplicationContext private val context: Context,
    private val connectionSecurityManager: ConnectionSecurityManager,
    defaultLogger: Logger,
    @IoDispatcher private val ioDispatcher: CoroutineDispatcher,
    private val keyManagerFactory: ClientCertKeyManager.Factory,
    private val oAuthInterceptorFactory: OAuthInterceptor.Factory,
    private val settingsManager: SettingsManager
) {
@@ -283,72 +269,29 @@ class HttpClientBuilder @Inject constructor(
    }

    private fun buildConnectionSecurity(okBuilder: OkHttpClient.Builder) {
        // allow cleartext and TLS 1.2+
        // Allow cleartext and TLS 1.2+
        okBuilder.connectionSpecs(listOf(
            ConnectionSpec.CLEARTEXT,
            ConnectionSpec.MODERN_TLS
        ))

        // client certificate
        val clientKeyManager: KeyManager? = certificateAlias?.let { alias ->
            try {
                val manager = keyManagerFactory.create(alias)
                logger.fine("Using certificate $alias for authentication")
        /* Set SSLSocketFactory, TrustManager and HostnameVerifier (if needed).
         * We shouldn't create these things here, because
         *
         * a. it involves complex logic that should be the responsibility of a dedicated class, and
         * b. we need to cache the instances because otherwise, HTTPS connection are not used
         *    correctly. okhttp checks the SSLSocketFactory/TrustManager of a connection in the pool
         *    and creates a new connection when they have changed. */
        val securityContext = connectionSecurityManager.getContext(certificateAlias)

                // HTTP/2 doesn't support client certificates (yet)
                // see https://datatracker.ietf.org/doc/draft-ietf-httpbis-secondary-server-certs/
        if (securityContext.disableHttp2)
            okBuilder.protocols(listOf(Protocol.HTTP_1_1))

                manager
            } catch (e: IllegalArgumentException) {
                logger.log(Level.SEVERE, "Couldn't create KeyManager for certificate $alias", e)
                null
            }
        }

        // select trust manager and hostname verifier depending on whether custom certificates are allowed
        val customTrustManager: X509TrustManager?
        val customHostnameVerifier: HostnameVerifier?

        if (BuildConfig.allowCustomCerts) {
            // use cert4android for custom certificate handling
            customTrustManager = CustomCertManager(
                certStore = CustomCertStore.getInstance(context),
                trustSystemCerts = !settingsManager.getBoolean(Settings.DISTRUST_SYSTEM_CERTIFICATES),
                appInForeground = ForegroundTracker.inForeground
            )
            // allow users to accept certificates with wrong host names
            customHostnameVerifier = customTrustManager.HostnameVerifier(OkHostnameVerifier)

        } else {
            // no custom certificates, use default trust manager and hostname verifier
            customTrustManager = null
            customHostnameVerifier = null
        }

        // change settings only if we have at least only one custom component
        if (clientKeyManager != null || customTrustManager != null) {
            val trustManager = customTrustManager ?: defaultTrustManager()

            // use trust manager and client key manager (if defined) for TLS connections
            val sslContext = SSLContext.getInstance("TLS")
            sslContext.init(
                /* km = */ if (clientKeyManager != null) arrayOf(clientKeyManager) else null,
                /* tm = */ arrayOf(trustManager),
                /* random = */ null
            )
            okBuilder.sslSocketFactory(sslContext.socketFactory, trustManager)
        }

        // also add the custom hostname verifier (if defined)
        if (customHostnameVerifier != null)
            okBuilder.hostnameVerifier(customHostnameVerifier)
    }
        if (securityContext.sslSocketFactory != null && securityContext.trustManager != null)
            okBuilder.sslSocketFactory(securityContext.sslSocketFactory, securityContext.trustManager)

    private fun defaultTrustManager(): X509TrustManager {
        val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
        factory.init(null as KeyStore?)
        return factory.trustManagers.filterIsInstance<X509TrustManager>().first()
        if (securityContext.hostnameVerifier != null)
            okBuilder.hostnameVerifier(securityContext.hostnameVerifier)
    }

    private fun buildProxy(okBuilder: OkHttpClient.Builder) {
+210 −0
Original line number Diff line number Diff line
/*
 * Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
 */

package at.bitfire.davdroid.network

import at.bitfire.cert4android.CustomCertManager
import io.mockk.mockk
import io.mockk.verify
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Test
import java.util.Optional
import java.util.logging.Logger

class ConnectionSecurityManagerTest {

    private val logger = Logger.getLogger(javaClass.name)

    @Test
    fun `getContext(no customTrustManager, no client certificate)`() {
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.empty(),
            customHostnameVerifier = Optional.empty(),
            keyManagerFactory = mockk(),
            logger = logger
        )
        val context = manager.getContext(null)
        assertNull(context.sslSocketFactory)
        assertNull(context.trustManager)
        assertNull(context.hostnameVerifier)
        assertFalse(context.disableHttp2)
    }

    @Test
    fun `getContext(no customTrustManager, with client certificate)`() {
        val kmf: ClientCertKeyManager.Factory = mockk(relaxed = true)
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.empty(),
            customHostnameVerifier = Optional.empty(),
            keyManagerFactory = kmf,
            logger = logger
        )
        val context = manager.getContext("alias")
        assertNotNull(context.sslSocketFactory)
        assertEquals(manager.defaultTrustManager().javaClass, context.trustManager?.javaClass)
        assertNull(context.hostnameVerifier)
        assertTrue(context.disableHttp2)
        verify(exactly = 1) {
            kmf.create("alias")
        }
    }

    @Test
    fun `getContext(with customTrustManager, no client certificate)`() {
        val customTrustManager: CustomCertManager = mockk()
        val customHostnameVerifier: CustomCertManager.HostnameVerifier = mockk()
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.of(customTrustManager),
            customHostnameVerifier = Optional.of(customHostnameVerifier),
            keyManagerFactory = mockk(),
            logger = logger
        )
        val context = manager.getContext(null)
        assertNotNull(context.sslSocketFactory)
        assertEquals(customTrustManager, context.trustManager)
        assertEquals(customHostnameVerifier, context.hostnameVerifier)
        assertFalse(context.disableHttp2)
    }

    @Test
    fun `getContext(with customTrustManager, with client certificate)`() {
        val customTrustManager: CustomCertManager = mockk()
        val customHostnameVerifier: CustomCertManager.HostnameVerifier = mockk()
        val kmf: ClientCertKeyManager.Factory = mockk(relaxed = true)
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.of(customTrustManager),
            customHostnameVerifier = Optional.of(customHostnameVerifier),
            keyManagerFactory = kmf,
            logger = logger
        )
        val context = manager.getContext("alias")
        assertNotNull(context.sslSocketFactory)
        assertEquals(customTrustManager, context.trustManager)
        assertEquals(customHostnameVerifier, context.hostnameVerifier)
        assertTrue(context.disableHttp2)
        verify(exactly = 1) {
            kmf.create("alias")
        }
    }

    @Test
    fun `getSocketFactory(no customTrustManager, no client certificate)`() {
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.empty(),
            customHostnameVerifier = Optional.empty(),
            keyManagerFactory = mockk(),
            logger = logger
        )
        val socketFactory = manager.getSocketFactory(null)
        assertNotNull(socketFactory.javaClass)
    }

    @Test
    fun `getSocketFactory(no customTrustManager, with client certificate)`() {
        val kmf: ClientCertKeyManager.Factory = mockk(relaxed = true)
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.empty(),
            customHostnameVerifier = Optional.empty(),
            keyManagerFactory = kmf,
            logger = logger
        )
        val socketFactory = manager.getSocketFactory("alias")
        assertNotNull(socketFactory.javaClass)
        verify(exactly = 1) {
            kmf.create("alias")
        }
    }

    @Test
    fun `getSocketFactory(with customTrustManager, no client certificate)`() {
        val customTrustManager: CustomCertManager = mockk()
        val customHostnameVerifier: CustomCertManager.HostnameVerifier = mockk()
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.of(customTrustManager),
            customHostnameVerifier = Optional.of(customHostnameVerifier),
            keyManagerFactory = mockk(),
            logger = logger
        )
        val socketFactory = manager.getSocketFactory(null)
        assertNotNull(socketFactory.javaClass)
    }

    @Test
    fun `getSocketFactory(with customTrustManager, with client certificate)`() {
        val customTrustManager: CustomCertManager = mockk()
        val customHostnameVerifier: CustomCertManager.HostnameVerifier = mockk()
        val kmf: ClientCertKeyManager.Factory = mockk(relaxed = true)
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.of(customTrustManager),
            customHostnameVerifier = Optional.of(customHostnameVerifier),
            keyManagerFactory = kmf,
            logger = logger
        )
        val socketFactory = manager.getSocketFactory("alias")
        assertNotNull(socketFactory.javaClass)
        verify(exactly = 1) {
            kmf.create("alias")
        }
    }

    @Test
    fun `getContext caches socket factories for same certificateAlias`() {
        val kmf: ClientCertKeyManager.Factory = mockk(relaxed = true)
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.empty(),
            customHostnameVerifier = Optional.empty(),
            keyManagerFactory = kmf,
            logger = logger
        )
        
        // First call - should create new socket factory
        val context1 = manager.getContext("alias")
        assertNotNull(context1.sslSocketFactory)
        
        // Second call with same alias - should return cached socket factory
        val context2 = manager.getContext("alias")
        assertNotNull(context2.sslSocketFactory)
        
        // Both should return the same socket factory instance
        assert(context1.sslSocketFactory === context2.sslSocketFactory)
        
        // Should only create key manager once
        verify(exactly = 1) {
            kmf.create("alias")
        }
    }

    @Test
    fun `getContext does not cache socket factories for different certificateAlias`() {
        val kmf: ClientCertKeyManager.Factory = mockk(relaxed = true)
        val manager = ConnectionSecurityManager(
            customTrustManager = Optional.empty(),
            customHostnameVerifier = Optional.empty(),
            keyManagerFactory = kmf,
            logger = logger
        )
        
        // Get context for first alias
        val context1 = manager.getContext("alias1")
        assertNotNull(context1.sslSocketFactory)
        
        // Get context for different alias
        val context2 = manager.getContext("alias2")
        assertNotNull(context2.sslSocketFactory)
        
        // Should be different instances
        assert(context1.sslSocketFactory !== context2.sslSocketFactory)
        
        // Should create key managers for both aliases
        verify(exactly = 1) {
            kmf.create("alias1")
            kmf.create("alias2")
        }
    }

}
 No newline at end of file
Loading