Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 40357cc2 authored by Matthew Maurer's avatar Matthew Maurer Committed by Automerger Merge Worker
Browse files

DoH: Modularize main event loop am: ed78fdaf

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1870956

Change-Id: I5dacc7971113538954d316f61c6ac9b58c942ce4
parents a70ab5e8 ed78fdaf
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -334,6 +334,7 @@ doh_rust_deps = [
    "liblibc",
    "liblog_rust",
    "libring",
    "libthiserror",
    "libtokio",
    "liburl",
]
+7 −6
Original line number Diff line number Diff line
@@ -27,7 +27,8 @@
use quiche::{h3, Result};
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex, RwLock, Weak};
use std::sync::{Arc, RwLock, Weak};
use tokio::sync::Mutex;

type WeakConfig = Weak<Mutex<quiche::Config>>;

@@ -80,8 +81,8 @@ impl Config {

    /// Take the underlying config, usable as `&mut quiche::Config` for use
    /// with `quiche::connect`.
    pub fn take(&mut self) -> impl DerefMut<Target = quiche::Config> + '_ {
        self.0.lock().unwrap()
    pub async fn take(&mut self) -> impl DerefMut<Target = quiche::Config> + '_ {
        self.0.lock().await
    }
}

@@ -229,11 +230,11 @@ fn lifetimes() {
    assert_eq!(cache.state.read().unwrap().path_to_config.len(), 2);
}

#[test]
fn quiche_connect() {
#[tokio::test]
async fn quiche_connect() {
    use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
    let mut config = Config::from_cert_path(None).unwrap();
    let socket_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 42));
    let conn_id = quiche::ConnectionId::from_ref(&[]);
    quiche::connect(None, &conn_id, socket_addr, &mut config.take()).unwrap();
    quiche::connect(None, &conn_id, socket_addr, config.take().await.deref_mut()).unwrap();
}
+369 −0
Original line number Diff line number Diff line
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//! Defines a backing task to keep a HTTP/3 connection running

use crate::boot_time;
use crate::boot_time::BootTime;
use log::warn;
use quiche::h3;
use std::collections::HashMap;
use std::default::Default;
use std::future;
use std::io;
use std::pin::Pin;
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::{mpsc, oneshot};

