Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 83b75c07 authored by Luke Huang's avatar Luke Huang Committed by Automerger Merge Worker
Browse files

Fix the bug of receiving long DNS answer on DoH am: 67936ef2 am: d29a72ba am: 3f02ce7b

Original change: https://android-review.googlesource.com/c/platform/packages/modules/DnsResolver/+/1816685

Change-Id: Id6c2f1f392c08f61ef785b6c47d9e6fe79da3dbb
parents 34fec704 3f02ce7b
Loading
Loading
Loading
Loading
+81 −46
Original line number Original line Diff line number Diff line
@@ -124,6 +124,12 @@ enum ConnectionStatus {
    Fail,
    Fail,
}
}


enum H3Result {
    Data { data: Vec<u8> },
    Finished,
    Ignore,
}

trait OptionDeref<T: Deref> {
trait OptionDeref<T: Deref> {
    fn as_deref(&self) -> Option<&T::Target>;
    fn as_deref(&self) -> Option<&T::Target>;
}
}
@@ -207,7 +213,7 @@ struct DohConnection {
    udp_sk: UdpSocket,
    udp_sk: UdpSocket,
    h3_conn: Option<h3::Connection>,
    h3_conn: Option<h3::Connection>,
    status: ConnectionStatus,
    status: ConnectionStatus,
    query_map: HashMap<u64, QueryResponder>,
    query_map: HashMap<u64, (Vec<u8>, QueryResponder)>,
    pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>,
    pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>,
    cached_session: Option<Vec<u8>>,
    cached_session: Option<Vec<u8>>,
    expired_time: Option<BootTime>,
    expired_time: Option<BootTime>,
@@ -257,14 +263,20 @@ impl DohConnection {
        loop {
        loop {
            self.recv_rx().await?;
            self.recv_rx().await?;
            self.flush_tx().await?;
            self.flush_tx().await?;
            if let Ok((stream_id, _buf)) = self.recv_query() {
            loop {
                match self.recv_h3() {
                    Ok((stream_id, H3Result::Finished)) => {
                        if stream_id == req_id {
                        if stream_id == req_id {
                            return Ok(());
                        }
                    }
                    // TODO: Verify the answer
                    // TODO: Verify the answer
                    break;
                    Ok((_stream_id, H3Result::Data { .. })) => {}
                    Ok((_stream_id, H3Result::Ignore)) => {}
                    Err(_) => break,
                }
                }
            }
            }
        }
        }
        Ok(())
    }
    }


    async fn connect(&mut self) -> Result<()> {
    async fn connect(&mut self) -> Result<()> {
@@ -302,7 +314,7 @@ impl DohConnection {
        match self.status {
        match self.status {
            ConnectionStatus::Ready => match self.send_dns_query(&req).await {
            ConnectionStatus::Ready => match self.send_dns_query(&req).await {
                Ok(req_id) => {
                Ok(req_id) => {
                    self.query_map.insert(req_id, resp);
                    self.query_map.insert(req_id, (Vec::new(), resp));
                }
                }
                Err(e) => {
                Err(e) => {
                    if let Ok(quiche::h3::Error::StreamBlocked) = e.downcast::<quiche::h3::Error>()
                    if let Ok(quiche::h3::Error::StreamBlocked) = e.downcast::<quiche::h3::Error>()
@@ -359,18 +371,32 @@ impl DohConnection {
            }
            }
            self.recv_rx().await?;
            self.recv_rx().await?;
            self.flush_tx().await?;
            self.flush_tx().await?;
            if let Ok((stream_id, buf)) = self.recv_query() {
            loop {
                if let Some(resp) = self.query_map.remove(&stream_id) {
                match self.recv_h3() {
                    Ok((stream_id, H3Result::Data { mut data })) => {
                        if let Some((answer, _)) = self.query_map.get_mut(&stream_id) {
                            answer.append(&mut data);
                        } else {
                            // Should not happen
                            warn!("No associated receiver found while receiving Data, Network {}, stream id: {}", self.net_id, stream_id);
                        }
                    }
                    Ok((stream_id, H3Result::Finished)) => {
                        if let Some((answer, resp)) = self.query_map.remove(&stream_id) {
                            debug!(
                            debug!(
                                "sending answer back to resolv, Network {}, stream id: {}",
                                "sending answer back to resolv, Network {}, stream id: {}",
                                self.net_id, stream_id
                                self.net_id, stream_id
                            );
                            );
                    resp.send(Response::Success { answer: buf }).unwrap_or_else(|e| {
                            resp.send(Response::Success { answer }).unwrap_or_else(|e| {
                                trace!("the receiver dropped {:?}, stream id: {}", e, stream_id);
                                trace!("the receiver dropped {:?}, stream id: {}", e, stream_id);
                            });
                            });
                        } else {
                        } else {
                            // Should not happen
                            // Should not happen
                    warn!("No associated receiver found");
                            warn!("No associated receiver found while receiving Finished, Network {}, stream id: {}", self.net_id, stream_id);
                        }
                    }
                    Ok((_stream_id, H3Result::Ignore)) => {}
                    Err(_) => break,
                }
                }
            }
            }
            if self.quic_conn.is_closed() || !self.quic_conn.is_established() {
            if self.quic_conn.is_closed() || !self.quic_conn.is_established() {
@@ -380,15 +406,15 @@ impl DohConnection {
        }
        }
    }
    }


    fn recv_query(&mut self) -> Result<(u64, Vec<u8>)> {
    fn recv_h3(&mut self) -> Result<(u64, H3Result)> {
        let h3_conn = self.h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
        let h3_conn = self.h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
        loop {
        match h3_conn.poll(&mut self.quic_conn) {
        match h3_conn.poll(&mut self.quic_conn) {
            // Process HTTP/3 events.
            // Process HTTP/3 events.
            Ok((stream_id, quiche::h3::Event::Data)) => {
            Ok((stream_id, quiche::h3::Event::Data)) => {
                debug!("quiche::h3::Event::Data");
                debug!("quiche::h3::Event::Data");
                let mut buf = vec![0; MAX_DATAGRAM_SIZE];
                let mut buf = vec![0; MAX_DATAGRAM_SIZE];
                    if let Ok(read) = h3_conn.recv_body(&mut self.quic_conn, stream_id, &mut buf) {
                match h3_conn.recv_body(&mut self.quic_conn, stream_id, &mut buf) {
                    Ok(read) => {
                        trace!(
                        trace!(
                            "got {} bytes of response data on stream {}: {:x?}",
                            "got {} bytes of response data on stream {}: {:x?}",
                            read,
                            read,
@@ -396,7 +422,12 @@ impl DohConnection {
                            &buf[..read]
                            &buf[..read]
                        );
                        );
                        buf.truncate(read);
                        buf.truncate(read);
                        return Ok((stream_id, buf));
                        Ok((stream_id, H3Result::Data { data: buf }))
                    }
                    Err(e) => {
                        warn!("recv_h3::recv_body {:?}", e);
                        bail!(e);
                    }
                }
                }
            }
            }
            Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
            Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
@@ -404,23 +435,27 @@ impl DohConnection {
                    "got response headers {:?} on stream id {} has_body {}",
                    "got response headers {:?} on stream id {} has_body {}",
                    list, stream_id, has_body
                    list, stream_id, has_body
                );
                );
                Ok((stream_id, H3Result::Ignore))
            }
            }
            Ok((stream_id, quiche::h3::Event::Finished)) => {
            Ok((stream_id, quiche::h3::Event::Finished)) => {
                debug!("quiche::h3::Event::Finished on stream id {}", stream_id);
                debug!("quiche::h3::Event::Finished on stream id {}", stream_id);
                Ok((stream_id, H3Result::Finished))
            }
            }
            Ok((stream_id, quiche::h3::Event::Datagram)) => {
            Ok((stream_id, quiche::h3::Event::Datagram)) => {
                debug!("quiche::h3::Event::Datagram on stream id {}", stream_id);
                debug!("quiche::h3::Event::Datagram on stream id {}", stream_id);
                Ok((stream_id, H3Result::Ignore))
            }
            }
            // TODO: Check if it's necessary to handle GoAway event.
            Ok((stream_id, quiche::h3::Event::GoAway)) => {
            Ok((stream_id, quiche::h3::Event::GoAway)) => {
                debug!("quiche::h3::Event::GoAway on stream id {}", stream_id);
                debug!("quiche::h3::Event::GoAway on stream id {}", stream_id);
                Ok((stream_id, H3Result::Ignore))
            }
            }
            Err(e) => {
            Err(e) => {
                    debug!("recv_query {:?}", e);
                debug!("recv_h3 {:?}", e);
                bail!(e);
                bail!(e);
            }
            }
        }
        }
    }
    }
    }


    async fn recv_rx(&mut self) -> Result<()> {
    async fn recv_rx(&mut self) -> Result<()> {
        // TODO: Evaluate if we could make the buffer smaller.
        // TODO: Evaluate if we could make the buffer smaller.
+28 −2
Original line number Original line Diff line number Diff line
@@ -50,11 +50,22 @@ pub struct Client {
    /// Queues the DNS queries being processed in backend.
    /// Queues the DNS queries being processed in backend.
    /// <Query ID, Stream ID>
    /// <Query ID, Stream ID>
    in_flight_queries: HashMap<[u8; 2], u64>,
    in_flight_queries: HashMap<[u8; 2], u64>,

    /// Queues the second part DNS answers needed to be sent after first part.
    /// <Stream ID, ans>
    pending_answers: Vec<(u64, Vec<u8>)>,
}
}


impl Client {
impl Client {
    fn new(conn: Pin<Box<quiche::Connection>>, addr: &SocketAddr, id: ConnectionID) -> Client {
    fn new(conn: Pin<Box<quiche::Connection>>, addr: &SocketAddr, id: ConnectionID) -> Client {
        Client { conn, h3_conn: None, addr: *addr, id, in_flight_queries: HashMap::new() }
        Client {
            conn,
            h3_conn: None,
            addr: *addr,
            id,
            in_flight_queries: HashMap::new(),
            pending_answers: Vec::new(),
        }
    }
    }


    fn create_http3_connection(&mut self) -> Result<()> {
    fn create_http3_connection(&mut self) -> Result<()> {
@@ -135,8 +146,23 @@ impl Client {
        info!("Preparing HTTP/3 response {:?} on stream {}", headers, stream_id);
        info!("Preparing HTTP/3 response {:?} on stream {}", headers, stream_id);


        h3_conn.send_response(&mut self.conn, stream_id, &headers, false)?;
        h3_conn.send_response(&mut self.conn, stream_id, &headers, false)?;
        h3_conn.send_body(&mut self.conn, stream_id, response, true)?;


        // In order to simulate the case that server send multiple packets for a DNS answer,
        // only send half of the answer here. The remaining one will be cached here and then
        // processed later in process_pending_answers().
        let (first, second) = response.split_at(len / 2);
        h3_conn.send_body(&mut self.conn, stream_id, first, false)?;
        self.pending_answers.push((stream_id, second.to_vec()));

        Ok(())
    }

    pub fn process_pending_answers(&mut self) -> Result<()> {
        if let Some((stream_id, ans)) = self.pending_answers.pop() {
            let h3_conn = self.h3_conn.as_mut().unwrap();
            info!("process the remaining response for stream {}", stream_id);
            h3_conn.send_body(&mut self.conn, stream_id, &ans, true)?;
        }
        Ok(())
        Ok(())
    }
    }


+1 −0
Original line number Original line Diff line number Diff line
@@ -298,6 +298,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> {
                                    error!("flush_egress failed: {}", e);
                                    error!("flush_egress failed: {}", e);
                                }
                                }
                            }
                            }
                            client.process_pending_answers().unwrap();
                        }
                        }
                    }
                    }
                }
                }