Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 108ac3af authored by Jeff Sharkey's avatar Jeff Sharkey Committed by Android (Google) Code Review
Browse files

Merge "Add custom scalar/aggregate functions to SQLite."

parents 805f3c1d 03475d9a
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -13251,6 +13251,8 @@ package android.database.sqlite {
    method public static int releaseMemory();
    method public long replace(String, String, android.content.ContentValues);
    method public long replaceOrThrow(String, String, android.content.ContentValues) throws android.database.SQLException;
    method public void setCustomAggregateFunction(@NonNull String, @NonNull java.util.function.BinaryOperator<java.lang.String>) throws android.database.sqlite.SQLiteException;
    method public void setCustomScalarFunction(@NonNull String, @NonNull java.util.function.UnaryOperator<java.lang.String>) throws android.database.sqlite.SQLiteException;
    method public void setForeignKeyConstraintsEnabled(boolean);
    method public void setLocale(java.util.Locale);
    method @Deprecated public void setLockingEnabled(boolean);
+27 −23
Original line number Diff line number Diff line
@@ -39,6 +39,8 @@ import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.Map;
import java.util.function.BinaryOperator;
import java.util.function.UnaryOperator;

/**
 * Represents a SQLite database connection.
@@ -123,8 +125,10 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
            boolean enableTrace, boolean enableProfile, int lookasideSlotSize,
            int lookasideSlotCount);
    private static native void nativeClose(long connectionPtr);
    private static native void nativeRegisterCustomFunction(long connectionPtr,
            SQLiteCustomFunction function);
    private static native void nativeRegisterCustomScalarFunction(long connectionPtr,
            String name, UnaryOperator<String> function);
    private static native void nativeRegisterCustomAggregateFunction(long connectionPtr,
            String name, BinaryOperator<String> function);
    private static native void nativeRegisterLocalizedCollators(long connectionPtr, String locale);
    private static native long nativePrepareStatement(long connectionPtr, String sql);
    private static native void nativeFinalizeStatement(long connectionPtr, long statementPtr);
@@ -225,13 +229,7 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
        setJournalSizeLimit();
        setAutoCheckpointInterval();
        setLocaleFromConfiguration();

        // Register custom functions.
        final int functionCount = mConfiguration.customFunctions.size();
        for (int i = 0; i < functionCount; i++) {
            SQLiteCustomFunction function = mConfiguration.customFunctions.get(i);
            nativeRegisterCustomFunction(mConnectionPtr, function);
        }
        setCustomFunctionsFromConfiguration();
    }

    private void dispose(boolean finalized) {
@@ -457,6 +455,19 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
        }
    }

    private void setCustomFunctionsFromConfiguration() {
        for (int i = 0; i < mConfiguration.customScalarFunctions.size(); i++) {
            nativeRegisterCustomScalarFunction(mConnectionPtr,
                    mConfiguration.customScalarFunctions.keyAt(i),
                    mConfiguration.customScalarFunctions.valueAt(i));
        }
        for (int i = 0; i < mConfiguration.customAggregateFunctions.size(); i++) {
            nativeRegisterCustomAggregateFunction(mConnectionPtr,
                    mConfiguration.customAggregateFunctions.keyAt(i),
                    mConfiguration.customAggregateFunctions.valueAt(i));
        }
    }

    private void checkDatabaseWiped() {
        if (!SQLiteGlobal.checkDbWipe()) {
            return;
@@ -491,15 +502,6 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
    void reconfigure(SQLiteDatabaseConfiguration configuration) {
        mOnlyAllowReadOnlyOperations = false;

        // Register custom functions.
        final int functionCount = configuration.customFunctions.size();
        for (int i = 0; i < functionCount; i++) {
            SQLiteCustomFunction function = configuration.customFunctions.get(i);
            if (!mConfiguration.customFunctions.contains(function)) {
                nativeRegisterCustomFunction(mConnectionPtr, function);
            }
        }

        // Remember what changed.
        boolean foreignKeyModeChanged = configuration.foreignKeyConstraintsEnabled
                != mConfiguration.foreignKeyConstraintsEnabled;
@@ -507,6 +509,10 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
                & (SQLiteDatabase.ENABLE_WRITE_AHEAD_LOGGING
                | SQLiteDatabase.ENABLE_LEGACY_COMPATIBILITY_WAL)) != 0;
        boolean localeChanged = !configuration.locale.equals(mConfiguration.locale);
        boolean customScalarFunctionsChanged = !configuration.customScalarFunctions
                .equals(mConfiguration.customScalarFunctions);
        boolean customAggregateFunctionsChanged = !configuration.customAggregateFunctions
                .equals(mConfiguration.customAggregateFunctions);

        // Update configuration parameters.
        mConfiguration.updateParametersFrom(configuration);
@@ -514,20 +520,18 @@ public final class SQLiteConnection implements CancellationSignal.OnCancelListen
        // Update prepared statement cache size.
        mPreparedStatementCache.resize(configuration.maxSqlCacheSize);

        // Update foreign key mode.
        if (foreignKeyModeChanged) {
            setForeignKeyModeFromConfiguration();
        }

        // Update WAL.
        if (walModeChanged) {
            setWalModeFromConfiguration();
        }

        // Update locale.
        if (localeChanged) {
            setLocaleFromConfiguration();
        }
        if (customScalarFunctionsChanged || customAggregateFunctionsChanged) {
            setCustomFunctionsFromConfiguration();
        }
    }

    // Called by SQLiteConnectionPool only.
+76 −13
Original line number Diff line number Diff line
@@ -62,6 +62,8 @@ import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.WeakHashMap;
import java.util.function.BinaryOperator;
import java.util.function.UnaryOperator;

/**
 * Exposes methods to manage a SQLite database.
@@ -958,26 +960,87 @@ public final class SQLiteDatabase extends SQLiteClosable {
    }

    /**
     * Registers a CustomFunction callback as a function that can be called from
     * SQLite database triggers.
     *
     * @param name the name of the sqlite3 function
     * @param numArgs the number of arguments for the function
     * @param function callback to call when the function is executed
     * @hide
     */
    public void addCustomFunction(String name, int numArgs, CustomFunction function) {
        // Create wrapper (also validates arguments).
        SQLiteCustomFunction wrapper = new SQLiteCustomFunction(name, numArgs, function);
     * Register a custom scalar function that can be called from SQL
     * expressions.
     * <p>
     * For example, registering a custom scalar function named {@code REVERSE}
     * could be used in a query like
     * {@code SELECT REVERSE(name) FROM employees}.
     * <p>
     * When attempting to register multiple functions with the same function
     * name, SQLite will replace any previously defined functions with the
     * latest definition, regardless of what function type they are. SQLite does
     * not support unregistering functions.
     *
     * @param functionName Case-insensitive name to register this function
     *            under, limited to 255 UTF-8 bytes in length.
     * @param scalarFunction Functional interface that will be invoked when the
     *            function name is used by a SQL statement. The argument values
     *            from the SQL statement are passed to the functional interface,
     *            and the return values from the functional interface are
     *            returned back into the SQL statement.
     * @throws SQLiteException if the custom function could not be registered.
     * @see #setCustomAggregateFunction(String, BinaryOperator)
     */
    public void setCustomScalarFunction(@NonNull String functionName,
            @NonNull UnaryOperator<String> scalarFunction) throws SQLiteException {
        Objects.requireNonNull(functionName);
        Objects.requireNonNull(scalarFunction);

        synchronized (mLock) {
            throwIfNotOpenLocked();

            mConfigurationLocked.customScalarFunctions.put(functionName, scalarFunction);
            try {
                mConnectionPoolLocked.reconfigure(mConfigurationLocked);
            } catch (RuntimeException ex) {
                mConfigurationLocked.customScalarFunctions.remove(functionName);
                throw ex;
            }
        }
    }

    /**
     * Register a custom aggregate function that can be called from SQL
     * expressions.
     * <p>
     * For example, registering a custom aggregation function named
     * {@code LONGEST} could be used in a query like
     * {@code SELECT LONGEST(name) FROM employees}.
     * <p>
     * The implementation of this method follows the reduction flow outlined in
     * {@link java.util.stream.Stream#reduce(BinaryOperator)}, and the custom
     * aggregation function is expected to be an associative accumulation
     * function, as defined by that class.
     * <p>
     * When attempting to register multiple functions with the same function
     * name, SQLite will replace any previously defined functions with the
     * latest definition, regardless of what function type they are. SQLite does
     * not support unregistering functions.
     *
     * @param functionName Case-insensitive name to register this function
     *            under, limited to 255 UTF-8 bytes in length.
     * @param aggregateFunction Functional interface that will be invoked when
     *            the function name is used by a SQL statement. The argument
     *            values from the SQL statement are passed to the functional
     *            interface, and the return values from the functional interface
     *            are returned back into the SQL statement.
     * @throws SQLiteException if the custom function could not be registered.
     * @see #setCustomScalarFunction(String, UnaryOperator)
     */
    public void setCustomAggregateFunction(@NonNull String functionName,
            @NonNull BinaryOperator<String> aggregateFunction) throws SQLiteException {
        Objects.requireNonNull(functionName);
        Objects.requireNonNull(aggregateFunction);

        synchronized (mLock) {
            throwIfNotOpenLocked();

            mConfigurationLocked.customFunctions.add(wrapper);
            mConfigurationLocked.customAggregateFunctions.put(functionName, aggregateFunction);
            try {
                mConnectionPoolLocked.reconfigure(mConfigurationLocked);
            } catch (RuntimeException ex) {
                mConfigurationLocked.customFunctions.remove(wrapper);
                mConfigurationLocked.customAggregateFunctions.remove(functionName);
                throw ex;
            }
        }
+17 −6
Original line number Diff line number Diff line
@@ -17,9 +17,12 @@
package android.database.sqlite;

import android.compat.annotation.UnsupportedAppUsage;
import android.util.ArrayMap;

import java.util.ArrayList;
import java.util.Locale;
import java.util.Map;
import java.util.function.BinaryOperator;
import java.util.function.UnaryOperator;
import java.util.regex.Pattern;

/**
@@ -87,10 +90,16 @@ public final class SQLiteDatabaseConfiguration {
    public boolean foreignKeyConstraintsEnabled;

    /**
     * The custom functions to register.
     * The custom scalar functions to register.
     */
    public final ArrayList<SQLiteCustomFunction> customFunctions =
            new ArrayList<SQLiteCustomFunction>();
    public final ArrayMap<String, UnaryOperator<String>> customScalarFunctions
            = new ArrayMap<>();

    /**
     * The custom aggregate functions to register.
     */
    public final ArrayMap<String, BinaryOperator<String>> customAggregateFunctions
            = new ArrayMap<>();

    /**
     * The size in bytes of each lookaside slot
@@ -181,8 +190,10 @@ public final class SQLiteDatabaseConfiguration {
        maxSqlCacheSize = other.maxSqlCacheSize;
        locale = other.locale;
        foreignKeyConstraintsEnabled = other.foreignKeyConstraintsEnabled;
        customFunctions.clear();
        customFunctions.addAll(other.customFunctions);
        customScalarFunctions.clear();
        customScalarFunctions.putAll(other.customScalarFunctions);
        customAggregateFunctions.clear();
        customAggregateFunctions.putAll(other.customAggregateFunctions);
        lookasideSlotSize = other.lookasideSlotSize;
        lookasideSlotCount = other.lookasideSlotCount;
        idleConnectionTimeoutMs = other.idleConnectionTimeoutMs;
+131 −61
Original line number Diff line number Diff line
@@ -59,14 +59,12 @@ namespace android {
static const int BUSY_TIMEOUT_MS = 2500;

static struct {
    jfieldID name;
    jfieldID numArgs;
    jmethodID dispatchCallback;
} gSQLiteCustomFunctionClassInfo;
    jmethodID apply;
} gUnaryOperator;

static struct {
    jclass clazz;
} gStringClassInfo;
    jmethodID apply;
} gBinaryOperator;

struct SQLiteConnection {
    // Open flags.
@@ -203,74 +201,146 @@ static void nativeClose(JNIEnv* env, jclass clazz, jlong connectionPtr) {
    }
}

// Called each time a custom function is evaluated.
static void sqliteCustomFunctionCallback(sqlite3_context *context,
static void sqliteCustomScalarFunctionCallback(sqlite3_context *context,
        int argc, sqlite3_value **argv) {
    JNIEnv* env = AndroidRuntime::getJNIEnv();

    // Get the callback function object.
    // Create a new local reference to it in case the callback tries to do something
    // dumb like unregister the function (thereby destroying the global ref) while it is running.
    jobject functionObjGlobal = reinterpret_cast<jobject>(sqlite3_user_data(context));
    jobject functionObj = env->NewLocalRef(functionObjGlobal);

    jobjectArray argsArray = env->NewObjectArray(argc, gStringClassInfo.clazz, NULL);
    if (argsArray) {
        for (int i = 0; i < argc; i++) {
            const jchar* arg = static_cast<const jchar*>(sqlite3_value_text16(argv[i]));
            if (!arg) {
                ALOGW("NULL argument in custom_function_callback.  This should not happen.");
    ScopedLocalRef<jobject> functionObj(env, env->NewLocalRef(functionObjGlobal));
    ScopedLocalRef<jstring> argString(env,
            env->NewStringUTF(reinterpret_cast<const char*>(sqlite3_value_text(argv[0]))));
    ScopedLocalRef<jstring> resString(env,
            (jstring) env->CallObjectMethod(functionObj.get(), gUnaryOperator.apply, argString.get()));

    if (env->ExceptionCheck()) {
        ALOGE("Exception thrown by custom scalar function");
        sqlite3_result_error(context, "Exception thrown by custom scalar function", -1);
        env->ExceptionDescribe();
        env->ExceptionClear();
        return;
    }

    if (resString.get() == nullptr) {
        sqlite3_result_null(context);
    } else {
                size_t argLen = sqlite3_value_bytes16(argv[i]) / sizeof(jchar);
                jstring argStr = env->NewString(arg, argLen);
                if (!argStr) {
                    goto error; // out of memory error
        ScopedUtfChars res(env, resString.get());
        sqlite3_result_text(context, res.c_str(), -1, SQLITE_TRANSIENT);
    }
                env->SetObjectArrayElement(argsArray, i, argStr);
                env->DeleteLocalRef(argStr);
}

static void sqliteCustomScalarFunctionDestructor(void* data) {
    jobject functionObjGlobal = reinterpret_cast<jobject>(data);

    JNIEnv* env = AndroidRuntime::getJNIEnv();
    env->DeleteGlobalRef(functionObjGlobal);
}

        // TODO: Support functions that return values.
        env->CallVoidMethod(functionObj,
                gSQLiteCustomFunctionClassInfo.dispatchCallback, argsArray);
static void nativeRegisterCustomScalarFunction(JNIEnv* env, jclass clazz, jlong connectionPtr,
        jstring functionName, jobject functionObj) {
    SQLiteConnection* connection = reinterpret_cast<SQLiteConnection*>(connectionPtr);

    jobject functionObjGlobal = env->NewGlobalRef(functionObj);
    ScopedUtfChars functionNameChars(env, functionName);
    int err = sqlite3_create_function_v2(connection->db,
            functionNameChars.c_str(), 1, SQLITE_UTF8,
            reinterpret_cast<void*>(functionObjGlobal),
            &sqliteCustomScalarFunctionCallback,
            nullptr,
            nullptr,
            &sqliteCustomScalarFunctionDestructor);

    if (err != SQLITE_OK) {
        ALOGE("sqlite3_create_function returned %d", err);
        env->DeleteGlobalRef(functionObjGlobal);
        throw_sqlite3_exception(env, connection->db);
        return;
    }
}

error:
        env->DeleteLocalRef(argsArray);
static void sqliteCustomAggregateFunctionStep(sqlite3_context *context,
        int argc, sqlite3_value **argv) {
    char** agg = reinterpret_cast<char**>(
            sqlite3_aggregate_context(context, sizeof(const char**)));
    if (agg == nullptr) {
        return;
    } else if (*agg == nullptr) {
        // During our first call the best we can do is allocate our result
        // holder and populate it with our first value; we'll reduce it
        // against any additional values in future calls
        const char* res = reinterpret_cast<const char*>(sqlite3_value_text(argv[0]));
        if (res == nullptr) {
            *agg = nullptr;
        } else {
            *agg = strdup(res);
        }
        return;
    }

    env->DeleteLocalRef(functionObj);
    JNIEnv* env = AndroidRuntime::getJNIEnv();
    jobject functionObjGlobal = reinterpret_cast<jobject>(sqlite3_user_data(context));
    ScopedLocalRef<jobject> functionObj(env, env->NewLocalRef(functionObjGlobal));
    ScopedLocalRef<jstring> arg0String(env,
            env->NewStringUTF(reinterpret_cast<const char*>(*agg)));
    ScopedLocalRef<jstring> arg1String(env,
            env->NewStringUTF(reinterpret_cast<const char*>(sqlite3_value_text(argv[0]))));
    ScopedLocalRef<jstring> resString(env,
            (jstring) env->CallObjectMethod(functionObj.get(), gBinaryOperator.apply,
                    arg0String.get(), arg1String.get()));

    if (env->ExceptionCheck()) {
        ALOGE("An exception was thrown by custom SQLite function.");
        LOGE_EX(env);
        ALOGE("Exception thrown by custom aggregate function");
        sqlite3_result_error(context, "Exception thrown by custom aggregate function", -1);
        env->ExceptionDescribe();
        env->ExceptionClear();
        return;
    }

    // One way or another, we have a new value to collect, and we need to
    // free our previous value
    if (*agg != nullptr) {
        free(*agg);
    }
    if (resString.get() == nullptr) {
        *agg = nullptr;
    } else {
        ScopedUtfChars res(env, resString.get());
        *agg = strdup(res.c_str());
    }
}

// Called when a custom function is destroyed.
static void sqliteCustomFunctionDestructor(void* data) {
static void sqliteCustomAggregateFunctionFinal(sqlite3_context *context) {
    // We pass zero size here to avoid allocating for empty sets
    char** agg = reinterpret_cast<char**>(
            sqlite3_aggregate_context(context, 0));
    if (agg == nullptr) {
        return;
    } else if (*agg == nullptr) {
        sqlite3_result_null(context);
    } else {
        sqlite3_result_text(context, *agg, -1, SQLITE_TRANSIENT);
        free(*agg);
    }
}

static void sqliteCustomAggregateFunctionDestructor(void* data) {
    jobject functionObjGlobal = reinterpret_cast<jobject>(data);

    JNIEnv* env = AndroidRuntime::getJNIEnv();
    env->DeleteGlobalRef(functionObjGlobal);
}

static void nativeRegisterCustomFunction(JNIEnv* env, jclass clazz, jlong connectionPtr,
        jobject functionObj) {
static void nativeRegisterCustomAggregateFunction(JNIEnv* env, jclass clazz, jlong connectionPtr,
        jstring functionName, jobject functionObj) {
    SQLiteConnection* connection = reinterpret_cast<SQLiteConnection*>(connectionPtr);

    jstring nameStr = jstring(env->GetObjectField(
            functionObj, gSQLiteCustomFunctionClassInfo.name));
    jint numArgs = env->GetIntField(functionObj, gSQLiteCustomFunctionClassInfo.numArgs);

    jobject functionObjGlobal = env->NewGlobalRef(functionObj);

    const char* name = env->GetStringUTFChars(nameStr, NULL);
    int err = sqlite3_create_function_v2(connection->db, name, numArgs, SQLITE_UTF16,
    ScopedUtfChars functionNameChars(env, functionName);
    int err = sqlite3_create_function_v2(connection->db,
            functionNameChars.c_str(), 1, SQLITE_UTF8,
            reinterpret_cast<void*>(functionObjGlobal),
            &sqliteCustomFunctionCallback, NULL, NULL, &sqliteCustomFunctionDestructor);
    env->ReleaseStringUTFChars(nameStr, name);
            nullptr,
            &sqliteCustomAggregateFunctionStep,
            &sqliteCustomAggregateFunctionFinal,
            &sqliteCustomAggregateFunctionDestructor);

    if (err != SQLITE_OK) {
        ALOGE("sqlite3_create_function returned %d", err);
@@ -812,8 +882,10 @@ static const JNINativeMethod sMethods[] =
            (void*)nativeOpen },
    { "nativeClose", "(J)V",
            (void*)nativeClose },
    { "nativeRegisterCustomFunction", "(JLandroid/database/sqlite/SQLiteCustomFunction;)V",
            (void*)nativeRegisterCustomFunction },
    { "nativeRegisterCustomScalarFunction", "(JLjava/lang/String;Ljava/util/function/UnaryOperator;)V",
            (void*)nativeRegisterCustomScalarFunction },
    { "nativeRegisterCustomAggregateFunction", "(JLjava/lang/String;Ljava/util/function/BinaryOperator;)V",
            (void*)nativeRegisterCustomAggregateFunction },
    { "nativeRegisterLocalizedCollators", "(JLjava/lang/String;)V",
            (void*)nativeRegisterLocalizedCollators },
    { "nativePrepareStatement", "(JLjava/lang/String;)J",
@@ -864,15 +936,13 @@ static const JNINativeMethod sMethods[] =

int register_android_database_SQLiteConnection(JNIEnv *env)
{
    jclass clazz = FindClassOrDie(env, "android/database/sqlite/SQLiteCustomFunction");

    gSQLiteCustomFunctionClassInfo.name = GetFieldIDOrDie(env, clazz, "name", "Ljava/lang/String;");
    gSQLiteCustomFunctionClassInfo.numArgs = GetFieldIDOrDie(env, clazz, "numArgs", "I");
    gSQLiteCustomFunctionClassInfo.dispatchCallback = GetMethodIDOrDie(env, clazz,
            "dispatchCallback", "([Ljava/lang/String;)V");
    jclass unaryClazz = FindClassOrDie(env, "java/util/function/UnaryOperator");
    gUnaryOperator.apply = GetMethodIDOrDie(env, unaryClazz,
            "apply", "(Ljava/lang/Object;)Ljava/lang/Object;");

    clazz = FindClassOrDie(env, "java/lang/String");
    gStringClassInfo.clazz = MakeGlobalRefOrDie(env, clazz);
    jclass binaryClazz = FindClassOrDie(env, "java/util/function/BinaryOperator");
    gBinaryOperator.apply = GetMethodIDOrDie(env, binaryClazz,
            "apply", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");

    return RegisterMethodsOrDie(env, "android/database/sqlite/SQLiteConnection", sMethods,
                                NELEM(sMethods));