Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 84ec222e authored by Michael Butler's avatar Michael Butler Committed by Automerger Merge Worker
Browse files

Merge "Handle case where NN AIDL callback is null in IDevice::prepareModel*" am: 5d4f1b70

parents c798728d 5d4f1b70
Loading
Loading
Loading
Loading
+15 −5
Original line number Original line Diff line number Diff line
@@ -135,16 +135,26 @@ std::shared_ptr<PreparedModel> adaptPreparedModel(nn::SharedPreparedModel prepar
    return ndk::SharedRefBase::make<PreparedModel>(std::move(preparedModel));
    return ndk::SharedRefBase::make<PreparedModel>(std::move(preparedModel));
}
}


void notify(IPreparedModelCallback* callback, ErrorStatus status,
            const std::shared_ptr<IPreparedModel>& preparedModel) {
    if (callback != nullptr) {
        const auto ret = callback->notify(status, preparedModel);
        if (!ret.isOk()) {
            LOG(ERROR) << "IPreparedModelCallback::notify failed with " << ret.getDescription();
        }
    }
}

void notify(IPreparedModelCallback* callback, PrepareModelResult result) {
void notify(IPreparedModelCallback* callback, PrepareModelResult result) {
    if (!result.has_value()) {
    if (!result.has_value()) {
        const auto& [message, status] = result.error();
        const auto& [message, status] = result.error();
        LOG(ERROR) << message;
        LOG(ERROR) << message;
        const auto aidlCode = utils::convert(status).value_or(ErrorStatus::GENERAL_FAILURE);
        const auto aidlCode = utils::convert(status).value_or(ErrorStatus::GENERAL_FAILURE);
        callback->notify(aidlCode, nullptr);
        notify(callback, aidlCode, nullptr);
    } else {
    } else {
        auto preparedModel = std::move(result).value();
        auto preparedModel = std::move(result).value();
        auto aidlPreparedModel = adaptPreparedModel(std::move(preparedModel));
        auto aidlPreparedModel = adaptPreparedModel(std::move(preparedModel));
        callback->notify(ErrorStatus::NONE, std::move(aidlPreparedModel));
        notify(callback, ErrorStatus::NONE, std::move(aidlPreparedModel));
    }
    }
}
}


@@ -284,7 +294,7 @@ ndk::ScopedAStatus Device::prepareModel(const Model& model, ExecutionPreference
    if (!result.has_value()) {
    if (!result.has_value()) {
        const auto& [message, code] = result.error();
        const auto& [message, code] = result.error();
        const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
        const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
        callback->notify(aidlCode, nullptr);
        notify(callback.get(), aidlCode, nullptr);
        return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
        return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
                static_cast<int32_t>(aidlCode), message.c_str());
                static_cast<int32_t>(aidlCode), message.c_str());
    }
    }
@@ -300,7 +310,7 @@ ndk::ScopedAStatus Device::prepareModelFromCache(
    if (!result.has_value()) {
    if (!result.has_value()) {
        const auto& [message, code] = result.error();
        const auto& [message, code] = result.error();
        const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
        const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
        callback->notify(aidlCode, nullptr);
        notify(callback.get(), aidlCode, nullptr);
        return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
        return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
                static_cast<int32_t>(aidlCode), message.c_str());
                static_cast<int32_t>(aidlCode), message.c_str());
    }
    }
@@ -317,7 +327,7 @@ ndk::ScopedAStatus Device::prepareModelWithConfig(
    if (!result.has_value()) {
    if (!result.has_value()) {
        const auto& [message, code] = result.error();
        const auto& [message, code] = result.error();
        const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
        const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
        callback->notify(aidlCode, nullptr);
        notify(callback.get(), aidlCode, nullptr);
        return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
        return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
                static_cast<int32_t>(aidlCode), message.c_str());
                static_cast<int32_t>(aidlCode), message.c_str());
    }
    }