Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 22d025b6 authored by Abhishek Pandit-Subedi's avatar Abhishek Pandit-Subedi
Browse files

libs/usb4: Refactor UserId and tests

Add a newtype, UserId, in apis for better type safety. Also re-write
tests to reduce dependency on specific sleep durations (opting for
polling + a long timeout).

Bug: 433329064
Test: atest frameworks/base/libs/usb4/tests/pci_authorizer_test.rs
Flag: com.android.server.usb.flags.enable_usb4
Change-Id: Iabad66adf7c6e18d010fcf20c971a73ed5c2d8e9
parent a8f74139
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -18,7 +18,10 @@ use jni::sys::{jboolean, jint};
use jni::JNIEnv;
use log::trace;
use std::sync::{Arc, LazyLock, Mutex};
use usb4_policies::{common::TunnelControl, policy_engine::PolicyEngine};
use usb4_policies::{
    common::{TunnelControl, UserId},
    policy_engine::PolicyEngine,
};

// Singleton of PolicyEngine to use for JNI. Will get created on first use.
static POLICY_ENGINE: LazyLock<Arc<Mutex<PolicyEngine>>> =
@@ -75,5 +78,5 @@ pub extern "system" fn Java_com_android_server_usb_Usb4Manager_updateLoggedInSta
) {
    trace!("updateLoggedInstate with {} = {}", user_id as usize, logged_in != 0);
    let mut engine = POLICY_ENGINE.lock().unwrap();
    engine.update_logged_in_state(logged_in != 0, user_id as usize);
    engine.update_logged_in_state(logged_in != 0, UserId(user_id as usize));
}
+6 −2
Original line number Diff line number Diff line
@@ -18,6 +18,10 @@

use std::collections::HashSet;

/// Newtype to hold user ids.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct UserId(pub usize);

/// Holds the live state variables that determine the authorization policy.
pub struct PolicySourceData {
    /// A flag indicating if the PCI tunneling feature is globally enabled.
@@ -25,7 +29,7 @@ pub struct PolicySourceData {
    /// A flag indicating if the user's screen is currently locked.
    pub is_locked: bool,
    /// A set tracking the IDs of all currently logged-in users.
    pub logged_in_users: HashSet<usize>,
    pub logged_in_users: HashSet<UserId>,
}

impl PolicySourceData {
@@ -56,5 +60,5 @@ pub trait TunnelControl {
    fn update_lock_state(&mut self, locked: bool);

    /// Notifies the engine of a user login or logout event.
    fn update_logged_in_state(&mut self, logged_in: bool, user_id: usize);
    fn update_logged_in_state(&mut self, logged_in: bool, user_id: UserId);
}
+3 −3
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::common::{PolicySourceData, TunnelControl};
use crate::common::{PolicySourceData, TunnelControl, UserId};
use crate::sysfs::SysfsUtils;
use anyhow::Result;
use kobject_uevent::ActionType;
@@ -43,7 +43,7 @@ pub enum PciAuthState {
enum PciServiceEvent {
    EnablePciTunnels(bool),
    UpdateLockState(bool),
    UpdateLoggedInState { logged_in: bool, user_id: usize },
    UpdateLoggedInState { logged_in: bool, user_id: UserId },
    Shutdown,
}

@@ -229,7 +229,7 @@ impl TunnelControl for PciAuthorizer {
        self.send_event(PciServiceEvent::UpdateLockState(locked));
    }

    fn update_logged_in_state(&mut self, logged_in: bool, user_id: usize) {
    fn update_logged_in_state(&mut self, logged_in: bool, user_id: UserId) {
        self.send_event(PciServiceEvent::UpdateLoggedInState { logged_in, user_id });
    }
}
+2 −2
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@
//! The `PolicyEngine` struct is the primary entry point for consumers of this
//! crate. It encapsulates the `PciAuthorizer`.

use crate::common::TunnelControl;
use crate::common::{TunnelControl, UserId};
use crate::pci_authorizer::PciAuthorizer;
use tokio::runtime::Runtime;

@@ -64,7 +64,7 @@ impl TunnelControl for PolicyEngine {
    }

    /// Notifies the engine of a user login or logout event.
    fn update_logged_in_state(&mut self, logged_in: bool, user_id: usize) {
    fn update_logged_in_state(&mut self, logged_in: bool, user_id: UserId) {
        self.pci_authorizer.update_logged_in_state(logged_in, user_id);
    }
}
+73 −55
Original line number Diff line number Diff line
@@ -18,15 +18,19 @@ mod pci_authorizer_tests {
    use std::os::unix::fs::symlink;
    use std::path::{Path, PathBuf};
    use std::sync::Arc;
    use std::time::Instant;
    use tempfile::TempDir;
    use tokio::time::{sleep, Duration};
    use uevent::netlink::AsyncUEventSocket;
    use usb4_policies::common::TunnelControl;
    use usb4_policies::common::{TunnelControl, UserId};
    use usb4_policies::pci_authorizer::PciAuthorizer;
    use usb4_policies::sysfs::SysfsUtils;

    const POLL_DURATION: Duration = Duration::from_millis(30); // Increased slightly for CI
    const SHUTDOWN_WAIT_DURATION: Duration = Duration::from_millis(150); // Wait for task shutdown
    // Time between file reads.
    const POLL_DURATION: Duration = Duration::from_millis(30);

    // Wait for this duration for paths to be updated to desired value.
    const WAIT_FOR_PATH_DURATION: Duration = Duration::from_millis(500);

    fn setup_environment_for_pci_authorizer_new(
    ) -> (TempDir, SysfsUtils, Arc<dyn AsyncUEventSocket>) {
@@ -80,6 +84,23 @@ mod pci_authorizer_tests {
        dev_path
    }

    async fn assert_wait_for_path_eq(path: PathBuf, expected_value: &str, assert_why: &str) {
        let start = Instant::now();
        let mut read_value: String = Default::default();

        // Wait for value to become expected value.
        while Instant::now().duration_since(start) < WAIT_FOR_PATH_DURATION {
            read_value = fs::read_to_string(&path).unwrap();
            if read_value.trim() == expected_value {
                break;
            }

            sleep(POLL_DURATION).await;
        }

        assert_eq!(read_value.trim(), expected_value, "{}", assert_why);
    }

    #[tokio::test]
    async fn test_full_authorization_flow() {
        let _ = env_logger::try_init();
@@ -89,42 +110,41 @@ mod pci_authorizer_tests {

        let tbt_dev_path = create_mock_tbt_device(root, "0-0", "0");

        sleep(POLL_DURATION).await; // Allow task to initialize
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "0",
            "TBT device should initially be deauthorized"
        );
            "TBT device should initially be deauthorized",
        )
        .await;

        // 1. Enable PCI Tunnels (State -> DenyNoUser)
        pci_authorizer.enable_pci_tunnels(true);
        sleep(POLL_DURATION * 2).await; // Allow event processing
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "0",
            "TBT device should remain deauthorized on DenyNoUser"
        );
            "TBT device should remain deauthorized on DenyNoUser",
        )
        .await;

        // 2. User logs in (State -> DeferNewDevices)
        pci_authorizer.update_logged_in_state(true, 1);
        sleep(POLL_DURATION * 2).await;
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        pci_authorizer.update_logged_in_state(true, UserId(1));
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "0",
            "TBT device should remain deauthorized on DeferNewDevices"
        );
            "TBT device should remain deauthorized on DeferNewDevices",
        )
        .await;

