Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit fd22b3e5 authored by Alex Vakulenko's avatar Alex Vakulenko
Browse files

libpdx_uds: Serialize access to connection socket between threads

Added a mutex to allow only one client thread to perform atomic
send-request/receive-responce actions.

Also added a unit test that perfroms multiple parallel client requests
to the same service to ensure it can handle multithreaded access
correctly.

Bug: 37443070
Test: `libpdx_uds_tests` pass
Change-Id: Ica516f7806f9146fb530b5cb371d2ee89146fed7
parent 6eefa42d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ cc_test {
        "-Werror",
    ],
    srcs: [
        "client_channel_tests.cpp",
        "ipc_helper_tests.cpp",
        "remote_method_tests.cpp",
        "service_framework_tests.cpp",
+2 −0
Original line number Diff line number Diff line
@@ -156,6 +156,7 @@ void ClientChannel::FreeTransactionState(void* state) {

Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,
                                        size_t length) {
  std::unique_lock<std::mutex> lock(socket_mutex_);
  Status<void> status;
  android::pdx::uds::RequestHeader<BorrowedHandle> request;
  if (length > request.impulse_payload.size() ||
@@ -174,6 +175,7 @@ Status<int> ClientChannel::SendAndReceive(void* transaction_state, int opcode,
                                          size_t send_count,
                                          const iovec* receive_vector,
                                          size_t receive_count) {
  std::unique_lock<std::mutex> lock(socket_mutex_);
  Status<int> result;
  if ((send_vector == nullptr && send_count != 0) ||
      (receive_vector == nullptr && receive_count != 0)) {
+162 −0
Original line number Diff line number Diff line
#include <uds/client_channel.h>

#include <sys/socket.h>

#include <algorithm>
#include <limits>
#include <random>
#include <thread>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <pdx/client.h>
#include <pdx/rpc/remote_method.h>
#include <pdx/service.h>

#include <uds/client_channel_factory.h>
#include <uds/service_endpoint.h>

using testing::Return;
using testing::_;

using android::pdx::ClientBase;
using android::pdx::LocalChannelHandle;
using android::pdx::LocalHandle;
using android::pdx::Message;
using android::pdx::ServiceBase;
using android::pdx::ServiceDispatcher;
using android::pdx::Status;
using android::pdx::rpc::DispatchRemoteMethod;
using android::pdx::uds::ClientChannel;
using android::pdx::uds::ClientChannelFactory;
using android::pdx::uds::Endpoint;

namespace {

struct TestProtocol {
  using DataType = int8_t;
  enum {
    kOpSum = 0,
  };
  PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
};

class TestService : public ServiceBase<TestService> {
 public:
  TestService(std::unique_ptr<Endpoint> endpoint)
      : ServiceBase{"TestService", std::move(endpoint)} {}

  Status<void> HandleMessage(Message& message) override {
    switch (message.GetOp()) {
      case TestProtocol::kOpSum:
        DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
                                                message);
        return {};

      default:
        return Service::HandleMessage(message);
    }
  }

  int64_t OnSum(Message& /*message*/,
                const std::vector<TestProtocol::DataType>& data) {
    return std::accumulate(data.begin(), data.end(), int64_t{0});
  }
};

class TestClient : public ClientBase<TestClient> {
 public:
  using ClientBase::ClientBase;

  int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
    auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
    return status ? status.get() : -1;
  }
};

class TestServiceRunner {
 public:
  TestServiceRunner(LocalHandle channel_socket) {
    auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
    endpoint->RegisterNewChannelForTests(std::move(channel_socket));
    service_ = TestService::Create(std::move(endpoint));
    dispatcher_ = android::pdx::uds::ServiceDispatcher::Create();
    dispatcher_->AddService(service_);
    dispatch_thread_ = std::thread(
        std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
  }

  ~TestServiceRunner() {
    dispatcher_->SetCanceled(true);
    dispatch_thread_.join();
    dispatcher_->RemoveService(service_);
  }

 private:
  std::shared_ptr<TestService> service_;
  std::unique_ptr<ServiceDispatcher> dispatcher_;
  std::thread dispatch_thread_;
};

class ClientChannelTest : public testing::Test {
 public:
  void SetUp() override {
    int channel_sockets[2] = {};
    ASSERT_EQ(
        0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
    LocalHandle service_channel{channel_sockets[0]};
    LocalHandle client_channel{channel_sockets[1]};

    service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
    auto factory = ClientChannelFactory::Create(std::move(client_channel));
    auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
    ASSERT_TRUE(status);
    client_ = TestClient::Create(status.take());
  }

  void TearDown() override {
    service_runner_.reset();
    client_.reset();
  }

 protected:
  std::unique_ptr<TestServiceRunner> service_runner_;
  std::shared_ptr<TestClient> client_;
};

TEST_F(ClientChannelTest, MultithreadedClient) {
  constexpr int kNumTestThreads = 8;
  constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.

  std::random_device rd;
  std::mt19937 gen{rd()};
  std::uniform_int_distribution<TestProtocol::DataType> dist{
      std::numeric_limits<TestProtocol::DataType>::min(),
      std::numeric_limits<TestProtocol::DataType>::max()};

  auto worker = [](std::shared_ptr<TestClient> client,
                   std::vector<TestProtocol::DataType> data) {
    constexpr int kMaxIterations = 500;
    int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
    for (int i = 0; i < kMaxIterations; i++) {
      ASSERT_EQ(expected, client->Sum(data));
    }
  };

  // Start client threads.
  std::vector<TestProtocol::DataType> data;
  data.resize(kDataSize);
  std::vector<std::thread> threads;
  for (int i = 0; i < kNumTestThreads; i++) {
    std::generate(data.begin(), data.end(),
                  [&dist, &gen]() { return dist(gen); });
    threads.emplace_back(worker, client_, data);
  }

  // Wait for threads to finish.
  for (auto& thread : threads)
    thread.join();
}

}  // namespace
+3 −0
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@

#include <pdx/client_channel.h>

#include <mutex>

#include <uds/channel_event_set.h>
#include <uds/channel_manager.h>
#include <uds/service_endpoint.h>
@@ -73,6 +75,7 @@ class ClientChannel : public pdx::ClientChannel {

  LocalChannelHandle channel_handle_;
  ChannelManager::ChannelData* channel_data_;
  std::mutex socket_mutex_;
};

}  // namespace uds