Loading doh.rs +46 −46 Original line number Diff line number Diff line Loading @@ -111,7 +111,7 @@ enum Response { #[derive(Debug)] enum DohCommand { Probe { info: ServerInfo, timeout: Duration }, Query { net_id: u32, base64_query: Base64Query, timeout: Duration, resp: QueryResponder }, Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder }, Clear { net_id: u32 }, Exit, } Loading Loading @@ -179,7 +179,7 @@ struct DohConnection { h3_conn: Option<h3::Connection>, status: ConnectionStatus, query_map: HashMap<u64, QueryResponder>, pending_queries: Vec<(DnsRequest, QueryResponder, Option<Instant>)>, pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>, cached_session: Option<Vec<u8>>, } Loading Loading @@ -250,30 +250,34 @@ impl DohConnection { async fn try_send_doh_query( &mut self, req: DnsRequest, timeout: Duration, resp: QueryResponder, ) { expired_time: Instant, ) -> Result<()> { match self.status { ConnectionStatus::Ready => { // Send an query to probe the server. match self.send_dns_query(&req).await { ConnectionStatus::Ready => match self.send_dns_query(&req).await { Ok(req_id) => { self.query_map.insert(req_id, resp); } Err(e) => { error!("send querry error {:?}", e); if let Ok(quiche::h3::Error::StreamBlocked) = e.downcast::<quiche::h3::Error>() { warn!("try to send query but error on StreamBlocked"); self.pending_queries.push((req, resp, expired_time)); bail!(quiche::h3::Error::StreamBlocked); } else { resp.send(Response::Error { error: QueryError::ConnectionError }).ok(); } } } }, ConnectionStatus::Pending => { self.pending_queries.push((req, resp, Instant::now().checked_add(timeout))); self.pending_queries.push((req, resp, expired_time)); } // Should not happen _ => { error!("Try to send query but status error {}", self.net_id); } } Ok(()) } fn resume_connection(&mut self, quic_conn: Pin<Box<quiche::Connection>>) { Loading @@ -295,22 +299,12 @@ impl DohConnection { loop { while !self.pending_queries.is_empty() { if let Some((req, resp, exp_time)) = self.pending_queries.pop() { // TODO: check if req is expired. match self.send_dns_query(&req).await { Ok(req_id) => { self.query_map.insert(req_id, resp); // Ignore the expired queries. if Instant::now().checked_duration_since(exp_time).is_some() { continue; } Err(e) => { if let Ok(quiche::h3::Error::StreamBlocked) = e.downcast::<quiche::h3::Error>() { self.pending_queries.push((req, resp, exp_time)); if self.try_send_doh_query(req, resp, exp_time).await.is_err() { break; } else { resp.send(Response::Error { error: QueryError::ConnectionError }) .ok(); } } } } } Loading Loading @@ -588,7 +582,7 @@ fn resume_connection( async fn handle_query_cmd( net_id: u32, base64_query: Base64Query, timeout: Duration, expired_time: Instant, resp: QueryResponder, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, config_cache: &mut QuicheConfigCache, Loading Loading @@ -617,7 +611,7 @@ async fn handle_query_cmd( } if let Ok(req) = make_dns_request(&base64_query, &info.url) { debug!("Try to send query"); quic_conn.try_send_doh_query(req, timeout, resp).await; let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await; } else { let _ = resp.send(Response::Error { error: QueryError::Unexpected }); } Loading Loading @@ -679,8 +673,8 @@ async fn doh_handler( } } }, DohCommand::Query { net_id, base64_query, timeout, resp } => { handle_query_cmd(net_id, base64_query, timeout, resp, &mut doh_conn_map, &mut config_cache).await; DohCommand::Query { net_id, base64_query, expired_time, resp } => { handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map, &mut config_cache).await; }, DohCommand::Clear { net_id } => { doh_conn_map.remove(&net_id); Loading Loading @@ -932,12 +926,13 @@ pub unsafe extern "C" fn doh_query( ) -> ssize_t { let q = slice::from_raw_parts_mut(dns_query, dns_query_len); let t = Duration::from_millis(timeout_ms); let (resp_tx, resp_rx) = oneshot::channel(); let t = Duration::from_millis(timeout_ms); if let Some(expired_time) = Instant::now().checked_add(t) { let cmd = DohCommand::Query { net_id, base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD), timeout: t, expired_time, resp: resp_tx, }; Loading @@ -945,6 +940,11 @@ pub unsafe extern "C" fn doh_query( error!("Failed to send the query: {:?}", e); return RESULT_CAN_NOT_SEND; } } else { error!("Bad timeout parameter: {}", timeout_ms); return RESULT_CAN_NOT_SEND; } if let Ok(rt) = Runtime::new() { let local = task::LocalSet::new(); match local.block_on(&rt, async { timeout(t, resp_rx).await }) { Loading Loading @@ -1068,7 +1068,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), t, Instant::now().checked_add(t).unwrap(), resp_tx, &mut test_map, &mut config, Loading @@ -1084,7 +1084,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), t, Instant::now().checked_add(t).unwrap(), resp_tx, &mut test_map, &mut config, Loading @@ -1106,7 +1106,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), t, Instant::now().checked_add(t).unwrap(), resp_tx, &mut test_map, &mut config, Loading Loading
doh.rs +46 −46 Original line number Diff line number Diff line Loading @@ -111,7 +111,7 @@ enum Response { #[derive(Debug)] enum DohCommand { Probe { info: ServerInfo, timeout: Duration }, Query { net_id: u32, base64_query: Base64Query, timeout: Duration, resp: QueryResponder }, Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder }, Clear { net_id: u32 }, Exit, } Loading Loading @@ -179,7 +179,7 @@ struct DohConnection { h3_conn: Option<h3::Connection>, status: ConnectionStatus, query_map: HashMap<u64, QueryResponder>, pending_queries: Vec<(DnsRequest, QueryResponder, Option<Instant>)>, pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>, cached_session: Option<Vec<u8>>, } Loading Loading @@ -250,30 +250,34 @@ impl DohConnection { async fn try_send_doh_query( &mut self, req: DnsRequest, timeout: Duration, resp: QueryResponder, ) { expired_time: Instant, ) -> Result<()> { match self.status { ConnectionStatus::Ready => { // Send an query to probe the server. match self.send_dns_query(&req).await { ConnectionStatus::Ready => match self.send_dns_query(&req).await { Ok(req_id) => { self.query_map.insert(req_id, resp); } Err(e) => { error!("send querry error {:?}", e); if let Ok(quiche::h3::Error::StreamBlocked) = e.downcast::<quiche::h3::Error>() { warn!("try to send query but error on StreamBlocked"); self.pending_queries.push((req, resp, expired_time)); bail!(quiche::h3::Error::StreamBlocked); } else { resp.send(Response::Error { error: QueryError::ConnectionError }).ok(); } } } }, ConnectionStatus::Pending => { self.pending_queries.push((req, resp, Instant::now().checked_add(timeout))); self.pending_queries.push((req, resp, expired_time)); } // Should not happen _ => { error!("Try to send query but status error {}", self.net_id); } } Ok(()) } fn resume_connection(&mut self, quic_conn: Pin<Box<quiche::Connection>>) { Loading @@ -295,22 +299,12 @@ impl DohConnection { loop { while !self.pending_queries.is_empty() { if let Some((req, resp, exp_time)) = self.pending_queries.pop() { // TODO: check if req is expired. match self.send_dns_query(&req).await { Ok(req_id) => { self.query_map.insert(req_id, resp); // Ignore the expired queries. if Instant::now().checked_duration_since(exp_time).is_some() { continue; } Err(e) => { if let Ok(quiche::h3::Error::StreamBlocked) = e.downcast::<quiche::h3::Error>() { self.pending_queries.push((req, resp, exp_time)); if self.try_send_doh_query(req, resp, exp_time).await.is_err() { break; } else { resp.send(Response::Error { error: QueryError::ConnectionError }) .ok(); } } } } } Loading Loading @@ -588,7 +582,7 @@ fn resume_connection( async fn handle_query_cmd( net_id: u32, base64_query: Base64Query, timeout: Duration, expired_time: Instant, resp: QueryResponder, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, config_cache: &mut QuicheConfigCache, Loading Loading @@ -617,7 +611,7 @@ async fn handle_query_cmd( } if let Ok(req) = make_dns_request(&base64_query, &info.url) { debug!("Try to send query"); quic_conn.try_send_doh_query(req, timeout, resp).await; let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await; } else { let _ = resp.send(Response::Error { error: QueryError::Unexpected }); } Loading Loading @@ -679,8 +673,8 @@ async fn doh_handler( } } }, DohCommand::Query { net_id, base64_query, timeout, resp } => { handle_query_cmd(net_id, base64_query, timeout, resp, &mut doh_conn_map, &mut config_cache).await; DohCommand::Query { net_id, base64_query, expired_time, resp } => { handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map, &mut config_cache).await; }, DohCommand::Clear { net_id } => { doh_conn_map.remove(&net_id); Loading Loading @@ -932,12 +926,13 @@ pub unsafe extern "C" fn doh_query( ) -> ssize_t { let q = slice::from_raw_parts_mut(dns_query, dns_query_len); let t = Duration::from_millis(timeout_ms); let (resp_tx, resp_rx) = oneshot::channel(); let t = Duration::from_millis(timeout_ms); if let Some(expired_time) = Instant::now().checked_add(t) { let cmd = DohCommand::Query { net_id, base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD), timeout: t, expired_time, resp: resp_tx, }; Loading @@ -945,6 +940,11 @@ pub unsafe extern "C" fn doh_query( error!("Failed to send the query: {:?}", e); return RESULT_CAN_NOT_SEND; } } else { error!("Bad timeout parameter: {}", timeout_ms); return RESULT_CAN_NOT_SEND; } if let Ok(rt) = Runtime::new() { let local = task::LocalSet::new(); match local.block_on(&rt, async { timeout(t, resp_rx).await }) { Loading Loading @@ -1068,7 +1068,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), t, Instant::now().checked_add(t).unwrap(), resp_tx, &mut test_map, &mut config, Loading @@ -1084,7 +1084,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), t, Instant::now().checked_add(t).unwrap(), resp_tx, &mut test_map, &mut config, Loading @@ -1106,7 +1106,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), t, Instant::now().checked_add(t).unwrap(), resp_tx, &mut test_map, &mut config, Loading