Loading tests/dns_responder/dns_responder_client_ndk.cpp +22 −19 Original line number Diff line number Diff line Loading @@ -21,8 +21,6 @@ #include <android/binder_manager.h> #include "NetdClient.h" // TODO: make this dynamic and stop depending on implementation details. #define TEST_OEM_NETWORK "oem29" #define TEST_NETID 30 // TODO: move this somewhere shared. Loading Loading @@ -172,43 +170,48 @@ void DnsResponderClient::SetupDNSServers(unsigned numServers, const std::vector< } } int DnsResponderClient::SetupOemNetwork() { mNetdSrv->networkDestroy(TEST_NETID); mDnsResolvSrv->destroyNetworkCache(TEST_NETID); int DnsResponderClient::SetupOemNetwork(int oemNetId) { mNetdSrv->networkDestroy(oemNetId); mDnsResolvSrv->destroyNetworkCache(oemNetId); ::ndk::ScopedAStatus ret; if (DnsResponderClient::isRemoteVersionSupported(mNetdSrv, 6)) { const auto& config = DnsResponderClient::makeNativeNetworkConfig( TEST_NETID, NativeNetworkType::PHYSICAL, INetd::PERMISSION_NONE, /*secure=*/false); oemNetId, NativeNetworkType::PHYSICAL, INetd::PERMISSION_NONE, /*secure=*/false); ret = mNetdSrv->networkCreate(config); } else { // Only for presubmit tests that run mainline module (and its tests) on R or earlier images. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wdeprecated-declarations" ret = mNetdSrv->networkCreatePhysical(TEST_NETID, INetd::PERMISSION_NONE); ret = mNetdSrv->networkCreatePhysical(oemNetId, INetd::PERMISSION_NONE); #pragma clang diagnostic pop } if (!ret.isOk()) { fprintf(stderr, "Creating physical network %d failed, %s\n", TEST_NETID, ret.getMessage()); fprintf(stderr, "Creating physical network %d failed, %s\n", oemNetId, ret.getMessage()); return -1; } ret = mDnsResolvSrv->createNetworkCache(TEST_NETID); ret = mDnsResolvSrv->createNetworkCache(oemNetId); if (!ret.isOk()) { fprintf(stderr, "Creating network cache %d failed, %s\n", TEST_NETID, ret.getMessage()); fprintf(stderr, "Creating network cache %d failed, %s\n", oemNetId, ret.getMessage()); return -1; } setNetworkForProcess(TEST_NETID); if ((unsigned)TEST_NETID != getNetworkForProcess()) { setNetworkForProcess(oemNetId); if ((unsigned)oemNetId != getNetworkForProcess()) { return -1; } return TEST_NETID; return 0; } void DnsResponderClient::TearDownOemNetwork(int oemNetId) { if (oemNetId != -1) { mNetdSrv->networkDestroy(oemNetId); mDnsResolvSrv->destroyNetworkCache(oemNetId); int DnsResponderClient::TearDownOemNetwork(int oemNetId) { if (auto status = mNetdSrv->networkDestroy(oemNetId); !status.isOk()) { fprintf(stderr, "Removing network %d failed, %s\n", oemNetId, status.getMessage()); return -1; } if (auto status = mDnsResolvSrv->destroyNetworkCache(oemNetId); !status.isOk()) { fprintf(stderr, "Removing network cache %d failed, %s\n", oemNetId, status.getMessage()); return -1; } return 0; } void DnsResponderClient::SetUp() { Loading @@ -228,11 +231,11 @@ void DnsResponderClient::SetUp() { // Ensure resolutions go via proxy. setenv(ANDROID_DNS_MODE, "", 1); mOemNetId = SetupOemNetwork(); SetupOemNetwork(TEST_NETID); } void DnsResponderClient::TearDown() { TearDownOemNetwork(mOemNetId); TearDownOemNetwork(TEST_NETID); } NativeNetworkConfig DnsResponderClient::makeNativeNetworkConfig(int netId, Loading tests/dns_responder/dns_responder_client_ndk.h +3 −4 Original line number Diff line number Diff line Loading @@ -128,9 +128,9 @@ class DnsResponderClient { const std::vector<std::string>& domains, const std::string& tlsHostname, const std::vector<std::string>& tlsServers, const std::string& caCert = ""); int SetupOemNetwork(); void TearDownOemNetwork(int oemNetId); // Returns 0 on success and a negative value on failure. int SetupOemNetwork(int oemNetId); int TearDownOemNetwork(int oemNetId); virtual void SetUp(); virtual void TearDown(); Loading @@ -141,5 +141,4 @@ class DnsResponderClient { private: std::shared_ptr<aidl::android::net::INetd> mNetdSrv; std::shared_ptr<aidl::android::net::IDnsResolver> mDnsResolvSrv; int mOemNetId = -1; }; tests/doh/include/lib.rs.h +3 −3 Original line number Diff line number Diff line Loading @@ -29,8 +29,8 @@ struct DohFrontend; struct Stats { /// The number of accumulated DoH queries that are received. uint32_t queries_received; /// The number of accumulated QUIC connections. uint32_t connections; /// The number of accumulated QUIC connections accepted. uint32_t connections_accepted; }; extern "C" { Loading Loading @@ -99,7 +99,7 @@ bool frontend_set_max_streams_bidi(DohFrontend *doh, uint64_t value); bool frontend_block_sending(DohFrontend *doh, bool block); /// Gets the statistics of the `DohFrontend` and writes the result to |out|. void frontend_stats(const DohFrontend *doh, Stats *out); bool frontend_stats(DohFrontend *doh, Stats *out); /// Resets `queries_received` field of `Stats` owned by the `DohFrontend`. bool frontend_stats_clear_queries(const DohFrontend *doh); Loading tests/doh/src/client.rs +11 −8 Original line number Diff line number Diff line Loading @@ -253,14 +253,13 @@ impl ClientMap { &mut self, hdr: &quiche::Header, addr: &SocketAddr, ) -> Result<(&mut Client, bool)> { let dcid = hdr.dcid.as_ref(); let is_new_client = !self.clients.contains_key(dcid); let client = if is_new_client { ) -> Result<&mut Client> { let dcid = hdr.dcid.as_ref().to_vec(); let client = if !self.clients.contains_key(&dcid) { ensure!(hdr.ty == quiche::Type::Initial, "Packet is not Initial"); ensure!(quiche::version_is_supported(hdr.version), "Protocol version not supported"); let scid = generate_conn_id(&self.conn_id_seed, dcid); let scid = generate_conn_id(&self.conn_id_seed, &dcid); let conn = quiche::accept( &quiche::ConnectionId::from_ref(&scid), None, /* odcid */ Loading @@ -273,19 +272,23 @@ impl ClientMap { self.clients.insert(scid.clone(), client); self.clients.get_mut(&scid).unwrap() } else { self.clients.get_mut(dcid).unwrap() self.clients.get_mut(&dcid).unwrap() }; Ok((client, is_new_client)) Ok(client) } pub fn get_mut(&mut self, id: &[u8]) -> Option<&mut Client> { self.clients.get_mut(&id.to_vec()) } pub fn get_mut_iter(&mut self) -> hash_map::IterMut<ConnectionID, Client> { pub fn iter_mut(&mut self) -> hash_map::IterMut<ConnectionID, Client> { self.clients.iter_mut() } pub fn len(&mut self) -> usize { self.clients.len() } } fn generate_conn_id(conn_id_seed: &hmac::Key, dcid: &[u8]) -> ConnectionID { Loading tests/doh/src/dns_https_frontend.rs +85 −30 Original line number Diff line number Diff line Loading @@ -29,7 +29,7 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::net::UdpSocket; use tokio::runtime::{Builder, Runtime}; use tokio::sync::mpsc::channel; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; lazy_static! { Loading @@ -44,11 +44,19 @@ lazy_static! { ); } /// Command used by worker_thread itself. #[derive(Debug)] enum Command { enum InternalCommand { MaybeWrite { connection_id: ConnectionID }, } /// Commands that DohFrontend to ask its worker_thread for. #[derive(Debug)] enum ControlCommand { Stats { resp: oneshot::Sender<Stats> }, StatsClearQueries, } /// Frontend object. #[derive(Debug)] pub struct DohFrontend { Loading @@ -73,10 +81,11 @@ pub struct DohFrontend { // TODO: use channel to update worker_thread configuration. config: Arc<Mutex<Config>>, // Stores some statistic to check DohFrontend status. // It's shared with the worker thread. // TODO: use channel to retrieve the stats from worker_thread. stats: Arc<Mutex<Stats>>, // Caches the latest stats so that the stats remains after worker_thread stops. latest_stats: Stats, // It is wrapped as Option because the channel is not created in DohFrontend construction. command_tx: Option<mpsc::UnboundedSender<ControlCommand>>, } /// The parameters passed to the worker thread. Loading @@ -85,7 +94,7 @@ struct WorkerParams { backend_socket: std::net::UdpSocket, clients: ClientMap, config: Arc<Mutex<Config>>, stats: Arc<Mutex<Stats>>, command_rx: mpsc::UnboundedReceiver<ControlCommand>, } impl DohFrontend { Loading @@ -100,7 +109,8 @@ impl DohFrontend { private_key: String::new(), worker_thread: None, config: Arc::new(Mutex::new(Config::new())), stats: Arc::new(Mutex::new(Stats::new())), latest_stats: Stats::new(), command_tx: None, }); debug!("DohFrontend created: {:?}", doh); Ok(doh) Loading @@ -123,6 +133,9 @@ impl DohFrontend { pub fn stop(&mut self) -> Result<()> { if let Some(worker_thread) = self.worker_thread.take() { // Update latest_stats before stopping worker_thread. let _ = self.request_stats(); worker_thread.abort(); } Loading Loading @@ -165,16 +178,48 @@ impl DohFrontend { Ok(()) } pub fn stats(&self) -> Stats { self.stats.lock().unwrap().clone() pub fn request_stats(&mut self) -> Result<Stats> { ensure!( self.command_tx.is_some(), "command_tx is None because worker thread not yet initialized" ); let command_tx = self.command_tx.as_ref().unwrap(); if command_tx.is_closed() { return Ok(self.latest_stats.clone()); } let (resp_tx, resp_rx) = oneshot::channel(); command_tx.send(ControlCommand::Stats { resp: resp_tx })?; match RUNTIME_STATIC .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await }) { Ok(v) => match v { Ok(stats) => { self.latest_stats = stats.clone(); Ok(stats) } Err(e) => bail!(e), }, Err(e) => bail!(e), } } pub fn stats_clear_queries(&self) -> Result<()> { self.stats.lock().unwrap().queries_received = 0; Ok(()) ensure!( self.command_tx.is_some(), "command_tx is None because worker thread not yet initialized" ); return self .command_tx .as_ref() .unwrap() .send(ControlCommand::StatsClearQueries) .or_else(|e| bail!(e)); } fn init_worker_thread_params(&self) -> Result<WorkerParams> { fn init_worker_thread_params(&mut self) -> Result<WorkerParams> { let bind_addr = if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }; let backend_socket = std::net::UdpSocket::bind(bind_addr)?; Loading @@ -190,12 +235,15 @@ impl DohFrontend { self.config.clone(), )?)?; let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>(); self.command_tx = Some(command_tx); Ok(WorkerParams { frontend_socket, backend_socket, clients, config: self.config.clone(), stats: self.stats.clone(), command_rx, }) } } Loading @@ -204,18 +252,19 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { let backend_socket = into_tokio_udp_socket(params.backend_socket)?; let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?; let config = params.config; let stats = params.stats; let (event_tx, mut event_rx) = channel::<Command>(100); let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>(); let mut command_rx = params.command_rx; let mut clients = params.clients; let mut frontend_buf = [0; 65535]; let mut backend_buf = [0; 16384]; let mut delay_queries_buffer: Vec<Vec<u8>> = vec![]; let mut queries_received = 0; debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket); loop { let timeout = clients .get_mut_iter() .iter_mut() .filter_map(|(_, c)| c.timeout()) .min() .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS)); Loading @@ -223,12 +272,12 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { tokio::select! { _ = tokio::time::sleep(timeout) => { debug!("timeout"); for (_, client) in clients.get_mut_iter() { for (_, client) in clients.iter_mut() { // If no timeout has occurred it does nothing. client.on_timeout(); let connection_id = client.connection_id().clone(); event_tx.send(Command::MaybeWrite{connection_id}).await?; event_tx.send(InternalCommand::MaybeWrite{connection_id})?; } } Loading @@ -247,12 +296,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { debug!("Got QUIC packet: {:?}", hdr); let client = match clients.get_or_create(&hdr, &src) { Ok((client, is_new_client)) => { if is_new_client { stats.lock().unwrap().connections += 1; } client } Ok(v) => v, Err(e) => { error!("Failed to get the client by the hdr {:?}: {}", hdr, e); continue; Loading @@ -263,7 +307,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { match client.handle_frontend_message(pkt_buf) { Ok(v) if !v.is_empty() => { delay_queries_buffer.push(v); stats.lock().unwrap().queries_received += 1; queries_received += 1; } Err(e) => { error!("Failed to process QUIC packet: {}", e); Loading @@ -280,7 +324,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } let connection_id = client.connection_id().clone(); event_tx.send(Command::MaybeWrite{connection_id}).await?; event_tx.send(InternalCommand::MaybeWrite{connection_id})?; } Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => { Loading @@ -291,13 +335,13 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } let query_id = [backend_buf[0], backend_buf[1]]; for (_, client) in clients.get_mut_iter() { for (_, client) in clients.iter_mut() { if client.is_waiting_for_query(&query_id) { if let Err(e) = client.handle_backend_message(&backend_buf[..len]) { error!("Failed to handle message from backend: {}", e); } let connection_id = client.connection_id().clone(); event_tx.send(Command::MaybeWrite{connection_id}).await?; event_tx.send(InternalCommand::MaybeWrite{connection_id})?; // It's a bug if more than one client is waiting for this query. break; Loading @@ -307,7 +351,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => { match command { Command::MaybeWrite {connection_id} => { InternalCommand::MaybeWrite {connection_id} => { if let Some(client) = clients.get_mut(&connection_id) { while let Ok(v) = client.flush_egress() { let addr = client.addr(); Loading @@ -321,6 +365,17 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } } } Some(command) = command_rx.recv() => { match command { ControlCommand::Stats {resp} => { let stats = Stats {queries_received, connections_accepted: clients.len() as u32}; if let Err(e) = resp.send(stats) { error!("Failed to send ControlCommand::Stats response: {:?}", e); } } ControlCommand::StatsClearQueries => queries_received = 0, } } } } } Loading Loading
tests/dns_responder/dns_responder_client_ndk.cpp +22 −19 Original line number Diff line number Diff line Loading @@ -21,8 +21,6 @@ #include <android/binder_manager.h> #include "NetdClient.h" // TODO: make this dynamic and stop depending on implementation details. #define TEST_OEM_NETWORK "oem29" #define TEST_NETID 30 // TODO: move this somewhere shared. Loading Loading @@ -172,43 +170,48 @@ void DnsResponderClient::SetupDNSServers(unsigned numServers, const std::vector< } } int DnsResponderClient::SetupOemNetwork() { mNetdSrv->networkDestroy(TEST_NETID); mDnsResolvSrv->destroyNetworkCache(TEST_NETID); int DnsResponderClient::SetupOemNetwork(int oemNetId) { mNetdSrv->networkDestroy(oemNetId); mDnsResolvSrv->destroyNetworkCache(oemNetId); ::ndk::ScopedAStatus ret; if (DnsResponderClient::isRemoteVersionSupported(mNetdSrv, 6)) { const auto& config = DnsResponderClient::makeNativeNetworkConfig( TEST_NETID, NativeNetworkType::PHYSICAL, INetd::PERMISSION_NONE, /*secure=*/false); oemNetId, NativeNetworkType::PHYSICAL, INetd::PERMISSION_NONE, /*secure=*/false); ret = mNetdSrv->networkCreate(config); } else { // Only for presubmit tests that run mainline module (and its tests) on R or earlier images. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wdeprecated-declarations" ret = mNetdSrv->networkCreatePhysical(TEST_NETID, INetd::PERMISSION_NONE); ret = mNetdSrv->networkCreatePhysical(oemNetId, INetd::PERMISSION_NONE); #pragma clang diagnostic pop } if (!ret.isOk()) { fprintf(stderr, "Creating physical network %d failed, %s\n", TEST_NETID, ret.getMessage()); fprintf(stderr, "Creating physical network %d failed, %s\n", oemNetId, ret.getMessage()); return -1; } ret = mDnsResolvSrv->createNetworkCache(TEST_NETID); ret = mDnsResolvSrv->createNetworkCache(oemNetId); if (!ret.isOk()) { fprintf(stderr, "Creating network cache %d failed, %s\n", TEST_NETID, ret.getMessage()); fprintf(stderr, "Creating network cache %d failed, %s\n", oemNetId, ret.getMessage()); return -1; } setNetworkForProcess(TEST_NETID); if ((unsigned)TEST_NETID != getNetworkForProcess()) { setNetworkForProcess(oemNetId); if ((unsigned)oemNetId != getNetworkForProcess()) { return -1; } return TEST_NETID; return 0; } void DnsResponderClient::TearDownOemNetwork(int oemNetId) { if (oemNetId != -1) { mNetdSrv->networkDestroy(oemNetId); mDnsResolvSrv->destroyNetworkCache(oemNetId); int DnsResponderClient::TearDownOemNetwork(int oemNetId) { if (auto status = mNetdSrv->networkDestroy(oemNetId); !status.isOk()) { fprintf(stderr, "Removing network %d failed, %s\n", oemNetId, status.getMessage()); return -1; } if (auto status = mDnsResolvSrv->destroyNetworkCache(oemNetId); !status.isOk()) { fprintf(stderr, "Removing network cache %d failed, %s\n", oemNetId, status.getMessage()); return -1; } return 0; } void DnsResponderClient::SetUp() { Loading @@ -228,11 +231,11 @@ void DnsResponderClient::SetUp() { // Ensure resolutions go via proxy. setenv(ANDROID_DNS_MODE, "", 1); mOemNetId = SetupOemNetwork(); SetupOemNetwork(TEST_NETID); } void DnsResponderClient::TearDown() { TearDownOemNetwork(mOemNetId); TearDownOemNetwork(TEST_NETID); } NativeNetworkConfig DnsResponderClient::makeNativeNetworkConfig(int netId, Loading
tests/dns_responder/dns_responder_client_ndk.h +3 −4 Original line number Diff line number Diff line Loading @@ -128,9 +128,9 @@ class DnsResponderClient { const std::vector<std::string>& domains, const std::string& tlsHostname, const std::vector<std::string>& tlsServers, const std::string& caCert = ""); int SetupOemNetwork(); void TearDownOemNetwork(int oemNetId); // Returns 0 on success and a negative value on failure. int SetupOemNetwork(int oemNetId); int TearDownOemNetwork(int oemNetId); virtual void SetUp(); virtual void TearDown(); Loading @@ -141,5 +141,4 @@ class DnsResponderClient { private: std::shared_ptr<aidl::android::net::INetd> mNetdSrv; std::shared_ptr<aidl::android::net::IDnsResolver> mDnsResolvSrv; int mOemNetId = -1; };
tests/doh/include/lib.rs.h +3 −3 Original line number Diff line number Diff line Loading @@ -29,8 +29,8 @@ struct DohFrontend; struct Stats { /// The number of accumulated DoH queries that are received. uint32_t queries_received; /// The number of accumulated QUIC connections. uint32_t connections; /// The number of accumulated QUIC connections accepted. uint32_t connections_accepted; }; extern "C" { Loading Loading @@ -99,7 +99,7 @@ bool frontend_set_max_streams_bidi(DohFrontend *doh, uint64_t value); bool frontend_block_sending(DohFrontend *doh, bool block); /// Gets the statistics of the `DohFrontend` and writes the result to |out|. void frontend_stats(const DohFrontend *doh, Stats *out); bool frontend_stats(DohFrontend *doh, Stats *out); /// Resets `queries_received` field of `Stats` owned by the `DohFrontend`. bool frontend_stats_clear_queries(const DohFrontend *doh); Loading
tests/doh/src/client.rs +11 −8 Original line number Diff line number Diff line Loading @@ -253,14 +253,13 @@ impl ClientMap { &mut self, hdr: &quiche::Header, addr: &SocketAddr, ) -> Result<(&mut Client, bool)> { let dcid = hdr.dcid.as_ref(); let is_new_client = !self.clients.contains_key(dcid); let client = if is_new_client { ) -> Result<&mut Client> { let dcid = hdr.dcid.as_ref().to_vec(); let client = if !self.clients.contains_key(&dcid) { ensure!(hdr.ty == quiche::Type::Initial, "Packet is not Initial"); ensure!(quiche::version_is_supported(hdr.version), "Protocol version not supported"); let scid = generate_conn_id(&self.conn_id_seed, dcid); let scid = generate_conn_id(&self.conn_id_seed, &dcid); let conn = quiche::accept( &quiche::ConnectionId::from_ref(&scid), None, /* odcid */ Loading @@ -273,19 +272,23 @@ impl ClientMap { self.clients.insert(scid.clone(), client); self.clients.get_mut(&scid).unwrap() } else { self.clients.get_mut(dcid).unwrap() self.clients.get_mut(&dcid).unwrap() }; Ok((client, is_new_client)) Ok(client) } pub fn get_mut(&mut self, id: &[u8]) -> Option<&mut Client> { self.clients.get_mut(&id.to_vec()) } pub fn get_mut_iter(&mut self) -> hash_map::IterMut<ConnectionID, Client> { pub fn iter_mut(&mut self) -> hash_map::IterMut<ConnectionID, Client> { self.clients.iter_mut() } pub fn len(&mut self) -> usize { self.clients.len() } } fn generate_conn_id(conn_id_seed: &hmac::Key, dcid: &[u8]) -> ConnectionID { Loading
tests/doh/src/dns_https_frontend.rs +85 −30 Original line number Diff line number Diff line Loading @@ -29,7 +29,7 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::net::UdpSocket; use tokio::runtime::{Builder, Runtime}; use tokio::sync::mpsc::channel; use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinHandle; lazy_static! { Loading @@ -44,11 +44,19 @@ lazy_static! { ); } /// Command used by worker_thread itself. #[derive(Debug)] enum Command { enum InternalCommand { MaybeWrite { connection_id: ConnectionID }, } /// Commands that DohFrontend to ask its worker_thread for. #[derive(Debug)] enum ControlCommand { Stats { resp: oneshot::Sender<Stats> }, StatsClearQueries, } /// Frontend object. #[derive(Debug)] pub struct DohFrontend { Loading @@ -73,10 +81,11 @@ pub struct DohFrontend { // TODO: use channel to update worker_thread configuration. config: Arc<Mutex<Config>>, // Stores some statistic to check DohFrontend status. // It's shared with the worker thread. // TODO: use channel to retrieve the stats from worker_thread. stats: Arc<Mutex<Stats>>, // Caches the latest stats so that the stats remains after worker_thread stops. latest_stats: Stats, // It is wrapped as Option because the channel is not created in DohFrontend construction. command_tx: Option<mpsc::UnboundedSender<ControlCommand>>, } /// The parameters passed to the worker thread. Loading @@ -85,7 +94,7 @@ struct WorkerParams { backend_socket: std::net::UdpSocket, clients: ClientMap, config: Arc<Mutex<Config>>, stats: Arc<Mutex<Stats>>, command_rx: mpsc::UnboundedReceiver<ControlCommand>, } impl DohFrontend { Loading @@ -100,7 +109,8 @@ impl DohFrontend { private_key: String::new(), worker_thread: None, config: Arc::new(Mutex::new(Config::new())), stats: Arc::new(Mutex::new(Stats::new())), latest_stats: Stats::new(), command_tx: None, }); debug!("DohFrontend created: {:?}", doh); Ok(doh) Loading @@ -123,6 +133,9 @@ impl DohFrontend { pub fn stop(&mut self) -> Result<()> { if let Some(worker_thread) = self.worker_thread.take() { // Update latest_stats before stopping worker_thread. let _ = self.request_stats(); worker_thread.abort(); } Loading Loading @@ -165,16 +178,48 @@ impl DohFrontend { Ok(()) } pub fn stats(&self) -> Stats { self.stats.lock().unwrap().clone() pub fn request_stats(&mut self) -> Result<Stats> { ensure!( self.command_tx.is_some(), "command_tx is None because worker thread not yet initialized" ); let command_tx = self.command_tx.as_ref().unwrap(); if command_tx.is_closed() { return Ok(self.latest_stats.clone()); } let (resp_tx, resp_rx) = oneshot::channel(); command_tx.send(ControlCommand::Stats { resp: resp_tx })?; match RUNTIME_STATIC .block_on(async { tokio::time::timeout(Duration::from_secs(1), resp_rx).await }) { Ok(v) => match v { Ok(stats) => { self.latest_stats = stats.clone(); Ok(stats) } Err(e) => bail!(e), }, Err(e) => bail!(e), } } pub fn stats_clear_queries(&self) -> Result<()> { self.stats.lock().unwrap().queries_received = 0; Ok(()) ensure!( self.command_tx.is_some(), "command_tx is None because worker thread not yet initialized" ); return self .command_tx .as_ref() .unwrap() .send(ControlCommand::StatsClearQueries) .or_else(|e| bail!(e)); } fn init_worker_thread_params(&self) -> Result<WorkerParams> { fn init_worker_thread_params(&mut self) -> Result<WorkerParams> { let bind_addr = if self.backend_socket_addr.ip().is_ipv4() { "0.0.0.0:0" } else { "[::]:0" }; let backend_socket = std::net::UdpSocket::bind(bind_addr)?; Loading @@ -190,12 +235,15 @@ impl DohFrontend { self.config.clone(), )?)?; let (command_tx, command_rx) = mpsc::unbounded_channel::<ControlCommand>(); self.command_tx = Some(command_tx); Ok(WorkerParams { frontend_socket, backend_socket, clients, config: self.config.clone(), stats: self.stats.clone(), command_rx, }) } } Loading @@ -204,18 +252,19 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { let backend_socket = into_tokio_udp_socket(params.backend_socket)?; let frontend_socket = into_tokio_udp_socket(params.frontend_socket)?; let config = params.config; let stats = params.stats; let (event_tx, mut event_rx) = channel::<Command>(100); let (event_tx, mut event_rx) = mpsc::unbounded_channel::<InternalCommand>(); let mut command_rx = params.command_rx; let mut clients = params.clients; let mut frontend_buf = [0; 65535]; let mut backend_buf = [0; 16384]; let mut delay_queries_buffer: Vec<Vec<u8>> = vec![]; let mut queries_received = 0; debug!("frontend={:?}, backend={:?}", frontend_socket, backend_socket); loop { let timeout = clients .get_mut_iter() .iter_mut() .filter_map(|(_, c)| c.timeout()) .min() .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS)); Loading @@ -223,12 +272,12 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { tokio::select! { _ = tokio::time::sleep(timeout) => { debug!("timeout"); for (_, client) in clients.get_mut_iter() { for (_, client) in clients.iter_mut() { // If no timeout has occurred it does nothing. client.on_timeout(); let connection_id = client.connection_id().clone(); event_tx.send(Command::MaybeWrite{connection_id}).await?; event_tx.send(InternalCommand::MaybeWrite{connection_id})?; } } Loading @@ -247,12 +296,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { debug!("Got QUIC packet: {:?}", hdr); let client = match clients.get_or_create(&hdr, &src) { Ok((client, is_new_client)) => { if is_new_client { stats.lock().unwrap().connections += 1; } client } Ok(v) => v, Err(e) => { error!("Failed to get the client by the hdr {:?}: {}", hdr, e); continue; Loading @@ -263,7 +307,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { match client.handle_frontend_message(pkt_buf) { Ok(v) if !v.is_empty() => { delay_queries_buffer.push(v); stats.lock().unwrap().queries_received += 1; queries_received += 1; } Err(e) => { error!("Failed to process QUIC packet: {}", e); Loading @@ -280,7 +324,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } let connection_id = client.connection_id().clone(); event_tx.send(Command::MaybeWrite{connection_id}).await?; event_tx.send(InternalCommand::MaybeWrite{connection_id})?; } Ok((len, src)) = backend_socket.recv_from(&mut backend_buf) => { Loading @@ -291,13 +335,13 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } let query_id = [backend_buf[0], backend_buf[1]]; for (_, client) in clients.get_mut_iter() { for (_, client) in clients.iter_mut() { if client.is_waiting_for_query(&query_id) { if let Err(e) = client.handle_backend_message(&backend_buf[..len]) { error!("Failed to handle message from backend: {}", e); } let connection_id = client.connection_id().clone(); event_tx.send(Command::MaybeWrite{connection_id}).await?; event_tx.send(InternalCommand::MaybeWrite{connection_id})?; // It's a bug if more than one client is waiting for this query. break; Loading @@ -307,7 +351,7 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { Some(command) = event_rx.recv(), if !config.lock().unwrap().block_sending => { match command { Command::MaybeWrite {connection_id} => { InternalCommand::MaybeWrite {connection_id} => { if let Some(client) = clients.get_mut(&connection_id) { while let Ok(v) = client.flush_egress() { let addr = client.addr(); Loading @@ -321,6 +365,17 @@ async fn worker_thread(params: WorkerParams) -> Result<()> { } } } Some(command) = command_rx.recv() => { match command { ControlCommand::Stats {resp} => { let stats = Stats {queries_received, connections_accepted: clients.len() as u32}; if let Err(e) = resp.send(stats) { error!("Failed to send ControlCommand::Stats response: {:?}", e); } } ControlCommand::StatsClearQueries => queries_received = 0, } } } } } Loading