Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit a7691414 authored by Luke Huang's avatar Luke Huang Committed by Automerger Merge Worker
Browse files

Tags the socket for DoH3 am: 10eed9da

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1839853

Change-Id: I0140647f9c168e510bfb629b730ede541278ebfe
parents e56f59b9 10eed9da
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -428,7 +428,8 @@ void PrivateDnsConfiguration::initDohLocked() {
            [](uint32_t net_id, bool success, const char* ip_addr, const char* host) {
                android::net::PrivateDnsConfiguration::getInstance().onDohStatusUpdate(
                        net_id, success, ip_addr, host);
            });
            },
            [](int32_t sock) { resolv_tag_socket(sock, AID_DNS, NET_CONTEXT_INVALID_PID); });
}

int PrivateDnsConfiguration::setDoh(int32_t netId, uint32_t mark,
+5 −2
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@

#pragma once

/* Generated with cbindgen:0.17.0 */
/* Generated with cbindgen:0.20.0 */

#include <stdint.h>
#include <sys/types.h>
@@ -55,6 +55,8 @@ struct DohDispatcher;
using ValidationCallback = void (*)(uint32_t net_id, bool success, const char* ip_addr,
                                    const char* host);

using TagSocketCallback = void (*)(int32_t sock);

extern "C" {

/// Performs static initialization for android logger.
@@ -65,7 +67,8 @@ void doh_set_log_level(uint32_t level);

/// Performs the initialization for the DoH engine.
/// Creates and returns a DoH engine instance.
DohDispatcher* doh_dispatcher_new(ValidationCallback ptr);
DohDispatcher* doh_dispatcher_new(ValidationCallback validation_fn,
                                  TagSocketCallback tag_socket_fn);

/// Deletes a DoH engine created by doh_dispatcher_new().
/// # Safety
+56 −18
Original line number Diff line number Diff line
@@ -82,6 +82,7 @@ type QueryResponder = oneshot::Sender<Response>;
type DnsRequest = Vec<quiche::h3::Header>;
type ValidationCallback =
    extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
type TagSocketCallback = extern "C" fn(sock: int32_t);

#[derive(Eq, PartialEq, Debug)]
enum QueryError {
@@ -190,7 +191,10 @@ pub struct DohDispatcher {

// DoH dispatcher
impl DohDispatcher {
    fn new(validation_fn: ValidationCallback) -> Result<Box<DohDispatcher>> {
    fn new(
        validation_fn: ValidationCallback,
        tag_socket_fn: TagSocketCallback,
    ) -> Result<Box<DohDispatcher>> {
        let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE);
        let runtime = Arc::new(
            Builder::new_multi_thread()
@@ -200,7 +204,8 @@ impl DohDispatcher {
                .build()
                .expect("Failed to create tokio runtime"),
        );
        let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn));
        let join_handle =
            runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn));
        Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime }))
    }

@@ -224,12 +229,14 @@ struct DohConnection {
    state: ConnectionState,
    pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>,
    cached_session: Option<Vec<u8>>,
    tag_socket_fn: TagSocketCallback,
}

impl DohConnection {
    fn new(
        info: &ServerInfo,
        shared_config: Arc<Mutex<QuicheConfigCache>>,
        tag_socket_fn: TagSocketCallback,
    ) -> Result<DohConnection> {
        let mut scid = [0; quiche::MAX_CONN_ID_LEN];
        ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?;
@@ -240,6 +247,7 @@ impl DohConnection {
            state: ConnectionState::Idle,
            pending_queries: Vec::new(),
            cached_session: None,
            tag_socket_fn,
        })
    }

