Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6276d530 authored by Matthew Maurer's avatar Matthew Maurer Committed by Automerger Merge Worker
Browse files

Merge changes I9cbba1b8,I77f8697c,I6d4c296f am: 09a42197 am: f76c9616

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1847137

Change-Id: Icadc4ec15b08cf58ea579d1388fa226e3637c7ec
parents 5fcd7dc8 f76c9616
Loading
Loading
Loading
Loading

doh/config.rs

0 → 100644
+239 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! Quiche Config support
//!
//! Quiche config objects are needed mutably for constructing a Quiche
//! connection object, but not when they are actually being used. As these
//! objects include a `SSL_CTX` which can be somewhat expensive and large when
//! using a certificate path, it can be beneficial to cache them.
//!
//! This module provides a caching layer for loading and constructing
//! these configurations.

use quiche::{h3, Result};
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex, RwLock, Weak};

type WeakConfig = Weak<Mutex<quiche::Config>>;

/// A cheaply clonable `quiche::Config`
#[derive(Clone)]
pub struct Config(Arc<Mutex<quiche::Config>>);

const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000;
const MAX_CONCURRENT_STREAM_SIZE: u64 = 100;
/// Maximum datagram size we will accept.
pub const MAX_DATAGRAM_SIZE: usize = 1350;
/// How long with no packets before we assume a connection is dead, in milliseconds.
pub const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000;

impl Config {
    fn from_weak(weak: &WeakConfig) -> Option<Self> {
        weak.upgrade().map(Self)
    }

    fn to_weak(&self) -> WeakConfig {
        Arc::downgrade(&self.0)
    }

    /// Construct a `Config` object from certificate path. If no path
    /// is provided, peers will not be verified.
    pub fn from_cert_path(cert_path: Option<&str>) -> Result<Self> {
        let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
        config.set_application_protos(h3::APPLICATION_PROTOCOL)?;
        match cert_path {
            Some(path) => {
                config.verify_peer(true);
                config.load_verify_locations_from_directory(path)?;
            }
            None => config.verify_peer(false),
        }

        // Some of these configs are necessary, or the server can't respond the HTTP/3 request.
        config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS);
        config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
        config.set_initial_max_data(MAX_INCOMING_BUFFER_SIZE_WHOLE);
        config.set_initial_max_stream_data_bidi_local(MAX_INCOMING_BUFFER_SIZE_EACH);
        config.set_initial_max_stream_data_bidi_remote(MAX_INCOMING_BUFFER_SIZE_EACH);
        config.set_initial_max_stream_data_uni(MAX_INCOMING_BUFFER_SIZE_EACH);
        config.set_initial_max_streams_bidi(MAX_CONCURRENT_STREAM_SIZE);
        config.set_initial_max_streams_uni(MAX_CONCURRENT_STREAM_SIZE);
        config.set_disable_active_migration(true);
        Ok(Self(Arc::new(Mutex::new(config))))
    }

    /// Take the underlying config, usable as `&mut quiche::Config` for use
    /// with `quiche::connect`.
    pub fn take(&mut self) -> impl DerefMut<Target = quiche::Config> + '_ {
        self.0.lock().unwrap()
    }
}

#[derive(Clone, Default)]
struct State {
    // Mapping from cert_path to configs
    path_to_config: HashMap<Option<String>, WeakConfig>,
    // Keep latest config alive to minimize reparsing when flapping
    // If more keep-alive is needed, replace with a LRU LinkedList
    latest: Option<Config>,
}

impl State {
    fn get_config(&self, cert_path: &Option<String>) -> Option<Config> {
        self.path_to_config.get(cert_path).and_then(Config::from_weak)
    }

    fn keep_alive(&mut self, config: Config) {
        self.latest = Some(config);
    }

    fn garbage_collect(&mut self) {
        self.path_to_config.retain(|_, config| config.strong_count() != 0)
    }
}

/// Cache of Quiche Config objects
///
/// Cloning this cache will create another handle to the same cache.
///
/// Loading a config object through this caching layer will only keep the
/// latest config loaded alive directly, but will still act as a cache
/// for any configurations still in use - if the returned `Config` is still
/// live, queries to `Cache` will not reconstruct it.
#[derive(Clone, Default)]
pub struct Cache {
    // Shared state amongst cache handles
    state: Arc<RwLock<State>>,
}

