Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 93bfc46e authored by Matthew Sedam's avatar Matthew Sedam Committed by Android (Google) Code Review
Browse files

Merge "Context Hub default HAL: Use std::thread for endpoint callbacks" into main

parents fb8ad9ab 0ec290cd
Loading
Loading
Loading
Loading
+47 −26
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
 */

#include "contexthub-impl/ContextHub.h"
#include "aidl/android/hardware/contexthub/IContextHubCallback.h"

#ifndef LOG_TAG
#define LOG_TAG "CHRE"
@@ -22,6 +23,8 @@

#include <inttypes.h>
#include <log/log.h>
#include <optional>
#include <thread>

using ::ndk::ScopedAStatus;

@@ -62,6 +65,9 @@ const EndpointInfo kMockEndpointInfos[kMockEndpointCount] = {
        },
};

//! Mutex used to ensure callbacks are called after the initial function returns.
std::mutex gCallbackMutex;

}  // anonymous namespace

ScopedAStatus ContextHub::getContextHubs(std::vector<ContextHubInfo>* out_contextHubInfos) {
@@ -297,8 +303,8 @@ ScopedAStatus ContextHub::requestSessionIdRange(int32_t in_size,
        mMaxValidSessionId = in_size;
    }

    _aidl_return->at(0) = 0;
    _aidl_return->at(1) = in_size;
    (*_aidl_return)[0] = 0;
    (*_aidl_return)[1] = in_size;
    return ScopedAStatus::ok();
};

