Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 373b7f62 authored by Matthew Maurer's avatar Matthew Maurer
Browse files

DoH: Factor out BootTime

* Move `BootTime` into the `boot_time` module
* Change `elapsed()` to match `Instant` API
* Add timeout + sleep functions, relative to `CLOCK_BOOTTIME`
* Change everywhere in DoH to use `boot_time` instead of `time`
* Add tests for boot_time module

BYPASS_INCLUSIVE_LANGUAGE_REASON="man is referring to the unix manual command, not a person"
Bug: 202081046
Bug: 200694560

Change-Id: I719965ff75abb0223ba20829ca0a3a4be1d07f40
parent 6b17842f
Loading
Loading
Loading
Loading

doh/boot_time.rs

0 → 100644
+206 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! This module provides a time hack to work around the broken `Instant` type in the standard
//! library.
//!
//! `BootTime` looks like `Instant`, but represents `CLOCK_BOOTTIME` instead of `CLOCK_MONOTONIC`.
//! This means the clock increments correctly during suspend.

pub use std::time::Duration;

use std::io;

use futures::future::pending;
use std::convert::TryInto;
use std::fmt;
use std::future::Future;
use std::os::unix::io::{AsRawFd, RawFd};
use tokio::io::unix::AsyncFd;
use tokio::select;

/// Represents a moment in time, with differences including time spent in suspend. Only valid for
/// a single boot - numbers from different boots are incomparable.
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BootTime {
    d: Duration,
}

// Return an error with the same structure as tokio::time::timeout to facilitate migration off it,
// and hopefully some day back to it.
/// Error returned by timeout
#[derive(Debug, PartialEq)]
pub struct Elapsed(());

impl fmt::Display for Elapsed {
    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
        "deadline has elapsed".fmt(fmt)
    }
}

impl std::error::Error for Elapsed {}

impl BootTime {
    /// Gets a `BootTime` representing the current moment in time.
    pub fn now() -> BootTime {
        let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 };
        // # Safety
        // clock_gettime's only action will be to possibly write to the pointer provided,
        // and no borrows exist from that object other than the &mut used to construct the pointer
        // itself.
        if unsafe { libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) } != 0
        {
            panic!(
                "libc::clock_gettime(libc::CLOCK_BOOTTIME) failed: {:?}",
                io::Error::last_os_error()
            );
        }
        BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) }
    }

    /// Determines how long has elapsed since the provided `BootTime`.
    pub fn elapsed(&self) -> Duration {
        BootTime::now().checked_duration_since(*self).unwrap()
    }

    /// Add a specified time delta to a moment in time. If this would overflow the representation,
    /// returns `None`.
    pub fn checked_add(&self, duration: Duration) -> Option<BootTime> {
        Some(BootTime { d: self.d.checked_add(duration)? })
    }

    /// Finds the difference from an earlier point in time. If the provided time is later, returns
    /// `None`.
    pub fn checked_duration_since(&self, earlier: BootTime) -> Option<Duration> {
        self.d.checked_sub(earlier.d)
    }
}

struct TimerFd(RawFd);

impl Drop for TimerFd {
    fn drop(&mut self) {
        // # Safety
        // The fd is owned by the TimerFd struct, and no memory access occurs as a result of this
        // call.
        unsafe {
            libc::close(self.0);
        }
    }
}

impl AsRawFd for TimerFd {
    fn as_raw_fd(&self) -> RawFd {
        self.0
    }
}

impl TimerFd {
    fn create() -> io::Result<Self> {
        // # Unsafe
        // This libc call will either give us back a file descriptor or fail, it does not act on
        // memory or resources.
        let raw = unsafe {
            libc::timerfd_create(libc::CLOCK_BOOTTIME, libc::TFD_NONBLOCK | libc::TFD_CLOEXEC)
        };
        if raw < 0 {
            return Err(io::Error::last_os_error());
        }
        Ok(Self(raw))
    }

