Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6f8d9243 authored by Hawkwood Glazier's avatar Hawkwood Glazier Committed by Android (Google) Code Review
Browse files

Merge "Protect base types by delegating to a proxy implementation" into main

parents 569c9945 aa6c7fb7
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -22,13 +22,14 @@ import com.android.compose.animation.scene.ElementKey
import com.android.compose.animation.scene.Key
import com.android.compose.animation.scene.MovableElementContentScope
import com.android.compose.animation.scene.MovableElementKey
import com.android.systemui.plugins.annotations.ProtectedBaseInterface
import com.android.systemui.plugins.annotations.ProtectedInterface
import com.android.systemui.plugins.annotations.SimpleProperty
import com.android.systemui.plugins.annotations.ThrowsOnFailure

/** Element Composable together with some metadata about the function. */
@Stable
@ProtectedInterface
@ProtectedBaseInterface
interface BaseLockscreenElement {
    @get:SimpleProperty
    /** Key of identifying this lockscreen element */
+7 −1
Original line number Diff line number Diff line
@@ -136,6 +136,9 @@ class JavaClassWriter(val writer: TabbedWriter, val className: String) : TabbedW
            }
        ) {
            writer.args()
            if (!writer.isFirstArg) {
                completeLine("")
            }
        }

        braceBlock { writer.contents() }
