Loading doh/doh.rs +35 −61 Original line number Diff line number Diff line Loading @@ -166,7 +166,7 @@ pub struct DohDispatcher { /// Used to submit cmds to the I/O task. cmd_sender: CmdSender, join_handle: task::JoinHandle<Result<()>>, runtime: Arc<Runtime>, runtime: Runtime, } // DoH dispatcher Loading @@ -176,16 +176,13 @@ impl DohDispatcher { tag_socket_fn: TagSocketCallback, ) -> Result<DohDispatcher> { let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); let runtime = Arc::new( Builder::new_multi_thread() let runtime = Builder::new_multi_thread() .worker_threads(2) .enable_all() .thread_name("doh-handler") .build() .expect("Failed to create tokio runtime"), ); let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn)); .expect("Failed to create tokio runtime"); let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation_fn, tag_socket_fn)); Ok(DohDispatcher { cmd_sender, join_handle, runtime }) } Loading Loading @@ -689,10 +686,9 @@ async fn flush_tx( Ok(()) } fn report_private_dns_validation( async fn report_private_dns_validation( info: &ServerInfo, state: &ConnectionState, runtime: Arc<Runtime>, validation_fn: ValidationCallback, ) { let (ip_addr, domain) = match ( Loading @@ -707,14 +703,16 @@ fn report_private_dns_validation( }; let netd_id = info.net_id; let success = matches!(state, ConnectionState::Connected { .. }); runtime .spawn_blocking(move || validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())); task::spawn_blocking(move || { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await .unwrap_or_else(|e| warn!("Validation function task failed: {}", e)); } fn handle_probe_result( async fn handle_probe_result( result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>), doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, runtime: Arc<Runtime>, validation_fn: ValidationCallback, ) { let (info, doh_conn) = match result { Loading Loading @@ -745,7 +743,7 @@ fn handle_probe_result( return; } } report_private_dns_validation(&info, &doh_conn.state, runtime, validation_fn); report_private_dns_validation(&info, &doh_conn.state, validation_fn).await; doh_conn_map.insert(info.net_id, (info, Some(doh_conn))); } Loading Loading @@ -849,7 +847,6 @@ fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConne async fn doh_handler( mut cmd_rx: CmdReceiver, runtime: Arc<Runtime>, validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<()> { Loading @@ -871,8 +868,7 @@ async fn doh_handler( join_all(futures).await }, if need_process_queries(&doh_conn_map) => {}, Some(result) = probe_futures.next() => { let runtime_clone = runtime.clone(); handle_probe_result(result, &mut doh_conn_map, runtime_clone, validation_fn); handle_probe_result(result, &mut doh_conn_map, validation_fn).await; info!("probe_futures remaining size: {}", probe_futures.len()); }, Some(cmd) = cmd_rx.recv() => { Loading @@ -892,7 +888,7 @@ async fn doh_handler( } Err(e) => { error!("create connection for network {} error {:?}", info.net_id, e); report_private_dns_validation(&info, &ConnectionState::Error, runtime.clone(), validation_fn); report_private_dns_validation(&info, &ConnectionState::Error, validation_fn).await; } } }, Loading Loading @@ -1031,7 +1027,7 @@ mod tests { ServerInfo, HashMap<u32, (ServerInfo, Option<DohConnection>)>, Arc<Mutex<QuicheConfigCache>>, Arc<Runtime>, Runtime, ) { let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new(); let info = ServerInfo { Loading @@ -1045,13 +1041,11 @@ mod tests { let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None })); let rt = Arc::new( Builder::new_current_thread() let rt = Builder::new_current_thread() .thread_name("test-runtime") .enable_all() .build() .expect("Failed to create testing tokio runtime"), ); .expect("Failed to create testing tokio runtime"); (info, test_map, config_cache, rt) } Loading Loading @@ -1247,13 +1241,11 @@ mod tests { sk_mark: 0, cert_path: None, }; let rt = Arc::new( Builder::new_current_thread() let rt = Builder::new_current_thread() .thread_name("test-runtime") .enable_io() .build() .expect("Failed to create testing tokio runtime"), ); .expect("Failed to create testing tokio runtime"); let default_panic = std::panic::take_hook(); // Exit the test if the worker inside tokio runtime panicked. std::panic::set_hook(Box::new(move |info| { Loading @@ -1261,30 +1253,12 @@ mod tests { std::process::exit(1); })); rt.block_on(async { super::report_private_dns_validation( &info, &make_dummy_connected_state(), rt.clone(), success_cb, ); super::report_private_dns_validation( &info, &ConnectionState::Error, rt.clone(), fail_cb, ); super::report_private_dns_validation( &info, &make_dummy_connecting_state(), rt.clone(), fail_cb, ); super::report_private_dns_validation( &info, &ConnectionState::Idle, rt.clone(), fail_cb, ); super::report_private_dns_validation(&info, &make_dummy_connected_state(), success_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Error, fail_cb).await; super::report_private_dns_validation(&info, &make_dummy_connecting_state(), fail_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Idle, fail_cb).await; }); } Loading Loading
doh/doh.rs +35 −61 Original line number Diff line number Diff line Loading @@ -166,7 +166,7 @@ pub struct DohDispatcher { /// Used to submit cmds to the I/O task. cmd_sender: CmdSender, join_handle: task::JoinHandle<Result<()>>, runtime: Arc<Runtime>, runtime: Runtime, } // DoH dispatcher Loading @@ -176,16 +176,13 @@ impl DohDispatcher { tag_socket_fn: TagSocketCallback, ) -> Result<DohDispatcher> { let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); let runtime = Arc::new( Builder::new_multi_thread() let runtime = Builder::new_multi_thread() .worker_threads(2) .enable_all() .thread_name("doh-handler") .build() .expect("Failed to create tokio runtime"), ); let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn)); .expect("Failed to create tokio runtime"); let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation_fn, tag_socket_fn)); Ok(DohDispatcher { cmd_sender, join_handle, runtime }) } Loading Loading @@ -689,10 +686,9 @@ async fn flush_tx( Ok(()) } fn report_private_dns_validation( async fn report_private_dns_validation( info: &ServerInfo, state: &ConnectionState, runtime: Arc<Runtime>, validation_fn: ValidationCallback, ) { let (ip_addr, domain) = match ( Loading @@ -707,14 +703,16 @@ fn report_private_dns_validation( }; let netd_id = info.net_id; let success = matches!(state, ConnectionState::Connected { .. }); runtime .spawn_blocking(move || validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())); task::spawn_blocking(move || { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await .unwrap_or_else(|e| warn!("Validation function task failed: {}", e)); } fn handle_probe_result( async fn handle_probe_result( result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>), doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, runtime: Arc<Runtime>, validation_fn: ValidationCallback, ) { let (info, doh_conn) = match result { Loading Loading @@ -745,7 +743,7 @@ fn handle_probe_result( return; } } report_private_dns_validation(&info, &doh_conn.state, runtime, validation_fn); report_private_dns_validation(&info, &doh_conn.state, validation_fn).await; doh_conn_map.insert(info.net_id, (info, Some(doh_conn))); } Loading Loading @@ -849,7 +847,6 @@ fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConne async fn doh_handler( mut cmd_rx: CmdReceiver, runtime: Arc<Runtime>, validation_fn: ValidationCallback, tag_socket_fn: TagSocketCallback, ) -> Result<()> { Loading @@ -871,8 +868,7 @@ async fn doh_handler( join_all(futures).await }, if need_process_queries(&doh_conn_map) => {}, Some(result) = probe_futures.next() => { let runtime_clone = runtime.clone(); handle_probe_result(result, &mut doh_conn_map, runtime_clone, validation_fn); handle_probe_result(result, &mut doh_conn_map, validation_fn).await; info!("probe_futures remaining size: {}", probe_futures.len()); }, Some(cmd) = cmd_rx.recv() => { Loading @@ -892,7 +888,7 @@ async fn doh_handler( } Err(e) => { error!("create connection for network {} error {:?}", info.net_id, e); report_private_dns_validation(&info, &ConnectionState::Error, runtime.clone(), validation_fn); report_private_dns_validation(&info, &ConnectionState::Error, validation_fn).await; } } }, Loading Loading @@ -1031,7 +1027,7 @@ mod tests { ServerInfo, HashMap<u32, (ServerInfo, Option<DohConnection>)>, Arc<Mutex<QuicheConfigCache>>, Arc<Runtime>, Runtime, ) { let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new(); let info = ServerInfo { Loading @@ -1045,13 +1041,11 @@ mod tests { let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None })); let rt = Arc::new( Builder::new_current_thread() let rt = Builder::new_current_thread() .thread_name("test-runtime") .enable_all() .build() .expect("Failed to create testing tokio runtime"), ); .expect("Failed to create testing tokio runtime"); (info, test_map, config_cache, rt) } Loading Loading @@ -1247,13 +1241,11 @@ mod tests { sk_mark: 0, cert_path: None, }; let rt = Arc::new( Builder::new_current_thread() let rt = Builder::new_current_thread() .thread_name("test-runtime") .enable_io() .build() .expect("Failed to create testing tokio runtime"), ); .expect("Failed to create testing tokio runtime"); let default_panic = std::panic::take_hook(); // Exit the test if the worker inside tokio runtime panicked. std::panic::set_hook(Box::new(move |info| { Loading @@ -1261,30 +1253,12 @@ mod tests { std::process::exit(1); })); rt.block_on(async { super::report_private_dns_validation( &info, &make_dummy_connected_state(), rt.clone(), success_cb, ); super::report_private_dns_validation( &info, &ConnectionState::Error, rt.clone(), fail_cb, ); super::report_private_dns_validation( &info, &make_dummy_connecting_state(), rt.clone(), fail_cb, ); super::report_private_dns_validation( &info, &ConnectionState::Idle, rt.clone(), fail_cb, ); super::report_private_dns_validation(&info, &make_dummy_connected_state(), success_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Error, fail_cb).await; super::report_private_dns_validation(&info, &make_dummy_connecting_state(), fail_cb) .await; super::report_private_dns_validation(&info, &ConnectionState::Idle, fail_cb).await; }); } Loading