impl Cache {
    /// Creates a fresh empty cache
    pub fn new() -> Self {
        Default::default()
    }

    /// Behaves as `Config::from_cert_path`, but with a cache.
    /// If any object previously given out by this cache is still live,
    /// a duplicate will not be made.
    pub fn from_cert_path(&self, cert_path: &Option<String>) -> Result<Config> {
        // Fast path - read-only access to state retrieves config
        if let Some(config) = self.state.read().unwrap().get_config(cert_path) {
            return Ok(config);
        }

        // Unlocked, calculate config. If we have two racing attempts to load
        // the cert path, we'll arbitrate that in the next step, but this
        // makes sure loading a new cert path doesn't block other loads to
        // refresh connections.
        let config = Config::from_cert_path(cert_path.as_deref())?;

        let mut state = self.state.write().unwrap();
        // We now have exclusive access to the state.
        // If someone else calculated a config at the same time as us, we
        // want to discard ours and use theirs, since it will result in
        // less total memory used.
        if let Some(config) = state.get_config(cert_path) {
            return Ok(config);
        }

        // We have exclusive access and a fresh config. Install it into
        // the cache.
        state.keep_alive(config.clone());
        state.path_to_config.insert(cert_path.to_owned(), config.to_weak());
        Ok(config)
    }

    /// Purges any config paths which no longer point to a config entry.
    pub fn garbage_collect(&self) {
        self.state.write().unwrap().garbage_collect();
    }
}

#[test]
fn create_quiche_config() {
    assert!(Config::from_cert_path(None).is_ok(), "quiche config without cert creating failed");
    assert!(
        Config::from_cert_path(Some("data/local/tmp/")).is_ok(),
        "quiche config with cert creating failed"
    );
}

#[test]
fn shared_cache() {
    let cache_a = Cache::new();
    let cache_b = cache_a.clone();
    let config_a = cache_a.from_cert_path(&None).unwrap();
    assert_eq!(Arc::strong_count(&config_a.0), 2);
    let _config_b = cache_b.from_cert_path(&None).unwrap();
    assert_eq!(Arc::strong_count(&config_a.0), 3);
}

#[test]
fn lifetimes() {
    let cache = Cache::new();
    let config_none = cache.from_cert_path(&None).unwrap();
    let config_a = cache.from_cert_path(&Some("a".to_string())).unwrap();
    let config_b = cache.from_cert_path(&Some("b".to_string())).unwrap();
    // The first two we created should have a strong count of one - those handles are the only
    // thing keeping them alive.
    assert_eq!(Arc::strong_count(&config_none.0), 1);
    assert_eq!(Arc::strong_count(&config_a.0), 1);

    // If we try to get another handle we already have, it should be the same one.
    let _config_a2 = cache.from_cert_path(&Some("a".to_string())).unwrap();
    assert_eq!(Arc::strong_count(&config_a.0), 2);

    // config_b was most recently created, so it should have a keep-alive
    // inside the cache.
    assert_eq!(Arc::strong_count(&config_b.0), 2);

    // If we weaken one of the first handles, then drop it, the weak handle should break
    let config_none_weak = Config::to_weak(&config_none);
    assert_eq!(config_none_weak.strong_count(), 1);
    drop(config_none);
    assert_eq!(config_none_weak.strong_count(), 0);
    assert!(Config::from_weak(&config_none_weak).is_none());

    // If we weaken the most *recent* handle, it should keep working
    let config_b_weak = Config::to_weak(&config_b);
    assert_eq!(config_b_weak.strong_count(), 2);
    drop(config_b);
    assert_eq!(config_b_weak.strong_count(), 1);
    assert!(Config::from_weak(&config_b_weak).is_some());
    assert_eq!(config_b_weak.strong_count(), 1);

    // If we try to get a config which is still kept alive by the cache, we should get the same
    // one.
    let _config_b2 = cache.from_cert_path(&Some("b".to_string())).unwrap();
    assert_eq!(config_b_weak.strong_count(), 2);

    // We broke None, but "a" and "b" should still both be alive. Check that
    // this is still the case in the mapping after garbage collection.
    cache.garbage_collect();
    assert_eq!(cache.state.read().unwrap().path_to_config.len(), 2);
}

