Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c8017d11 authored by Luke Huang's avatar Luke Huang Committed by Automerger Merge Worker
Browse files

Merge "Add some unit test for doh" am: 509e3705

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1778905

Change-Id: I5cb8a99b88eb690b475f1429e8f49cd948270090
parents b8cb39be 509e3705
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
{
    "presubmit": [
        { "name": "doh_unit_test" },
        { "name": "resolv_integration_test" },
        { "name": "resolv_gold_test" },
        { "name": "resolv_unit_test" },
+288 −2
Original line number Diff line number Diff line
@@ -84,7 +84,7 @@ type DnsRequestArg = [quiche::h3::Header];
type ValidationCallback =
    extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);

#[derive(Debug)]
#[derive(Eq, PartialEq, Debug)]
enum QueryError {
    BrokenServer,
    ConnectionError,
@@ -102,7 +102,7 @@ struct ServerInfo {
    cert_path: Option<String>,
}

#[derive(Debug)]
#[derive(Eq, PartialEq, Debug)]
enum Response {
    Error { error: QueryError },
    Success { answer: Vec<u8> },
@@ -982,3 +982,289 @@ pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) {
        error!("Failed to send the query: {:?}", e);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use quiche::h3::NameValue;
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};

    const TEST_NET_ID: u32 = 50;
    const PROBE_QUERY_SIZE: usize = 56;
    const H3_DNS_REQUEST_HEADER_SIZE: usize = 6;
    const TEST_MARK: u32 = 0xD0033;
    const LOOPBACK_ADDR: &str = "127.0.0.1:443";
    const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";

    // TODO: Make some tests for DohConnection and QuicheConfigCache.

    fn make_testing_variables() -> (
        ServerInfo,
        HashMap<u32, (ServerInfo, Option<DohConnection>)>,
        QuicheConfigCache,
        Arc<Runtime>,
    ) {
        let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };
        let config: QuicheConfigCache = QuicheConfigCache { cert_path: None, config: None };

        let rt = Arc::new(
            Builder::new_current_thread()
                .thread_name("test-runtime")
                .enable_all()
                .build()
                .expect("Failed to create testing tokio runtime"),
        );
        (info, test_map, config, rt)
    }

    #[test]
    fn make_connection_if_needed() {
        let (info, mut test_map, mut config, rt) = make_testing_variables();
        rt.block_on(async {
            // Expect to make a new connection.
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .unwrap();
            assert_eq!(doh.net_id, info.net_id);
            assert_eq!(doh.status, ConnectionStatus::Pending);
            doh.status = ConnectionStatus::Fail;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            // Expect that we will get a connection with fail status that we added to the map before.
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .unwrap();
            assert_eq!(doh.net_id, info.net_id);
            assert_eq!(doh.status, ConnectionStatus::Fail);
            doh.status = ConnectionStatus::Ready;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            // Expect that we will get None because the map contains a connection with ready status.
            assert!(super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .is_none());
        });
    }

    #[test]
    fn handle_query_cmd() {
        let (info, mut test_map, mut config, rt) = make_testing_variables();
        let t = Duration::from_millis(100);

        rt.block_on(async {
            // Test no available server cases.
            let (resp_tx, resp_rx) = oneshot::channel();
            let query = super::make_probe_query().unwrap();
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                t,
                resp_tx,
                &mut test_map,
                &mut config,
            )
            .await;
            assert_eq!(
                timeout(t, resp_rx).await.unwrap().unwrap(),
                Response::Error { error: QueryError::ServerNotReady }
            );

            let (resp_tx, resp_rx) = oneshot::channel();
            test_map.insert(info.net_id, (info.clone(), None));
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                t,
                resp_tx,
                &mut test_map,
                &mut config,
            )
            .await;
            assert_eq!(
                timeout(t, resp_rx).await.unwrap().unwrap(),
                Response::Error { error: QueryError::ServerNotReady }
            );

            // Test the connection broken case.
            test_map.clear();
            let (resp_tx, resp_rx) = oneshot::channel();
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .unwrap();
            doh.status = ConnectionStatus::Fail;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                t,
                resp_tx,
                &mut test_map,
                &mut config,
            )
            .await;
            assert_eq!(
                timeout(t, resp_rx).await.unwrap().unwrap(),
                Response::Error { error: QueryError::BrokenServer }
            );
        });
    }

    extern "C" fn success_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    extern "C" fn fail_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(!success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    // # Safety
    // `ip_addr`, `host` are null terminated strings
    unsafe fn assert_validation_info(
        net_id: uint32_t,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert_eq!(net_id, TEST_NET_ID);
        let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
        let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
        assert_eq!(ip_addr, expected_addr.ip().to_string());
        let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
        assert_eq!(host, "");
    }

    #[test]
    fn report_private_dns_validation() {
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };
        let rt = Arc::new(
            Builder::new_current_thread()
                .thread_name("test-runtime")
                .build()
                .expect("Failed to create testing tokio runtime"),
        );
        let default_panic = std::panic::take_hook();
        // Exit the test if the worker inside tokio runtime panicked.
        std::panic::set_hook(Box::new(move |info| {
            default_panic(info);
            std::process::exit(1);
        }));
        super::report_private_dns_validation(
            &info,
            &ConnectionStatus::Ready,
            rt.clone(),
            success_cb,
        );
        super::report_private_dns_validation(&info, &ConnectionStatus::Fail, rt.clone(), fail_cb);
        super::report_private_dns_validation(
            &info,
            &ConnectionStatus::Pending,
            rt.clone(),
            fail_cb,
        );
        super::report_private_dns_validation(&info, &ConnectionStatus::Idle, rt, fail_cb);
    }

    #[test]
    fn make_probe_query_and_request() {
        let probe_query = super::make_probe_query().unwrap();
        let url = Url::parse(LOCALHOST_URL).unwrap();
        let request = make_dns_request(&probe_query, &url).unwrap();
        // Verify H3 DNS request.
        assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE);
        assert_eq!(request[0].name(), b":method");
        assert_eq!(request[0].value(), b"GET");
        assert_eq!(request[1].name(), b":scheme");
        assert_eq!(request[1].value(), b"https");
        assert_eq!(request[2].name(), b":authority");
        assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes());
        assert_eq!(request[3].name(), b":path");
        let mut path = String::from(url.path());
        path.push_str("?dns=");
        path.push_str(&probe_query);
        assert_eq!(request[3].value(), path.as_bytes());
        assert_eq!(request[5].name(), b"accept");
        assert_eq!(request[5].value(), b"application/dns-message");

        // Verify DNS probe packet.
        let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap();
        assert_eq!(bytes.len(), PROBE_QUERY_SIZE);
        // TODO: Parse the result to ensure it's a valid DNS packet.
    }

    #[test]
    fn create_quiche_config() {
        assert!(
            super::create_quiche_config(None).is_ok(),
            "quiche config without cert creating failed"
        );
        assert!(
            super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
            "quiche config with cert creating failed"
        );
    }

    #[test]
    fn make_doh_udp_socket() {
        // Make a socket connecting to loopback with a test mark.
        let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
        // Check if the socket is connected to loopback.
        assert_eq!(
            sk.peer_addr().unwrap(),
            SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
        );

        // Check if the socket mark is correct.
        let fd: RawFd = sk.as_raw_fd();

        let mut mark: u32 = 50;
        let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
        unsafe {
            // Safety: It's fine since the fd belongs to this test.
            assert_eq!(
                libc::getsockopt(
                    fd,
                    libc::SOL_SOCKET,
                    libc::SO_MARK,
                    &mut mark as *mut _ as *mut libc::c_void,
                    &mut size as *mut libc::socklen_t,
                ),
                0
            );
        }
        assert_eq!(mark, TEST_MARK);

        // Check if the socket is non-blocking.
        unsafe {
            // Safety: It's fine since the fd belongs to this test.
            assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
        }
    }
}