#[derive(Error, Debug)]
pub enum Error {
    #[error("network IO error: {0}")]
    Network(#[from] io::Error),
    #[error("QUIC error: {0}")]
    Quic(#[from] quiche::Error),
    #[error("HTTP/3 error: {0}")]
    H3(#[from] h3::Error),
    #[error("Response delivery error: {0}")]
    StreamSend(#[from] mpsc::error::SendError<Stream>),
    #[error("Connection closed")]
    Closed,
}

pub type Result<T> = std::result::Result<T, Error>;

#[derive(Debug)]
/// HTTP/3 Request to be sent on the connection
pub struct Request {
    /// Request headers
    pub headers: Vec<h3::Header>,
    /// Expiry time for the request, relative to `CLOCK_BOOTTIME`
    pub expiry: Option<BootTime>,
    /// Channel to send the response to
    pub response_tx: oneshot::Sender<Stream>,
}

#[derive(Debug)]
/// HTTP/3 Response
pub struct Stream {
    /// Response headers
    pub headers: Vec<h3::Header>,
    /// Response body
    pub data: Vec<u8>,
    /// Error code if stream was reset
    pub error: Option<u64>,
}

impl Stream {
    fn new(headers: Vec<h3::Header>) -> Self {
        Self { headers, data: Vec::new(), error: None }
    }
}

const MAX_UDP_PACKET_SIZE: usize = 65536;

struct Driver {
    request_rx: mpsc::Receiver<Request>,
    quiche_conn: Pin<Box<quiche::Connection>>,
    socket: UdpSocket,
    // This buffer is large, boxing it will keep it
    // off the stack and prevent it being copied during
    // moves of the driver.
    buffer: Box<[u8; MAX_UDP_PACKET_SIZE]>,
}

struct H3Driver {
    driver: Driver,
    // h3_conn sometimes can't "fit" a request in its available windows.
    // This value holds a peeked request in that case, waiting for
    // transmission to become possible.
    buffered_request: Option<Request>,
    // We can't check if a receiver is dead without potentially receiving a message, and if we poll
    // on a dead receiver in a select! it will immediately return None. As a result, we need this
    // to gate whether or not to include .recv() in our select!
    closing: bool,
    h3_conn: h3::Connection,
    requests: HashMap<u64, Request>,
    streams: HashMap<u64, Stream>,
}

async fn optional_timeout(timeout: Option<boot_time::Duration>) {
    match timeout {
        Some(timeout) => boot_time::sleep(timeout).await,
        None => future::pending().await,
    }
}

/// Creates a future which when polled will handle events related to a HTTP/3 connection.
/// The returned error code will explain why the connection terminated.
pub async fn drive(
    request_rx: mpsc::Receiver<Request>,
    quiche_conn: Pin<Box<quiche::Connection>>,
    socket: UdpSocket,
) -> Result<()> {
    Driver::new(request_rx, quiche_conn, socket).drive().await
}

impl Driver {
    fn new(
        request_rx: mpsc::Receiver<Request>,
        quiche_conn: Pin<Box<quiche::Connection>>,
        socket: UdpSocket,
    ) -> Self {
        Self { request_rx, quiche_conn, socket, buffer: Box::new([0; MAX_UDP_PACKET_SIZE]) }
    }

    async fn drive(mut self) -> Result<()> {
        // Prime connection
        self.flush_tx().await?;
        loop {
            self = self.drive_once().await?
        }
    }

    fn handle_closed(&self) -> Result<()> {
        if self.quiche_conn.is_closed() {
            Err(Error::Closed)
        } else {
            Ok(())
        }
    }

    async fn drive_once(mut self) -> Result<Self> {
        let timer = optional_timeout(self.quiche_conn.timeout());
        select! {
            // If a quiche timer would fire, call their callback
            _ = timer => self.quiche_conn.on_timeout(),
            // If we got packets from our peer, pass them to quiche
            Ok((size, from)) = self.socket.recv_from(self.buffer.as_mut()) => {
                self.quiche_conn.recv(&mut self.buffer[..size], quiche::RecvInfo { from })?;
            }
        };
        // Any of the actions in the select could require us to send packets to the peer
        self.flush_tx().await?;

        // If the QUIC connection is live, but the HTTP/3 is not, try to bring it up
        if self.quiche_conn.is_established() {
            let h3_config = h3::Config::new()?;
            let h3_conn = h3::Connection::with_transport(&mut self.quiche_conn, &h3_config)?;
            return H3Driver::new(self, h3_conn).drive().await;
        }

        // If the connection has closed, tear down
        self.handle_closed()?;

        Ok(self)
    }

    async fn flush_tx(&mut self) -> Result<()> {
        let send_buf = self.buffer.as_mut();
        loop {
            match self.quiche_conn.send(send_buf) {
                Err(quiche::Error::Done) => return Ok(()),
                Err(e) => return Err(e.into()),
                Ok((valid_len, send_info)) => {
                    self.socket.send_to(&send_buf[..valid_len], send_info.to).await?;
                }
            }
        }
    }
}

impl H3Driver {
    fn new(driver: Driver, h3_conn: h3::Connection) -> Self {
        Self {
            driver,
            h3_conn,
            closing: false,
            requests: HashMap::new(),
            streams: HashMap::new(),
            buffered_request: None,
        }
    }

    async fn drive(mut self) -> Result<Driver> {
        loop {
            self.drive_once().await?;
        }
    }

    async fn drive_once(&mut self) -> Result<()> {
        // We can't call self.driver.drive_once at the same time as
        // self.driver.request_rx.recv() due to ownership
        let timer = optional_timeout(self.driver.quiche_conn.timeout());
        // If we've buffered a request (due to the connection being full)
        // try to resend that first
        if let Some(request) = self.buffered_request.take() {
            self.handle_request(request)?;
        }
        select! {
            // Only attempt to enqueue new requests if we have no buffered request and aren't
            // closing
            msg = self.driver.request_rx.recv(), if !self.closing && self.buffered_request.is_none() => match msg {
                Some(request) => self.handle_request(request)?,
                None => self.shutdown(true, b"DONE").await?,
            },
            // If a quiche timer would fire, call their callback
            _ = timer => self.driver.quiche_conn.on_timeout(),
            // If we got packets from our peer, pass them to quiche
            Ok((size, from)) = self.driver.socket.recv_from(self.driver.buffer.as_mut()) => {
                self.driver.quiche_conn.recv(&mut self.driver.buffer[..size], quiche::RecvInfo { from })?;
            }
        };

        // Any of the actions in the select could require us to send packets to the peer
        self.driver.flush_tx().await?;

        // Process any incoming HTTP/3 events
        self.flush_h3().await?;

        // If the connection has closed, tear down
        self.driver.handle_closed()
    }

    fn handle_request(&mut self, request: Request) -> Result<()> {
        // If the request has already timed out, don't issue it to the server.
        if let Some(expiry) = request.expiry {
            if BootTime::now() > expiry {
                return Ok(());
            }
        }
        let stream_id =
            // If h3_conn says the stream is blocked, this error is recoverable just by trying
            // again once the stream has made progress. Buffer the request for a later retry.
            match self.h3_conn.send_request(&mut self.driver.quiche_conn, &request.headers, true) {
                Err(h3::Error::StreamBlocked) | Err(h3::Error::TransportError(quiche::Error::StreamLimit)) => {
                    // We only call handle_request on a value that has just come out of
                    // buffered_request, or when buffered_request is empty. This assert just
                    // validates that we don't break that assumption later, as it could result in
                    // requests being dropped on the floor under high load.
                    assert!(self.buffered_request.is_none());
                    self.buffered_request = Some(request);
                    return Ok(())
                }
                result => result?,
            };
        self.requests.insert(stream_id, request);
        Ok(())
    }

    async fn recv_body(&mut self, stream_id: u64) -> Result<()> {
        const STREAM_READ_CHUNK: usize = 4096;
        if let Some(stream) = self.streams.get_mut(&stream_id) {
            loop {
                let base_len = stream.data.len();
                stream.data.resize(base_len + STREAM_READ_CHUNK, 0);
                match self.h3_conn.recv_body(
                    &mut self.driver.quiche_conn,
                    stream_id,
                    &mut stream.data[base_len..],
                ) {
                    Err(h3::Error::Done) => {
                        stream.data.truncate(base_len);
                        return Ok(());
                    }
                    Err(e) => {
                        stream.data.truncate(base_len);
                        return Err(e.into());
                    }
                    Ok(recvd) => stream.data.truncate(base_len + recvd),
                }
            }
        } else {
            warn!("Received body for untracked stream ID {}", stream_id);
        }
        Ok(())
    }

    fn discard_datagram(&mut self, _flow_id: u64) -> Result<()> {
        loop {
            match self.h3_conn.recv_dgram(&mut self.driver.quiche_conn, self.driver.buffer.as_mut())
            {
                Err(h3::Error::Done) => return Ok(()),
                Err(e) => return Err(e.into()),
                _ => (),
            }
        }
    }

    async fn flush_h3(&mut self) -> Result<()> {
        loop {
            match self.h3_conn.poll(&mut self.driver.quiche_conn) {
                Err(h3::Error::Done) => return Ok(()),
                Err(e) => return Err(e.into()),
                Ok((stream_id, event)) => self.process_h3_event(stream_id, event).await?,
            }
        }
    }

    async fn process_h3_event(&mut self, stream_id: u64, event: h3::Event) -> Result<()> {
        if !self.requests.contains_key(&stream_id) {
            warn!("Received event {:?} for stream_id {} without a request.", event, stream_id);
        }
        match event {
            h3::Event::Headers { list, has_body } => {
                let stream = Stream::new(list);
                if self.streams.insert(stream_id, stream).is_some() {
                    warn!("Re-using stream ID {} before it was completed.", stream_id)
                }
                if !has_body {
                    self.respond(stream_id);
                }
            }
            h3::Event::Data => {
                self.recv_body(stream_id).await?;
            }
            h3::Event::Finished => self.respond(stream_id),
            // This clause is for quiche 0.10.x, we're still on 0.9.x
            //h3::Event::Reset(e) => {
            //    self.streams.get_mut(&stream_id).map(|stream| stream.error = Some(e));
            //    self.respond(stream_id);
            //}
            h3::Event::Datagram => {
                warn!("Unexpected Datagram received");
                // We don't care if something went wrong with the datagram, we didn't
                // want it anyways.
                let _ = self.discard_datagram(stream_id);
            }
            h3::Event::GoAway => self.shutdown(false, b"SERVER GOAWAY").await?,
        }
        Ok(())
    }

    async fn shutdown(&mut self, send_goaway: bool, msg: &[u8]) -> Result<()> {
        self.driver.request_rx.close();
        while self.driver.request_rx.recv().await.is_some() {}
        self.closing = true;
        if send_goaway {
            self.h3_conn.send_goaway(&mut self.driver.quiche_conn, 0)?;
        }
        if self.driver.quiche_conn.close(true, 0, msg).is_err() {
            warn!("Trying to close already closed QUIC connection");
        }
        Ok(())
    }

    fn respond(&mut self, stream_id: u64) {
        match (self.streams.remove(&stream_id), self.requests.remove(&stream_id)) {
            (Some(stream), Some(request)) => {
                // We don't care about the error, because it means the requestor has left.
                let _ = request.response_tx.send(stream);
            }
            (None, _) => warn!("Tried to deliver untracked stream {}", stream_id),
            (_, None) => warn!("Tried to deliver stream {} to untracked requestor", stream_id),
        }
    }
}

doh/connection/mod.rs

0 → 100644
+150 −0
Original line number Diff line number Diff line
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

//! Module providing an async abstraction around a quiche HTTP/3 connection

use crate::boot_time::BootTime;
use crate::network::SocketTagger;
use log::error;
use quiche::h3;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use thiserror::Error;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task;

mod driver;

pub use driver::Stream;
use driver::{drive, Request};

/// Quiche HTTP/3 connection
pub struct Connection {
    request_tx: mpsc::Sender<Request>,
}

fn new_scid() -> [u8; quiche::MAX_CONN_ID_LEN] {
    use ring::rand::{SecureRandom, SystemRandom};
    let mut scid = [0; quiche::MAX_CONN_ID_LEN];
    SystemRandom::new().fill(&mut scid).unwrap();
    scid
}

fn mark_socket(socket: &std::net::UdpSocket, socket_mark: u32) -> io::Result<()> {
    use std::os::unix::io::AsRawFd;
    let fd = socket.as_raw_fd();
    // libc::setsockopt is a wrapper function calling into bionic setsockopt.
    // The only pointer being passed in is &socket_mark, which is valid by virtue of being a
    // reference, and the foreign function doesn't take ownership or a reference to that memory
    // after completion.
    if unsafe {
        libc::setsockopt(
            fd,
            libc::SOL_SOCKET,
            libc::SO_MARK,
            &socket_mark as *const _ as *const libc::c_void,
            std::mem::size_of::<u32>() as libc::socklen_t,
        )
    } == 0
    {
        Ok(())
    } else {
        Err(io::Error::last_os_error())
    }
}

async fn build_socket(
    peer_addr: SocketAddr,
    socket_mark: u32,
    tag_socket: &SocketTagger,
) -> io::Result<UdpSocket> {
    let bind_addr = match peer_addr {
        SocketAddr::V4(_) => "0.0.0.0:0",
        SocketAddr::V6(_) => "[::]:0",
    };

    let socket = UdpSocket::bind(bind_addr).await?;
    let std_socket = socket.into_std()?;
    mark_socket(&std_socket, socket_mark)
        .unwrap_or_else(|e| error!("Unable to mark socket : {:?}", e));
    tag_socket(&std_socket).await;
    let socket = UdpSocket::from_std(std_socket)?;
    socket.connect(peer_addr).await?;
    Ok(socket)
}

/// Error type for HTTP/3 connection
#[derive(Debug, Error)]
pub enum Error {
    /// QUIC protocol error
    #[error("QUIC error: {0}")]
    Quic(#[from] quiche::Error),
    /// HTTP/3 protocol error
    #[error("HTTP/3 error: {0}")]
    H3(#[from] h3::Error),
    /// Unable to send the request to the driver. This likely means the
    /// backing task has died.
    #[error("Unable to send request")]
    SendRequest(#[from] mpsc::error::SendError<Request>),
    /// IO failed. This is most likely to occur while trying to set up the
    /// UDP socket for use by the connection.
    #[error("IO error: {0}")]
    Io(#[from] io::Error),
    /// The request is no longer being serviced. This could mean that the
    /// request was dropped for an unspecified reason, or that the connection
    /// was closed prematurely and it can no longer be serviced.
    #[error("Driver dropped request")]
    RecvResponse(#[from] oneshot::error::RecvError),
}

/// Common result type for working with a HTTP/3 connection
pub type Result<T> = std::result::Result<T, Error>;

impl Connection {
    const MAX_PENDING_REQUESTS: usize = 10;
    /// Create a new connection with a background task handling IO.
    pub async fn new(
        server_name: Option<&str>,
        to: SocketAddr,
        socket_mark: u32,
        tag_socket: &SocketTagger,
        config: &mut quiche::Config,
    ) -> Result<Self> {
        let (request_tx, request_rx) = mpsc::channel(Self::MAX_PENDING_REQUESTS);
        let scid = new_scid();
        let quiche_conn =
            quiche::connect(server_name, &quiche::ConnectionId::from_ref(&scid), to, config)?;
        let socket = build_socket(to, socket_mark, tag_socket).await?;
        let driver = drive(request_rx, quiche_conn, socket);
        task::spawn(driver);
        Ok(Self { request_tx })
    }

    /// Send a query, produce a future which will provide a response.
    /// The future is separately returned rather than awaited to allow it to be waited on without
    /// keeping the `Connection` itself borrowed.
    pub async fn query(
        &self,
        headers: Vec<h3::Header>,
        expiry: Option<BootTime>,
    ) -> Result<impl Future<Output = Option<Stream>>> {
        let (response_tx, response_rx) = oneshot::channel();
        self.request_tx.send(Request { headers, response_tx, expiry }).await?;
        Ok(async move { response_rx.await.ok() })
    }
}
+127 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! Provides a backing task to implement a Dispatcher

use crate::boot_time::{BootTime, Duration};
use anyhow::{bail, Result};
use log::{debug, trace, warn};
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};

use super::{Command, QueryError, Response};
use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter};
use crate::{config, network};

pub struct Driver {
    command_rx: mpsc::Receiver<Command>,
    networks: HashMap<u32, Network>,
    validation: ValidationReporter,
    tagger: SocketTagger,
    config_cache: config::Cache,
}

fn debug_err(r: Result<()>) {
    if let Err(e) = r {
        debug!("Dispatcher loop got {:?}", e);
    }
}

impl Driver {
    pub fn new(
        command_rx: mpsc::Receiver<Command>,
        validation: ValidationReporter,
        tagger: SocketTagger,
    ) -> Self {
        Self {
            command_rx,
            networks: HashMap::new(),
            validation,
            tagger,
            config_cache: config::Cache::new(),
        }
    }

    pub async fn drive(mut self) -> Result<()> {
        loop {
            self.drive_once().await?
        }
    }

    async fn drive_once(&mut self) -> Result<()> {
        if let Some(command) = self.command_rx.recv().await {
            trace!("dispatch command: {:?}", command);
            match command {
                Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await),
                Command::Query { net_id, base64_query, expired_time, resp } => {
                    debug_err(self.query(net_id, base64_query, expired_time, resp).await)
                }
                Command::Clear { net_id } => {
                    self.networks.remove(&net_id);
                    self.config_cache.garbage_collect();
                }
                Command::Exit => {
                    bail!("Death due to Exit")
                }
            }
            Ok(())
        } else {
            bail!("Death due to command_tx dying")
        }
    }

    async fn query(
        &mut self,
        net_id: u32,
        query: String,
        expiry: BootTime,
        response: oneshot::Sender<Response>,
    ) -> Result<()> {
        if let Some(network) = self.networks.get_mut(&net_id) {
            network.query(network::Query { query, response, expiry }).await?;
        } else {
            warn!("Tried to send a query to non-existent network net_id={}", net_id);
            response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| {
                warn!("Unable to send reply for non-existent network net_id={}", net_id);
            })
        }
        Ok(())
    }

    async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> {
        use std::collections::hash_map::Entry;
        if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) {
            // If we have a network registered to the provided net_id, but the server info doesn't
            // match, our API has been used incorrectly. Attempt to recover by deleting the old
            // network and recreating it according to the probe request.
            warn!("Probing net_id={} with mismatched server info", info.net_id);
            self.networks.remove(&info.net_id);
        }
        // Can't use or_insert_with because creating a network may fail
        let net = match self.networks.entry(info.net_id) {
            Entry::Occupied(network) => network.into_mut(),
            Entry::Vacant(vacant) => {
                let config = self.config_cache.from_cert_path(&info.cert_path)?;
                vacant.insert(
                    Network::new(info, config, self.validation.clone(), self.tagger.clone())
                        .await?,
                )
            }
        };
        net.probe(timeout).await?;
        Ok(())
    }
}
Loading