Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 92378eb7 authored by Philip Quinn's avatar Philip Quinn Committed by Android (Google) Code Review
Browse files

Merge changes from topic "stylus-prediction"

* changes:
  Use mmap to read TFLite model.
  Replace shared libtflite dependency with static library.
parents 82a999b8 cb3229aa
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -22,8 +22,8 @@
#include <memory>
#include <optional>
#include <span>
#include <string>

#include <android-base/mapped_file.h>
#include <input/RingBuffer.h>

#include <tensorflow/lite/core/api/error_reporter.h>
@@ -101,6 +101,8 @@ public:
    // Creates a model from an encoded Flatbuffer model.
    static std::unique_ptr<TfLiteMotionPredictorModel> create(const char* modelPath);

    ~TfLiteMotionPredictorModel();

    // Returns the length of the model's input buffers.
    size_t inputLength() const;

@@ -122,7 +124,7 @@ public:
    std::span<const float> outputPressure() const;

private:
    explicit TfLiteMotionPredictorModel(std::string model);
    explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model);

    void allocateTensors();
    void attachInputTensors();
@@ -138,7 +140,7 @@ private:
    const TfLiteTensor* mOutputPhi = nullptr;
    const TfLiteTensor* mOutputPressure = nullptr;

    std::string mFlatBuffer;
    std::unique_ptr<android::base::MappedFile> mFlatBuffer;
    std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
    std::unique_ptr<tflite::FlatBufferModel> mModel;
    std::unique_ptr<tflite::Interpreter> mInterpreter;
+5 −1
Original line number Diff line number Diff line
@@ -73,11 +73,15 @@ cc_library {
        "liblog",
        "libPlatformProperties",
        "libvintf",
        "libtflite",
    ],

    ldflags: [
        "-Wl,--exclude-libs=libtflite_static.a",
    ],

    static_libs: [
        "libui-types",
        "libtflite_static",
    ],

    export_static_lib_headers: [
+45 −15
Original line number Diff line number Diff line
@@ -17,27 +17,31 @@
#define LOG_TAG "TfLiteMotionPredictor"
#include <input/TfLiteMotionPredictor.h>

#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <fstream>
#include <ios>
#include <iterator>
#include <memory>
#include <span>
#include <string>
#include <type_traits>
#include <utility>

#include <android-base/logging.h>
#include <android-base/mapped_file.h>
#define ATRACE_TAG ATRACE_TAG_INPUT
#include <cutils/trace.h>
#include <log/log.h>

#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/mutable_op_resolver.h"

namespace android {
namespace {
@@ -102,6 +106,15 @@ void checkTensor(const TfLiteTensor* tensor) {
    LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
}

std::unique_ptr<tflite::OpResolver> createOpResolver() {
    auto resolver = std::make_unique<tflite::MutableOpResolver>();
    resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
                         ::tflite::ops::builtin::Register_CONCATENATION());
    resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
                         ::tflite::ops::builtin::Register_FULLY_CONNECTED());
    return resolver;
}

} // namespace

TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
@@ -195,27 +208,42 @@ void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,

std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create(
        const char* modelPath) {
    std::ifstream f(modelPath, std::ios::binary);
    LOG_ALWAYS_FATAL_IF(!f, "Could not read model from %s", modelPath);
    const int fd = open(modelPath, O_RDONLY);
    if (fd == -1) {
        PLOG(FATAL) << "Could not read model from " << modelPath;
    }

    std::string data;
    data.assign(std::istreambuf_iterator<char>(f), std::istreambuf_iterator<char>());
    const off_t fdSize = lseek(fd, 0, SEEK_END);
    if (fdSize == -1) {
        PLOG(FATAL) << "Failed to determine file size";
    }

    std::unique_ptr<android::base::MappedFile> modelBuffer =
            android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
    if (!modelBuffer) {
        PLOG(FATAL) << "Failed to mmap model";
    }
    if (close(fd) == -1) {
        PLOG(FATAL) << "Failed to close model fd";
    }

    return std::unique_ptr<TfLiteMotionPredictorModel>(
            new TfLiteMotionPredictorModel(std::move(data)));
            new TfLiteMotionPredictorModel(std::move(modelBuffer)));
}

TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
        std::unique_ptr<android::base::MappedFile> model)
      : mFlatBuffer(std::move(model)) {
    CHECK(mFlatBuffer);
    mErrorReporter = std::make_unique<LoggingErrorReporter>();
    mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer.data(),
                                                               mFlatBuffer.length(),
    mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
                                                               mFlatBuffer->size(),
                                                               /*extra_verifier=*/nullptr,
                                                               mErrorReporter.get());
    LOG_ALWAYS_FATAL_IF(!mModel);

    tflite::ops::builtin::BuiltinOpResolver resolver;
    tflite::InterpreterBuilder builder(*mModel, resolver);
    auto resolver = createOpResolver();
    tflite::InterpreterBuilder builder(*mModel, *resolver);

    if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
        LOG_ALWAYS_FATAL("Failed to build interpreter");
@@ -227,6 +255,8 @@ TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
    allocateTensors();
}

TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}

void TfLiteMotionPredictorModel::allocateTensors() {
    if (mRunner->AllocateTensors() != kTfLiteOk) {
        LOG_ALWAYS_FATAL("Failed to allocate tensors");
+1 −1
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ cc_test {
        "libgmock",
        "libgui_window_info_static",
        "libinput",
        "libtflite_static",
        "libui-types",
    ],
    cflags: [
@@ -48,7 +49,6 @@ cc_test {
        "libcutils",
        "liblog",
        "libPlatformProperties",
        "libtflite",
        "libutils",
        "libvintf",
    ],