Loading doh/doh.rs +42 −144 Original line number Diff line number Diff line Loading @@ -17,18 +17,17 @@ //! DoH backend for the Android DnsResolver module. use anyhow::{anyhow, bail, Context, Result}; use futures::future::join_all; use futures::future::{join_all, BoxFuture}; use futures::stream::FuturesUnordered; use futures::StreamExt; use libc::{c_char, int32_t, uint32_t}; use log::{debug, error, info, trace, warn}; use quiche::h3; use ring::rand::SecureRandom; use std::collections::HashMap; use std::ffi::CString; use std::net::SocketAddr; use std::os::unix::io::{AsRawFd, RawFd}; use std::pin::Pin; use std::sync::Arc; use tokio::net::UdpSocket; use tokio::runtime::{Builder, Runtime}; use tokio::sync::{mpsc, oneshot}; Loading @@ -46,15 +45,15 @@ use config::Config; const MAX_BUFFERED_CMD_SIZE: usize = 400; const DOH_PORT: u16 = 443; type ValidationReporter = Box<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>; type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>; type SCID = [u8; quiche::MAX_CONN_ID_LEN]; type Base64Query = String; type CmdSender = mpsc::Sender<DohCommand>; type CmdReceiver = mpsc::Receiver<DohCommand>; type QueryResponder = oneshot::Sender<Response>; type DnsRequest = Vec<quiche::h3::Header>; type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); type TagSocketCallback = extern "C" fn(sock: int32_t); #[derive(Eq, PartialEq, Debug)] enum QueryError { Loading Loading @@ -108,6 +107,15 @@ enum ConnectionState { Error, } impl ConnectionState { fn is_connected(&self) -> bool { matches!(*self, Self::Connected { .. }) } fn is_error(&self) -> bool { matches!(*self, Self::Error) } } enum H3Result { Data { data: Vec<u8> }, Finished, Loading @@ -124,10 +132,7 @@ pub struct DohDispatcher { // DoH dispatcher impl DohDispatcher { fn new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<DohDispatcher> { fn new(validation: ValidationReporter, tag_socket: SocketTagger) -> Result<DohDispatcher> { let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); let runtime = Builder::new_multi_thread() .worker_threads(2) Loading @@ -135,7 +140,7 @@ impl DohDispatcher { .thread_name("doh-handler") .build() .expect("Failed to create tokio runtime"); let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation_fn, tag_socket_fn)); let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation, tag_socket)); Ok(DohDispatcher { cmd_sender, join_handle, runtime }) } Loading @@ -159,15 +164,11 @@ struct DohConnection { state: ConnectionState, pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>, cached_session: Option<Vec<u8>>, tag_socket_fn: TagSocketCallback, tag_socket: SocketTagger, } impl DohConnection { fn new( info: &ServerInfo, config: Config, tag_socket_fn: TagSocketCallback, ) -> Result<DohConnection> { fn new(info: &ServerInfo, config: Config, tag_socket: SocketTagger) -> Result<DohConnection> { let mut scid = [0; quiche::MAX_CONN_ID_LEN]; ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?; Ok(DohConnection { Loading @@ -177,15 +178,18 @@ impl DohConnection { state: ConnectionState::Idle, pending_queries: Vec::new(), cached_session: None, tag_socket_fn, tag_socket, }) } fn state_to_connecting(&mut self) -> Result<()> { async fn state_to_connecting(&mut self) -> Result<()> { if self.state.is_error() { self.state_to_idle(); } self.state = match self.state { ConnectionState::Idle => { let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?; (self.tag_socket_fn)(udp_sk_std.as_raw_fd()); (self.tag_socket)(&udp_sk_std).await; let udp_sk = UdpSocket::from_std(udp_sk_std)?; let connid = quiche::ConnectionId::from_ref(&self.scid); debug!("init the connection for Network {}", self.info.net_id); Loading @@ -206,10 +210,7 @@ impl DohConnection { expired_time: None, } } ConnectionState::Error => { self.state_to_idle(); return self.state_to_connecting(); } ConnectionState::Error => panic!("state_to_idle did not transition"), ConnectionState::Connecting { .. } => return Ok(()), ConnectionState::Connected { .. } => { panic!("Invalid state transition to Connecting state!") Loading Loading @@ -351,7 +352,7 @@ impl DohConnection { if matches!(self.state, ConnectionState::Connected { .. }) { return Ok(()); } self.state_to_connecting()?; self.state_to_connecting().await?; debug!("connecting to Network {}", self.info.net_id); let (quic_conn, udp_sk, expired_time) = match &mut self.state { Loading Loading @@ -636,34 +637,10 @@ async fn flush_tx( Ok(()) } async fn report_private_dns_validation( info: &ServerInfo, state: &ConnectionState, validation_fn: ValidationCallback, ) { let (ip_addr, domain) = match ( CString::new(info.peer_addr.ip().to_string()), CString::new(info.domain.clone().unwrap_or_default()), ) { (Ok(ip_addr), Ok(domain)) => (ip_addr, domain), _ => { error!("report_private_dns_validation bad input"); return; } }; let netd_id = info.net_id; let success = matches!(state, ConnectionState::Connected { .. }); task::spawn_blocking(move || { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await .unwrap_or_else(|e| warn!("Validation function task failed: {}", e)); } async fn handle_probe_result( result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>), doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, validation_fn: ValidationCallback, validation: &ValidationReporter, ) { let (info, doh_conn) = match result { (info, Ok(doh_conn)) => { Loading Loading @@ -693,7 +670,7 @@ async fn handle_probe_result( return; } } report_private_dns_validation(&info, &doh_conn.state, validation_fn).await; validation(&info, doh_conn.state.is_connected()).await; doh_conn_map.insert(info.net_id, (info, Some(doh_conn))); } Loading @@ -712,7 +689,7 @@ fn make_connection_if_needed( info: &ServerInfo, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, config_cache: &config::Cache, tag_socket_fn: TagSocketCallback, tag_socket: SocketTagger, ) -> Result<Option<DohConnection>> { // Check if connection exists. match doh_conn_map.get(&info.net_id) { Loading @@ -729,7 +706,7 @@ fn make_connection_if_needed( _ => doh_conn_map.remove(&info.net_id), }; let config = config_cache.from_cert_path(&info.cert_path)?; let doh = DohConnection::new(info, config, tag_socket_fn)?; let doh = DohConnection::new(info, config, tag_socket)?; doh_conn_map.insert(info.net_id, (info.clone(), None)); Ok(Some(doh)) } Loading Loading @@ -781,8 +758,8 @@ fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConne async fn doh_handler( mut cmd_rx: CmdReceiver, validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, validation: ValidationReporter, tag_socket: SocketTagger, ) -> Result<()> { info!("doh_dispatcher entry"); let config_cache = config::Cache::new(); Loading @@ -802,14 +779,14 @@ async fn doh_handler( join_all(futures).await }, if need_process_queries(&doh_conn_map) => {}, Some(result) = probe_futures.next() => { handle_probe_result(result, &mut doh_conn_map, validation_fn).await; handle_probe_result(result, &mut doh_conn_map, &validation).await; info!("probe_futures remaining size: {}", probe_futures.len()); }, Some(cmd) = cmd_rx.recv() => { trace!("recv {:?}", cmd); match cmd { DohCommand::Probe { info, timeout: t } => { match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket_fn) { match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket.clone()) { Ok(Some(doh)) => { // Create a new async task associated to the DoH connection. probe_futures.push(probe_task(info, doh, t)); Loading @@ -822,7 +799,7 @@ async fn doh_handler( } Err(e) => { error!("create connection for network {} error {:?}", info.net_id, e); report_private_dns_validation(&info, &ConnectionState::Error, validation_fn).await; validation(&info, false).await } } }, Loading Loading @@ -879,6 +856,7 @@ fn mark_socket(fd: RawFd, mark: u32) -> Result<()> { #[cfg(test)] mod tests { use super::*; use futures::FutureExt; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; const TEST_NET_ID: u32 = 50; Loading Loading @@ -909,8 +887,8 @@ mod tests { (info, test_map, config_cache, rt) } extern "C" fn tag_socket_cb(sock: int32_t) { assert!(sock >= 0); fn build_socket_tagger() -> SocketTagger { Arc::new(|_| async {}.boxed()) } #[test] Loading @@ -922,7 +900,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb, build_socket_tagger(), ) .unwrap() .unwrap(); Loading @@ -935,7 +913,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb, build_socket_tagger(), ) .unwrap() .unwrap(); Loading @@ -948,7 +926,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb build_socket_tagger() ) .unwrap() .is_none()); Loading @@ -959,7 +937,6 @@ mod tests { fn handle_query_cmd() { let (info, mut test_map, config_cache, rt) = make_testing_variables(); let t = Duration::from_millis(100); rt.block_on(async { // Test no available server cases. let (resp_tx, resp_rx) = oneshot::channel(); Loading Loading @@ -999,7 +976,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb, build_socket_tagger(), ) .unwrap() .unwrap(); Loading @@ -1020,45 +997,6 @@ mod tests { }); } extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); unsafe { assert_validation_info(net_id, ip_addr, host); } } extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); unsafe { assert_validation_info(net_id, ip_addr, host); } } // # Safety // `ip_addr`, `host` are null terminated strings unsafe fn assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); assert_eq!(host, ""); } fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) { let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); let udp_sk = UdpSocket::from_std(sk).unwrap(); Loading @@ -1083,46 +1021,6 @@ mod tests { } } fn make_dummy_connecting_state() -> super::ConnectionState { let (quic_conn, udp_sk) = make_testing_connection_variables(); ConnectionState::Connecting { quic_conn: Some(quic_conn), udp_sk: Some(udp_sk), expired_time: None, } } #[test] fn report_private_dns_validation() { let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; let rt = Builder::new_current_thread() .thread_name("test-runtime") .enable_io() .build() .expect("Failed to create testing tokio runtime"); let default_panic = std::panic::take_hook(); // Exit the test if the worker inside tokio runtime panicked. std::panic::set_hook(Box::new(move |info| { default_panic(info); std::process::exit(1); })); rt.block_on(async { super::report_private_dns_validation(&info, &make_dummy_connected_state(), success_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Error, fail_cb).await; super::report_private_dns_validation(&info, &make_dummy_connecting_state(), fail_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Idle, fail_cb).await; }); } #[test] fn make_doh_udp_socket() { // Make a socket connecting to loopback with a test mark. Loading doh/ffi.rs +128 −4 Original line number Diff line number Diff line Loading @@ -17,19 +17,67 @@ //! C API for the DoH backend for the Android DnsResolver module. use crate::boot_time::{timeout, BootTime, Duration}; use futures::FutureExt; use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t}; use log::error; use log::{error, warn}; use std::ffi::CString; use std::net::{IpAddr, SocketAddr}; use std::ops::DerefMut; use std::os::unix::io::RawFd; use std::str::FromStr; use std::sync::Mutex; use std::{ptr, slice}; use tokio::runtime::Runtime; use tokio::sync::oneshot; use tokio::task; use url::Url; use super::DohDispatcher as Dispatcher; use super::{DohCommand, Response, ServerInfo, TagSocketCallback, ValidationCallback, DOH_PORT}; use super::{DohCommand, Response, ServerInfo, SocketTagger, ValidationReporter, DOH_PORT}; pub type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); pub type TagSocketCallback = extern "C" fn(sock: RawFd); fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter { Box::new(move |info: &ServerInfo, success: bool| { async move { let (ip_addr, domain) = match ( CString::new(info.peer_addr.ip().to_string()), CString::new(info.domain.clone().unwrap_or_default()), ) { (Ok(ip_addr), Ok(domain)) => (ip_addr, domain), _ => { error!("validation_callback bad input"); return; } }; let netd_id = info.net_id; task::spawn_blocking(move || { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await .unwrap_or_else(|e| warn!("Validation function task failed: {}", e)) } .boxed() }) } fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger { use std::os::unix::io::AsRawFd; use std::sync::Arc; Arc::new(move |udp_socket: &std::net::UdpSocket| { let fd = udp_socket.as_raw_fd(); async move { task::spawn_blocking(move || { tag_socket_fn(fd); }) .await .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e)) } .boxed() }) } pub struct DohDispatcher(Mutex<Dispatcher>); Loading Loading @@ -97,7 +145,10 @@ pub extern "C" fn doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher { match Dispatcher::new(validation_fn, tag_socket_fn) { match Dispatcher::new( wrap_validation_callback(validation_fn), wrap_tag_socket_callback(tag_socket_fn), ) { Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))), Err(e) => { error!("doh_dispatcher_new: failed: {:?}", e); Loading Loading @@ -158,7 +209,7 @@ pub unsafe extern "C" fn doh_net_new( } }; let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) { let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) { (Ok(url), Ok(ip_addr)) => (url, ip_addr), _ => { error!("bad ip or url"); // Should not happen Loading Loading @@ -262,3 +313,76 @@ pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) { error!("Failed to send the query: {:?}", e); } } #[cfg(test)] mod tests { use super::*; const TEST_NET_ID: u32 = 50; const LOOPBACK_ADDR: &str = "127.0.0.1:443"; const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); unsafe { assert_validation_info(net_id, ip_addr, host); } } extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); unsafe { assert_validation_info(net_id, ip_addr, host); } } // # Safety // `ip_addr`, `host` are null terminated strings unsafe fn assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); assert_eq!(host, ""); } #[tokio::test] async fn wrap_validation_callback_converts_correctly() { let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; wrap_validation_callback(success_cb)(&info, true).await; wrap_validation_callback(fail_cb)(&info, false).await; } extern "C" fn tag_socket_cb(raw_fd: RawFd) { assert!(raw_fd > 0) } #[tokio::test] async fn wrap_tag_socket_callback_converts_correctly() { let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); wrap_tag_socket_callback(tag_socket_cb)(&sock).await; } } Loading
doh/doh.rs +42 −144 Original line number Diff line number Diff line Loading @@ -17,18 +17,17 @@ //! DoH backend for the Android DnsResolver module. use anyhow::{anyhow, bail, Context, Result}; use futures::future::join_all; use futures::future::{join_all, BoxFuture}; use futures::stream::FuturesUnordered; use futures::StreamExt; use libc::{c_char, int32_t, uint32_t}; use log::{debug, error, info, trace, warn}; use quiche::h3; use ring::rand::SecureRandom; use std::collections::HashMap; use std::ffi::CString; use std::net::SocketAddr; use std::os::unix::io::{AsRawFd, RawFd}; use std::pin::Pin; use std::sync::Arc; use tokio::net::UdpSocket; use tokio::runtime::{Builder, Runtime}; use tokio::sync::{mpsc, oneshot}; Loading @@ -46,15 +45,15 @@ use config::Config; const MAX_BUFFERED_CMD_SIZE: usize = 400; const DOH_PORT: u16 = 443; type ValidationReporter = Box<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>; type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>; type SCID = [u8; quiche::MAX_CONN_ID_LEN]; type Base64Query = String; type CmdSender = mpsc::Sender<DohCommand>; type CmdReceiver = mpsc::Receiver<DohCommand>; type QueryResponder = oneshot::Sender<Response>; type DnsRequest = Vec<quiche::h3::Header>; type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); type TagSocketCallback = extern "C" fn(sock: int32_t); #[derive(Eq, PartialEq, Debug)] enum QueryError { Loading Loading @@ -108,6 +107,15 @@ enum ConnectionState { Error, } impl ConnectionState { fn is_connected(&self) -> bool { matches!(*self, Self::Connected { .. }) } fn is_error(&self) -> bool { matches!(*self, Self::Error) } } enum H3Result { Data { data: Vec<u8> }, Finished, Loading @@ -124,10 +132,7 @@ pub struct DohDispatcher { // DoH dispatcher impl DohDispatcher { fn new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<DohDispatcher> { fn new(validation: ValidationReporter, tag_socket: SocketTagger) -> Result<DohDispatcher> { let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); let runtime = Builder::new_multi_thread() .worker_threads(2) Loading @@ -135,7 +140,7 @@ impl DohDispatcher { .thread_name("doh-handler") .build() .expect("Failed to create tokio runtime"); let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation_fn, tag_socket_fn)); let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation, tag_socket)); Ok(DohDispatcher { cmd_sender, join_handle, runtime }) } Loading @@ -159,15 +164,11 @@ struct DohConnection { state: ConnectionState, pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>, cached_session: Option<Vec<u8>>, tag_socket_fn: TagSocketCallback, tag_socket: SocketTagger, } impl DohConnection { fn new( info: &ServerInfo, config: Config, tag_socket_fn: TagSocketCallback, ) -> Result<DohConnection> { fn new(info: &ServerInfo, config: Config, tag_socket: SocketTagger) -> Result<DohConnection> { let mut scid = [0; quiche::MAX_CONN_ID_LEN]; ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?; Ok(DohConnection { Loading @@ -177,15 +178,18 @@ impl DohConnection { state: ConnectionState::Idle, pending_queries: Vec::new(), cached_session: None, tag_socket_fn, tag_socket, }) } fn state_to_connecting(&mut self) -> Result<()> { async fn state_to_connecting(&mut self) -> Result<()> { if self.state.is_error() { self.state_to_idle(); } self.state = match self.state { ConnectionState::Idle => { let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?; (self.tag_socket_fn)(udp_sk_std.as_raw_fd()); (self.tag_socket)(&udp_sk_std).await; let udp_sk = UdpSocket::from_std(udp_sk_std)?; let connid = quiche::ConnectionId::from_ref(&self.scid); debug!("init the connection for Network {}", self.info.net_id); Loading @@ -206,10 +210,7 @@ impl DohConnection { expired_time: None, } } ConnectionState::Error => { self.state_to_idle(); return self.state_to_connecting(); } ConnectionState::Error => panic!("state_to_idle did not transition"), ConnectionState::Connecting { .. } => return Ok(()), ConnectionState::Connected { .. } => { panic!("Invalid state transition to Connecting state!") Loading Loading @@ -351,7 +352,7 @@ impl DohConnection { if matches!(self.state, ConnectionState::Connected { .. }) { return Ok(()); } self.state_to_connecting()?; self.state_to_connecting().await?; debug!("connecting to Network {}", self.info.net_id); let (quic_conn, udp_sk, expired_time) = match &mut self.state { Loading Loading @@ -636,34 +637,10 @@ async fn flush_tx( Ok(()) } async fn report_private_dns_validation( info: &ServerInfo, state: &ConnectionState, validation_fn: ValidationCallback, ) { let (ip_addr, domain) = match ( CString::new(info.peer_addr.ip().to_string()), CString::new(info.domain.clone().unwrap_or_default()), ) { (Ok(ip_addr), Ok(domain)) => (ip_addr, domain), _ => { error!("report_private_dns_validation bad input"); return; } }; let netd_id = info.net_id; let success = matches!(state, ConnectionState::Connected { .. }); task::spawn_blocking(move || { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await .unwrap_or_else(|e| warn!("Validation function task failed: {}", e)); } async fn handle_probe_result( result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>), doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, validation_fn: ValidationCallback, validation: &ValidationReporter, ) { let (info, doh_conn) = match result { (info, Ok(doh_conn)) => { Loading Loading @@ -693,7 +670,7 @@ async fn handle_probe_result( return; } } report_private_dns_validation(&info, &doh_conn.state, validation_fn).await; validation(&info, doh_conn.state.is_connected()).await; doh_conn_map.insert(info.net_id, (info, Some(doh_conn))); } Loading @@ -712,7 +689,7 @@ fn make_connection_if_needed( info: &ServerInfo, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, config_cache: &config::Cache, tag_socket_fn: TagSocketCallback, tag_socket: SocketTagger, ) -> Result<Option<DohConnection>> { // Check if connection exists. match doh_conn_map.get(&info.net_id) { Loading @@ -729,7 +706,7 @@ fn make_connection_if_needed( _ => doh_conn_map.remove(&info.net_id), }; let config = config_cache.from_cert_path(&info.cert_path)?; let doh = DohConnection::new(info, config, tag_socket_fn)?; let doh = DohConnection::new(info, config, tag_socket)?; doh_conn_map.insert(info.net_id, (info.clone(), None)); Ok(Some(doh)) } Loading Loading @@ -781,8 +758,8 @@ fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConne async fn doh_handler( mut cmd_rx: CmdReceiver, validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, validation: ValidationReporter, tag_socket: SocketTagger, ) -> Result<()> { info!("doh_dispatcher entry"); let config_cache = config::Cache::new(); Loading @@ -802,14 +779,14 @@ async fn doh_handler( join_all(futures).await }, if need_process_queries(&doh_conn_map) => {}, Some(result) = probe_futures.next() => { handle_probe_result(result, &mut doh_conn_map, validation_fn).await; handle_probe_result(result, &mut doh_conn_map, &validation).await; info!("probe_futures remaining size: {}", probe_futures.len()); }, Some(cmd) = cmd_rx.recv() => { trace!("recv {:?}", cmd); match cmd { DohCommand::Probe { info, timeout: t } => { match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket_fn) { match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket.clone()) { Ok(Some(doh)) => { // Create a new async task associated to the DoH connection. probe_futures.push(probe_task(info, doh, t)); Loading @@ -822,7 +799,7 @@ async fn doh_handler( } Err(e) => { error!("create connection for network {} error {:?}", info.net_id, e); report_private_dns_validation(&info, &ConnectionState::Error, validation_fn).await; validation(&info, false).await } } }, Loading Loading @@ -879,6 +856,7 @@ fn mark_socket(fd: RawFd, mark: u32) -> Result<()> { #[cfg(test)] mod tests { use super::*; use futures::FutureExt; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; const TEST_NET_ID: u32 = 50; Loading Loading @@ -909,8 +887,8 @@ mod tests { (info, test_map, config_cache, rt) } extern "C" fn tag_socket_cb(sock: int32_t) { assert!(sock >= 0); fn build_socket_tagger() -> SocketTagger { Arc::new(|_| async {}.boxed()) } #[test] Loading @@ -922,7 +900,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb, build_socket_tagger(), ) .unwrap() .unwrap(); Loading @@ -935,7 +913,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb, build_socket_tagger(), ) .unwrap() .unwrap(); Loading @@ -948,7 +926,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb build_socket_tagger() ) .unwrap() .is_none()); Loading @@ -959,7 +937,6 @@ mod tests { fn handle_query_cmd() { let (info, mut test_map, config_cache, rt) = make_testing_variables(); let t = Duration::from_millis(100); rt.block_on(async { // Test no available server cases. let (resp_tx, resp_rx) = oneshot::channel(); Loading Loading @@ -999,7 +976,7 @@ mod tests { &info, &mut test_map, &config_cache, tag_socket_cb, build_socket_tagger(), ) .unwrap() .unwrap(); Loading @@ -1020,45 +997,6 @@ mod tests { }); } extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); unsafe { assert_validation_info(net_id, ip_addr, host); } } extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); unsafe { assert_validation_info(net_id, ip_addr, host); } } // # Safety // `ip_addr`, `host` are null terminated strings unsafe fn assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); assert_eq!(host, ""); } fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) { let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); let udp_sk = UdpSocket::from_std(sk).unwrap(); Loading @@ -1083,46 +1021,6 @@ mod tests { } } fn make_dummy_connecting_state() -> super::ConnectionState { let (quic_conn, udp_sk) = make_testing_connection_variables(); ConnectionState::Connecting { quic_conn: Some(quic_conn), udp_sk: Some(udp_sk), expired_time: None, } } #[test] fn report_private_dns_validation() { let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; let rt = Builder::new_current_thread() .thread_name("test-runtime") .enable_io() .build() .expect("Failed to create testing tokio runtime"); let default_panic = std::panic::take_hook(); // Exit the test if the worker inside tokio runtime panicked. std::panic::set_hook(Box::new(move |info| { default_panic(info); std::process::exit(1); })); rt.block_on(async { super::report_private_dns_validation(&info, &make_dummy_connected_state(), success_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Error, fail_cb).await; super::report_private_dns_validation(&info, &make_dummy_connecting_state(), fail_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Idle, fail_cb).await; }); } #[test] fn make_doh_udp_socket() { // Make a socket connecting to loopback with a test mark. Loading
doh/ffi.rs +128 −4 Original line number Diff line number Diff line Loading @@ -17,19 +17,67 @@ //! C API for the DoH backend for the Android DnsResolver module. use crate::boot_time::{timeout, BootTime, Duration}; use futures::FutureExt; use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t}; use log::error; use log::{error, warn}; use std::ffi::CString; use std::net::{IpAddr, SocketAddr}; use std::ops::DerefMut; use std::os::unix::io::RawFd; use std::str::FromStr; use std::sync::Mutex; use std::{ptr, slice}; use tokio::runtime::Runtime; use tokio::sync::oneshot; use tokio::task; use url::Url; use super::DohDispatcher as Dispatcher; use super::{DohCommand, Response, ServerInfo, TagSocketCallback, ValidationCallback, DOH_PORT}; use super::{DohCommand, Response, ServerInfo, SocketTagger, ValidationReporter, DOH_PORT}; pub type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); pub type TagSocketCallback = extern "C" fn(sock: RawFd); fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter { Box::new(move |info: &ServerInfo, success: bool| { async move { let (ip_addr, domain) = match ( CString::new(info.peer_addr.ip().to_string()), CString::new(info.domain.clone().unwrap_or_default()), ) { (Ok(ip_addr), Ok(domain)) => (ip_addr, domain), _ => { error!("validation_callback bad input"); return; } }; let netd_id = info.net_id; task::spawn_blocking(move || { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await .unwrap_or_else(|e| warn!("Validation function task failed: {}", e)) } .boxed() }) } fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger { use std::os::unix::io::AsRawFd; use std::sync::Arc; Arc::new(move |udp_socket: &std::net::UdpSocket| { let fd = udp_socket.as_raw_fd(); async move { task::spawn_blocking(move || { tag_socket_fn(fd); }) .await .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e)) } .boxed() }) } pub struct DohDispatcher(Mutex<Dispatcher>); Loading Loading @@ -97,7 +145,10 @@ pub extern "C" fn doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher { match Dispatcher::new(validation_fn, tag_socket_fn) { match Dispatcher::new( wrap_validation_callback(validation_fn), wrap_tag_socket_callback(tag_socket_fn), ) { Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))), Err(e) => { error!("doh_dispatcher_new: failed: {:?}", e); Loading Loading @@ -158,7 +209,7 @@ pub unsafe extern "C" fn doh_net_new( } }; let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) { let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) { (Ok(url), Ok(ip_addr)) => (url, ip_addr), _ => { error!("bad ip or url"); // Should not happen Loading Loading @@ -262,3 +313,76 @@ pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) { error!("Failed to send the query: {:?}", e); } } #[cfg(test)] mod tests { use super::*; const TEST_NET_ID: u32 = 50; const LOOPBACK_ADDR: &str = "127.0.0.1:443"; const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); unsafe { assert_validation_info(net_id, ip_addr, host); } } extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); unsafe { assert_validation_info(net_id, ip_addr, host); } } // # Safety // `ip_addr`, `host` are null terminated strings unsafe fn assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); assert_eq!(host, ""); } #[tokio::test] async fn wrap_validation_callback_converts_correctly() { let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; wrap_validation_callback(success_cb)(&info, true).await; wrap_validation_callback(fail_cb)(&info, false).await; } extern "C" fn tag_socket_cb(raw_fd: RawFd) { assert!(raw_fd > 0) } #[tokio::test] async fn wrap_tag_socket_callback_converts_correctly() { let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); wrap_tag_socket_callback(tag_socket_cb)(&sock).await; } }