Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit f82ce150 authored by Mike Yu's avatar Mike Yu Committed by Automerger Merge Worker
Browse files

Use pull model to get the statistics from DohFrontend am: c293f29f

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1906164

Change-Id: I5267167dd3a6b079c3ea3eab260c968ba6dbea78
parents a746bc32 c293f29f
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -29,8 +29,8 @@ struct DohFrontend;
struct Stats {
    /// The number of accumulated DoH queries that are received.
    uint32_t queries_received;
    /// The number of accumulated QUIC connections.
    uint32_t connections;
    /// The number of accumulated QUIC connections accepted.
    uint32_t connections_accepted;
};

extern "C" {
@@ -99,7 +99,7 @@ bool frontend_set_max_streams_bidi(DohFrontend *doh, uint64_t value);
bool frontend_block_sending(DohFrontend *doh, bool block);

/// Gets the statistics of the `DohFrontend` and writes the result to |out|.
void frontend_stats(const DohFrontend *doh, Stats *out);
bool frontend_stats(DohFrontend *doh, Stats *out);

/// Resets `queries_received` field of `Stats` owned by the `DohFrontend`.
bool frontend_stats_clear_queries(const DohFrontend *doh);
+11 −8
Original line number Diff line number Diff line
@@ -253,14 +253,13 @@ impl ClientMap {
        &mut self,
        hdr: &quiche::Header,
        addr: &SocketAddr,
    ) -> Result<(&mut Client, bool)> {
        let dcid = hdr.dcid.as_ref();
        let is_new_client = !self.clients.contains_key(dcid);
        let client = if is_new_client {
    ) -> Result<&mut Client> {
        let dcid = hdr.dcid.as_ref().to_vec();
        let client = if !self.clients.contains_key(&dcid) {
            ensure!(hdr.ty == quiche::Type::Initial, "Packet is not Initial");
            ensure!(quiche::version_is_supported(hdr.version), "Protocol version not supported");

            let scid = generate_conn_id(&self.conn_id_seed, dcid);
            let scid = generate_conn_id(&self.conn_id_seed, &dcid);
            let conn = quiche::accept(
                &quiche::ConnectionId::from_ref(&scid),
                None, /* odcid */
@@ -273,19 +272,23 @@ impl ClientMap {
            self.clients.insert(scid.clone(), client);
            self.clients.get_mut(&scid).unwrap()
        } else {
            self.clients.get_mut(dcid).unwrap()
            self.clients.get_mut(&dcid).unwrap()
        };

        Ok((client, is_new_client))
        Ok(client)
    }

    pub fn get_mut(&mut self, id: &[u8]) -> Option<&mut Client> {
        self.clients.get_mut(&id.to_vec())
    }

    pub fn get_mut_iter(&mut self) -> hash_map::IterMut<ConnectionID, Client> {
    pub fn iter_mut(&mut self) -> hash_map::IterMut<ConnectionID, Client> {
        self.clients.iter_mut()
    }

    pub fn len(&mut self) -> usize {
        self.clients.len()
    }
}

fn generate_conn_id(conn_id_seed: &hmac::Key, dcid: &[u8]) -> ConnectionID {
+85 −30
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::mpsc::channel;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;

lazy_static! {
@@ -44,11 +44,19 @@ lazy_static! {
    );
}

/// Command used by worker_thread itself.
#[derive(Debug)]
enum Command {
enum InternalCommand {
    MaybeWrite { connection_id: ConnectionID },
}

/// Commands that DohFrontend to ask its worker_thread for.
#[derive(Debug)]
enum ControlCommand {
    Stats { resp: oneshot::Sender<Stats> },
    StatsClearQueries,
}

/// Frontend object.
#[derive(Debug)]
pub struct DohFrontend {
@@ -73,10 +81,11 @@ pub struct DohFrontend {
    // TODO: use channel to update worker_thread configuration.
    config: Arc<Mutex<Config>>,

    // Stores some statistic to check DohFrontend status.
    // It's shared with the worker thread.
    // TODO: use channel to retrieve the stats from worker_thread.
    stats: Arc<Mutex<Stats>>,
    // Caches the latest stats so that the stats remains after worker_thread stops.
    latest_stats: Stats,

    // It is wrapped as Option because the channel is not created in DohFrontend construction.
    command_tx: Option<mpsc::UnboundedSender<ControlCommand>>,
}

/// The parameters passed to the worker thread.
@@ -85,7 +94,7 @@ struct WorkerParams {
    backend_socket: std::net::UdpSocket,
    clients: ClientMap,
    config: Arc<Mutex<Config>>,
    stats: Arc<Mutex<Stats>>,
    command_rx: mpsc::UnboundedReceiver<ControlCommand>,
}

impl DohFrontend {
@@ -100,7 +109,8 @@ impl DohFrontend {
            private_key: String::new(),
            worker_thread: None,
            config: Arc::new(Mutex::new(Config::new())),
            stats: Arc::new(Mutex::new(Stats::new())),
            latest_stats: Stats::new(),
            command_tx: None,
        });
        debug!("DohFrontend created: {:?}", doh);
        Ok(doh)
@@ -123,6 +133,9 @@ impl DohFrontend {

    pub fn stop(&mut self) -> Result<()> {
        if let Some(worker_thread) = self.worker_thread.take() {
            // Update latest_stats before stopping worker_thread.
            let _ = self.request_stats();

            worker_thread.abort();
        }

@@ -165,16 +178,48 @@ impl DohFrontend {
        Ok(())
    }

    pub fn stats(&self) -> Stats {
        self.stats.lock().unwrap().clone()
    pub fn request_stats(&mut self) -> Result<Stats> {
        ensure!(
            self.command_tx.is_some(),
            "command_tx is None because worker thread not yet initialized"
        );
        let command_tx = self.command_tx.as_ref().unwrap();

        if command_tx.is_closed() {
            return Ok(self.latest_stats.clone());
        }

        let (resp_tx, resp_rx) = oneshot::channel();
        command_tx.send(ControlCommand::Stats { resp: resp_tx })?;

        match RUNTIME_STATIC
            .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await })
        {
            Ok(v) => match v {
                Ok(stats) => {
                    self.latest_stats = stats.clone();
                    Ok(stats)
                }
                Err(e) => bail!(e),
            },
            Err(e) => bail!(e),
        }
    }

    pub fn stats_clear_queries(&self) -> Result<()> {
        self.stats.lock().unwrap().queries_received = 0;
        Ok(())
        ensure!(
            self.command_tx.is_some(),
            "command_tx is None because worker thread not yet initialized"
        );
        return self
            .command_tx
            .as_ref()
            .unwrap()
            .send(ControlCommand::StatsClearQueries)
            .or_else(|e| bail!(e));
    }

    fn init_worker_thread_params(&self) -> Result<WorkerParams> {
    fn init_worker_thread_params(&mut self) -> Result<WorkerParams> {
        let bind_addr =
            if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" };
        let backend_socket = std::net::UdpSocket::bind(bind_addr)?;
@@ -190,12 +235,15 @@ impl DohFrontend {
            self.config.clone(),
        )?)?;

        let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>();
        self.command_tx = Some(command_tx);

        Ok(WorkerParams {
            frontend_socket,
            backend_socket,
            clients,
            config: self.config.clone(),
            stats: self.stats.clone(),
            command_rx,
        })
    }
}
@@ -204,18 +252,19 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
    let backend_socket = into_tokio_udp_socket(params.backend_socket)?;
    let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?;
    let config = params.config;
    let stats = params.stats;
    let (event_tx, mut event_rx) = channel::<Command>(100);
    let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>();
    let mut command_rx = params.command_rx;
    let mut clients = params.clients;
    let mut frontend_buf = [0; 65535];
    let mut backend_buf = [0; 16384];
    let mut delay_queries_buffer: Vec<Vec<u8>> = vec![];
    let mut queries_received = 0;

    debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket);

    loop {
        let timeout = clients
            .get_mut_iter()
            .iter_mut()
            .filter_map(|(_, c)| c.timeout())
            .min()
            .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
@@ -223,12 +272,12 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
        tokio::select! {
            _ = tokio::time::sleep(timeout) => {
                debug!("timeout");
                for (_, client) in clients.get_mut_iter() {
                for (_, client) in clients.iter_mut() {
                    // If no timeout has occurred it does nothing.
                    client.on_timeout();

                    let connection_id = client.connection_id().clone();
                    event_tx.send(Command::MaybeWrite{connection_id}).await?;
                    event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
                }
            }

@@ -247,12 +296,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
                debug!("Got QUIC packet: {:?}", hdr);

                let client = match clients.get_or_create(&hdr, &src) {
                    Ok((client, is_new_client)) => {
                        if is_new_client {
                            stats.lock().unwrap().connections += 1;
                        }
                        client
                    }
                    Ok(v) => v,
                    Err(e) => {
                        error!("Failed to get the client by the hdr {:?}: {}", hdr, e);
                        continue;
@@ -263,7 +307,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
                match client.handle_frontend_message(pkt_buf) {
                    Ok(v) if !v.is_empty() => {
                        delay_queries_buffer.push(v);
                        stats.lock().unwrap().queries_received += 1;
                        queries_received += 1;
                    }
                    Err(e) => {
                        error!("Failed to process QUIC packet: {}", e);
@@ -280,7 +324,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
                }

                let connection_id = client.connection_id().clone();
                event_tx.send(Command::MaybeWrite{connection_id}).await?;
                event_tx.send(InternalCommand::MaybeWrite{connection_id})?;
            }

            Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => {
@@ -291,13 +335,13 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
                }

                let query_id = [backend_buf[0], backend_buf[1]];
                for (_, client) in clients.get_mut_iter() {
                for (_, client) in clients.iter_mut() {
                    if client.is_waiting_for_query(&query_id) {
                        if let Err(e) = client.handle_backend_message(&backend_buf[..len]) {
                            error!("Failed to handle message from backend: {}", e);
                        }
                        let connection_id = client.connection_id().clone();
                        event_tx.send(Command::MaybeWrite{connection_id}).await?;
                        event_tx.send(InternalCommand::MaybeWrite{connection_id})?;

                        // It's a bug if more than one client is waiting for this query.
                        break;
@@ -307,7 +351,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {

            Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => {
                match command {
                    Command::MaybeWrite {connection_id} => {
                    InternalCommand::MaybeWrite {connection_id} => {
                        if let Some(client) = clients.get_mut(&connection_id) {
                            while let Ok(v) = client.flush_egress() {
                                let addr = client.addr();
@@ -321,6 +365,17 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
                    }
                }
            }
            Some(command) = command_rx.recv() => {
                match command {
                    ControlCommand::Stats {resp} => {
                        let stats = Stats {queries_received, connections_accepted: clients.len() as u32};
                        if let Err(e) = resp.send(stats) {
                            error!("Failed to send ControlCommand::Stats response: {:?}", e);
                        }
                    }
                    ControlCommand::StatsClearQueries => queries_received = 0,
                }
            }
        }
    }
}
+9 −5
Original line number Diff line number Diff line
@@ -155,10 +155,14 @@ pub extern "C" fn frontend_block_sending(doh: &mut DohFrontend, block: bool) ->

/// Gets the statistics of the `DohFrontend` and writes the result to |out|.
#[no_mangle]
pub extern "C" fn frontend_stats(doh: &DohFrontend, out: &mut Stats) {
    let stats = doh.stats();
pub extern "C" fn frontend_stats(doh: &mut DohFrontend, out: &mut Stats) -> bool {
    doh.request_stats()
        .map(|stats| {
            out.queries_received = stats.queries_received;
    out.connections = stats.connections;
            out.connections_accepted = stats.connections_accepted;
        })
        .or_else(logging_and_return_err)
        .is_ok()
}

/// Resets `queries_received` field of `Stats` owned by the `DohFrontend`.
@@ -181,6 +185,6 @@ fn to_socket_addr(addr: &str, port: &str) -> Result<SocketAddr> {
}

fn logging_and_return_err<T, U: std::fmt::Debug>(e: U) -> Result<T> {
    warn!("{:?}", e);
    warn!("logging_and_return_err: {:?}", e);
    bail!("{:?}", e)
}
+2 −2
Original line number Diff line number Diff line
@@ -21,8 +21,8 @@
pub struct Stats {
    /// The number of accumulated DoH queries that are received.
    pub queries_received: u32,
    /// The number of accumulated QUIC connections.
    pub connections: u32,
    /// The number of accumulated QUIC connections accepted.
    pub connections_accepted: u32,
}

impl Stats {
Loading