#[test]
fn quiche_connect() {
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
    let mut config = Config::from_cert_path(None).unwrap();
    let socket_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 42));
    let conn_id = quiche::ConnectionId::from_ref(&[]);
    quiche::connect(None, &conn_id, socket_addr, &mut config.take()).unwrap();
}
+73 −326

File changed.

Preview size limit exceeded, changes collapsed.

doh/encoding.rs

0 → 100644
+115 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! Format DoH requests

use anyhow::{anyhow, Context, Result};
use quiche::h3;
use ring::rand::SecureRandom;
use url::Url;

pub type DnsRequest = Vec<quiche::h3::Header>;

const NS_T_AAAA: u8 = 28;
const NS_C_IN: u8 = 1;
// Used to randomly generate query prefix and query id.
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
                         abcdefghijklmnopqrstuvwxyz\
                         0123456789";

/// Produces a DNS query with randomized query ID and random 6-byte charset-legal prefix to produce
/// a request for a domain of the form:
/// ??????-dnsohttps-ds.metric.gstatic.com
#[rustfmt::skip]
pub fn probe_query() -> Result<String> {
    let mut rnd = [0; 8];
    ring::rand::SystemRandom::new().fill(&mut rnd).context("failed to generate probe rnd")?;
    let c = |byte| CHARSET[(byte as usize) % CHARSET.len()];
    let query = vec![
        rnd[6], rnd[7],  // [0-1]   query ID
        1,      0,       // [2-3]   flags; query[2] = 1 for recursion desired (RD).
        0,      1,       // [4-5]   QDCOUNT (number of queries)
        0,      0,       // [6-7]   ANCOUNT (number of answers)
        0,      0,       // [8-9]   NSCOUNT (number of name server records)
        0,      0,       // [10-11] ARCOUNT (number of additional records)
        19,     c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), b'-', b'd', b'n',
        b's',   b'o',      b'h',      b't',      b't',      b'p',      b's',      b'-', b'd', b's',
        6,      b'm',      b'e',      b't',      b'r',      b'i',      b'c',      7,    b'g', b's',
        b't',   b'a',      b't',      b'i',      b'c',      3,         b'c',      b'o', b'm',
        0,                  // null terminator of FQDN (root TLD)
        0,      NS_T_AAAA,  // QTYPE
        0,      NS_C_IN     // QCLASS
    ];
    Ok(base64::encode_config(query, base64::URL_SAFE_NO_PAD))
}

/// Takes in a base64-encoded copy of a traditional DNS request and a
/// URL at which the DoH server is running and produces a set of HTTP/3 headers
/// corresponding to a DoH request for it.
pub fn dns_request(base64_query: &str, url: &Url) -> Result<DnsRequest> {
    let mut path = String::from(url.path());
    path.push_str("?dns=");
    path.push_str(base64_query);
    let req = vec![
        h3::Header::new(b":method", b"GET"),
        h3::Header::new(b":scheme", b"https"),
        h3::Header::new(
            b":authority",
            url.host_str().ok_or_else(|| anyhow!("failed to get host"))?.as_bytes(),
        ),
        h3::Header::new(b":path", path.as_bytes()),
        h3::Header::new(b"user-agent", b"quiche"),
        h3::Header::new(b"accept", b"application/dns-message"),
    ];

    Ok(req)
}

#[cfg(test)]
mod tests {
    use quiche::h3::NameValue;
    use url::Url;

    const PROBE_QUERY_SIZE: usize = 56;
    const H3_DNS_REQUEST_HEADER_SIZE: usize = 6;
    const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";

    #[test]
    fn make_probe_query_and_request() {
        let probe_query = super::probe_query().unwrap();
        let url = Url::parse(LOCALHOST_URL).unwrap();
        let request = super::dns_request(&probe_query, &url).unwrap();
        // Verify H3 DNS request.
        assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE);
        assert_eq!(request[0].name(), b":method");
        assert_eq!(request[0].value(), b"GET");
        assert_eq!(request[1].name(), b":scheme");
        assert_eq!(request[1].value(), b"https");
        assert_eq!(request[2].name(), b":authority");
        assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes());
        assert_eq!(request[3].name(), b":path");
        let mut path = String::from(url.path());
        path.push_str("?dns=");
        path.push_str(&probe_query);
        assert_eq!(request[3].value(), path.as_bytes());
        assert_eq!(request[5].name(), b"accept");
        assert_eq!(request[5].value(), b"application/dns-message");

        // Verify DNS probe packet.
        let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap();
        assert_eq!(bytes.len(), PROBE_QUERY_SIZE);
    }
}
+128 −4
Original line number Diff line number Diff line
@@ -17,19 +17,67 @@
//! C API for the DoH backend for the Android DnsResolver module.