@@ -247,6 +255,7 @@ impl DohConnection {
        self.state = match self.state {
            ConnectionState::Idle => {
                let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?;
                (self.tag_socket_fn)(udp_sk_std.as_raw_fd());
                let udp_sk = UdpSocket::from_std(udp_sk_std)?;
                let connid = quiche::ConnectionId::from_ref(&self.scid);
                let mut cache = self.shared_config.lock().unwrap();
@@ -775,6 +784,7 @@ fn make_connection_if_needed(
    info: &ServerInfo,
    doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
    shared_config: Arc<Mutex<QuicheConfigCache>>,
    tag_socket_fn: TagSocketCallback,
) -> Result<Option<DohConnection>> {
    // Check if connection exists.
    match doh_conn_map.get(&info.net_id) {
@@ -790,7 +800,7 @@ fn make_connection_if_needed(
        // TODO: change the inner connection instead of removing?
        _ => doh_conn_map.remove(&info.net_id),
    };
    let doh = DohConnection::new(info, shared_config)?;
    let doh = DohConnection::new(info, shared_config, tag_socket_fn)?;
    doh_conn_map.insert(info.net_id, (info.clone(), None));
    Ok(Some(doh))
}
@@ -861,6 +871,7 @@ async fn doh_handler(
    mut cmd_rx: CmdReceiver,
    runtime: Arc<Runtime>,
    validation_fn: ValidationCallback,
    tag_socket_fn: TagSocketCallback,
) -> Result<()> {
    info!("doh_dispatcher entry");
    let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None }));
@@ -888,7 +899,7 @@ async fn doh_handler(
                trace!("recv {:?}", cmd);
                match cmd {
                    DohCommand::Probe { info, timeout: t } => {
                        match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone()) {
                        match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone(), tag_socket_fn) {
                            Ok(Some(doh)) => {
                                // Create a new async task associated to the DoH connection.
                                probe_futures.push(probe_task(info, doh, t));
@@ -1051,8 +1062,11 @@ pub extern "C" fn doh_set_log_level(level: u32) {
/// Performs the initialization for the DoH engine.
/// Creates and returns a DoH engine instance.
#[no_mangle]
pub extern "C" fn doh_dispatcher_new(ptr: ValidationCallback) -> *mut DohDispatcher {
    match DohDispatcher::new(ptr) {
pub extern "C" fn doh_dispatcher_new(
    validation_fn: ValidationCallback,
    tag_socket_fn: TagSocketCallback,
) -> *mut DohDispatcher {
    match DohDispatcher::new(validation_fn, tag_socket_fn) {
        Ok(c) => Box::into_raw(c),
        Err(e) => {
            error!("doh_dispatcher_new: failed: {:?}", e);
@@ -1261,12 +1275,21 @@ mod tests {
        (info, test_map, config_cache, rt)
    }

    extern "C" fn tag_socket_cb(sock: int32_t) {
        assert!(sock >= 0);
    }

    #[test]
    fn make_connection_if_needed() {
        let (info, mut test_map, config, rt) = make_testing_variables();
        rt.block_on(async {
            // Expect to make a new connection.
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone())
            let mut doh = super::make_connection_if_needed(
                &info,
                &mut test_map,
                config.clone(),
                tag_socket_cb,
            )
            .unwrap()
            .unwrap();
            assert_eq!(doh.info.net_id, info.net_id);
@@ -1274,7 +1297,12 @@ mod tests {
            doh.state = ConnectionState::Error;
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            // Expect that we will get a connection with fail status that we added to the map before.
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone())
            let mut doh = super::make_connection_if_needed(
                &info,
                &mut test_map,
                config.clone(),
                tag_socket_cb,
            )
            .unwrap()
            .unwrap();
            assert_eq!(doh.info.net_id, info.net_id);
@@ -1282,7 +1310,12 @@ mod tests {
            doh.state = make_dummy_connected_state();
            test_map.insert(info.net_id, (info.clone(), Some(doh)));
            // Expect that we will get None because the map contains a connection with ready status.
            assert!(super::make_connection_if_needed(&info, &mut test_map, config.clone())
            assert!(super::make_connection_if_needed(
                &info,
                &mut test_map,
                config.clone(),
                tag_socket_cb
            )
            .unwrap()
            .is_none());
        });
@@ -1328,7 +1361,12 @@ mod tests {
            // Test the connection broken case.
            test_map.clear();
            let (resp_tx, resp_rx) = oneshot::channel();
            let mut doh = super::make_connection_if_needed(&info, &mut test_map, config.clone())
            let mut doh = super::make_connection_if_needed(
                &info,
                &mut test_map,
                config.clone(),
                tag_socket_cb,
            )
            .unwrap()
            .unwrap();
            doh.state = ConnectionState::Error;
+5 −2
Original line number Diff line number Diff line
@@ -40,14 +40,17 @@ TEST(DoHFFITest, SmokeTest) {
    // To ensure that we have a real network.
    ASSERT_GE(dnsNetId, MINIMAL_NET_ID) << "No available networks";

    auto callback = [](uint32_t netId, bool success, const char* ip_addr, const char* host) {
    auto validation_cb = [](uint32_t netId, bool success, const char* ip_addr, const char* host) {
        EXPECT_EQ(netId, dnsNetId);
        EXPECT_TRUE(success);
        EXPECT_STREQ(ip_addr, GOOGLE_SERVER_IP);
        EXPECT_STREQ(host, "");
        cv.notify_one();
    };
    DohDispatcher* doh = doh_dispatcher_new(callback);

    auto tag_socket_cb = [](int32_t sock) { EXPECT_GE(sock, 0); };

    DohDispatcher* doh = doh_dispatcher_new(validation_cb, tag_socket_cb);
    EXPECT_TRUE(doh != nullptr);

    // TODO: Use a local server instead of dns.google.