Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 19171c18 authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Support final classes with final fields for @Immutable"

parents 0d011bdb 944239c5
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -54,5 +54,14 @@ java_test_host {
        "ImmutabilityAnnotationProcessorHostLibrary",
    ],

    // Bundle the source file so it can be loaded into the test compiler
    java_resources: [":ImmutabilityAnnotationJavaSource"],

    test_suites: ["general-tests"],
}

filegroup {
    name: "ImmutabilityAnnotationJavaSource",
    srcs: ["src/android/processor/immutability/Immutable.java"],
    path: "src/android/processor/immutability/",
}
+177 −60
Original line number Diff line number Diff line
@@ -51,6 +51,11 @@ class ImmutabilityProcessor : AbstractProcessor() {
            "java.lang.Short",
            "java.lang.String",
            "java.lang.Void",
            "android.os.Parcelable.Creator",
        )

        private val IGNORED_METHODS = listOf(
            "writeToParcel",
        )
    }

@@ -59,7 +64,7 @@ class ImmutabilityProcessor : AbstractProcessor() {

    private lateinit var ignoredTypes: List<TypeMirror>

    private val seenTypes = mutableSetOf<Type>()
    private val seenTypesByPolicy = mutableMapOf<Set<Immutable.Policy.Exception>, Set<Type>>()

    override fun getSupportedSourceVersion() = SourceVersion.latest()!!

@@ -67,9 +72,9 @@ class ImmutabilityProcessor : AbstractProcessor() {

    override fun init(processingEnv: ProcessingEnvironment) {
        super.init(processingEnv)
        collectionType = processingEnv.erasedType("java.util.Collection")
        mapType = processingEnv.erasedType("java.util.Map")
        ignoredTypes = IGNORED_TYPES.map { processingEnv.elementUtils.getTypeElement(it).asType() }
        collectionType = processingEnv.erasedType("java.util.Collection")!!
        mapType = processingEnv.erasedType("java.util.Map")!!
        ignoredTypes = IGNORED_TYPES.mapNotNull { processingEnv.erasedType(it) }
    }

    override fun process(
@@ -80,72 +85,146 @@ class ImmutabilityProcessor : AbstractProcessor() {
            it.qualifiedName.toString() == IMMUTABLE_ANNOTATION_NAME
        } ?: return false
        roundEnvironment.getElementsAnnotatedWith(Immutable::class.java)
            .forEach { visitClass(emptyList(), seenTypes, it, it as Symbol.TypeSymbol) }
            .forEach {
                visitClass(
                    parentChain = emptyList(),
                    seenTypesByPolicy = seenTypesByPolicy,
                    elementToPrint = it,
                    classType = it as Symbol.TypeSymbol,
                    parentPolicyExceptions = emptySet()
                )
            }
        return true
    }

    /**
     * @return true if any error was encountered at this level or any child level
     */
    private fun visitClass(
        parentChain: List<String>,
        seenTypes: MutableSet<Type>,
        seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>,
        elementToPrint: Element,
        classType: Symbol.TypeSymbol,
    ) {
        if (!seenTypes.add(classType.asType())) return
        if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return
        parentPolicyExceptions: Set<Immutable.Policy.Exception>,
    ): Boolean {
        if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return false

        if (classType.getAnnotation(Immutable::class.java) == null) {
            printError(parentChain, elementToPrint,
                MessageUtils.classNotImmutableFailure(classType.simpleName.toString()))
        }
        val policyAnnotation = classType.getAnnotation(Immutable.Policy::class.java)
        val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty()

        if (classType.getKind() != ElementKind.INTERFACE) {
            printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure())
        }
        // If already seen this type with the same policies applied, skip it
        val seenTypes = seenTypesByPolicy[newPolicyExceptions]
        val type = classType.asType()
        if (seenTypes?.contains(type) == true) return false
        seenTypesByPolicy[newPolicyExceptions] = seenTypes.orEmpty() + type

        val allowFinalClassesFinalFields =
            newPolicyExceptions.contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS)

        val filteredElements = classType.enclosedElements
            .filterNot(::isIgnored)

        filteredElements
        val hasFieldError = filteredElements
            .filter { it.getKind() == ElementKind.FIELD }
            .forEach {
                if (it.isStatic) {
                    if (!it.isPrivate) {
                        if (!it.modifiers.contains(Modifier.FINAL)) {
                            printError(parentChain, it, MessageUtils.staticNonFinalFailure())
            .fold(false) { anyError, field ->
                if (field.isStatic) {
                    if (!field.isPrivate) {
                        var finalityError = !field.modifiers.contains(Modifier.FINAL)
                        if (finalityError) {
                            printError(parentChain, field, MessageUtils.staticNonFinalFailure())
                        }

                        visitType(parentChain, seenTypes, it, it.type)
                        // Must call visitType first so it doesn't get short circuited by the ||
                        visitType(
                            parentChain = parentChain,
                            seenTypesByPolicy = seenTypesByPolicy,
                            symbol = field,
                            type = field.type,
                            parentPolicyExceptions = parentPolicyExceptions
                        ) || anyError || finalityError
                    }
                    return@fold anyError
                } else {
                    printError(parentChain, it, MessageUtils.memberNotMethodFailure())
                    val isFinal = field.modifiers.contains(Modifier.FINAL)
                    if (!isFinal || !allowFinalClassesFinalFields) {
                        printError(parentChain, field, MessageUtils.memberNotMethodFailure())
                        return@fold true
                    }

                    return@fold anyError
                }
            }

        // Scan inner classes before methods so that any violations isolated to the file prints
        // the error on the class declaration rather than on the method that returns the type.
        // Although it doesn't matter too much either way.
        filteredElements
        val hasClassError = filteredElements
            .filter { it.getKind() == ElementKind.CLASS }
            .map { it as Symbol.ClassSymbol }
            .forEach {
                visitClass(parentChain, seenTypes, it, it)
            .fold(false) { anyError, innerClass ->
                // Must call visitClass first so it doesn't get short circuited by the ||
                visitClass(
                    parentChain,
                    seenTypesByPolicy,
                    innerClass,
                    innerClass,
                    newPolicyExceptions
                ) || anyError
            }

        val newChain = parentChain + "$classType"

        filteredElements
        val hasMethodError = filteredElements
            .filter { it.getKind() == ElementKind.METHOD }
            .map { it as Symbol.MethodSymbol }
            .forEach {
                visitMethod(newChain, seenTypes, it)
            .filterNot { IGNORED_METHODS.contains(it.name.toString()) }
            .fold(false) { anyError, method ->
                // Must call visitMethod first so it doesn't get short circuited by the ||
                visitMethod(newChain, seenTypesByPolicy, method, newPolicyExceptions) || anyError
            }

        val className = classType.simpleName.toString()
        val isRegularClass = classType.getKind() == ElementKind.CLASS

        var anyError = hasFieldError || hasClassError || hasMethodError

        // If final classes are not considered OR there's a non-field failure, also check for
        // interface/@Immutable, assuming the class is malformed
        if ((isRegularClass && !allowFinalClassesFinalFields) || hasMethodError || hasClassError) {
            if (classType.getAnnotation(Immutable::class.java) == null) {
                printError(
                    parentChain,
                    elementToPrint,
                    MessageUtils.classNotImmutableFailure(className)
                )
                anyError = true
            }

            if (classType.getKind() != ElementKind.INTERFACE) {
                printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure())
                anyError = true
            }
        }

        if (isRegularClass && !anyError && allowFinalClassesFinalFields
            && !classType.modifiers.contains(Modifier.FINAL)
        ) {
            printError(parentChain, elementToPrint, MessageUtils.classNotFinalFailure(className))
            return true
        }

        return anyError
    }

    /**
     * @return true if any error was encountered at this level or any child level
     */
    private fun visitMethod(
        parentChain: List<String>,
        seenTypes: MutableSet<Type>,
        seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>,
        method: Symbol.MethodSymbol,
    ) {
        parentPolicyExceptions: Set<Immutable.Policy.Exception>,
    ): Boolean {
        val returnType = method.returnType
        val typeName = returnType.toString()
        when (returnType.kind) {
@@ -164,13 +243,21 @@ class ImmutabilityProcessor : AbstractProcessor() {
            TypeKind.VOID -> {
                if (!method.isConstructor) {
                    printError(parentChain, method, MessageUtils.voidReturnFailure())
                    return true
                }
            }
            TypeKind.ARRAY -> {
                printError(parentChain, method, MessageUtils.arrayFailure())
                return true
            }
            TypeKind.DECLARED -> {
                visitType(parentChain, seenTypes, method, method.returnType)
                return visitType(
                    parentChain,
                    seenTypesByPolicy,
                    method,
                    method.returnType,
                    parentPolicyExceptions
                )
            }
            TypeKind.ERROR,
            TypeKind.TYPEVAR,
@@ -182,54 +269,84 @@ class ImmutabilityProcessor : AbstractProcessor() {
            TypeKind.INTERSECTION,
                // Java 9+
                // TypeKind.MODULE,
            null -> printError(parentChain, method,
                MessageUtils.genericTypeKindFailure(typeName = typeName))
            else -> printError(parentChain, method,
                MessageUtils.genericTypeKindFailure(typeName = typeName))
            null -> {
                printError(
                    parentChain, method,
                    MessageUtils.genericTypeKindFailure(typeName = typeName)
                )
                return true
            }
            else -> {
                printError(
                    parentChain, method,
                    MessageUtils.genericTypeKindFailure(typeName = typeName)
                )
                return true
            }
        }

        return false
    }

    /**
     * @return true if any error was encountered at this level or any child level
     */
    private fun visitType(
        parentChain: List<String>,
        seenTypes: MutableSet<Type>,
        seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>,
        symbol: Symbol,
        type: Type,
        parentPolicyExceptions: Set<Immutable.Policy.Exception>,
        nonInterfaceClassFailure: () -> String = { MessageUtils.nonInterfaceReturnFailure() },
    ) {
        if (type.isPrimitive) return
    ): Boolean {
        if (type.isPrimitive) return false
        if (type.isPrimitiveOrVoid) {
            printError(parentChain, symbol, MessageUtils.voidReturnFailure())
            return
            return true
        }

        if (ignoredTypes.any { processingEnv.typeUtils.isSameType(it, type) }) {
            return
        if (ignoredTypes.any { processingEnv.typeUtils.isAssignable(type, it) }) {
            return false
        }

        val policyAnnotation = symbol.getAnnotation(Immutable.Policy::class.java)
        val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty()

        // Collection (and Map) types are ignored for the interface check as they have immutability
        // enforced through a runtime exception which must be verified in a separate runtime test
        val isMap = processingEnv.typeUtils.isAssignable(type, mapType)
        if (!processingEnv.typeUtils.isAssignable(type, collectionType) && !isMap) {
            if (type.isInterface) {
                visitClass(parentChain, seenTypes, symbol,
                    processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol)
            } else {
            if (!type.isInterface && !newPolicyExceptions
                    .contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS)
            ) {
                printError(parentChain, symbol, nonInterfaceClassFailure())
                // If the type already isn't an interface, don't scan deeper children
                // to avoid printing an excess amount of errors for a known bad type.
                return
                return true
            } else {
                return visitClass(
                    parentChain, seenTypesByPolicy, symbol,
                    processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol,
                    newPolicyExceptions,
                )
            }
        }

        var anyError = false

        type.typeArguments.forEachIndexed { index, typeArg ->
            visitType(parentChain, seenTypes, symbol, typeArg) {
                MessageUtils.nonInterfaceReturnFailure(prefix = when {
            val argError =
                visitType(parentChain, seenTypesByPolicy, symbol, typeArg, newPolicyExceptions) {
                    MessageUtils.nonInterfaceReturnFailure(
                        prefix = when {
                            !isMap -> ""
                            index == 0 -> "Key " + typeArg.asElement().simpleName
                            else -> "Value " + typeArg.asElement().simpleName
                }, index = index)
                        }, index = index
                    )
                }
            anyError = anyError || argError
        }

        return anyError
    }

    private fun printError(
@@ -246,7 +363,7 @@ class ImmutabilityProcessor : AbstractProcessor() {
    )

    private fun ProcessingEnvironment.erasedType(typeName: String) =
        typeUtils.erasure(elementUtils.getTypeElement(typeName).asType())
        elementUtils.getTypeElement(typeName)?.asType()?.let(typeUtils::erasure)

    private fun isIgnored(symbol: Symbol) =
        symbol.getAnnotation(Immutable.Ignore::class.java) != null
+23 −0
Original line number Diff line number Diff line
@@ -48,4 +48,27 @@ public @interface Immutable {
    @interface Ignore {
        String reason() default "";
    }

    /**
     * Marks an element and its reachable children with a specific policy.
     */
    @Retention(RetentionPolicy.CLASS)
    @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD})
    @interface Policy {
        Exception[] exceptions() default {};

        enum Exception {
            /**
             * Allow final classes with only final fields. By default these are not allowed because
             * direct field access disallows hard removal of APIs (by having their getters return
             * mocks/stubs) and also prevents field compaction, which can occur with booleans
             * stuffed into a number as flags.
             *
             * This exception is allowed though because several framework classes are built around
             * the final field access model and it would be unnecessarily difficult to migrate or
             * wrap each type.
             */
            FINAL_CLASSES_WITH_FINAL_FIELDS,
        }
    }
}
+2 −0
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ object MessageUtils {

    fun classNotImmutableFailure(className: String) = "$className should be marked @Immutable"

    fun classNotFinalFailure(className: String) = "$className should be marked final"

    fun memberNotMethodFailure() = "Member must be a method"

    fun nonInterfaceClassFailure() = "Class was not an interface"
+123 −30
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ import com.google.testing.compile.Compiler.javac
import com.google.testing.compile.JavaFileObjects
import org.junit.Rule
import org.junit.Test
import java.util.*
import javax.tools.JavaFileObject

class ImmutabilityProcessorTest {
@@ -32,20 +33,56 @@ class ImmutabilityProcessorTest {
    companion object {
        private const val PACKAGE_PREFIX = "android.processor.immutability"
        private const val DATA_CLASS_NAME = "DataClass"
        private val ANNOTATION = JavaFileObjects.forSourceString(IMMUTABLE_ANNOTATION_NAME,
        private val ANNOTATION = JavaFileObjects.forResource("Immutable.java")

        private val FINAL_CLASSES = listOf(
            JavaFileObjects.forSourceString(
                "$PACKAGE_PREFIX.NonFinalClassFinalFields",
                /* language=JAVA */ """
                    package $PACKAGE_PREFIX;

                    public class NonFinalClassFinalFields {
                        private final String finalField;
                        public NonFinalClassFinalFields(String value) {
                            this.finalField = value;
                        }
                    }
                """.trimIndent()
            ),
            JavaFileObjects.forSourceString(
                "$PACKAGE_PREFIX.NonFinalClassNonFinalFields",
                /* language=JAVA */ """
                    package $PACKAGE_PREFIX;

                    public class NonFinalClassNonFinalFields {
                        private String nonFinalField;
                    }
                """.trimIndent()
            ),
            JavaFileObjects.forSourceString(
                "$PACKAGE_PREFIX.FinalClassFinalFields",
                /* language=JAVA */ """
                    package $PACKAGE_PREFIX;

                import java.lang.annotation.Retention;
                import java.lang.annotation.RetentionPolicy;
                    public final class FinalClassFinalFields {
                        private final String finalField;
                        public FinalClassFinalFields(String value) {
                            this.finalField = value;
                        }
                    }
                """.trimIndent()
            ),
            JavaFileObjects.forSourceString(
                "$PACKAGE_PREFIX.FinalClassNonFinalFields",
                /* language=JAVA */ """
                    package $PACKAGE_PREFIX;

                @Retention(RetentionPolicy.SOURCE)
                public @interface Immutable {
                    @Retention(RetentionPolicy.SOURCE)
                    @interface Ignore {}
                    public final class FinalClassNonFinalFields {
                        private String nonFinalField;
                    }
                """.trimIndent()
            )
        )
    }

    @get:Rule
@@ -53,7 +90,8 @@ class ImmutabilityProcessorTest {

    @Test
    fun validInterface() = test(
        JavaFileObjects.forSourceString("$PACKAGE_PREFIX.$DATA_CLASS_NAME",
        JavaFileObjects.forSourceString(
            "$PACKAGE_PREFIX.$DATA_CLASS_NAME",
            /* language=JAVA */ """
                package $PACKAGE_PREFIX;

@@ -86,11 +124,13 @@ class ImmutabilityProcessorTest {
                    }
                }
                """.trimIndent()
        ), errors = emptyList())
        ), errors = emptyList()
    )

    @Test
    fun abstractClass() = test(
        JavaFileObjects.forSourceString("$PACKAGE_PREFIX.$DATA_CLASS_NAME",
        JavaFileObjects.forSourceString(
            "$PACKAGE_PREFIX.$DATA_CLASS_NAME",
            /* language=JAVA */ """
                package $PACKAGE_PREFIX;

@@ -140,30 +180,77 @@ class ImmutabilityProcessorTest {
            arrayFailure(line = 17),
            nonInterfaceReturnFailure(line = 18),
            nonInterfaceReturnFailure(line = 19),
            classNotImmutableFailure(line = 22, className = "InnerInterface"),
            nonInterfaceReturnFailure(line = 25, prefix = "Key InnerClass"),
            nonInterfaceReturnFailure(line = 25, prefix = "Value InnerClass"),
            classNotImmutableFailure(line = 27, className = "InnerClass"),
            nonInterfaceClassFailure(line = 27),
            memberNotMethodFailure(line = 28),
            arrayFailure(line = 29),
            classNotImmutableFailure(line = 22, className = "InnerInterface"),
            arrayFailure(line = 33),
            nonInterfaceReturnFailure(line = 34),
        ))
        )
    )

    @Test
    fun finalClasses() = test(
        JavaFileObjects.forSourceString(
            "$PACKAGE_PREFIX.$DATA_CLASS_NAME",
            /* language=JAVA */ """
            package $PACKAGE_PREFIX;

            import java.util.List;

            @Immutable
            public interface $DATA_CLASS_NAME {
                NonFinalClassFinalFields getNonFinalFinal();
                List<NonFinalClassNonFinalFields> getNonFinalNonFinal();
                FinalClassFinalFields getFinalFinal();
                List<FinalClassNonFinalFields> getFinalNonFinal();

                @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS})
                NonFinalClassFinalFields getPolicyNonFinalFinal();

    private fun test(source: JavaFileObject, errors: List<CompilationError>) {
                @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS})
                List<NonFinalClassNonFinalFields> getPolicyNonFinalNonFinal();

                @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS})
                FinalClassFinalFields getPolicyFinalFinal();

                @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS})
                List<FinalClassNonFinalFields> getPolicyFinalNonFinal();
            }
            """.trimIndent()
        ), errors = listOf(
            nonInterfaceReturnFailure(line = 7),
            nonInterfaceReturnFailure(line = 8, index = 0),
            nonInterfaceReturnFailure(line = 9),
            nonInterfaceReturnFailure(line = 10, index = 0),
            classNotFinalFailure(line = 13, "NonFinalClassFinalFields"),
        ), otherErrors = listOf(
            memberNotMethodFailure(line = 4) to FINAL_CLASSES[1],
            memberNotMethodFailure(line = 4) to FINAL_CLASSES[3],
        )
    )

    private fun test(
        source: JavaFileObject,
        errors: List<CompilationError>,
        otherErrors: List<Pair<CompilationError, JavaFileObject>> = emptyList(),
    ) {
        val compilation = javac()
            .withProcessors(ImmutabilityProcessor())
            .compile(listOf(source) + ANNOTATION)
        errors.forEach {
            .compile(FINAL_CLASSES + ANNOTATION + listOf(source))
        val allErrors = otherErrors + errors.map { it to source }
        allErrors.forEach { (error, file) ->
            try {
                assertThat(compilation)
                    .hadErrorContaining(it.message)
                    .inFile(source)
                    .onLine(it.line)
                    .hadErrorContaining(error.message)
                    .inFile(file)
                    .onLine(error.line)
            } catch (e: AssertionError) {
                // Wrap the exception so that the line number is logged
                val wrapped = AssertionError("Expected $it, ${e.message}").apply {
                val wrapped = AssertionError("Expected $error, ${e.message}").apply {
                    stackTrace = e.stackTrace
                }

@@ -175,11 +262,14 @@ class ImmutabilityProcessorTest {
        }

        try {
            assertThat(compilation).hadErrorCount(errors.size)
            assertThat(compilation).hadErrorCount(allErrors.size)
        } catch (e: AssertionError) {
            if (expect.hasFailures()) {
                expect.that(e).isNull()
            } else throw e
            expect.withMessage(
                compilation.errors()
                    .joinToString(separator = "\n") {
                        "${it.lineNumber}: ${it.getMessage(Locale.ENGLISH)?.trim()}"
                    }
            ).that(e).isNull()
        }
    }

@@ -192,7 +282,7 @@ class ImmutabilityProcessorTest {
    private fun nonInterfaceReturnFailure(line: Long) =
        CompilationError(line = line, message = MessageUtils.nonInterfaceReturnFailure())

    private fun nonInterfaceReturnFailure(line: Long, prefix: String, index: Int = -1) =
    private fun nonInterfaceReturnFailure(line: Long, prefix: String = "", index: Int = -1) =
        CompilationError(
            line = line,
            message = MessageUtils.nonInterfaceReturnFailure(prefix = prefix, index = index)
@@ -210,6 +300,9 @@ class ImmutabilityProcessorTest {
    private fun arrayFailure(line: Long) =
        CompilationError(line = line, message = MessageUtils.arrayFailure())

    private fun classNotFinalFailure(line: Long, className: String) =
        CompilationError(line = line, message = MessageUtils.classNotFinalFailure(className))

    data class CompilationError(
        val line: Long,
        val message: String,