        // 3. Screen unlocks (State -> Authorized)
        pci_authorizer.update_lock_state(false);
        sleep(POLL_DURATION * 2).await;
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "1",
            "TBT device should be authorized on Authorized state"
        );
            "TBT device should be authorized on Authorized state",
        )
        .await;

        drop(pci_authorizer);
        sleep(SHUTDOWN_WAIT_DURATION).await;
    }

    #[tokio::test]
@@ -139,14 +159,14 @@ mod pci_authorizer_tests {

        // Setup: Go to Authorized state first
        pci_authorizer.enable_pci_tunnels(true);
        pci_authorizer.update_logged_in_state(true, 1);
        pci_authorizer.update_logged_in_state(true, UserId(1));
        pci_authorizer.update_lock_state(false);
        sleep(POLL_DURATION * 3).await; // Allow transition to Authorized
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "1",
            "TBT device should be authorized"
        );
            "TBT device should be authorized",
        )
        .await;
        assert_eq!(
            fs::read_to_string(removable_pci_dev_path.join("remove")).unwrap().trim(),
            "1",
@@ -155,21 +175,21 @@ mod pci_authorizer_tests {

        // 1. Screen locks (State -> DeferNewDevices)
        pci_authorizer.update_lock_state(true);
        sleep(POLL_DURATION * 2).await;
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "1",
            "TBT device should remain authorized on DeferNewDevices"
        );
            "TBT device should remain authorized on DeferNewDevices",
        )
        .await;

        // 2. User logs out (State -> DenyNoUser)
        pci_authorizer.update_logged_in_state(false, 1); // Last user logs out
        sleep(POLL_DURATION * 2).await;
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        pci_authorizer.update_logged_in_state(false, UserId(1)); // Last user logs out
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "0",
            "TBT device should be deauthorized on DenyNoUser"
        );
            "TBT device should be deauthorized on DenyNoUser",
        )
        .await;
        assert_eq!(
            fs::read_to_string(removable_pci_dev_path.join("remove")).unwrap().trim(),
            "1",
@@ -178,23 +198,23 @@ mod pci_authorizer_tests {

        // Re-setup to Authorized state for the next step
        fs::write(removable_pci_dev_path.join("remove"), "0").unwrap(); // Reset remove state
        pci_authorizer.update_logged_in_state(true, 1); // Log back in
        pci_authorizer.update_logged_in_state(true, UserId(1)); // Log back in
        pci_authorizer.update_lock_state(false); // Unlock screen (State -> Authorized)
        sleep(POLL_DURATION * 3).await;
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "1",
            "TBT device should be re-authorized"
        );
            "TBT device should be re-authorized",
        )
        .await;

        // 3. Disable PCI Tunnels (State -> Disabled)
        pci_authorizer.enable_pci_tunnels(false);
        sleep(POLL_DURATION * 2).await;
        assert_eq!(
            fs::read_to_string(tbt_dev_path.join("authorized")).unwrap().trim(),
        assert_wait_for_path_eq(
            tbt_dev_path.join("authorized"),
            "0",
            "TBT device should be deauthorized when tunnels are disabled"
        );
            "TBT device should be deauthorized when tunnels are disabled",
        )
        .await;
        assert_eq!(
            fs::read_to_string(removable_pci_dev_path.join("remove")).unwrap().trim(),
            "1",
@@ -202,7 +222,6 @@ mod pci_authorizer_tests {
        );

        drop(pci_authorizer);
        sleep(SHUTDOWN_WAIT_DURATION).await;
    }

    #[tokio::test]
@@ -217,6 +236,5 @@ mod pci_authorizer_tests {
        // The test passes if drop completes without panic.
        // A panic in the task during shutdown would be propagated by the await in Drop.
        // Allow a bit of time for async runtime to fully process the drop and task completion.
        sleep(SHUTDOWN_WAIT_DURATION).await;
    }
}