    fn set(&self, duration: Duration) {
        let timer = libc::itimerspec {
            it_interval: libc::timespec { tv_sec: 0, tv_nsec: 0 },
            it_value: libc::timespec {
                tv_sec: duration.as_secs().try_into().unwrap(),
                tv_nsec: duration.subsec_nanos().try_into().unwrap(),
            },
        };
        // # Unsafe
        // We own `timer` and there are no borrows to it other than the pointer we pass to
        // timerfd_settime. timerfd_settime is explicitly documented to handle a null output
        // parameter for its fourth argument by not filling out the output. The fd passed in at
        // self.0 is owned by the `TimerFd` struct, so we aren't breaking anyone else's invariants.
        if unsafe { libc::timerfd_settime(self.0, 0, &timer, std::ptr::null_mut()) } != 0 {
            panic!("timerfd_settime failed: {:?}", io::Error::last_os_error());
        }
    }
}

/// Runs the provided future until completion or `duration` has passed on the `CLOCK_BOOTTIME`
/// clock. In the event of a timeout, returns the elapsed time as an error.
pub async fn timeout<T>(duration: Duration, future: impl Future<Output = T>) -> Result<T, Elapsed> {
    // Ideally, all timeouts in a runtime would share a timerfd. That will be much more
    // straightforwards to implement when moving this functionality into `tokio`.

    // The failure conditions for this are rare (see `man 2 timerfd_create`) and the caller would
    // not be able to do much in response to them. When integrated into tokio, this would be called
    // during runtime setup.
    let timer_fd = TimerFd::create().unwrap();
    timer_fd.set(duration);
    let async_fd = AsyncFd::new(timer_fd).unwrap();
    select! {
        v = future => Ok(v),
        _ = async_fd.readable() => Err(Elapsed(())),
    }
}

/// Provides a future which will complete once the provided duration has passed, as measured by the
/// `CLOCK_BOOTTIME` clock.
pub async fn sleep(duration: Duration) {
    assert!(timeout(duration, pending::<()>()).await.is_err());
}

#[test]
fn monotonic_smoke() {
    for _ in 0..1000 {
        // If BootTime is not monotonic, .elapsed() will panic on the unwrap.
        BootTime::now().elapsed();
    }
}

#[test]
fn round_trip() {
    use std::thread::sleep;
    for _ in 0..10 {
        let start = BootTime::now();
        sleep(Duration::from_millis(1));
        let end = BootTime::now();
        let delta = end.checked_duration_since(start).unwrap();
        assert_eq!(start.checked_add(delta).unwrap(), end);
    }
}

#[tokio::test]
async fn timeout_drift() {
    let delta = Duration::from_millis(20);
    for _ in 0..10 {
        let start = BootTime::now();
        assert!(timeout(delta, pending::<()>()).await.is_err());
        let taken = start.elapsed();
        let drift = if taken > delta { taken - delta } else { delta - taken };
        assert!(drift < Duration::from_millis(5));
    }

    for _ in 0..10 {
        let start = BootTime::now();
        sleep(delta).await;
        let taken = start.elapsed();
        let drift = if taken > delta { taken - delta } else { delta - taken };
        assert!(drift < Duration::from_millis(5));
    }
}
+14 −41
Original line number Diff line number Diff line
@@ -35,11 +35,13 @@ use tokio::net::UdpSocket;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{mpsc, oneshot};
use tokio::task;
use tokio::time::{timeout, Duration, Instant};
use url::Url;

pub mod boot_time;
mod ffi;

use boot_time::{timeout, BootTime, Duration};

const MAX_BUFFERED_CMD_SIZE: usize = 400;
const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000;
@@ -91,7 +93,7 @@ enum Response {
#[derive(Debug)]
enum DohCommand {
    Probe { info: ServerInfo, timeout: Duration },
    Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder },
    Query { net_id: u32, base64_query: Base64Query, expired_time: BootTime, resp: QueryResponder },
    Clear { net_id: u32 },
    Exit,
}
@@ -132,35 +134,6 @@ impl<T: Deref> OptionDeref<T> for Option<T> {
    }
}

#[derive(Copy, Clone, Debug)]
struct BootTime {
    d: Duration,
}

impl BootTime {
    fn now() -> BootTime {
        unsafe {
            let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 };
            if libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) != 0 {
                panic!("get boot time failed: {:?}", std::io::Error::last_os_error());
            }
            BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) }
        }
    }

    fn elapsed(&self) -> Option<Duration> {
        BootTime::now().duration_since(*self)
    }

    fn checked_add(&self, duration: Duration) -> Option<BootTime> {
        Some(BootTime { d: self.d.checked_add(duration)? })
    }

    fn duration_since(&self, earlier: BootTime) -> Option<Duration> {
        self.d.checked_sub(earlier.d)
    }
}

