Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 0ec290cd authored by Matthew Sedam's avatar Matthew Sedam
Browse files

Context Hub default HAL: Use std::thread for endpoint callbacks

This will ensure the endpoint callbacks are called after
returning from the original initiating function in another thread.

This CL also updates the VtsAidlHalContextHubTargetTest to handle
the async callbacks.

Bug: 380335353
Change-Id: I29d932f8a4d8989c06cfa6007368a424c963c91f
Flag: TEST_ONLY
Test: atest VtsAidlHalContextHubTargetTest
parent 0d100347
Loading
Loading
Loading
Loading
+47 −26
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
 */

#include "contexthub-impl/ContextHub.h"
#include "aidl/android/hardware/contexthub/IContextHubCallback.h"

#ifndef LOG_TAG
#define LOG_TAG "CHRE"
@@ -22,6 +23,8 @@

#include <inttypes.h>
#include <log/log.h>
#include <optional>
#include <thread>

using ::ndk::ScopedAStatus;

@@ -62,6 +65,9 @@ const EndpointInfo kMockEndpointInfos[kMockEndpointCount] = {
        },
};

//! Mutex used to ensure callbacks are called after the initial function returns.
std::mutex gCallbackMutex;

}  // anonymous namespace

ScopedAStatus ContextHub::getContextHubs(std::vector<ContextHubInfo>* out_contextHubInfos) {
@@ -297,8 +303,8 @@ ScopedAStatus ContextHub::requestSessionIdRange(int32_t in_size,
        mMaxValidSessionId = in_size;
    }

    _aidl_return->at(0) = 0;
    _aidl_return->at(1) = in_size;
    (*_aidl_return)[0] = 0;
    (*_aidl_return)[1] = in_size;
    return ScopedAStatus::ok();
};

