Loading TEST_MAPPING +1 −0 Original line number Diff line number Diff line { "presubmit": [ { "name": "doh_unit_test" }, { "name": "resolv_integration_test" }, { "name": "resolv_gold_test" }, { "name": "resolv_unit_test" }, Loading doh.rs +288 −2 Original line number Diff line number Diff line Loading @@ -84,7 +84,7 @@ type DnsRequestArg = [quiche::h3::Header]; type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); #[derive(Debug)] #[derive(Eq, PartialEq, Debug)] enum QueryError { BrokenServer, ConnectionError, Loading @@ -102,7 +102,7 @@ struct ServerInfo { cert_path: Option<String>, } #[derive(Debug)] #[derive(Eq, PartialEq, Debug)] enum Response { Error { error: QueryError }, Success { answer: Vec<u8> }, Loading Loading @@ -983,3 +983,289 @@ pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) { error!("Failed to send the query: {:?}", e); } } #[cfg(test)] mod tests { use super::*; use quiche::h3::NameValue; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; const TEST_NET_ID: u32 = 50; const PROBE_QUERY_SIZE: usize = 56; const H3_DNS_REQUEST_HEADER_SIZE: usize = 6; const TEST_MARK: u32 = 0xD0033; const LOOPBACK_ADDR: &str = "127.0.0.1:443"; const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; // TODO: Make some tests for DohConnection and QuicheConfigCache. fn make_testing_variables() -> ( ServerInfo, HashMap<u32, (ServerInfo, Option<DohConnection>)>, QuicheConfigCache, Arc<Runtime>, ) { let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new(); let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; let config: QuicheConfigCache = QuicheConfigCache { cert_path: None, config: None }; let rt = Arc::new( Builder::new_current_thread() .thread_name("test-runtime") .enable_all() .build() .expect("Failed to create testing tokio runtime"), ); (info, test_map, config, rt) } #[test] fn make_connection_if_needed() { let (info, mut test_map, mut config, rt) = make_testing_variables(); rt.block_on(async { // Expect to make a new connection. let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .unwrap(); assert_eq!(doh.net_id, info.net_id); assert_eq!(doh.status, ConnectionStatus::Pending); doh.status = ConnectionStatus::Fail; test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get a connection with fail status that we added to the map before. let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .unwrap(); assert_eq!(doh.net_id, info.net_id); assert_eq!(doh.status, ConnectionStatus::Fail); doh.status = ConnectionStatus::Ready; test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get None because the map contains a connection with ready status. assert!(super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .is_none()); }); } #[test] fn handle_query_cmd() { let (info, mut test_map, mut config, rt) = make_testing_variables(); let t = Duration::from_millis(100); rt.block_on(async { // Test no available server cases. let (resp_tx, resp_rx) = oneshot::channel(); let query = super::make_probe_query().unwrap(); super::handle_query_cmd( info.net_id, query.clone(), t, resp_tx, &mut test_map, &mut config, ) .await; assert_eq!( timeout(t, resp_rx).await.unwrap().unwrap(), Response::Error { error: QueryError::ServerNotReady } ); let (resp_tx, resp_rx) = oneshot::channel(); test_map.insert(info.net_id, (info.clone(), None)); super::handle_query_cmd( info.net_id, query.clone(), t, resp_tx, &mut test_map, &mut config, ) .await; assert_eq!( timeout(t, resp_rx).await.unwrap().unwrap(), Response::Error { error: QueryError::ServerNotReady } ); // Test the connection broken case. test_map.clear(); let (resp_tx, resp_rx) = oneshot::channel(); let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .unwrap(); doh.status = ConnectionStatus::Fail; test_map.insert(info.net_id, (info.clone(), Some(doh))); super::handle_query_cmd( info.net_id, query.clone(), t, resp_tx, &mut test_map, &mut config, ) .await; assert_eq!( timeout(t, resp_rx).await.unwrap().unwrap(), Response::Error { error: QueryError::BrokenServer } ); }); } extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); unsafe { assert_validation_info(net_id, ip_addr, host); } } extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); unsafe { assert_validation_info(net_id, ip_addr, host); } } // # Safety // `ip_addr`, `host` are null terminated strings unsafe fn assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); assert_eq!(host, ""); } #[test] fn report_private_dns_validation() { let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; let rt = Arc::new( Builder::new_current_thread() .thread_name("test-runtime") .build() .expect("Failed to create testing tokio runtime"), ); let default_panic = std::panic::take_hook(); // Exit the test if the worker inside tokio runtime panicked. std::panic::set_hook(Box::new(move |info| { default_panic(info); std::process::exit(1); })); super::report_private_dns_validation( &info, &ConnectionStatus::Ready, rt.clone(), success_cb, ); super::report_private_dns_validation(&info, &ConnectionStatus::Fail, rt.clone(), fail_cb); super::report_private_dns_validation( &info, &ConnectionStatus::Pending, rt.clone(), fail_cb, ); super::report_private_dns_validation(&info, &ConnectionStatus::Idle, rt, fail_cb); } #[test] fn make_probe_query_and_request() { let probe_query = super::make_probe_query().unwrap(); let url = Url::parse(LOCALHOST_URL).unwrap(); let request = make_dns_request(&probe_query, &url).unwrap(); // Verify H3 DNS request. assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE); assert_eq!(request[0].name(), b":method"); assert_eq!(request[0].value(), b"GET"); assert_eq!(request[1].name(), b":scheme"); assert_eq!(request[1].value(), b"https"); assert_eq!(request[2].name(), b":authority"); assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes()); assert_eq!(request[3].name(), b":path"); let mut path = String::from(url.path()); path.push_str("?dns="); path.push_str(&probe_query); assert_eq!(request[3].value(), path.as_bytes()); assert_eq!(request[5].name(), b"accept"); assert_eq!(request[5].value(), b"application/dns-message"); // Verify DNS probe packet. let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap(); assert_eq!(bytes.len(), PROBE_QUERY_SIZE); // TODO: Parse the result to ensure it's a valid DNS packet. } #[test] fn create_quiche_config() { assert!( super::create_quiche_config(None).is_ok(), "quiche config without cert creating failed" ); assert!( super::create_quiche_config(Some("data/local/tmp/")).is_ok(), "quiche config with cert creating failed" ); } #[test] fn make_doh_udp_socket() { // Make a socket connecting to loopback with a test mark. let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); // Check if the socket is connected to loopback. assert_eq!( sk.peer_addr().unwrap(), SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT)) ); // Check if the socket mark is correct. let fd: RawFd = sk.as_raw_fd(); let mut mark: u32 = 50; let mut size = std::mem::size_of::<u32>() as libc::socklen_t; unsafe { // Safety: It's fine since the fd belongs to this test. assert_eq!( libc::getsockopt( fd, libc::SOL_SOCKET, libc::SO_MARK, &mut mark as *mut _ as *mut libc::c_void, &mut size as *mut libc::socklen_t, ), 0 ); } assert_eq!(mark, TEST_MARK); // Check if the socket is non-blocking. unsafe { // Safety: It's fine since the fd belongs to this test. assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK); } } } Loading
TEST_MAPPING +1 −0 Original line number Diff line number Diff line { "presubmit": [ { "name": "doh_unit_test" }, { "name": "resolv_integration_test" }, { "name": "resolv_gold_test" }, { "name": "resolv_unit_test" }, Loading
doh.rs +288 −2 Original line number Diff line number Diff line Loading @@ -84,7 +84,7 @@ type DnsRequestArg = [quiche::h3::Header]; type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); #[derive(Debug)] #[derive(Eq, PartialEq, Debug)] enum QueryError { BrokenServer, ConnectionError, Loading @@ -102,7 +102,7 @@ struct ServerInfo { cert_path: Option<String>, } #[derive(Debug)] #[derive(Eq, PartialEq, Debug)] enum Response { Error { error: QueryError }, Success { answer: Vec<u8> }, Loading Loading @@ -983,3 +983,289 @@ pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) { error!("Failed to send the query: {:?}", e); } } #[cfg(test)] mod tests { use super::*; use quiche::h3::NameValue; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; const TEST_NET_ID: u32 = 50; const PROBE_QUERY_SIZE: usize = 56; const H3_DNS_REQUEST_HEADER_SIZE: usize = 6; const TEST_MARK: u32 = 0xD0033; const LOOPBACK_ADDR: &str = "127.0.0.1:443"; const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; // TODO: Make some tests for DohConnection and QuicheConfigCache. fn make_testing_variables() -> ( ServerInfo, HashMap<u32, (ServerInfo, Option<DohConnection>)>, QuicheConfigCache, Arc<Runtime>, ) { let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new(); let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; let config: QuicheConfigCache = QuicheConfigCache { cert_path: None, config: None }; let rt = Arc::new( Builder::new_current_thread() .thread_name("test-runtime") .enable_all() .build() .expect("Failed to create testing tokio runtime"), ); (info, test_map, config, rt) } #[test] fn make_connection_if_needed() { let (info, mut test_map, mut config, rt) = make_testing_variables(); rt.block_on(async { // Expect to make a new connection. let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .unwrap(); assert_eq!(doh.net_id, info.net_id); assert_eq!(doh.status, ConnectionStatus::Pending); doh.status = ConnectionStatus::Fail; test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get a connection with fail status that we added to the map before. let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .unwrap(); assert_eq!(doh.net_id, info.net_id); assert_eq!(doh.status, ConnectionStatus::Fail); doh.status = ConnectionStatus::Ready; test_map.insert(info.net_id, (info.clone(), Some(doh))); // Expect that we will get None because the map contains a connection with ready status. assert!(super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .is_none()); }); } #[test] fn handle_query_cmd() { let (info, mut test_map, mut config, rt) = make_testing_variables(); let t = Duration::from_millis(100); rt.block_on(async { // Test no available server cases. let (resp_tx, resp_rx) = oneshot::channel(); let query = super::make_probe_query().unwrap(); super::handle_query_cmd( info.net_id, query.clone(), t, resp_tx, &mut test_map, &mut config, ) .await; assert_eq!( timeout(t, resp_rx).await.unwrap().unwrap(), Response::Error { error: QueryError::ServerNotReady } ); let (resp_tx, resp_rx) = oneshot::channel(); test_map.insert(info.net_id, (info.clone(), None)); super::handle_query_cmd( info.net_id, query.clone(), t, resp_tx, &mut test_map, &mut config, ) .await; assert_eq!( timeout(t, resp_rx).await.unwrap().unwrap(), Response::Error { error: QueryError::ServerNotReady } ); // Test the connection broken case. test_map.clear(); let (resp_tx, resp_rx) = oneshot::channel(); let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config) .unwrap() .unwrap(); doh.status = ConnectionStatus::Fail; test_map.insert(info.net_id, (info.clone(), Some(doh))); super::handle_query_cmd( info.net_id, query.clone(), t, resp_tx, &mut test_map, &mut config, ) .await; assert_eq!( timeout(t, resp_rx).await.unwrap().unwrap(), Response::Error { error: QueryError::BrokenServer } ); }); } extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); unsafe { assert_validation_info(net_id, ip_addr, host); } } extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); unsafe { assert_validation_info(net_id, ip_addr, host); } } // # Safety // `ip_addr`, `host` are null terminated strings unsafe fn assert_validation_info( net_id: uint32_t, ip_addr: *const c_char, host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); assert_eq!(host, ""); } #[test] fn report_private_dns_validation() { let info = ServerInfo { net_id: TEST_NET_ID, url: Url::parse(LOCALHOST_URL).unwrap(), peer_addr: LOOPBACK_ADDR.parse().unwrap(), domain: None, sk_mark: 0, cert_path: None, }; let rt = Arc::new( Builder::new_current_thread() .thread_name("test-runtime") .build() .expect("Failed to create testing tokio runtime"), ); let default_panic = std::panic::take_hook(); // Exit the test if the worker inside tokio runtime panicked. std::panic::set_hook(Box::new(move |info| { default_panic(info); std::process::exit(1); })); super::report_private_dns_validation( &info, &ConnectionStatus::Ready, rt.clone(), success_cb, ); super::report_private_dns_validation(&info, &ConnectionStatus::Fail, rt.clone(), fail_cb); super::report_private_dns_validation( &info, &ConnectionStatus::Pending, rt.clone(), fail_cb, ); super::report_private_dns_validation(&info, &ConnectionStatus::Idle, rt, fail_cb); } #[test] fn make_probe_query_and_request() { let probe_query = super::make_probe_query().unwrap(); let url = Url::parse(LOCALHOST_URL).unwrap(); let request = make_dns_request(&probe_query, &url).unwrap(); // Verify H3 DNS request. assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE); assert_eq!(request[0].name(), b":method"); assert_eq!(request[0].value(), b"GET"); assert_eq!(request[1].name(), b":scheme"); assert_eq!(request[1].value(), b"https"); assert_eq!(request[2].name(), b":authority"); assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes()); assert_eq!(request[3].name(), b":path"); let mut path = String::from(url.path()); path.push_str("?dns="); path.push_str(&probe_query); assert_eq!(request[3].value(), path.as_bytes()); assert_eq!(request[5].name(), b"accept"); assert_eq!(request[5].value(), b"application/dns-message"); // Verify DNS probe packet. let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap(); assert_eq!(bytes.len(), PROBE_QUERY_SIZE); // TODO: Parse the result to ensure it's a valid DNS packet. } #[test] fn create_quiche_config() { assert!( super::create_quiche_config(None).is_ok(), "quiche config without cert creating failed" ); assert!( super::create_quiche_config(Some("data/local/tmp/")).is_ok(), "quiche config with cert creating failed" ); } #[test] fn make_doh_udp_socket() { // Make a socket connecting to loopback with a test mark. let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); // Check if the socket is connected to loopback. assert_eq!( sk.peer_addr().unwrap(), SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT)) ); // Check if the socket mark is correct. let fd: RawFd = sk.as_raw_fd(); let mut mark: u32 = 50; let mut size = std::mem::size_of::<u32>() as libc::socklen_t; unsafe { // Safety: It's fine since the fd belongs to this test. assert_eq!( libc::getsockopt( fd, libc::SOL_SOCKET, libc::SO_MARK, &mut mark as *mut _ as *mut libc::c_void, &mut size as *mut libc::socklen_t, ), 0 ); } assert_eq!(mark, TEST_MARK); // Check if the socket is non-blocking. unsafe { // Safety: It's fine since the fd belongs to this test. assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK); } } }