Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit da6a448e authored by Philip Quinn's avatar Philip Quinn
Browse files

Replace shared libtflite dependency with static library.

This allows us to only include the ops required to run the model and
have the linker strip the rest out, reducing memory overhead.

Bug: 267050081
Test: atest libinput_tests
Change-Id: I4055a0c8971ed4308ccfa425ab5e5ba560deb58c
parent b7f27919
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -101,6 +101,8 @@ public:
    // Creates a model from an encoded Flatbuffer model.
    static std::unique_ptr<TfLiteMotionPredictorModel> create(const char* modelPath);

    ~TfLiteMotionPredictorModel();

    // Returns the length of the model's input buffers.
    size_t inputLength() const;

+5 −1
Original line number Diff line number Diff line
@@ -73,11 +73,15 @@ cc_library {
        "liblog",
        "libPlatformProperties",
        "libvintf",
        "libtflite",
    ],

    ldflags: [
        "-Wl,--exclude-libs=libtflite_static.a",
    ],

    static_libs: [
        "libui-types",
        "libtflite_static",
    ],

    export_static_lib_headers: [
+16 −3
Original line number Diff line number Diff line
@@ -35,9 +35,11 @@
#include <log/log.h>

#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/builtin_op_kernels.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/mutable_op_resolver.h"

namespace android {
namespace {
@@ -102,6 +104,15 @@ void checkTensor(const TfLiteTensor* tensor) {
    LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
}

std::unique_ptr<tflite::OpResolver> createOpResolver() {
    auto resolver = std::make_unique<tflite::MutableOpResolver>();
    resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
                         ::tflite::ops::builtin::Register_CONCATENATION());
    resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
                         ::tflite::ops::builtin::Register_FULLY_CONNECTED());
    return resolver;
}

} // namespace

TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
@@ -214,8 +225,8 @@ TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
                                                               mErrorReporter.get());
    LOG_ALWAYS_FATAL_IF(!mModel);

    tflite::ops::builtin::BuiltinOpResolver resolver;
    tflite::InterpreterBuilder builder(*mModel, resolver);
    auto resolver = createOpResolver();
    tflite::InterpreterBuilder builder(*mModel, *resolver);

    if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
        LOG_ALWAYS_FATAL("Failed to build interpreter");
@@ -227,6 +238,8 @@ TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(std::string model)
    allocateTensors();
}

TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}

void TfLiteMotionPredictorModel::allocateTensors() {
    if (mRunner->AllocateTensors() != kTfLiteOk) {
        LOG_ALWAYS_FATAL("Failed to allocate tensors");
+1 −1
Original line number Diff line number Diff line
@@ -34,6 +34,7 @@ cc_test {
        "libgmock",
        "libgui_window_info_static",
        "libinput",
        "libtflite_static",
        "libui-types",
    ],
    cflags: [
@@ -48,7 +49,6 @@ cc_test {
        "libcutils",
        "liblog",
        "libPlatformProperties",
        "libtflite",
        "libutils",
        "libvintf",
    ],