@@ -308,7 +314,7 @@ ScopedAStatus ContextHub::openEndpointSession(
    // We are not calling onCloseEndpointSession on failure because the remote endpoints (our
    // mock endpoints) always accept the session.

    std::shared_ptr<IEndpointCallback> callback = nullptr;
    std::weak_ptr<IEndpointCallback> callback;
    {
        std::unique_lock<std::mutex> lock(mEndpointMutex);
        if (in_sessionId > mMaxValidSessionId) {
@@ -355,23 +361,27 @@ ScopedAStatus ContextHub::openEndpointSession(
                .serviceDescriptor = in_serviceDescriptor,
        });

        if (mEndpointCallback != nullptr) {
            callback = mEndpointCallback;
        if (mEndpointCallback == nullptr) {
            return ScopedAStatus::ok();
        }
        callback = mEndpointCallback;
    }

    if (callback != nullptr) {
        callback->onEndpointSessionOpenComplete(in_sessionId);
    std::unique_lock<std::mutex> lock(gCallbackMutex);
    std::thread{[callback, in_sessionId]() {
        std::unique_lock<std::mutex> lock(gCallbackMutex);
        if (auto cb = callback.lock(); cb != nullptr) {
            cb->onEndpointSessionOpenComplete(in_sessionId);
        }
    }}.detach();
    return ScopedAStatus::ok();
};

ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
    bool foundSession = false;
    std::shared_ptr<IEndpointCallback> callback = nullptr;
    std::weak_ptr<IEndpointCallback> callback;
    {
        std::unique_lock<std::mutex> lock(mEndpointMutex);

        bool foundSession = false;
        for (const EndpointSession& session : mEndpointSessions) {
            if (session.sessionId == in_sessionId) {
                foundSession = true;
@@ -379,27 +389,38 @@ ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Mess
            }
        }

        if (mEndpointCallback != nullptr) {
            callback = mEndpointCallback;
        }
    }

        if (!foundSession) {
            ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
            return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
        }

    if (callback != nullptr) {
        if (in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) {
        if (mEndpointCallback == nullptr) {
            return ScopedAStatus::ok();
        }
        callback = mEndpointCallback;
    }

    std::unique_lock<std::mutex> lock(gCallbackMutex);
    if ((in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) != 0) {
        MessageDeliveryStatus msgStatus = {};
        msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
        msgStatus.errorCode = ErrorCode::OK;
            callback->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);

        std::thread{[callback, in_sessionId, msgStatus]() {
            std::unique_lock<std::mutex> lock(gCallbackMutex);
            if (auto cb = callback.lock(); cb != nullptr) {
                cb->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
            }
        }}.detach();
    }

    // Echo the message back
        callback->onMessageReceived(in_sessionId, in_msg);
    std::thread{[callback, in_sessionId, in_msg]() {
        std::unique_lock<std::mutex> lock(gCallbackMutex);
        if (auto cb = callback.lock(); cb != nullptr) {
            cb->onMessageReceived(in_sessionId, in_msg);
        }
    }}.detach();
    return ScopedAStatus::ok();
};

+21 −6
Original line number Diff line number Diff line
@@ -492,7 +492,11 @@ class TestEndpointCallback : public BnEndpointCallback {
    }

    Status onMessageReceived(int32_t /* sessionId */, const Message& message) override {
        {
            std::unique_lock<std::mutex> lock(mMutex);
            mMessages.push_back(message);
        }
        mCondVar.notify_one();
        return Status::ok();
    }

@@ -513,21 +517,30 @@ class TestEndpointCallback : public BnEndpointCallback {
    }

    Status onEndpointSessionOpenComplete(int32_t /* sessionId */) override {
        {
            std::unique_lock<std::mutex> lock(mMutex);
            mWasOnEndpointSessionOpenCompleteCalled = true;
        }
        mCondVar.notify_one();
        return Status::ok();
    }

    std::vector<Message> getMessages() { return mMessages; }

    bool wasOnEndpointSessionOpenCompleteCalled() {
        return mWasOnEndpointSessionOpenCompleteCalled;
    }

    void resetWasOnEndpointSessionOpenCompleteCalled() {
        mWasOnEndpointSessionOpenCompleteCalled = false;
    }

    std::mutex& getMutex() { return mMutex; }
    std::condition_variable& getCondVar() { return mCondVar; }
    std::vector<Message> getMessages() { return mMessages; }

  private:
    std::vector<Message> mMessages;
    std::mutex mMutex;
    std::condition_variable mCondVar;
    bool mWasOnEndpointSessionOpenCompleteCalled = false;
};

@@ -690,14 +703,12 @@ TEST_P(ContextHubAidl, OpenEndpointSessionInvalidRange) {
    EXPECT_GE(range[1] - range[0] + 1, requestedRange);

    // Open the session
    cb->resetWasOnEndpointSessionOpenCompleteCalled();
    int32_t sessionId = range[1] + 10;  // invalid
    EXPECT_FALSE(contextHub
                         ->openEndpointSession(sessionId, destinationEndpoint->id,
                                               initiatorEndpoint.id,
                                               /* in_serviceDescriptor= */ String16("ECHO"))
                         .isOk());
    EXPECT_FALSE(cb->wasOnEndpointSessionOpenCompleteCalled());
}

TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
@@ -710,6 +721,8 @@ TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
        EXPECT_TRUE(status.isOk());
    }

    std::unique_lock<std::mutex> lock(cb->getMutex());

    // Register the endpoint
    EndpointInfo initiatorEndpoint;
    initiatorEndpoint.id.id = 8;
@@ -750,6 +763,7 @@ TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
                                              initiatorEndpoint.id,
                                              /* in_serviceDescriptor= */ String16("ECHO"))
                        .isOk());
    cb->getCondVar().wait(lock);
    EXPECT_TRUE(cb->wasOnEndpointSessionOpenCompleteCalled());

    // Send the message
@@ -760,6 +774,7 @@ TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
    ASSERT_TRUE(contextHub->sendMessageToEndpoint(sessionId, message).isOk());

    // Check for echo
    cb->getCondVar().wait(lock);
    EXPECT_FALSE(cb->getMessages().empty());
    EXPECT_EQ(cb->getMessages().back().content.back(), 42);
}