@@ -154,7 +157,10 @@ class JavaMethodWriter(
    val isVoid: Boolean = returnType == "void"

    var isFirstArg = true
        private set

    var callArgs = StringBuilder()
        private set

    fun arg(name: String, type: String) {
        if (!isFirstArg) {
+84 −28
Original line number Diff line number Diff line
@@ -55,21 +55,37 @@ class ProtectedPluginProcessor : AbstractProcessor() {
    }

    override fun getSupportedAnnotationTypes(): Set<String> =
        setOf("com.android.systemui.plugins.annotations.ProtectedInterface")
        setOf(
            "com.android.systemui.plugins.annotations.ProtectedInterface",
            "com.android.systemui.plugins.annotations.ProtectedBaseInterface",
        )

    private data class TargetData(
        val attribute: TypeElement,
        val sourceType: Element,
        val sourcePkg: String,
        val sourceName: String,
        val outputName: String,
        val exTypeAttr: ProtectedInterface,
        val writeProxyType: Boolean,
    )

    override fun process(annotations: Set<TypeElement>, roundEnv: RoundEnvironment): Boolean {
        val targets = mutableMapOf<String, TargetData>() // keyed by fully-qualified source name
        val genImports = mutableListOf<String>()
        for (attr in annotations) {
            val writeProxy =
                when (attr.simpleName.toString()) {
                    "ProtectedInterface" -> true
                    "ProtectedBaseInterface" -> false
                    else -> {
                        procEnv.messager.printMessage(
                            Kind.ERROR,
                            "${attr.qualifiedName} is not recognized by this processor",
                        )
                        false
                    }
                }

            for (target in roundEnv.getElementsAnnotatedWith(attr)) {
                // Find the target exception types to be used
                var exTypeAttr = target.getAnnotation(ProtectedInterface::class.java)
@@ -82,7 +98,14 @@ class ProtectedPluginProcessor : AbstractProcessor() {
                val pkg = (target.getEnclosingElement() as PackageElement).qualifiedName.toString()
                targets.put(
                    "$target",
                    TargetData(attr, target, pkg, sourceName, outputName, exTypeAttr),
                    TargetData(
                        sourceType = target,
                        sourcePkg = pkg,
                        sourceName = sourceName,
                        outputName = outputName,
                        exTypeAttr = exTypeAttr,
                        writeProxyType = writeProxy,
                    ),
                )

                // This creates excessive imports, but it should be fine
@@ -92,11 +115,12 @@ class ProtectedPluginProcessor : AbstractProcessor() {
        }

        if (targets.size <= 0) return false
        for ((_, sourceType, sourcePkg, sourceName, outputName, exTypeAttr) in targets.values) {
        for (target in targets.values) {
            // Find all methods in this type and all super types to that need to be implemented
            val types = ArrayDeque<TypeMirror>().apply { addLast(sourceType.asType()) }
            val impAttrs = mutableListOf<GeneratedImport>()
            val methods = mutableListOf<ExecutableElement>()
            if (target.writeProxyType) {
                val types = ArrayDeque<TypeMirror>().apply { addLast(target.sourceType.asType()) }
                while (types.size > 0) {
                    val typeMirror = types.removeLast()
                    if (typeMirror.toString() == "java.lang.Object") continue
@@ -115,17 +139,46 @@ class ProtectedPluginProcessor : AbstractProcessor() {
                    impAttrs.addAll(type.getAnnotationsByType(GeneratedImport::class.java))
                    types.addAll(procEnv.typeUtils.directSupertypes(typeMirror))
                }
            }

            val sourceName = target.sourceName
            val outputName = target.outputName
            val file = procEnv.filer.createSourceFile("$outputName")
            JavaFileWriter.writeTo(file.openWriter()) {
                pkg(sourcePkg)
                pkg(target.sourcePkg)
                imports(
                    BASIC_IMPORTS,
                    genImports,
                    exTypeAttr.exTypes.toList(),
                    target.exTypeAttr.exTypes.toList(),
                    impAttrs.map { it.extraImport },
                )

                if (!target.writeProxyType) {
                    cls(outputName) {
                        line("private static final String CLASS = \"$sourceName\";")
                        line("private static final String TAG = \"$outputName\";")
                        constructor(visibility = "private")

                        method(
                            "protect",
                            isStatic = true,
                            returnType = sourceName,
                            args = {
                                arg("src", "$sourceName")
                                arg("listener", "ProtectedPluginListener")
                            },
                        ) {
                            line("$sourceName result = PluginProtector.tryProtect(src, listener);")
                            line("if (result != null)")
                            line("    return result;")
                            line()
                            line("Log.wtf(TAG, \"Failed to protect: \" + src);")
                            line("return src;")
                        }
                    }
                    return@writeTo
                }

                cls(outputName, interfaces = listOf(sourceName, "PluginWrapper<$sourceName>")) {
                    line("private static final String CLASS = \"$sourceName\";")
                    line("private static final String TAG = \"$outputName\";")
@@ -136,13 +189,13 @@ class ProtectedPluginProcessor : AbstractProcessor() {
                        isStatic = true,
                        returnType = outputName,
                        args = {
                            arg("instance", "$sourceName")
                            arg("src", "$sourceName")
                            arg("listener", "ProtectedPluginListener")
                        },
                    ) {
                        line("if (instance instanceof $outputName)")
                        line("    return ($outputName)instance;")
                        line("return new $outputName(instance, listener);")
                        line("if (src instanceof $outputName)")
                        line("    return ($outputName)src;")
                        line("return new $outputName(src, listener);")
                    }

                    // Member Fields
@@ -190,7 +243,7 @@ class ProtectedPluginProcessor : AbstractProcessor() {
                            line("/*")
                        }

                        writeProxyMethodImpl(sourceName, method, exTypeAttr, targets)
                        writeProxyMethodImpl(sourceName, method, target.exTypeAttr, targets)

                        if (skipReason != null) {
                            line("*/")
@@ -223,6 +276,7 @@ class ProtectedPluginProcessor : AbstractProcessor() {
                parenBlock("private static final Map<Class, Factory> sFactories = Map.ofEntries") {
                    var isFirst = true
                    for (target in targets.values) {
                        if (!target.writeProxyType) continue
                        if (!isFirst) completeLine(",")
                        target.apply {
                            startLine("entry($sourceName.class, ")
@@ -346,7 +400,8 @@ class ProtectedPluginProcessor : AbstractProcessor() {
                        targets.containsKey(returnGenericArgs[0]) -> {
                        val listArg = returnGenericArgs[0].substringAfterLast(".")
                        val targetType = targets.get(returnGenericArgs[0])!!.outputName
                        line("$returnType source = $nestedCall;")
                        val listType = returnType.substringBefore("<").substringAfterLast(".")
                        line("$listType<$listArg> source = $nestedCall;")
                        line("ArrayList<$listArg> dest = new ArrayList<$listArg>();")
                        braceBlock("for ($listArg item : source)") {
                            line("dest.add($targetType.protect(item, mListener));")
@@ -446,6 +501,7 @@ class ProtectedPluginProcessor : AbstractProcessor() {
            listOf(
                "android.util.Log",
                "com.android.systemui.plugins.PluginWrapper",
                "com.android.systemui.plugins.PluginProtector",
                "com.android.systemui.plugins.ProtectedPluginListener",
            ) + LIST_TYPES

+10 −9
Original line number Diff line number Diff line
@@ -61,7 +61,6 @@ class TabbedWriterImpl(private val target: BufferedWriter) : TabbedWriter {

    override fun completeLine(str: String) {
        if (!isInProgress) {
            target.newLine()
            target.append("    ".repeat(tabCount))
        }

@@ -90,31 +89,33 @@ class TabbedWriterImpl(private val target: BufferedWriter) : TabbedWriter {
    }

    override fun braceBlock(str: String, write: TabbedWriter.() -> Unit) {
        block(str, " {", "}", true, write)
        block(str, " {", "}", newLine = true, write)
    }

    override fun parenBlock(str: String, write: TabbedWriter.() -> Unit) {
        block(str, "(", ")", false, write)
        block(str, "(", ")", newLine = false, write)
    }

    private fun block(
        str: String,
        start: String,
        end: String,
        newLineForEnd: Boolean,
        newLine: Boolean,
        write: TabbedWriter.() -> Unit,
    ) {
        appendLine(str)
        completeLine(start)
        if (str != "") {
            startLine(str)
        }
        appendLine(start)

        tabCount++
        this.write()
        tabCount--

        if (newLineForEnd) {
            line(end)
        if (newLine) {
            completeLine(end)
        } else {
            startLine(end)
            appendLine(end)
        }
    }
}
+2 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ package com.android.systemui.plugins;

import android.content.Context;

import com.android.systemui.plugins.annotations.ProtectedBaseInterface;
import com.android.systemui.plugins.annotations.ProtectedReturn;
import com.android.systemui.plugins.annotations.Requires;

@@ -111,6 +112,7 @@ import com.android.systemui.plugins.annotations.Requires;
 * }
 * </pre>
 */
@ProtectedBaseInterface
public interface Plugin {

    /**
Loading