Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3a6b54a1 authored by Matthew Maurer's avatar Matthew Maurer
Browse files

DoH: Migrate C callback logic to ffi module

Bug: 202081046
Change-Id: I9cbba1b8181ce5c1b0276aebb64fbec194cef6a5
parent 30836670
Loading
Loading
Loading
Loading
+42 −144
Original line number Diff line number Diff line
@@ -17,18 +17,17 @@
//! DoH backend for the Android DnsResolver module.

use anyhow::{anyhow, bail, Context, Result};
use futures::future::join_all;
use futures::future::{join_all, BoxFuture};
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use libc::{c_char, int32_t, uint32_t};
use log::{debug, error, info, trace, warn};
use quiche::h3;
use ring::rand::SecureRandom;
use std::collections::HashMap;
use std::ffi::CString;
use std::net::SocketAddr;
use std::os::unix::io::{AsRawFd, RawFd};
use std::pin::Pin;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{mpsc, oneshot};
@@ -46,15 +45,15 @@ use config::Config;
const MAX_BUFFERED_CMD_SIZE: usize = 400;
const DOH_PORT: u16 = 443;

type ValidationReporter = Box<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>;
type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>;

type SCID = [u8; quiche::MAX_CONN_ID_LEN];
type Base64Query = String;
type CmdSender = mpsc::Sender<DohCommand>;
type CmdReceiver = mpsc::Receiver<DohCommand>;
type QueryResponder = oneshot::Sender<Response>;
type DnsRequest = Vec<quiche::h3::Header>;
type ValidationCallback =
    extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
type TagSocketCallback = extern "C" fn(sock: int32_t);

