Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 509e3705 authored by Luke Huang's avatar Luke Huang Committed by Gerrit Code Review
Browse files

Merge "Add some unit test for doh"

parents 44cb608c df84c209
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
{
    "presubmit": [
        { "name": "doh_unit_test" },
        { "name": "resolv_integration_test" },
        { "name": "resolv_gold_test" },
        { "name": "resolv_unit_test" },
+288 −2
Original line number Diff line number Diff line
@@ -84,7 +84,7 @@ type DnsRequestArg = [quiche::h3::Header];
type ValidationCallback =
    extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);

#[derive(Debug)]
#[derive(Eq, PartialEq, Debug)]
enum QueryError {
    BrokenServer,
    ConnectionError,
@@ -102,7 +102,7 @@ struct ServerInfo {
    cert_path: Option<String>,
}

#[derive(Debug)]
#[derive(Eq, PartialEq, Debug)]
enum Response {
    Error { error: QueryError },
    Success { answer: Vec<u8> },
@@ -982,3 +982,289 @@ pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) {
        error!("Failed to send the query: {:?}", e);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use quiche::h3::NameValue;
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};

    const TEST_NET_ID: u32 = 50;
    const PROBE_QUERY_SIZE: usize = 56;
    const H3_DNS_REQUEST_HEADER_SIZE: usize = 6;
    const TEST_MARK: u32 = 0xD0033;
    const LOOPBACK_ADDR: &str = "127.0.0.1:443";
    const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";

    // TODO: Make some tests for DohConnection and QuicheConfigCache.

    fn make_testing_variables() -> (
        ServerInfo,
        HashMap<u32, (ServerInfo, Option<DohConnection>)>,
        QuicheConfigCache,
        Arc<Runtime>,
    ) {
        let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };
        let config: QuicheConfigCache = QuicheConfigCache { cert_path: None, config: None };

        let rt = Arc::new(
            Builder::new_current_thread()
                .thread_name("test-runtime")
                .enable_all()
                .build()
                .expect("Failed to create testing tokio runtime"),
        );
        (info, test_map, config, rt)
    }

    #[test]
    fn make_connection_if_needed() {
        let (info, mut test_map, mut config, rt) = make_testing_variables();
        rt.block_on(async {
            // Expect to make a new connection.
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .unwrap();
            assert_eq!(doh.net_id, info.net_id);
            assert_eq!(doh.status, ConnectionStatus::Pending);
            doh.status = ConnectionStatus::Fail;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            // Expect that we will get a connection with fail status that we added to the map before.
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .unwrap();
            assert_eq!(doh.net_id, info.net_id);
            assert_eq!(doh.status, ConnectionStatus::Fail);
            doh.status = ConnectionStatus::Ready;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            // Expect that we will get None because the map contains a connection with ready status.
            assert!(super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .is_none());
        });
    }

    #[test]
    fn handle_query_cmd() {
        let (info, mut test_map, mut config, rt) = make_testing_variables();
        let t = Duration::from_millis(100);

        rt.block_on(async {
            // Test no available server cases.
            let (resp_tx, resp_rx) = oneshot::channel();
            let query = super::make_probe_query().unwrap();
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                t,
                resp_tx,
                &mut test_map,
                &mut config,
            )
            .await;
            assert_eq!(
                timeout(t, resp_rx).await.unwrap().unwrap(),
                Response::Error { error: QueryError::ServerNotReady }
            );

            let (resp_tx, resp_rx) = oneshot::channel();
            test_map.insert(info.net_id, (info.clone(), None));
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                t,
                resp_tx,
                &mut test_map,
                &mut config,
            )
            .await;
            assert_eq!(
                timeout(t, resp_rx).await.unwrap().unwrap(),
                Response::Error { error: QueryError::ServerNotReady }
            );

            // Test the connection broken case.
            test_map.clear();
            let (resp_tx, resp_rx) = oneshot::channel();
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, &mut config)
                .unwrap()
                .unwrap();
            doh.status = ConnectionStatus::Fail;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                t,
                resp_tx,
                &mut test_map,
                &mut config,
            )
            .await;
            assert_eq!(
                timeout(t, resp_rx).await.unwrap().unwrap(),
                Response::Error { error: QueryError::BrokenServer }
            );
        });
    }

    extern "C" fn success_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    extern "C" fn fail_cb(
        net_id: uint32_t,
        success: bool,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert!(!success);
        unsafe {
            assert_validation_info(net_id, ip_addr, host);
        }
    }

    // # Safety
    // `ip_addr`, `host` are null terminated strings
    unsafe fn assert_validation_info(
        net_id: uint32_t,
        ip_addr: *const c_char,
        host: *const c_char,
    ) {
        assert_eq!(net_id, TEST_NET_ID);
        let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
        let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
        assert_eq!(ip_addr, expected_addr.ip().to_string());
        let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
        assert_eq!(host, "");
    }

    #[test]
    fn report_private_dns_validation() {
        let info = ServerInfo {
            net_id: TEST_NET_ID,
            url: Url::parse(LOCALHOST_URL).unwrap(),
            peer_addr: LOOPBACK_ADDR.parse().unwrap(),
            domain: None,
            sk_mark: 0,
            cert_path: None,
        };
        let rt = Arc::new(
            Builder::new_current_thread()
                .thread_name("test-runtime")
                .build()
                .expect("Failed to create testing tokio runtime"),
        );
        let default_panic = std::panic::take_hook();
        // Exit the test if the worker inside tokio runtime panicked.
        std::panic::set_hook(Box::new(move |info| {
            default_panic(info);
            std::process::exit(1);
        }));
        super::report_private_dns_validation(
            &info,
            &ConnectionStatus::Ready,
            rt.clone(),
            success_cb,
        );
        super::report_private_dns_validation(&info, &ConnectionStatus::Fail, rt.clone(), fail_cb);
        super::report_private_dns_validation(
            &info,
            &ConnectionStatus::Pending,
            rt.clone(),
            fail_cb,
        );
        super::report_private_dns_validation(&info, &ConnectionStatus::Idle, rt, fail_cb);
    }

    #[test]
    fn make_probe_query_and_request() {
        let probe_query = super::make_probe_query().unwrap();
        let url = Url::parse(LOCALHOST_URL).unwrap();
        let request = make_dns_request(&probe_query, &url).unwrap();
        // Verify H3 DNS request.
        assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE);
        assert_eq!(request[0].name(), b":method");
        assert_eq!(request[0].value(), b"GET");
        assert_eq!(request[1].name(), b":scheme");
        assert_eq!(request[1].value(), b"https");
        assert_eq!(request[2].name(), b":authority");
        assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes());
        assert_eq!(request[3].name(), b":path");
        let mut path = String::from(url.path());
        path.push_str("?dns=");
        path.push_str(&probe_query);
        assert_eq!(request[3].value(), path.as_bytes());
        assert_eq!(request[5].name(), b"accept");
        assert_eq!(request[5].value(), b"application/dns-message");

        // Verify DNS probe packet.
        let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap();
        assert_eq!(bytes.len(), PROBE_QUERY_SIZE);
        // TODO: Parse the result to ensure it's a valid DNS packet.
    }

    #[test]
    fn create_quiche_config() {
        assert!(
            super::create_quiche_config(None).is_ok(),
            "quiche config without cert creating failed"
        );
        assert!(
            super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
            "quiche config with cert creating failed"
        );
    }

    #[test]
    fn make_doh_udp_socket() {
        // Make a socket connecting to loopback with a test mark.
        let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
        // Check if the socket is connected to loopback.
        assert_eq!(
            sk.peer_addr().unwrap(),
            SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
        );

        // Check if the socket mark is correct.
        let fd: RawFd = sk.as_raw_fd();

        let mut mark: u32 = 50;
        let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
        unsafe {
            // Safety: It's fine since the fd belongs to this test.
            assert_eq!(
                libc::getsockopt(
                    fd,
                    libc::SOL_SOCKET,
                    libc::SO_MARK,
                    &mut mark as *mut _ as *mut libc::c_void,
                    &mut size as *mut libc::socklen_t,
                ),
                0
            );
        }
        assert_eq!(mark, TEST_MARK);

        // Check if the socket is non-blocking.
        unsafe {
            // Safety: It's fine since the fd belongs to this test.
            assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
        }
    }
}