/// Context for a running DoH engine.
pub struct DohDispatcher {
    /// Used to submit cmds to the I/O task.
@@ -204,7 +177,7 @@ struct DohConnection {
    shared_config: Arc<Mutex<QuicheConfigCache>>,
    scid: SCID,
    state: ConnectionState,
    pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>,
    pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>,
    cached_session: Option<Vec<u8>>,
    tag_socket_fn: TagSocketCallback,
}
@@ -333,7 +306,7 @@ impl DohConnection {
        };

        if let Some(expired_time) = expired_time {
            if let Some(elapsed) = expired_time.elapsed() {
            if let Some(elapsed) = BootTime::now().checked_duration_since(*expired_time) {
                warn!(
                    "Change the state to Idle due to connection timeout, {:?}, {}",
                    elapsed, self.info.net_id
@@ -429,7 +402,7 @@ impl DohConnection {
        &mut self,
        req: DnsRequest,
        resp: QueryResponder,
        expired_time: Instant,
        expired_time: BootTime,
    ) -> Result<()> {
        self.handle_if_connection_expired();
        match &mut self.state {
@@ -472,7 +445,7 @@ impl DohConnection {
                while !self.pending_queries.is_empty() {
                    if let Some((req, resp, exp_time)) = self.pending_queries.pop() {
                        // Ignore the expired queries.
                        if Instant::now().checked_duration_since(exp_time).is_some() {
                        if BootTime::now().checked_duration_since(exp_time).is_some() {
                            warn!("Drop the obsolete query for network {}", self.info.net_id);
                            continue;
                        }
@@ -596,9 +569,9 @@ async fn send_dns_query(
    udp_sk: &mut UdpSocket,
    h3_conn: &mut h3::Connection,
    query_map: &mut HashMap<u64, (Vec<u8>, QueryResponder)>,
    pending_queries: &mut Vec<(DnsRequest, QueryResponder, Instant)>,
    pending_queries: &mut Vec<(DnsRequest, QueryResponder, BootTime)>,
    resp: QueryResponder,
    expired_time: Instant,
    expired_time: BootTime,
    req: DnsRequest,
) -> Result<()> {
    if !quic_conn.is_established() {
@@ -803,7 +776,7 @@ impl QuicheConfigCache {
async fn handle_query_cmd(
    net_id: u32,
    base64_query: Base64Query,
    expired_time: Instant,
    expired_time: BootTime,
    resp: QueryResponder,
    doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
) {
@@ -1107,7 +1080,7 @@ mod tests {
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                Instant::now().checked_add(t).unwrap(),
                BootTime::now().checked_add(t).unwrap(),
                resp_tx,
                &mut test_map,
            )
@@ -1122,7 +1095,7 @@ mod tests {
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                Instant::now().checked_add(t).unwrap(),
                BootTime::now().checked_add(t).unwrap(),
                resp_tx,
                &mut test_map,
            )
@@ -1148,7 +1121,7 @@ mod tests {
            super::handle_query_cmd(
                info.net_id,
                query.clone(),
                Instant::now().checked_add(t).unwrap(),
                BootTime::now().checked_add(t).unwrap(),
                resp_tx,
                &mut test_map,
            )
+2 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

//! C API for the DoH backend for the Android DnsResolver module.

use crate::boot_time::{timeout, BootTime, Duration};
use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::error;
use std::net::{IpAddr, SocketAddr};
@@ -26,7 +27,6 @@ use std::{ptr, slice};
use tokio::runtime::Runtime;
use tokio::sync::oneshot;
use tokio::task;
use tokio::time::{timeout, Duration, Instant};

use super::DohDispatcher as Dispatcher;
use super::{DohCommand, Response, ServerInfo, TagSocketCallback, ValidationCallback, DOH_PORT};
@@ -205,7 +205,7 @@ pub unsafe extern "C" fn doh_query(

    let (resp_tx, resp_rx) = oneshot::channel();
    let t = Duration::from_millis(timeout_ms);
    if let Some(expired_time) = Instant::now().checked_add(t) {
    if let Some(expired_time) = BootTime::now().checked_add(t) {
        let cmd = DohCommand::Query {
            net_id,
            base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),