@@ -308,7 +314,7 @@ ScopedAStatus ContextHub::openEndpointSession(
    // We are not calling onCloseEndpointSession on failure because the remote endpoints (our
    // mock endpoints) always accept the session.

    std::shared_ptr<IEndpointCallback> callback = nullptr;
    std::weak_ptr<IEndpointCallback> callback;
    {
        std::unique_lock<std::mutex> lock(mEndpointMutex);
        if (in_sessionId > mMaxValidSessionId) {
@@ -355,23 +361,27 @@ ScopedAStatus ContextHub::openEndpointSession(
                .serviceDescriptor = in_serviceDescriptor,
        });

        if (mEndpointCallback != nullptr) {
            callback = mEndpointCallback;
        if (mEndpointCallback == nullptr) {
            return ScopedAStatus::ok();
        }
        callback = mEndpointCallback;
    }

    if (callback != nullptr) {
        callback->onEndpointSessionOpenComplete(in_sessionId);
    std::unique_lock<std::mutex> lock(gCallbackMutex);
    std::thread{[callback, in_sessionId]() {
        std::unique_lock<std::mutex> lock(gCallbackMutex);
        if (auto cb = callback.lock(); cb != nullptr) {
            cb->onEndpointSessionOpenComplete(in_sessionId);
        }
    }}.detach();
    return ScopedAStatus::ok();
};

ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Message& in_msg) {
    bool foundSession = false;
    std::shared_ptr<IEndpointCallback> callback = nullptr;
    std::weak_ptr<IEndpointCallback> callback;
    {
        std::unique_lock<std::mutex> lock(mEndpointMutex);

        bool foundSession = false;
        for (const EndpointSession& session : mEndpointSessions) {
            if (session.sessionId == in_sessionId) {
                foundSession = true;
@@ -379,27 +389,38 @@ ScopedAStatus ContextHub::sendMessageToEndpoint(int32_t in_sessionId, const Mess
            }
        }

        if (mEndpointCallback != nullptr) {
            callback = mEndpointCallback;
        }
    }

        if (!foundSession) {
            ALOGE("sendMessageToEndpoint: session ID %" PRId32 " is invalid", in_sessionId);
            return ScopedAStatus::fromExceptionCode(EX_ILLEGAL_ARGUMENT);
        }

    if (callback != nullptr) {
        if (in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) {
        if (mEndpointCallback == nullptr) {
            return ScopedAStatus::ok();
        }
        callback = mEndpointCallback;
    }

    std::unique_lock<std::mutex> lock(gCallbackMutex);
    if ((in_msg.flags & Message::FLAG_REQUIRES_DELIVERY_STATUS) != 0) {
        MessageDeliveryStatus msgStatus = {};
        msgStatus.messageSequenceNumber = in_msg.sequenceNumber;
        msgStatus.errorCode = ErrorCode::OK;
            callback->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);

        std::thread{[callback, in_sessionId, msgStatus]() {
            std::unique_lock<std::mutex> lock(gCallbackMutex);
            if (auto cb = callback.lock(); cb != nullptr) {
                cb->onMessageDeliveryStatusReceived(in_sessionId, msgStatus);
            }
        }}.detach();
    }

    // Echo the message back
        callback->onMessageReceived(in_sessionId, in_msg);
    std::thread{[callback, in_sessionId, in_msg]() {
        std::unique_lock<std::mutex> lock(gCallbackMutex);
        if (auto cb = callback.lock(); cb != nullptr) {
            cb->onMessageReceived(in_sessionId, in_msg);
        }
    }}.detach();
    return ScopedAStatus::ok();
};

+21 −6
Original line number Diff line number Diff line
@@ -492,7 +492,11 @@ class TestEndpointCallback : public BnEndpointCallback {
    }

    Status onMessageReceived(int32_t /* sessionId */, const Message& message) override {
        {
            std::unique_lock<std::mutex> lock(mMutex);
            mMessages.push_back(message);
        }
        mCondVar.notify_one();
        return Status::ok();
    }

@@ -513,21 +517,30 @@ class TestEndpointCallback : public BnEndpointCallback {
    }

    Status onEndpointSessionOpenComplete(int32_t /* sessionId */) override {
        {
            std::unique_lock<std::mutex> lock(mMutex);
            mWasOnEndpointSessionOpenCompleteCalled = true;
        }
        mCondVar.notify_one();
        return Status::ok();
    }

    std::vector<Message> getMessages() { return mMessages; }

    bool wasOnEndpointSessionOpenCompleteCalled() {
        return mWasOnEndpointSessionOpenCompleteCalled;
    }

    void resetWasOnEndpointSessionOpenCompleteCalled() {
        mWasOnEndpointSessionOpenCompleteCalled = false;
    }

    std::mutex& getMutex() { return mMutex; }
    std::condition_variable& getCondVar() { return mCondVar; }
    std::vector<Message> getMessages() { return mMessages; }

  private:
    std::vector<Message> mMessages;
    std::mutex mMutex;
    std::condition_variable mCondVar;
    bool mWasOnEndpointSessionOpenCompleteCalled = false;
};

@@ -690,14 +703,12 @@ TEST_P(ContextHubAidl, OpenEndpointSessionInvalidRange) {
    EXPECT_GE(range[1] - range[0] + 1, requestedRange);

    // Open the session
    cb->resetWasOnEndpointSessionOpenCompleteCalled();
    int32_t sessionId = range[1] + 10;  // invalid
    EXPECT_FALSE(contextHub
                         ->openEndpointSession(sessionId, destinationEndpoint->id,
                                               initiatorEndpoint.id,
                                               /* in_serviceDescriptor= */ String16("ECHO"))
                         .isOk());
    EXPECT_FALSE(cb->wasOnEndpointSessionOpenCompleteCalled());
}

TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
@@ -710,6 +721,8 @@ TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
        EXPECT_TRUE(status.isOk());
    }

    std::unique_lock<std::mutex> lock(cb->getMutex());

    // Register the endpoint
    EndpointInfo initiatorEndpoint;
    initiatorEndpoint.id.id = 8;
@@ -750,6 +763,7 @@ TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
                                              initiatorEndpoint.id,
                                              /* in_serviceDescriptor= */ String16("ECHO"))
                        .isOk());
    cb->getCondVar().wait(lock);
    EXPECT_TRUE(cb->wasOnEndpointSessionOpenCompleteCalled());

    // Send the message
@@ -760,6 +774,7 @@ TEST_P(ContextHubAidl, OpenEndpointSessionAndSendMessageEchoesBack) {
    ASSERT_TRUE(contextHub->sendMessageToEndpoint(sessionId, message).isOk());

    // Check for echo
    cb->getCondVar().wait(lock);
    EXPECT_FALSE(cb->getMessages().empty());
    EXPECT_EQ(cb->getMessages().back().content.back(), 42);
}