#[derive(Eq, PartialEq, Debug)]
enum QueryError {
@@ -108,6 +107,15 @@ enum ConnectionState {
    Error,
}

impl ConnectionState {
    fn is_connected(&self) -> bool {
        matches!(*self, Self::Connected { .. })
    }
    fn is_error(&self) -> bool {
        matches!(*self, Self::Error)
    }
}

enum H3Result {
    Data { data: Vec<u8> },
    Finished,
@@ -124,10 +132,7 @@ pub struct DohDispatcher {

// DoH dispatcher
impl DohDispatcher {
    fn new(
        validation_fn: ValidationCallback,
        tag_socket_fn: TagSocketCallback,
    ) -> Result<DohDispatcher> {
    fn new(validation: ValidationReporter, tag_socket: SocketTagger) -> Result<DohDispatcher> {
        let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE);
        let runtime = Builder::new_multi_thread()
            .worker_threads(2)
@@ -135,7 +140,7 @@ impl DohDispatcher {
            .thread_name("doh-handler")
            .build()
            .expect("Failed to create tokio runtime");
        let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation_fn, tag_socket_fn));
        let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation, tag_socket));
        Ok(DohDispatcher { cmd_sender, join_handle, runtime })
    }

@@ -159,15 +164,11 @@ struct DohConnection {
    state: ConnectionState,
    pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>,
    cached_session: Option<Vec<u8>>,
    tag_socket_fn: TagSocketCallback,
    tag_socket: SocketTagger,
}

impl DohConnection {
    fn new(
        info: &ServerInfo,
        config: Config,
        tag_socket_fn: TagSocketCallback,
    ) -> Result<DohConnection> {
    fn new(info: &ServerInfo, config: Config, tag_socket: SocketTagger) -> Result<DohConnection> {
        let mut scid = [0; quiche::MAX_CONN_ID_LEN];
        ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?;
        Ok(DohConnection {
@@ -177,15 +178,18 @@ impl DohConnection {
            state: ConnectionState::Idle,
            pending_queries: Vec::new(),
            cached_session: None,
            tag_socket_fn,
            tag_socket,
        })
    }

    fn state_to_connecting(&mut self) -> Result<()> {
    async fn state_to_connecting(&mut self) -> Result<()> {
        if self.state.is_error() {
            self.state_to_idle();
        }
        self.state = match self.state {
            ConnectionState::Idle => {
                let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?;
                (self.tag_socket_fn)(udp_sk_std.as_raw_fd());
                (self.tag_socket)(&udp_sk_std).await;
                let udp_sk = UdpSocket::from_std(udp_sk_std)?;
                let connid = quiche::ConnectionId::from_ref(&self.scid);
                debug!("init the connection for Network {}", self.info.net_id);
@@ -206,10 +210,7 @@ impl DohConnection {
                    expired_time: None,
                }
            }
            ConnectionState::Error => {
                self.state_to_idle();
                return self.state_to_connecting();
            }
            ConnectionState::Error => panic!("state_to_idle did not transition"),
            ConnectionState::Connecting { .. } => return Ok(()),
            ConnectionState::Connected { .. } => {
                panic!("Invalid state transition to Connecting state!")
@@ -351,7 +352,7 @@ impl DohConnection {
        if matches!(self.state, ConnectionState::Connected { .. }) {
            return Ok(());
        }
        self.state_to_connecting()?;
        self.state_to_connecting().await?;
        debug!("connecting to Network {}", self.info.net_id);

        let (quic_conn, udp_sk, expired_time) = match &mut self.state {
@@ -636,34 +637,10 @@ async fn flush_tx(
    Ok(())
}

async fn report_private_dns_validation(
    info: &ServerInfo,
    state: &ConnectionState,
    validation_fn: ValidationCallback,
) {
    let (ip_addr, domain) = match (
        CString::new(info.peer_addr.ip().to_string()),
        CString::new(info.domain.clone().unwrap_or_default()),
    ) {
        (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
        _ => {
            error!("report_private_dns_validation bad input");
            return;
        }
    };
    let netd_id = info.net_id;
    let success = matches!(state, ConnectionState::Connected { .. });
    task::spawn_blocking(move || {
        validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
    })
    .await
    .unwrap_or_else(|e| warn!("Validation function task failed: {}", e));
}

async fn handle_probe_result(
    result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>),
    doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
    validation_fn: ValidationCallback,
    validation: &ValidationReporter,
) {
    let (info, doh_conn) = match result {
        (info, Ok(doh_conn)) => {
@@ -693,7 +670,7 @@ async fn handle_probe_result(
            return;
        }
    }
    report_private_dns_validation(&info, &doh_conn.state, validation_fn).await;
    validation(&info, doh_conn.state.is_connected()).await;
    doh_conn_map.insert(info.net_id, (info, Some(doh_conn)));
}

@@ -712,7 +689,7 @@ fn make_connection_if_needed(
    info: &ServerInfo,
    doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
    config_cache: &config::Cache,
    tag_socket_fn: TagSocketCallback,
    tag_socket: SocketTagger,
) -> Result<Option<DohConnection>> {
    // Check if connection exists.
    match doh_conn_map.get(&info.net_id) {
@@ -729,7 +706,7 @@ fn make_connection_if_needed(
        _ => doh_conn_map.remove(&info.net_id),
    };
    let config = config_cache.from_cert_path(&info.cert_path)?;
    let doh = DohConnection::new(info, config, tag_socket_fn)?;
    let doh = DohConnection::new(info, config, tag_socket)?;
    doh_conn_map.insert(info.net_id, (info.clone(), None));
    Ok(Some(doh))
}
@@ -781,8 +758,8 @@ fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConne

async fn doh_handler(
    mut cmd_rx: CmdReceiver,
    validation_fn: ValidationCallback,
    tag_socket_fn: TagSocketCallback,
    validation: ValidationReporter,
    tag_socket: SocketTagger,
) -> Result<()> {
    info!("doh_dispatcher entry");
    let config_cache = config::Cache::new();
@@ -802,14 +779,14 @@ async fn doh_handler(
                join_all(futures).await
            }, if need_process_queries(&doh_conn_map) => {},
            Some(result) = probe_futures.next() => {
                handle_probe_result(result, &mut doh_conn_map, validation_fn).await;
                handle_probe_result(result, &mut doh_conn_map, &validation).await;
                info!("probe_futures remaining size: {}", probe_futures.len());
            },
            Some(cmd) = cmd_rx.recv() => {
                trace!("recv {:?}", cmd);
                match cmd {
                    DohCommand::Probe { info, timeout: t } => {
                        match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket_fn) {
                        match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket.clone()) {
                            Ok(Some(doh)) => {
                                // Create a new async task associated to the DoH connection.
                                probe_futures.push(probe_task(info, doh, t));
@@ -822,7 +799,7 @@ async fn doh_handler(
                            }
                            Err(e) => {
                                error!("create connection for network {} error {:?}", info.net_id, e);
                                report_private_dns_validation(&info, &ConnectionState::Error, validation_fn).await;
                                validation(&info, false).await
                            }
                        }
                    },
@@ -879,6 +856,7 @@ fn mark_socket(fd: RawFd, mark: u32) -> Result<()> {
#[cfg(test)]
mod tests {
    use super::*;
    use futures::FutureExt;
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};

    const TEST_NET_ID: u32 = 50;
@@ -909,8 +887,8 @@ mod tests {
        (info, test_map, config_cache, rt)
    }

    extern "C" fn tag_socket_cb(sock: int32_t) {
        assert!(sock >= 0);
    fn build_socket_tagger() -> SocketTagger {
        Arc::new(|_| async {}.boxed())
    }

    #[test]
@@ -922,7 +900,7 @@ mod tests {
                &info,
                &mut test_map,
                &config_cache,
                tag_socket_cb,
                build_socket_tagger(),
            )
            .unwrap()
            .unwrap();
@@ -935,7 +913,7 @@ mod tests {
                &info,
                &mut test_map,
                &config_cache,
                tag_socket_cb,
                build_socket_tagger(),
            )
            .unwrap()
            .unwrap();
@@ -948,7 +926,7 @@ mod tests {
                &info,
                &mut test_map,
                &config_cache,
                tag_socket_cb
                build_socket_tagger()
            )
            .unwrap()
            .is_none());
@@ -959,7 +937,6 @@ mod tests {
    fn handle_query_cmd() {
        let (info, mut test_map, config_cache, rt) = make_testing_variables();
        let t = Duration::from_millis(100);

        rt.block_on(async {
            // Test no available server cases.
            let (resp_tx, resp_rx) = oneshot::channel();
@@ -999,7 +976,7 @@ mod tests {
                &info,
                &mut test_map,
                &config_cache,
                tag_socket_cb,
                build_socket_tagger(),
            )
            .unwrap()
            .unwrap();
@@ -1020,45 +997,6 @@ mod tests {
        });
    }

    extern "C" fn success_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    extern "C" fn fail_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(!success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    // # Safety
    // `ip_addr`, `host` are null terminated strings
    unsafe fn assert_validation_info(
        net_id: uint32_t,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert_eq!(net_id, TEST_NET_ID);
        let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
        let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
        assert_eq!(ip_addr, expected_addr.ip().to_string());
        let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
        assert_eq!(host, "");
    }

    fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) {
        let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
        let udp_sk = UdpSocket::from_std(sk).unwrap();
@@ -1083,46 +1021,6 @@ mod tests {
        }
    }

    fn make_dummy_connecting_state() -> super::ConnectionState {
        let (quic_conn, udp_sk) = make_testing_connection_variables();
        ConnectionState::Connecting {
            quic_conn: Some(quic_conn),
            udp_sk: Some(udp_sk),
            expired_time: None,
        }
    }

    #[test]
    fn report_private_dns_validation() {
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };
        let rt = Builder::new_current_thread()
            .thread_name("test-runtime")
            .enable_io()
            .build()
            .expect("Failed to create testing tokio runtime");
        let default_panic = std::panic::take_hook();
        // Exit the test if the worker inside tokio runtime panicked.
        std::panic::set_hook(Box::new(move |info| {
            default_panic(info);
            std::process::exit(1);
        }));
        rt.block_on(async {
            super::report_private_dns_validation(&info, &make_dummy_connected_state(), success_cb)
                .await;
            super::report_private_dns_validation(&info, &ConnectionState::Error, fail_cb).await;
            super::report_private_dns_validation(&info, &make_dummy_connecting_state(), fail_cb)
                .await;
            super::report_private_dns_validation(&info, &ConnectionState::Idle, fail_cb).await;
        });
    }

    #[test]
    fn make_doh_udp_socket() {
        // Make a socket connecting to loopback with a test mark.
+128 −4
Original line number Diff line number Diff line
@@ -17,19 +17,67 @@
//! C API for the DoH backend for the Android DnsResolver module.

use crate::boot_time::{timeout, BootTime, Duration};
use futures::FutureExt;
use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::error;
use log::{error, warn};
use std::ffi::CString;
use std::net::{IpAddr, SocketAddr};
use std::ops::DerefMut;
use std::os::unix::io::RawFd;
use std::str::FromStr;
use std::sync::Mutex;
use std::{ptr, slice};
use tokio::runtime::Runtime;
use tokio::sync::oneshot;
use tokio::task;
use url::Url;

use super::DohDispatcher as Dispatcher;
use super::{DohCommand, Response, ServerInfo, TagSocketCallback, ValidationCallback, DOH_PORT};
use super::{DohCommand, Response, ServerInfo, SocketTagger, ValidationReporter, DOH_PORT};

pub type ValidationCallback =
    extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
pub type TagSocketCallback = extern "C" fn(sock: RawFd);

fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
    Box::new(move |info: &ServerInfo, success: bool| {
        async move {
            let (ip_addr, domain) = match (
                CString::new(info.peer_addr.ip().to_string()),
                CString::new(info.domain.clone().unwrap_or_default()),
            ) {
                (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
                _ => {
                    error!("validation_callback bad input");
                    return;
                }
            };
            let netd_id = info.net_id;
            task::spawn_blocking(move || {
                validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
            })
            .await
            .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
        }
        .boxed()
    })
}

fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
    use std::os::unix::io::AsRawFd;
    use std::sync::Arc;
    Arc::new(move |udp_socket: &std::net::UdpSocket| {
        let fd = udp_socket.as_raw_fd();
        async move {
            task::spawn_blocking(move || {
                tag_socket_fn(fd);
            })
            .await
            .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
        }
        .boxed()
    })
}

pub struct DohDispatcher(Mutex<Dispatcher>);

@@ -97,7 +145,10 @@ pub extern "C" fn doh_dispatcher_new(
    validation_fn: ValidationCallback,
    tag_socket_fn: TagSocketCallback,
) -> *mut DohDispatcher {
    match Dispatcher::new(validation_fn, tag_socket_fn) {
    match Dispatcher::new(
        wrap_validation_callback(validation_fn),
        wrap_tag_socket_callback(tag_socket_fn),
    ) {
        Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
        Err(e) => {
            error!("doh_dispatcher_new: failed: {:?}", e);
@@ -158,7 +209,7 @@ pub unsafe extern "C" fn doh_net_new(
        }
    };

    let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) {
    let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
        (Ok(url), Ok(ip_addr)) => (url, ip_addr),
        _ => {
            error!("bad ip or url"); // Should not happen
@@ -262,3 +313,76 @@ pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
        error!("Failed to send the query: {:?}", e);
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const TEST_NET_ID: u32 = 50;
    const LOOPBACK_ADDR: &str = "127.0.0.1:443";
    const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";

    extern "C" fn success_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    extern "C" fn fail_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(!success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    // # Safety
    // `ip_addr`, `host` are null terminated strings
    unsafe fn assert_validation_info(
        net_id: uint32_t,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert_eq!(net_id, TEST_NET_ID);
        let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
        let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
        assert_eq!(ip_addr, expected_addr.ip().to_string());
        let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
        assert_eq!(host, "");
    }

    #[tokio::test]
    async fn wrap_validation_callback_converts_correctly() {
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };

        wrap_validation_callback(success_cb)(&info, true).await;
        wrap_validation_callback(fail_cb)(&info, false).await;
    }

    extern "C" fn tag_socket_cb(raw_fd: RawFd) {
        assert!(raw_fd > 0)
    }

    #[tokio::test]
    async fn wrap_tag_socket_callback_converts_correctly() {
        let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
        wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
    }
}