Loading Android.bp +3 −2 Original line number Original line Diff line number Diff line Loading @@ -70,7 +70,7 @@ cc_library { // on system ABIs // on system ABIs stl: "libc++_static", stl: "libc++_static", static_libs: [ static_libs: [ "dnsresolver_aidl_interface-V2-ndk_platform", "dnsresolver_aidl_interface-ndk_platform", "libbase", "libbase", "libcrypto", "libcrypto", "libcutils", "libcutils", Loading Loading @@ -100,6 +100,7 @@ cc_library { debuggable: { debuggable: { cppflags: [ cppflags: [ "-DRESOLV_ALLOW_VERBOSE_LOGGING=1", "-DRESOLV_ALLOW_VERBOSE_LOGGING=1", "-DRESOLV_INJECT_CA_CERTIFICATE=1", ], ], }, }, }, }, Loading Loading @@ -171,6 +172,7 @@ cc_test { "libutils", "libutils", ], ], static_libs: [ static_libs: [ "dnsresolver_aidl_interface-cpp", "libgmock", "libgmock", "libnetd_test_dnsresponder", "libnetd_test_dnsresponder", "libnetd_test_metrics_listener", "libnetd_test_metrics_listener", Loading @@ -180,7 +182,6 @@ cc_test { "libnetdutils", "libnetdutils", "netd_aidl_interface-V2-cpp", "netd_aidl_interface-V2-cpp", "netd_event_listener_interface-V1-cpp", "netd_event_listener_interface-V1-cpp", "dnsresolver_aidl_interface-V2-cpp", ], ], compile_multilib: "both", compile_multilib: "both", sanitize: { sanitize: { Loading DnsResolverService.cpp +2 −43 Original line number Original line Diff line number Diff line Loading @@ -27,8 +27,6 @@ #include <android/binder_manager.h> #include <android/binder_manager.h> #include <android/binder_process.h> #include <android/binder_process.h> #include <netdutils/DumpWriter.h> #include <netdutils/DumpWriter.h> #include <netdutils/NetworkConstants.h> // SHA256_SIZE #include <openssl/base64.h> #include <private/android_filesystem_config.h> // AID_SYSTEM #include <private/android_filesystem_config.h> // AID_SYSTEM #include "DnsResolver.h" #include "DnsResolver.h" Loading Loading @@ -164,33 +162,6 @@ binder_status_t DnsResolverService::dump(int fd, const char**, uint32_t) { return ::ndk::ScopedAStatus(AStatus_fromExceptionCodeWithMessage(EX_SECURITY, err.c_str())); return ::ndk::ScopedAStatus(AStatus_fromExceptionCodeWithMessage(EX_SECURITY, err.c_str())); } } namespace { // Parse a base64 encoded string into a vector of bytes. // On failure, return an empty vector. static std::vector<uint8_t> parseBase64(const std::string& input) { std::vector<uint8_t> decoded; size_t out_len; if (EVP_DecodedLength(&out_len, input.size()) != 1) { return decoded; } // out_len is now an upper bound on the output length. decoded.resize(out_len); if (EVP_DecodeBase64(decoded.data(), &out_len, decoded.size(), reinterpret_cast<const uint8_t*>(input.data()), input.size()) == 1) { // Possibly shrink the vector if the actual output was smaller than the bound. decoded.resize(out_len); } else { decoded.clear(); } if (out_len != android::netdutils::SHA256_SIZE) { decoded.clear(); } return decoded; } } // namespace ::ndk::ScopedAStatus DnsResolverService::setResolverConfiguration( ::ndk::ScopedAStatus DnsResolverService::setResolverConfiguration( const ResolverParamsParcel& resolverParams) { const ResolverParamsParcel& resolverParams) { // Locking happens in PrivateDnsConfiguration and res_* functions. // Locking happens in PrivateDnsConfiguration and res_* functions. Loading @@ -203,21 +174,9 @@ static std::vector<uint8_t> parseBase64(const std::string& input) { resolverParams.sampleValiditySeconds, resolverParams.successThreshold, resolverParams.sampleValiditySeconds, resolverParams.successThreshold, resolverParams.minSamples, resolverParams.maxSamples, resolverParams.minSamples, resolverParams.maxSamples, resolverParams.baseTimeoutMsec, resolverParams.retryCount, resolverParams.baseTimeoutMsec, resolverParams.retryCount, resolverParams.tlsServers, resolverParams.tlsFingerprints); resolverParams.tlsName, resolverParams.tlsServers); std::set<std::vector<uint8_t>> decoded_fingerprints; for (const std::string& fingerprint : resolverParams.tlsFingerprints) { std::vector<uint8_t> decoded = parseBase64(fingerprint); if (decoded.empty()) { return ::ndk::ScopedAStatus(AStatus_fromServiceSpecificErrorWithMessage( EINVAL, "ResolverController error: bad fingerprint")); } decoded_fingerprints.emplace(decoded); } int res = gDnsResolv->resolverCtrl.setResolverConfiguration(resolverParams, decoded_fingerprints); int res = gDnsResolv->resolverCtrl.setResolverConfiguration(resolverParams); gResNetdCallbacks.log(entry.returns(res).withAutomaticDuration().toString().c_str()); gResNetdCallbacks.log(entry.returns(res).withAutomaticDuration().toString().c_str()); return statusFromErrcode(res); return statusFromErrcode(res); Loading DnsTlsServer.cpp +2 −3 Original line number Original line Diff line number Diff line Loading @@ -88,7 +88,7 @@ static bool operator ==(const sockaddr_storage& x, const sockaddr_storage& y) { namespace android { namespace android { namespace net { namespace net { // This comparison ignores ports and fingerprints. // This comparison ignores ports and certificates. bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const { bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const { if (x.ss.ss_family != y.ss.ss_family) { if (x.ss.ss_family != y.ss.ss_family) { return x.ss.ss_family < y.ss.ss_family; return x.ss.ss_family < y.ss.ss_family; Loading @@ -112,7 +112,6 @@ auto make_tie(const DnsTlsServer& s) { return std::tie( return std::tie( s.ss, s.ss, s.name, s.name, s.fingerprints, s.protocol s.protocol ); ); } } Loading @@ -126,7 +125,7 @@ bool DnsTlsServer::operator ==(const DnsTlsServer& other) const { } } bool DnsTlsServer::wasExplicitlyConfigured() const { bool DnsTlsServer::wasExplicitlyConfigured() const { return !name.empty() || !fingerprints.empty(); return !name.empty(); } } } // namespace net } // namespace net Loading DnsTlsServer.h +4 −6 Original line number Original line Diff line number Diff line Loading @@ -47,16 +47,14 @@ struct DnsTlsServer { // The server location, including IP and port. // The server location, including IP and port. sockaddr_storage ss = {}; sockaddr_storage ss = {}; // A set of SHA256 public key fingerprints. If this set is nonempty, the server // must present a self-consistent certificate chain that contains a certificate // whose public key matches one of these fingerprints. Otherwise, the client will // terminate the connection. std::set<std::vector<uint8_t>> fingerprints; // The server's hostname. If this string is nonempty, the server must present a // The server's hostname. If this string is nonempty, the server must present a // certificate that indicates this name and has a valid chain to a trusted root CA. // certificate that indicates this name and has a valid chain to a trusted root CA. std::string name; std::string name; // The certificate of the CA that signed the server's certificate. // It is used to store temporary test CA certificate for internal tests. std::string certificate; // Placeholder. More protocols might be defined in the future. // Placeholder. More protocols might be defined in the future. int protocol = IPPROTO_TCP; int protocol = IPPROTO_TCP; Loading DnsTlsSocket.cpp +33 −73 Original line number Original line Diff line number Diff line Loading @@ -39,6 +39,11 @@ #include "netdutils/SocketOption.h" #include "netdutils/SocketOption.h" #include "private/android_filesystem_config.h" // AID_DNS #include "private/android_filesystem_config.h" // AID_DNS // NOTE: Inject CA certificate for internal testing -- do NOT enable in production builds #ifndef RESOLV_INJECT_CA_CERTIFICATE #define RESOLV_INJECT_CA_CERTIFICATE 0 #endif namespace android { namespace android { using netdutils::enableSockopt; using netdutils::enableSockopt; Loading @@ -51,7 +56,6 @@ namespace net { namespace { namespace { constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; constexpr size_t SHA256_SIZE = SHA256_DIGEST_LENGTH; int waitForReading(int fd) { int waitForReading(int fd) { struct pollfd fds = { .fd = fd, .events = POLLIN }; struct pollfd fds = { .fd = fd, .events = POLLIN }; Loading Loading @@ -121,31 +125,27 @@ Status DnsTlsSocket::tcpConnect() { return netdutils::status::ok; return netdutils::status::ok; } } bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) { bool DnsTlsSocket::setTestCaCertificate() { int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr); bssl::UniquePtr<BIO> bio( unsigned char spki[spki_len]; BIO_new_mem_buf(mServer.certificate.data(), mServer.certificate.size())); unsigned char* temp = spki; bssl::UniquePtr<X509> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) { if (!cert) { LOG(WARNING) << "SPKI length mismatch"; LOG(ERROR) << "Failed to read cert"; return false; } out->resize(SHA256_SIZE); unsigned int digest_len = 0; int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr); if (ret != 1) { LOG(WARNING) << "Server cert digest extraction failed"; return false; return false; } } if (digest_len != out->size()) { LOG(WARNING) << "Wrong digest length: " << digest_len; X509_STORE* cert_store = SSL_CTX_get_cert_store(mSslCtx.get()); if (!X509_STORE_add_cert(cert_store, cert.get())) { LOG(ERROR) << "Failed to add cert"; return false; return false; } } return true; return true; } } // TODO: Try to use static sSslCtx instead of mSslCtx bool DnsTlsSocket::initialize() { bool DnsTlsSocket::initialize() { // This method should only be called once, at the beginning, so locking should be // This method is called every time when a new SSL connection is created. // unnecessary. This lock only serves to help catch bugs in code that calls this method. // This lock only serves to help catch bugs in code that calls this method. std::lock_guard guard(mLock); std::lock_guard guard(mLock); if (mSslCtx) { if (mSslCtx) { // This is a bug in the caller. // This is a bug in the caller. Loading @@ -156,13 +156,22 @@ bool DnsTlsSocket::initialize() { return false; return false; } } // Load system CA certs for hostname verification. // Load system CA certs from CAPath for hostname verification. // // // For discussion of alternative, sustainable approaches see b/71909242. // For discussion of alternative, sustainable approaches see b/71909242. if (RESOLV_INJECT_CA_CERTIFICATE && !mServer.certificate.empty()) { // Inject test CA certs from ResolverParamsParcel.caCertificate for internal testing. LOG(WARNING) << "test CA certificate is valid"; if (!setTestCaCertificate()) { LOG(ERROR) << "Failed to set test CA certificate"; return false; } } else { if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) { if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) { LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir; LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir; return false; return false; } } } // Enable TLS false start // Enable TLS false start SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1); SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1); Loading Loading @@ -210,8 +219,9 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { } } if (!mServer.name.empty()) { if (!mServer.name.empty()) { LOG(VERBOSE) << "Checking DNS over TLS hostname = " << mServer.name.c_str(); if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) { if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) { LOG(ERROR) << "ailed to set SNI to " << mServer.name; LOG(ERROR) << "Failed to set SNI to " << mServer.name; return nullptr; return nullptr; } } X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get()); X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get()); Loading Loading @@ -258,56 +268,6 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { } } } } // TODO: Call SSL_shutdown before discarding the session if validation fails. if (!mServer.fingerprints.empty()) { LOG(DEBUG) << "Checking DNS over TLS fingerprint"; // We only care that the chain is internally self-consistent, not that // it chains to a trusted root, so we can ignore some kinds of errors. // TODO: Add a CA root verification mode that respects these errors. int verify_result = SSL_get_verify_result(ssl.get()); switch (verify_result) { case X509_V_OK: case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: case X509_V_ERR_CERT_UNTRUSTED: break; default: LOG(WARNING) << "Invalid certificate chain, error " << verify_result; return nullptr; } STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get()); if (!chain) { LOG(WARNING) << "Server has null certificate"; return nullptr; } // Chain and its contents are owned by ssl, so we don't need to free explicitly. bool matched = false; for (size_t i = 0; i < sk_X509_num(chain); ++i) { // This appears to be O(N^2), but there doesn't seem to be a straightforward // way to walk a STACK_OF nondestructively in linear time. X509* cert = sk_X509_value(chain, i); std::vector<uint8_t> digest; if (!getSPKIDigest(cert, &digest)) { LOG(ERROR) << "Digest computation failed"; return nullptr; } if (mServer.fingerprints.count(digest) > 0) { matched = true; break; } } if (!matched) { LOG(WARNING) << "No matching fingerprint"; return nullptr; } LOG(DEBUG) << "DNS over TLS fingerprint is correct"; } LOG(DEBUG) << mMark << " handshake complete"; LOG(DEBUG) << mMark << " handshake complete"; return ssl; return ssl; Loading Loading
Android.bp +3 −2 Original line number Original line Diff line number Diff line Loading @@ -70,7 +70,7 @@ cc_library { // on system ABIs // on system ABIs stl: "libc++_static", stl: "libc++_static", static_libs: [ static_libs: [ "dnsresolver_aidl_interface-V2-ndk_platform", "dnsresolver_aidl_interface-ndk_platform", "libbase", "libbase", "libcrypto", "libcrypto", "libcutils", "libcutils", Loading Loading @@ -100,6 +100,7 @@ cc_library { debuggable: { debuggable: { cppflags: [ cppflags: [ "-DRESOLV_ALLOW_VERBOSE_LOGGING=1", "-DRESOLV_ALLOW_VERBOSE_LOGGING=1", "-DRESOLV_INJECT_CA_CERTIFICATE=1", ], ], }, }, }, }, Loading Loading @@ -171,6 +172,7 @@ cc_test { "libutils", "libutils", ], ], static_libs: [ static_libs: [ "dnsresolver_aidl_interface-cpp", "libgmock", "libgmock", "libnetd_test_dnsresponder", "libnetd_test_dnsresponder", "libnetd_test_metrics_listener", "libnetd_test_metrics_listener", Loading @@ -180,7 +182,6 @@ cc_test { "libnetdutils", "libnetdutils", "netd_aidl_interface-V2-cpp", "netd_aidl_interface-V2-cpp", "netd_event_listener_interface-V1-cpp", "netd_event_listener_interface-V1-cpp", "dnsresolver_aidl_interface-V2-cpp", ], ], compile_multilib: "both", compile_multilib: "both", sanitize: { sanitize: { Loading
DnsResolverService.cpp +2 −43 Original line number Original line Diff line number Diff line Loading @@ -27,8 +27,6 @@ #include <android/binder_manager.h> #include <android/binder_manager.h> #include <android/binder_process.h> #include <android/binder_process.h> #include <netdutils/DumpWriter.h> #include <netdutils/DumpWriter.h> #include <netdutils/NetworkConstants.h> // SHA256_SIZE #include <openssl/base64.h> #include <private/android_filesystem_config.h> // AID_SYSTEM #include <private/android_filesystem_config.h> // AID_SYSTEM #include "DnsResolver.h" #include "DnsResolver.h" Loading Loading @@ -164,33 +162,6 @@ binder_status_t DnsResolverService::dump(int fd, const char**, uint32_t) { return ::ndk::ScopedAStatus(AStatus_fromExceptionCodeWithMessage(EX_SECURITY, err.c_str())); return ::ndk::ScopedAStatus(AStatus_fromExceptionCodeWithMessage(EX_SECURITY, err.c_str())); } } namespace { // Parse a base64 encoded string into a vector of bytes. // On failure, return an empty vector. static std::vector<uint8_t> parseBase64(const std::string& input) { std::vector<uint8_t> decoded; size_t out_len; if (EVP_DecodedLength(&out_len, input.size()) != 1) { return decoded; } // out_len is now an upper bound on the output length. decoded.resize(out_len); if (EVP_DecodeBase64(decoded.data(), &out_len, decoded.size(), reinterpret_cast<const uint8_t*>(input.data()), input.size()) == 1) { // Possibly shrink the vector if the actual output was smaller than the bound. decoded.resize(out_len); } else { decoded.clear(); } if (out_len != android::netdutils::SHA256_SIZE) { decoded.clear(); } return decoded; } } // namespace ::ndk::ScopedAStatus DnsResolverService::setResolverConfiguration( ::ndk::ScopedAStatus DnsResolverService::setResolverConfiguration( const ResolverParamsParcel& resolverParams) { const ResolverParamsParcel& resolverParams) { // Locking happens in PrivateDnsConfiguration and res_* functions. // Locking happens in PrivateDnsConfiguration and res_* functions. Loading @@ -203,21 +174,9 @@ static std::vector<uint8_t> parseBase64(const std::string& input) { resolverParams.sampleValiditySeconds, resolverParams.successThreshold, resolverParams.sampleValiditySeconds, resolverParams.successThreshold, resolverParams.minSamples, resolverParams.maxSamples, resolverParams.minSamples, resolverParams.maxSamples, resolverParams.baseTimeoutMsec, resolverParams.retryCount, resolverParams.baseTimeoutMsec, resolverParams.retryCount, resolverParams.tlsServers, resolverParams.tlsFingerprints); resolverParams.tlsName, resolverParams.tlsServers); std::set<std::vector<uint8_t>> decoded_fingerprints; for (const std::string& fingerprint : resolverParams.tlsFingerprints) { std::vector<uint8_t> decoded = parseBase64(fingerprint); if (decoded.empty()) { return ::ndk::ScopedAStatus(AStatus_fromServiceSpecificErrorWithMessage( EINVAL, "ResolverController error: bad fingerprint")); } decoded_fingerprints.emplace(decoded); } int res = gDnsResolv->resolverCtrl.setResolverConfiguration(resolverParams, decoded_fingerprints); int res = gDnsResolv->resolverCtrl.setResolverConfiguration(resolverParams); gResNetdCallbacks.log(entry.returns(res).withAutomaticDuration().toString().c_str()); gResNetdCallbacks.log(entry.returns(res).withAutomaticDuration().toString().c_str()); return statusFromErrcode(res); return statusFromErrcode(res); Loading
DnsTlsServer.cpp +2 −3 Original line number Original line Diff line number Diff line Loading @@ -88,7 +88,7 @@ static bool operator ==(const sockaddr_storage& x, const sockaddr_storage& y) { namespace android { namespace android { namespace net { namespace net { // This comparison ignores ports and fingerprints. // This comparison ignores ports and certificates. bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const { bool AddressComparator::operator() (const DnsTlsServer& x, const DnsTlsServer& y) const { if (x.ss.ss_family != y.ss.ss_family) { if (x.ss.ss_family != y.ss.ss_family) { return x.ss.ss_family < y.ss.ss_family; return x.ss.ss_family < y.ss.ss_family; Loading @@ -112,7 +112,6 @@ auto make_tie(const DnsTlsServer& s) { return std::tie( return std::tie( s.ss, s.ss, s.name, s.name, s.fingerprints, s.protocol s.protocol ); ); } } Loading @@ -126,7 +125,7 @@ bool DnsTlsServer::operator ==(const DnsTlsServer& other) const { } } bool DnsTlsServer::wasExplicitlyConfigured() const { bool DnsTlsServer::wasExplicitlyConfigured() const { return !name.empty() || !fingerprints.empty(); return !name.empty(); } } } // namespace net } // namespace net Loading
DnsTlsServer.h +4 −6 Original line number Original line Diff line number Diff line Loading @@ -47,16 +47,14 @@ struct DnsTlsServer { // The server location, including IP and port. // The server location, including IP and port. sockaddr_storage ss = {}; sockaddr_storage ss = {}; // A set of SHA256 public key fingerprints. If this set is nonempty, the server // must present a self-consistent certificate chain that contains a certificate // whose public key matches one of these fingerprints. Otherwise, the client will // terminate the connection. std::set<std::vector<uint8_t>> fingerprints; // The server's hostname. If this string is nonempty, the server must present a // The server's hostname. If this string is nonempty, the server must present a // certificate that indicates this name and has a valid chain to a trusted root CA. // certificate that indicates this name and has a valid chain to a trusted root CA. std::string name; std::string name; // The certificate of the CA that signed the server's certificate. // It is used to store temporary test CA certificate for internal tests. std::string certificate; // Placeholder. More protocols might be defined in the future. // Placeholder. More protocols might be defined in the future. int protocol = IPPROTO_TCP; int protocol = IPPROTO_TCP; Loading
DnsTlsSocket.cpp +33 −73 Original line number Original line Diff line number Diff line Loading @@ -39,6 +39,11 @@ #include "netdutils/SocketOption.h" #include "netdutils/SocketOption.h" #include "private/android_filesystem_config.h" // AID_DNS #include "private/android_filesystem_config.h" // AID_DNS // NOTE: Inject CA certificate for internal testing -- do NOT enable in production builds #ifndef RESOLV_INJECT_CA_CERTIFICATE #define RESOLV_INJECT_CA_CERTIFICATE 0 #endif namespace android { namespace android { using netdutils::enableSockopt; using netdutils::enableSockopt; Loading @@ -51,7 +56,6 @@ namespace net { namespace { namespace { constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; constexpr size_t SHA256_SIZE = SHA256_DIGEST_LENGTH; int waitForReading(int fd) { int waitForReading(int fd) { struct pollfd fds = { .fd = fd, .events = POLLIN }; struct pollfd fds = { .fd = fd, .events = POLLIN }; Loading Loading @@ -121,31 +125,27 @@ Status DnsTlsSocket::tcpConnect() { return netdutils::status::ok; return netdutils::status::ok; } } bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) { bool DnsTlsSocket::setTestCaCertificate() { int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), nullptr); bssl::UniquePtr<BIO> bio( unsigned char spki[spki_len]; BIO_new_mem_buf(mServer.certificate.data(), mServer.certificate.size())); unsigned char* temp = spki; bssl::UniquePtr<X509> cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) { if (!cert) { LOG(WARNING) << "SPKI length mismatch"; LOG(ERROR) << "Failed to read cert"; return false; } out->resize(SHA256_SIZE); unsigned int digest_len = 0; int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), nullptr); if (ret != 1) { LOG(WARNING) << "Server cert digest extraction failed"; return false; return false; } } if (digest_len != out->size()) { LOG(WARNING) << "Wrong digest length: " << digest_len; X509_STORE* cert_store = SSL_CTX_get_cert_store(mSslCtx.get()); if (!X509_STORE_add_cert(cert_store, cert.get())) { LOG(ERROR) << "Failed to add cert"; return false; return false; } } return true; return true; } } // TODO: Try to use static sSslCtx instead of mSslCtx bool DnsTlsSocket::initialize() { bool DnsTlsSocket::initialize() { // This method should only be called once, at the beginning, so locking should be // This method is called every time when a new SSL connection is created. // unnecessary. This lock only serves to help catch bugs in code that calls this method. // This lock only serves to help catch bugs in code that calls this method. std::lock_guard guard(mLock); std::lock_guard guard(mLock); if (mSslCtx) { if (mSslCtx) { // This is a bug in the caller. // This is a bug in the caller. Loading @@ -156,13 +156,22 @@ bool DnsTlsSocket::initialize() { return false; return false; } } // Load system CA certs for hostname verification. // Load system CA certs from CAPath for hostname verification. // // // For discussion of alternative, sustainable approaches see b/71909242. // For discussion of alternative, sustainable approaches see b/71909242. if (RESOLV_INJECT_CA_CERTIFICATE && !mServer.certificate.empty()) { // Inject test CA certs from ResolverParamsParcel.caCertificate for internal testing. LOG(WARNING) << "test CA certificate is valid"; if (!setTestCaCertificate()) { LOG(ERROR) << "Failed to set test CA certificate"; return false; } } else { if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) { if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) { LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir; LOG(ERROR) << "Failed to load CA cert dir: " << kCaCertDir; return false; return false; } } } // Enable TLS false start // Enable TLS false start SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1); SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1); Loading Loading @@ -210,8 +219,9 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { } } if (!mServer.name.empty()) { if (!mServer.name.empty()) { LOG(VERBOSE) << "Checking DNS over TLS hostname = " << mServer.name.c_str(); if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) { if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) { LOG(ERROR) << "ailed to set SNI to " << mServer.name; LOG(ERROR) << "Failed to set SNI to " << mServer.name; return nullptr; return nullptr; } } X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get()); X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get()); Loading Loading @@ -258,56 +268,6 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { } } } } // TODO: Call SSL_shutdown before discarding the session if validation fails. if (!mServer.fingerprints.empty()) { LOG(DEBUG) << "Checking DNS over TLS fingerprint"; // We only care that the chain is internally self-consistent, not that // it chains to a trusted root, so we can ignore some kinds of errors. // TODO: Add a CA root verification mode that respects these errors. int verify_result = SSL_get_verify_result(ssl.get()); switch (verify_result) { case X509_V_OK: case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: case X509_V_ERR_CERT_UNTRUSTED: break; default: LOG(WARNING) << "Invalid certificate chain, error " << verify_result; return nullptr; } STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get()); if (!chain) { LOG(WARNING) << "Server has null certificate"; return nullptr; } // Chain and its contents are owned by ssl, so we don't need to free explicitly. bool matched = false; for (size_t i = 0; i < sk_X509_num(chain); ++i) { // This appears to be O(N^2), but there doesn't seem to be a straightforward // way to walk a STACK_OF nondestructively in linear time. X509* cert = sk_X509_value(chain, i); std::vector<uint8_t> digest; if (!getSPKIDigest(cert, &digest)) { LOG(ERROR) << "Digest computation failed"; return nullptr; } if (mServer.fingerprints.count(digest) > 0) { matched = true; break; } } if (!matched) { LOG(WARNING) << "No matching fingerprint"; return nullptr; } LOG(DEBUG) << "DNS over TLS fingerprint is correct"; } LOG(DEBUG) << mMark << " handshake complete"; LOG(DEBUG) << mMark << " handshake complete"; return ssl; return ssl; Loading