Loading PrivateDnsConfiguration.cpp +2 −1 Original line number Diff line number Diff line Loading @@ -428,7 +428,8 @@ void PrivateDnsConfiguration::initDohLocked() { [](uint32_t net_id, bool success, const char* ip_addr, const char* host) { android::net::PrivateDnsConfiguration::getInstance().onDohStatusUpdate( net_id, success, ip_addr, host); }); }, [](int32_t sock) { resolv_tag_socket(sock, AID_DNS, NET_CONTEXT_INVALID_PID); }); } int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark, Loading doh.h +5 −2 Original line number Diff line number Diff line Loading @@ -20,7 +20,7 @@ #pragma once /* Generated with cbindgen:0.17.0 */ /* Generated with cbindgen:0.20.0 */ #include <stdint.h> #include <sys/types.h> Loading Loading @@ -55,6 +55,8 @@ struct DohDispatcher; using ValidationCallback = void (*)(uint32_t net_id, bool success, const char* ip_addr, const char* host); using TagSocketCallback = void (*)(int32_t sock); extern "C" { /// Performs static initialization for android logger. Loading @@ -65,7 +67,8 @@ void doh_set_log_level(uint32_t level); /// Performs the initialization for the DoH engine. /// Creates and returns a DoH engine instance. DohDispatcher* doh_dispatcher_new(ValidationCallback ptr); DohDispatcher* doh_dispatcher_new(ValidationCallback validation_fn, TagSocketCallback tag_socket_fn); /// Deletes a DoH engine created by doh_dispatcher_new(). /// # Safety Loading doh.rs +56 −18 Original line number Diff line number Diff line Loading @@ -82,6 +82,7 @@ type QueryResponder = oneshot::Sender<Response>; type DnsRequest = Vec<quiche::h3::Header>; type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); type TagSocketCallback = extern "C" fn(sock: int32_t); #[derive(Eq, PartialEq, Debug)] enum QueryError { Loading Loading @@ -190,7 +191,10 @@ pub struct DohDispatcher { // DoH dispatcher impl DohDispatcher { fn new(validation_fn: ValidationCallback) -> Result<Box<DohDispatcher>> { fn new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<Box<DohDispatcher>> { let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); let runtime = Arc::new( Builder::new_multi_thread() Loading @@ -200,7 +204,8 @@ impl DohDispatcher { .build() .expect("Failed to create tokio runtime"), ); let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn)); let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn)); Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime })) } Loading @@ -224,12 +229,14 @@ struct DohConnection { state: ConnectionState, pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>, cached_session: Option<Vec<u8>>, tag_socket_fn: TagSocketCallback, } impl DohConnection { fn new( info: &ServerInfo, shared_config: Arc<Mutex<QuicheConfigCache>>, tag_socket_fn: TagSocketCallback, ) -> Result<DohConnection> { let mut scid = [0; quiche::MAX_CONN_ID_LEN]; ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?; Loading @@ -240,6 +247,7 @@ impl DohConnection { state: ConnectionState::Idle, pending_queries: Vec::new(), cached_session: None, tag_socket_fn, }) } Loading @@ -247,6 +255,7 @@ impl DohConnection { self.state = match self.state { ConnectionState::Idle => { let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?; (self.tag_socket_fn)(udp_sk_std.as_raw_fd()); let udp_sk = UdpSocket::from_std(udp_sk_std)?; let connid = quiche::ConnectionId::from_ref(&self.scid); let mut cache = self.shared_config.lock().unwrap(); Loading Loading @@ -775,6 +784,7 @@ fn make_connection_if_needed( info: &ServerInfo, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, shared_config: Arc<Mutex<QuicheConfigCache>>, tag_socket_fn: TagSocketCallback, ) -> Result<Option<DohConnection>> { // Check if connection exists. match doh_conn_map.get(&info.net_id) { Loading @@ -790,7 +800,7 @@ fn make_connection_if_needed( // TODO: change the inner connection instead of removing? _ => doh_conn_map.remove(&info.net_id), }; let doh = DohConnection::new(info, shared_config)?; let doh = DohConnection::new(info, shared_config, tag_socket_fn)?; doh_conn_map.insert(info.net_id, (info.clone(), None)); Ok(Some(doh)) } Loading Loading @@ -861,6 +871,7 @@ async fn doh_handler( mut cmd_rx: CmdReceiver, runtime: Arc<Runtime>, validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<()> { info!("doh_dispatcher entry"); let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None })); Loading Loading @@ -888,7 +899,7 @@ async fn doh_handler( trace!("recv {:?}", cmd); match cmd { DohCommand::Probe { info, timeout: t } => { match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone()) { match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone(), tag_socket_fn) { Ok(Some(doh)) => { // Create a new async task associated to the DoH connection. probe_futures.push(probe_task(info, doh, t)); Loading Loading @@ -1051,8 +1062,11 @@ pub extern "C" fn doh_set_log_level(level: u32) { /// Performs the initialization for the DoH engine. /// Creates and returns a DoH engine instance. #[no_mangle] pub extern "C" fn doh_dispatcher_new(ptr: ValidationCallback) -> *mut DohDispatcher { match DohDispatcher::new(ptr) { pub extern "C" fn doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher { match DohDispatcher::new(validation_fn, tag_socket_fn) { Ok(c) => Box::into_raw(c), Err(e) => { error!("doh_dispatcher_new: failed: {:?}", e); Loading Loading @@ -1261,12 +1275,21 @@ mod tests { (info, test_map, config_cache, rt) } extern "C" fn tag_socket_cb(sock: int32_t) { assert!(sock >= 0); } #[test] fn make_connection_if_needed() { let (info, mut test_map, config, rt) = make_testing_variables(); rt.block_on(async { // Expect to make a new connection. let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone()) let mut doh = super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb, ) .unwrap() .unwrap(); assert_eq!(doh.info.net_id, info.net_id); Loading @@ -1274,7 +1297,12 @@ mod tests { doh.state = ConnectionState::Error; test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get a connection with fail status that we added to the map before. let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone()) let mut doh = super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb, ) .unwrap() .unwrap(); assert_eq!(doh.info.net_id, info.net_id); Loading @@ -1282,7 +1310,12 @@ mod tests { doh.state = make_dummy_connected_state(); test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get None because the map contains a connection with ready status. assert!(super::make_connection_if_needed(&info, &mut test_map, config.clone()) assert!(super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb ) .unwrap() .is_none()); }); Loading Loading @@ -1328,7 +1361,12 @@ mod tests { // Test the connection broken case. test_map.clear(); let (resp_tx, resp_rx) = oneshot::channel(); let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone()) let mut doh = super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb, ) .unwrap() .unwrap(); doh.state = ConnectionState::Error; Loading tests/doh_ffi_test.cpp +5 −2 Original line number Diff line number Diff line Loading @@ -40,14 +40,17 @@ TEST(DoHFFITest, SmokeTest) { // To ensure that we have a real network. ASSERT_GE(dnsNetId, MINIMAL_NET_ID) << "No available networks"; auto callback = [](uint32_t netId, bool success, const char* ip_addr, const char* host) { auto validation_cb = [](uint32_t netId, bool success, const char* ip_addr, const char* host) { EXPECT_EQ(netId, dnsNetId); EXPECT_TRUE(success); EXPECT_STREQ(ip_addr, GOOGLE_SERVER_IP); EXPECT_STREQ(host, ""); cv.notify_one(); }; DohDispatcher* doh = doh_dispatcher_new(callback); auto tag_socket_cb = [](int32_t sock) { EXPECT_GE(sock, 0); }; DohDispatcher* doh = doh_dispatcher_new(validation_cb, tag_socket_cb); EXPECT_TRUE(doh != nullptr); // TODO: Use a local server instead of dns.google. Loading Loading
PrivateDnsConfiguration.cpp +2 −1 Original line number Diff line number Diff line Loading @@ -428,7 +428,8 @@ void PrivateDnsConfiguration::initDohLocked() { [](uint32_t net_id, bool success, const char* ip_addr, const char* host) { android::net::PrivateDnsConfiguration::getInstance().onDohStatusUpdate( net_id, success, ip_addr, host); }); }, [](int32_t sock) { resolv_tag_socket(sock, AID_DNS, NET_CONTEXT_INVALID_PID); }); } int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark, Loading
doh.h +5 −2 Original line number Diff line number Diff line Loading @@ -20,7 +20,7 @@ #pragma once /* Generated with cbindgen:0.17.0 */ /* Generated with cbindgen:0.20.0 */ #include <stdint.h> #include <sys/types.h> Loading Loading @@ -55,6 +55,8 @@ struct DohDispatcher; using ValidationCallback = void (*)(uint32_t net_id, bool success, const char* ip_addr, const char* host); using TagSocketCallback = void (*)(int32_t sock); extern "C" { /// Performs static initialization for android logger. Loading @@ -65,7 +67,8 @@ void doh_set_log_level(uint32_t level); /// Performs the initialization for the DoH engine. /// Creates and returns a DoH engine instance. DohDispatcher* doh_dispatcher_new(ValidationCallback ptr); DohDispatcher* doh_dispatcher_new(ValidationCallback validation_fn, TagSocketCallback tag_socket_fn); /// Deletes a DoH engine created by doh_dispatcher_new(). /// # Safety Loading
doh.rs +56 −18 Original line number Diff line number Diff line Loading @@ -82,6 +82,7 @@ type QueryResponder = oneshot::Sender<Response>; type DnsRequest = Vec<quiche::h3::Header>; type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); type TagSocketCallback = extern "C" fn(sock: int32_t); #[derive(Eq, PartialEq, Debug)] enum QueryError { Loading Loading @@ -190,7 +191,10 @@ pub struct DohDispatcher { // DoH dispatcher impl DohDispatcher { fn new(validation_fn: ValidationCallback) -> Result<Box<DohDispatcher>> { fn new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<Box<DohDispatcher>> { let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); let runtime = Arc::new( Builder::new_multi_thread() Loading @@ -200,7 +204,8 @@ impl DohDispatcher { .build() .expect("Failed to create tokio runtime"), ); let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn)); let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn)); Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime })) } Loading @@ -224,12 +229,14 @@ struct DohConnection { state: ConnectionState, pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>, cached_session: Option<Vec<u8>>, tag_socket_fn: TagSocketCallback, } impl DohConnection { fn new( info: &ServerInfo, shared_config: Arc<Mutex<QuicheConfigCache>>, tag_socket_fn: TagSocketCallback, ) -> Result<DohConnection> { let mut scid = [0; quiche::MAX_CONN_ID_LEN]; ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?; Loading @@ -240,6 +247,7 @@ impl DohConnection { state: ConnectionState::Idle, pending_queries: Vec::new(), cached_session: None, tag_socket_fn, }) } Loading @@ -247,6 +255,7 @@ impl DohConnection { self.state = match self.state { ConnectionState::Idle => { let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?; (self.tag_socket_fn)(udp_sk_std.as_raw_fd()); let udp_sk = UdpSocket::from_std(udp_sk_std)?; let connid = quiche::ConnectionId::from_ref(&self.scid); let mut cache = self.shared_config.lock().unwrap(); Loading Loading @@ -775,6 +784,7 @@ fn make_connection_if_needed( info: &ServerInfo, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, shared_config: Arc<Mutex<QuicheConfigCache>>, tag_socket_fn: TagSocketCallback, ) -> Result<Option<DohConnection>> { // Check if connection exists. match doh_conn_map.get(&info.net_id) { Loading @@ -790,7 +800,7 @@ fn make_connection_if_needed( // TODO: change the inner connection instead of removing? _ => doh_conn_map.remove(&info.net_id), }; let doh = DohConnection::new(info, shared_config)?; let doh = DohConnection::new(info, shared_config, tag_socket_fn)?; doh_conn_map.insert(info.net_id, (info.clone(), None)); Ok(Some(doh)) } Loading Loading @@ -861,6 +871,7 @@ async fn doh_handler( mut cmd_rx: CmdReceiver, runtime: Arc<Runtime>, validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<()> { info!("doh_dispatcher entry"); let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None })); Loading Loading @@ -888,7 +899,7 @@ async fn doh_handler( trace!("recv {:?}", cmd); match cmd { DohCommand::Probe { info, timeout: t } => { match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone()) { match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone(), tag_socket_fn) { Ok(Some(doh)) => { // Create a new async task associated to the DoH connection. probe_futures.push(probe_task(info, doh, t)); Loading Loading @@ -1051,8 +1062,11 @@ pub extern "C" fn doh_set_log_level(level: u32) { /// Performs the initialization for the DoH engine. /// Creates and returns a DoH engine instance. #[no_mangle] pub extern "C" fn doh_dispatcher_new(ptr: ValidationCallback) -> *mut DohDispatcher { match DohDispatcher::new(ptr) { pub extern "C" fn doh_dispatcher_new( validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> *mut DohDispatcher { match DohDispatcher::new(validation_fn, tag_socket_fn) { Ok(c) => Box::into_raw(c), Err(e) => { error!("doh_dispatcher_new: failed: {:?}", e); Loading Loading @@ -1261,12 +1275,21 @@ mod tests { (info, test_map, config_cache, rt) } extern "C" fn tag_socket_cb(sock: int32_t) { assert!(sock >= 0); } #[test] fn make_connection_if_needed() { let (info, mut test_map, config, rt) = make_testing_variables(); rt.block_on(async { // Expect to make a new connection. let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone()) let mut doh = super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb, ) .unwrap() .unwrap(); assert_eq!(doh.info.net_id, info.net_id); Loading @@ -1274,7 +1297,12 @@ mod tests { doh.state = ConnectionState::Error; test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get a connection with fail status that we added to the map before. let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone()) let mut doh = super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb, ) .unwrap() .unwrap(); assert_eq!(doh.info.net_id, info.net_id); Loading @@ -1282,7 +1310,12 @@ mod tests { doh.state = make_dummy_connected_state(); test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get None because the map contains a connection with ready status. assert!(super::make_connection_if_needed(&info, &mut test_map, config.clone()) assert!(super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb ) .unwrap() .is_none()); }); Loading Loading @@ -1328,7 +1361,12 @@ mod tests { // Test the connection broken case. test_map.clear(); let (resp_tx, resp_rx) = oneshot::channel(); let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone()) let mut doh = super::make_connection_if_needed( &info, &mut test_map, config.clone(), tag_socket_cb, ) .unwrap() .unwrap(); doh.state = ConnectionState::Error; Loading
tests/doh_ffi_test.cpp +5 −2 Original line number Diff line number Diff line Loading @@ -40,14 +40,17 @@ TEST(DoHFFITest, SmokeTest) { // To ensure that we have a real network. ASSERT_GE(dnsNetId, MINIMAL_NET_ID) << "No available networks"; auto callback = [](uint32_t netId, bool success, const char* ip_addr, const char* host) { auto validation_cb = [](uint32_t netId, bool success, const char* ip_addr, const char* host) { EXPECT_EQ(netId, dnsNetId); EXPECT_TRUE(success); EXPECT_STREQ(ip_addr, GOOGLE_SERVER_IP); EXPECT_STREQ(host, ""); cv.notify_one(); }; DohDispatcher* doh = doh_dispatcher_new(callback); auto tag_socket_cb = [](int32_t sock) { EXPECT_GE(sock, 0); }; DohDispatcher* doh = doh_dispatcher_new(validation_cb, tag_socket_cb); EXPECT_TRUE(doh != nullptr); // TODO: Use a local server instead of dns.google. Loading