use crate::boot_time::{timeout, BootTime, Duration};
use futures::FutureExt;
use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::error;
use log::{error, warn};
use std::ffi::CString;
use std::net::{IpAddr, SocketAddr};
use std::ops::DerefMut;
use std::os::unix::io::RawFd;
use std::str::FromStr;
use std::sync::Mutex;
use std::{ptr, slice};
use tokio::runtime::Runtime;
use tokio::sync::oneshot;
use tokio::task;
use url::Url;

use super::DohDispatcher as Dispatcher;
use super::{DohCommand, Response, ServerInfo, TagSocketCallback, ValidationCallback, DOH_PORT};
use super::{DohCommand, Response, ServerInfo, SocketTagger, ValidationReporter, DOH_PORT};

pub type ValidationCallback =
    extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
pub type TagSocketCallback = extern "C" fn(sock: RawFd);

fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
    Box::new(move |info: &ServerInfo, success: bool| {
        async move {
            let (ip_addr, domain) = match (
                CString::new(info.peer_addr.ip().to_string()),
                CString::new(info.domain.clone().unwrap_or_default()),
            ) {
                (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
                _ => {
                    error!("validation_callback bad input");
                    return;
                }
            };
            let netd_id = info.net_id;
            task::spawn_blocking(move || {
                validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
            })
            .await
            .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
        }
        .boxed()
    })
}

fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
    use std::os::unix::io::AsRawFd;
    use std::sync::Arc;
    Arc::new(move |udp_socket: &std::net::UdpSocket| {
        let fd = udp_socket.as_raw_fd();
        async move {
            task::spawn_blocking(move || {
                tag_socket_fn(fd);
            })
            .await
            .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
        }
        .boxed()
    })
}

pub struct DohDispatcher(Mutex<Dispatcher>);

@@ -97,7 +145,10 @@ pub extern "C" fn doh_dispatcher_new(
    validation_fn: ValidationCallback,
    tag_socket_fn: TagSocketCallback,
) -> *mut DohDispatcher {
    match Dispatcher::new(validation_fn, tag_socket_fn) {
    match Dispatcher::new(
        wrap_validation_callback(validation_fn),
        wrap_tag_socket_callback(tag_socket_fn),
    ) {
        Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
        Err(e) => {
            error!("doh_dispatcher_new: failed: {:?}", e);
@@ -158,7 +209,7 @@ pub unsafe extern "C" fn doh_net_new(
        }
    };

    let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) {
    let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
        (Ok(url), Ok(ip_addr)) => (url, ip_addr),
        _ => {
            error!("bad ip or url"); // Should not happen
@@ -262,3 +313,76 @@ pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
        error!("Failed to send the query: {:?}", e);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const TEST_NET_ID: u32 = 50;
    const LOOPBACK_ADDR: &str = "127.0.0.1:443";
    const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";

    extern "C" fn success_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    extern "C" fn fail_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(!success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    // # Safety
    // `ip_addr`, `host` are null terminated strings
    unsafe fn assert_validation_info(
        net_id: uint32_t,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert_eq!(net_id, TEST_NET_ID);
        let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
        let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
        assert_eq!(ip_addr, expected_addr.ip().to_string());
        let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
        assert_eq!(host, "");
    }

    #[tokio::test]
    async fn wrap_validation_callback_converts_correctly() {
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };

        wrap_validation_callback(success_cb)(&info, true).await;
        wrap_validation_callback(fail_cb)(&info, false).await;
    }

    extern "C" fn tag_socket_cb(raw_fd: RawFd) {
        assert!(raw_fd > 0)
    }

    #[tokio::test]
    async fn wrap_tag_socket_callback_converts_correctly() {
        let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
        wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
    }
}