Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 03475d9a authored by Jeff Sharkey's avatar Jeff Sharkey
Browse files

Add custom scalar/aggregate functions to SQLite.

SQLite ships with a handful of basic functions, such as UPPER() as
a scalar function and MAX() as a aggregate function.  We now have
several use-cases for adding custom functions, where it's otherwise
prohibitively expensive to perform post-processing on the returned
Cursor, as that requires copying processed data to yet another
MatrixCursor before returning to apps.

This change adds the ability for developers to register custom
scalar and aggregate functions on databases that they've opened;
some contrived examples are scalar functions like REVERSE() for
reversing a string, or aggregate functions like STDDEV().

To give developers the most flexibility, we use the Java functional
interfaces for defining these operations, as developers may already
be familiar with the contracts of those methods.  This also opens
the door to quickly adapting existing code through utility methods
like BinaryOperator.minBy(Comparator).

Bug: 142564473
Test: atest CtsDatabaseTestCases:android.database.sqlite.cts.SQLiteDatabaseTest
Change-Id: I9fa0e60ec77bab676396729cc9cb8ba8aaf56224
parent e709c6a4
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -13207,6 +13207,8 @@ package android.database.sqlite {
    method public static int releaseMemory();
    method public long replace(String, String, android.content.ContentValues);
    method public long replaceOrThrow(String, String, android.content.ContentValues) throws android.database.SQLException;
    method public void setCustomAggregateFunction(@NonNull String, @NonNull java.util.function.BinaryOperator<java.lang.String>) throws android.database.sqlite.SQLiteException;
    method public void setCustomScalarFunction(@NonNull String, @NonNull java.util.function.UnaryOperator<java.lang.String>) throws android.database.sqlite.SQLiteException;
    method public void setForeignKeyConstraintsEnabled(boolean);
    method public void setLocale(java.util.Locale);
    method @Deprecated public void setLockingEnabled(boolean);
+27 −23
Original line number Diff line number Diff line
@@ -39,6 +39,8 @@ import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.Map;
import java.util.function.BinaryOperator;
import java.util.function.UnaryOperator;

/**
 * Represents a SQLite database connection.
@@ -123,8 +125,10 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
            boolean enableTrace, boolean enableProfile, int lookasideSlotSize,
            int lookasideSlotCount);
    private static native void nativeClose(long connectionPtr);
    private static native void nativeRegisterCustomFunction(long connectionPtr,
            SQLiteCustomFunction function);
    private static native void nativeRegisterCustomScalarFunction(long connectionPtr,
            String name, UnaryOperator<String> function);
    private static native void nativeRegisterCustomAggregateFunction(long connectionPtr,
            String name, BinaryOperator<String> function);
    private static native void nativeRegisterLocalizedCollators(long connectionPtr, String locale);
    private static native long nativePrepareStatement(long connectionPtr, String sql);
    private static native void nativeFinalizeStatement(long connectionPtr, long statementPtr);
@@ -225,13 +229,7 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
        setJournalSizeLimit();
        setAutoCheckpointInterval();
        setLocaleFromConfiguration();

        // Register custom functions.
        final int functionCount = mConfiguration.customFunctions.size();
        for (int i = 0; i < functionCount; i++) {
            SQLiteCustomFunction function = mConfiguration.customFunctions.get(i);
            nativeRegisterCustomFunction(mConnectionPtr, function);
        }
        setCustomFunctionsFromConfiguration();
    }

    private void dispose(boolean finalized) {
@@ -457,6 +455,19 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
        }
    }

    private void setCustomFunctionsFromConfiguration() {
        for (int i = 0; i < mConfiguration.customScalarFunctions.size(); i++) {
            nativeRegisterCustomScalarFunction(mConnectionPtr,
                    mConfiguration.customScalarFunctions.keyAt(i),
                    mConfiguration.customScalarFunctions.valueAt(i));
        }
        for (int i = 0; i < mConfiguration.customAggregateFunctions.size(); i++) {
            nativeRegisterCustomAggregateFunction(mConnectionPtr,
                    mConfiguration.customAggregateFunctions.keyAt(i),
                    mConfiguration.customAggregateFunctions.valueAt(i));
        }
    }

    private void checkDatabaseWiped() {
        if (!SQLiteGlobal.checkDbWipe()) {
            return;
@@ -491,15 +502,6 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
    void reconfigure(SQLiteDatabaseConfiguration configuration) {
        mOnlyAllowReadOnlyOperations = false;

        // Register custom functions.
        final int functionCount = configuration.customFunctions.size();
        for (int i = 0; i < functionCount; i++) {
            SQLiteCustomFunction function = configuration.customFunctions.get(i);
            if (!mConfiguration.customFunctions.contains(function)) {
                nativeRegisterCustomFunction(mConnectionPtr, function);
            }
        }

        // Remember what changed.
        boolean foreignKeyModeChanged = configuration.foreignKeyConstraintsEnabled
                != mConfiguration.foreignKeyConstraintsEnabled;
@@ -507,6 +509,10 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
                & (SQLiteDatabase.ENABLE_WRITE_AHEAD_LOGGING
                | SQLiteDatabase.ENABLE_LEGACY_COMPATIBILITY_WAL)) != 0;
        boolean localeChanged = !configuration.locale.equals(mConfiguration.locale);
        boolean customScalarFunctionsChanged = !configuration.customScalarFunctions
                .equals(mConfiguration.customScalarFunctions);
        boolean customAggregateFunctionsChanged = !configuration.customAggregateFunctions
                .equals(mConfiguration.customAggregateFunctions);

        // Update configuration parameters.
        mConfiguration.updateParametersFrom(configuration);
@@ -514,20 +520,18 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
        // Update prepared statement cache size.
        mPreparedStatementCache.resize(configuration.maxSqlCacheSize);

        // Update foreign key mode.
        if (foreignKeyModeChanged) {
            setForeignKeyModeFromConfiguration();
        }

        // Update WAL.
        if (walModeChanged) {
            setWalModeFromConfiguration();
        }

        // Update locale.
        if (localeChanged) {
            setLocaleFromConfiguration();
        }
        if (customScalarFunctionsChanged || customAggregateFunctionsChanged) {
            setCustomFunctionsFromConfiguration();
        }
    }

    // Called by SQLiteConnectionPool only.
+76 −13
Original line number Diff line number Diff line
@@ -62,6 +62,8 @@ import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.WeakHashMap;
import java.util.function.BinaryOperator;
import java.util.function.UnaryOperator;

/**
 * Exposes methods to manage a SQLite database.
@@ -958,26 +960,87 @@ public final class SQLiteDatabase extends SQLiteClosable {
    }

    /**
     * Registers a CustomFunction callback as a function that can be called from
     * SQLite database triggers.
     *
     * @param name the name of the sqlite3 function
     * @param numArgs the number of arguments for the function
     * @param function callback to call when the function is executed
     * @hide
     */
    public void addCustomFunction(String name, int numArgs, CustomFunction function) {
        // Create wrapper (also validates arguments).
        SQLiteCustomFunction wrapper = new SQLiteCustomFunction(name, numArgs, function);
     * Register a custom scalar function that can be called from SQL
     * expressions.
     * <p>
     * For example, registering a custom scalar function named {@code REVERSE}
     * could be used in a query like
     * {@code SELECT REVERSE(name) FROM employees}.
     * <p>
     * When attempting to register multiple functions with the same function
     * name, SQLite will replace any previously defined functions with the
     * latest definition, regardless of what function type they are. SQLite does
     * not support unregistering functions.
     *
     * @param functionName Case-insensitive name to register this function
     *            under, limited to 255 UTF-8 bytes in length.
     * @param scalarFunction Functional interface that will be invoked when the
     *            function name is used by a SQL statement. The argument values
     *            from the SQL statement are passed to the functional interface,
     *            and the return values from the functional interface are
     *            returned back into the SQL statement.
     * @throws SQLiteException if the custom function could not be registered.
     * @see #setCustomAggregateFunction(String, BinaryOperator)
     */
    public void setCustomScalarFunction(@NonNull String functionName,
            @NonNull UnaryOperator<String> scalarFunction) throws SQLiteException {
        Objects.requireNonNull(functionName);
        Objects.requireNonNull(scalarFunction);

        synchronized (mLock) {
            throwIfNotOpenLocked();

            mConfigurationLocked.customScalarFunctions.put(functionName, scalarFunction);
            try {
                mConnectionPoolLocked.reconfigure(mConfigurationLocked);
            } catch (RuntimeException ex) {
                mConfigurationLocked.customScalarFunctions.remove(functionName);
                throw ex;
            }
        }
    }

    /**
     * Register a custom aggregate function that can be called from SQL
     * expressions.
     * <p>
     * For example, registering a custom aggregation function named
     * {@code LONGEST} could be used in a query like
     * {@code SELECT LONGEST(name) FROM employees}.
     * <p>
     * The implementation of this method follows the reduction flow outlined in
     * {@link java.util.stream.Stream#reduce(BinaryOperator)}, and the custom
     * aggregation function is expected to be an associative accumulation
     * function, as defined by that class.
     * <p>
     * When attempting to register multiple functions with the same function
     * name, SQLite will replace any previously defined functions with the
     * latest definition, regardless of what function type they are. SQLite does
     * not support unregistering functions.
     *
     * @param functionName Case-insensitive name to register this function
     *            under, limited to 255 UTF-8 bytes in length.
     * @param aggregateFunction Functional interface that will be invoked when
     *            the function name is used by a SQL statement. The argument
     *            values from the SQL statement are passed to the functional
     *            interface, and the return values from the functional interface
     *            are returned back into the SQL statement.
     * @throws SQLiteException if the custom function could not be registered.
     * @see #setCustomScalarFunction(String, UnaryOperator)
     */
    public void setCustomAggregateFunction(@NonNull String functionName,
            @NonNull BinaryOperator<String> aggregateFunction) throws SQLiteException {
        Objects.requireNonNull(functionName);
        Objects.requireNonNull(aggregateFunction);

        synchronized (mLock) {
            throwIfNotOpenLocked();

            mConfigurationLocked.customFunctions.add(wrapper);
            mConfigurationLocked.customAggregateFunctions.put(functionName, aggregateFunction);
            try {
                mConnectionPoolLocked.reconfigure(mConfigurationLocked);
            } catch (RuntimeException ex) {
                mConfigurationLocked.customFunctions.remove(wrapper);
                mConfigurationLocked.customAggregateFunctions.remove(functionName);
                throw ex;
            }
        }
+17 −6
Original line number Diff line number Diff line
@@ -17,9 +17,12 @@
package android.database.sqlite;

import android.compat.annotation.UnsupportedAppUsage;
import android.util.ArrayMap;

import java.util.ArrayList;
import java.util.Locale;
import java.util.Map;
import java.util.function.BinaryOperator;
import java.util.function.UnaryOperator;
import java.util.regex.Pattern;

/**
@@ -87,10 +90,16 @@ public final class SQLiteDatabaseConfiguration {
    public boolean foreignKeyConstraintsEnabled;

    /**
     * The custom functions to register.
     * The custom scalar functions to register.
     */
    public final ArrayList<SQLiteCustomFunction> customFunctions =
            new ArrayList<SQLiteCustomFunction>();
    public final ArrayMap<String, UnaryOperator<String>> customScalarFunctions
            = new ArrayMap<>();

    /**
     * The custom aggregate functions to register.
     */
    public final ArrayMap<String, BinaryOperator<String>> customAggregateFunctions
            = new ArrayMap<>();

    /**
     * The size in bytes of each lookaside slot
@@ -181,8 +190,10 @@ public final class SQLiteDatabaseConfiguration {
        maxSqlCacheSize = other.maxSqlCacheSize;
        locale = other.locale;
        foreignKeyConstraintsEnabled = other.foreignKeyConstraintsEnabled;
        customFunctions.clear();
        customFunctions.addAll(other.customFunctions);
        customScalarFunctions.clear();
        customScalarFunctions.putAll(other.customScalarFunctions);
        customAggregateFunctions.clear();
        customAggregateFunctions.putAll(other.customAggregateFunctions);
        lookasideSlotSize = other.lookasideSlotSize;
        lookasideSlotCount = other.lookasideSlotCount;
        idleConnectionTimeoutMs = other.idleConnectionTimeoutMs;
+131 −61
Original line number Diff line number Diff line
@@ -59,14 +59,12 @@ namespace android {
static const int BUSY_TIMEOUT_MS = 2500;

static struct {
    jfieldID name;
    jfieldID numArgs;
    jmethodID dispatchCallback;
} gSQLiteCustomFunctionClassInfo;
    jmethodID apply;
} gUnaryOperator;

static struct {
    jclass clazz;
} gStringClassInfo;
    jmethodID apply;
} gBinaryOperator;

struct SQLiteConnection {
    // Open flags.
@@ -203,74 +201,146 @@ static void nativeClose(JNIEnv* env, jclass clazz, jlong connectionPtr) {
    }
}

// Called each time a custom function is evaluated.
static void sqliteCustomFunctionCallback(sqlite3_context *context,
static void sqliteCustomScalarFunctionCallback(sqlite3_context *context,
        int argc, sqlite3_value **argv) {
    JNIEnv* env = AndroidRuntime::getJNIEnv();

    // Get the callback function object.
    // Create a new local reference to it in case the callback tries to do something
    // dumb like unregister the function (thereby destroying the global ref) while it is running.
    jobject functionObjGlobal = reinterpret_cast<jobject>(sqlite3_user_data(context));
    jobject functionObj = env->NewLocalRef(functionObjGlobal);

    jobjectArray argsArray = env->NewObjectArray(argc, gStringClassInfo.clazz, NULL);
    if (argsArray) {
        for (int i = 0; i < argc; i++) {
            const jchar* arg = static_cast<const jchar*>(sqlite3_value_text16(argv[i]));
            if (!arg) {
                ALOGW("NULL argument in custom_function_callback.  This should not happen.");
    ScopedLocalRef<jobject> functionObj(env, env->NewLocalRef(functionObjGlobal));
    ScopedLocalRef<jstring> argString(env,
            env->NewStringUTF(reinterpret_cast<const char*>(sqlite3_value_text(argv[0]))));
    ScopedLocalRef<jstring> resString(env,
            (jstring) env->CallObjectMethod(functionObj.get(), gUnaryOperator.apply, argString.get()));

    if (env->ExceptionCheck()) {
        ALOGE("Exception thrown by custom scalar function");
        sqlite3_result_error(context, "Exception thrown by custom scalar function", -1);
        env->ExceptionDescribe();
        env->ExceptionClear();
        return;
    }

    if (resString.get() == nullptr) {
        sqlite3_result_null(context);
    } else {
                size_t argLen = sqlite3_value_bytes16(argv[i]) / sizeof(jchar);
                jstring argStr = env->NewString(arg, argLen);
                if (!argStr) {
                    goto error; // out of memory error
        ScopedUtfChars res(env, resString.get());
        sqlite3_result_text(context, res.c_str(), -1, SQLITE_TRANSIENT);
    }
                env->SetObjectArrayElement(argsArray, i, argStr);
                env->DeleteLocalRef(argStr);
}

static void sqliteCustomScalarFunctionDestructor(void* data) {
    jobject functionObjGlobal = reinterpret_cast<jobject>(data);

    JNIEnv* env = AndroidRuntime::getJNIEnv();
    env->DeleteGlobalRef(functionObjGlobal);
}

        // TODO: Support functions that return values.
        env->CallVoidMethod(functionObj,
                gSQLiteCustomFunctionClassInfo.dispatchCallback, argsArray);
static void nativeRegisterCustomScalarFunction(JNIEnv* env, jclass clazz, jlong connectionPtr,
        jstring functionName, jobject functionObj) {
    SQLiteConnection* connection = reinterpret_cast<SQLiteConnection*>(connectionPtr);

    jobject functionObjGlobal = env->NewGlobalRef(functionObj);
    ScopedUtfChars functionNameChars(env, functionName);
    int err = sqlite3_create_function_v2(connection->db,
            functionNameChars.c_str(), 1, SQLITE_UTF8,
            reinterpret_cast<void*>(functionObjGlobal),
            &sqliteCustomScalarFunctionCallback,
            nullptr,
            nullptr,
            &sqliteCustomScalarFunctionDestructor);

    if (err != SQLITE_OK) {
        ALOGE("sqlite3_create_function returned %d", err);
        env->DeleteGlobalRef(functionObjGlobal);
        throw_sqlite3_exception(env, connection->db);
        return;
    }
}

error:
        env->DeleteLocalRef(argsArray);
static void sqliteCustomAggregateFunctionStep(sqlite3_context *context,
        int argc, sqlite3_value **argv) {
    char** agg = reinterpret_cast<char**>(
            sqlite3_aggregate_context(context, sizeof(const char**)));
    if (agg == nullptr) {
        return;
    } else if (*agg == nullptr) {
        // During our first call the best we can do is allocate our result
        // holder and populate it with our first value; we'll reduce it
        // against any additional values in future calls
        const char* res = reinterpret_cast<const char*>(sqlite3_value_text(argv[0]));
        if (res == nullptr) {
            *agg = nullptr;
        } else {
            *agg = strdup(res);
        }
        return;
    }

    env->DeleteLocalRef(functionObj);
    JNIEnv* env = AndroidRuntime::getJNIEnv();
    jobject functionObjGlobal = reinterpret_cast<jobject>(sqlite3_user_data(context));
    ScopedLocalRef<jobject> functionObj(env, env->NewLocalRef(functionObjGlobal));
    ScopedLocalRef<jstring> arg0String(env,
            env->NewStringUTF(reinterpret_cast<const char*>(*agg)));
    ScopedLocalRef<jstring> arg1String(env,
            env->NewStringUTF(reinterpret_cast<const char*>(sqlite3_value_text(argv[0]))));
    ScopedLocalRef<jstring> resString(env,
            (jstring) env->CallObjectMethod(functionObj.get(), gBinaryOperator.apply,
                    arg0String.get(), arg1String.get()));

    if (env->ExceptionCheck()) {
        ALOGE("An exception was thrown by custom SQLite function.");
        LOGE_EX(env);
        ALOGE("Exception thrown by custom aggregate function");
        sqlite3_result_error(context, "Exception thrown by custom aggregate function", -1);
        env->ExceptionDescribe();
        env->ExceptionClear();
        return;
    }

    // One way or another, we have a new value to collect, and we need to
    // free our previous value
    if (*agg != nullptr) {
        free(*agg);
    }
    if (resString.get() == nullptr) {
        *agg = nullptr;
    } else {
        ScopedUtfChars res(env, resString.get());
        *agg = strdup(res.c_str());
    }
}

// Called when a custom function is destroyed.
static void sqliteCustomFunctionDestructor(void* data) {
static void sqliteCustomAggregateFunctionFinal(sqlite3_context *context) {
    // We pass zero size here to avoid allocating for empty sets
    char** agg = reinterpret_cast<char**>(
            sqlite3_aggregate_context(context, 0));
    if (agg == nullptr) {
        return;
    } else if (*agg == nullptr) {
        sqlite3_result_null(context);
    } else {
        sqlite3_result_text(context, *agg, -1, SQLITE_TRANSIENT);
        free(*agg);
    }
}

static void sqliteCustomAggregateFunctionDestructor(void* data) {
    jobject functionObjGlobal = reinterpret_cast<jobject>(data);

    JNIEnv* env = AndroidRuntime::getJNIEnv();
    env->DeleteGlobalRef(functionObjGlobal);
}

static void nativeRegisterCustomFunction(JNIEnv* env, jclass clazz, jlong connectionPtr,
        jobject functionObj) {
static void nativeRegisterCustomAggregateFunction(JNIEnv* env, jclass clazz, jlong connectionPtr,
        jstring functionName, jobject functionObj) {
    SQLiteConnection* connection = reinterpret_cast<SQLiteConnection*>(connectionPtr);

    jstring nameStr = jstring(env->GetObjectField(
            functionObj, gSQLiteCustomFunctionClassInfo.name));
    jint numArgs = env->GetIntField(functionObj, gSQLiteCustomFunctionClassInfo.numArgs);

    jobject functionObjGlobal = env->NewGlobalRef(functionObj);

    const char* name = env->GetStringUTFChars(nameStr, NULL);
    int err = sqlite3_create_function_v2(connection->db, name, numArgs, SQLITE_UTF16,
    ScopedUtfChars functionNameChars(env, functionName);
    int err = sqlite3_create_function_v2(connection->db,
            functionNameChars.c_str(), 1, SQLITE_UTF8,
            reinterpret_cast<void*>(functionObjGlobal),
            &sqliteCustomFunctionCallback, NULL, NULL, &sqliteCustomFunctionDestructor);
    env->ReleaseStringUTFChars(nameStr, name);
            nullptr,
            &sqliteCustomAggregateFunctionStep,
            &sqliteCustomAggregateFunctionFinal,
            &sqliteCustomAggregateFunctionDestructor);

    if (err != SQLITE_OK) {
        ALOGE("sqlite3_create_function returned %d", err);
@@ -812,8 +882,10 @@ static const JNINativeMethod sMethods[] =
            (void*)nativeOpen },
    { "nativeClose", "(J)V",
            (void*)nativeClose },
    { "nativeRegisterCustomFunction", "(JLandroid/database/sqlite/SQLiteCustomFunction;)V",
            (void*)nativeRegisterCustomFunction },
    { "nativeRegisterCustomScalarFunction", "(JLjava/lang/String;Ljava/util/function/UnaryOperator;)V",
            (void*)nativeRegisterCustomScalarFunction },
    { "nativeRegisterCustomAggregateFunction", "(JLjava/lang/String;Ljava/util/function/BinaryOperator;)V",
            (void*)nativeRegisterCustomAggregateFunction },
    { "nativeRegisterLocalizedCollators", "(JLjava/lang/String;)V",
            (void*)nativeRegisterLocalizedCollators },
    { "nativePrepareStatement", "(JLjava/lang/String;)J",
@@ -864,15 +936,13 @@ static const JNINativeMethod sMethods[] =

int register_android_database_SQLiteConnection(JNIEnv *env)
{
    jclass clazz = FindClassOrDie(env, "android/database/sqlite/SQLiteCustomFunction");

    gSQLiteCustomFunctionClassInfo.name = GetFieldIDOrDie(env, clazz, "name", "Ljava/lang/String;");
    gSQLiteCustomFunctionClassInfo.numArgs = GetFieldIDOrDie(env, clazz, "numArgs", "I");
    gSQLiteCustomFunctionClassInfo.dispatchCallback = GetMethodIDOrDie(env, clazz,
            "dispatchCallback", "([Ljava/lang/String;)V");
    jclass unaryClazz = FindClassOrDie(env, "java/util/function/UnaryOperator");
    gUnaryOperator.apply = GetMethodIDOrDie(env, unaryClazz,
            "apply", "(Ljava/lang/Object;)Ljava/lang/Object;");

    clazz = FindClassOrDie(env, "java/lang/String");
    gStringClassInfo.clazz = MakeGlobalRefOrDie(env, clazz);
    jclass binaryClazz = FindClassOrDie(env, "java/util/function/BinaryOperator");
    gBinaryOperator.apply = GetMethodIDOrDie(env, binaryClazz,
            "apply", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");

    return RegisterMethodsOrDie(env, "android/database/sqlite/SQLiteConnection", sMethods,
                                NELEM(sMethods));