Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7ca7a30d authored by TreeHugger Robot's avatar TreeHugger Robot Committed by Android (Google) Code Review
Browse files

Merge "Handle super classes for @Immutable"

parents fd8bf04f 3d7578b7
Loading
Loading
Loading
Loading
+58 −16
Original line number Diff line number Diff line
@@ -37,10 +37,11 @@ val IMMUTABLE_ANNOTATION_NAME = Immutable::class.qualifiedName
class ImmutabilityProcessor : AbstractProcessor() {

    companion object {

        /**
         * Types that are already immutable.
         * Types that are already immutable. Will also ignore subclasses.
         */
        private val IGNORED_TYPES = listOf(
        private val IGNORED_SUPER_TYPES = listOf(
            "java.io.File",
            "java.lang.Boolean",
            "java.lang.Byte",
@@ -56,6 +57,15 @@ class ImmutabilityProcessor : AbstractProcessor() {
            "android.os.Parcelable.Creator",
        )

        /**
         * Types that are already immutable. Must be an exact match, does not include any super
         * or sub classes.
         */
        private val IGNORED_EXACT_TYPES = listOf(
            "java.lang.Class",
            "java.lang.Object",
        )

        private val IGNORED_METHODS = listOf(
            "writeToParcel",
        )
@@ -64,7 +74,8 @@ class ImmutabilityProcessor : AbstractProcessor() {
    private lateinit var collectionType: TypeMirror
    private lateinit var mapType: TypeMirror

    private lateinit var ignoredTypes: List<TypeMirror>
    private lateinit var ignoredSuperTypes: List<TypeMirror>
    private lateinit var ignoredExactTypes: List<TypeMirror>

    private val seenTypesByPolicy = mutableMapOf<Set<Immutable.Policy.Exception>, Set<Type>>()

@@ -76,7 +87,8 @@ class ImmutabilityProcessor : AbstractProcessor() {
        super.init(processingEnv)
        collectionType = processingEnv.erasedType("java.util.Collection")!!
        mapType = processingEnv.erasedType("java.util.Map")!!
        ignoredTypes = IGNORED_TYPES.mapNotNull { processingEnv.erasedType(it) }
        ignoredSuperTypes = IGNORED_SUPER_TYPES.mapNotNull { processingEnv.erasedType(it) }
        ignoredExactTypes = IGNORED_EXACT_TYPES.mapNotNull { processingEnv.erasedType(it) }
    }

    override fun process(
@@ -109,7 +121,7 @@ class ImmutabilityProcessor : AbstractProcessor() {
        classType: Symbol.TypeSymbol,
        parentPolicyExceptions: Set<Immutable.Policy.Exception>,
    ): Boolean {
        if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return false
        if (isIgnored(classType)) return false

        val policyAnnotation = classType.getAnnotation(Immutable.Policy::class.java)
        val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty()
@@ -131,7 +143,7 @@ class ImmutabilityProcessor : AbstractProcessor() {
            .fold(false) { anyError, field ->
                if (field.isStatic) {
                    if (!field.isPrivate) {
                        var finalityError = !field.modifiers.contains(Modifier.FINAL)
                        val finalityError = !field.modifiers.contains(Modifier.FINAL)
                        if (finalityError) {
                            printError(parentChain, field, MessageUtils.staticNonFinalFailure())
                        }
@@ -177,8 +189,10 @@ class ImmutabilityProcessor : AbstractProcessor() {
        val newChain = parentChain + "$classType"

        val hasMethodError = filteredElements
            .asSequence()
            .filter { it.getKind() == ElementKind.METHOD }
            .map { it as Symbol.MethodSymbol }
            .filterNot { it.isStatic }
            .filterNot { IGNORED_METHODS.contains(it.name.toString()) }
            .fold(false) { anyError, method ->
                // Must call visitMethod first so it doesn't get short circuited by the ||
@@ -208,6 +222,14 @@ class ImmutabilityProcessor : AbstractProcessor() {
            }
        }

        // Check all of the super classes, since methods in those classes are also accessible
        (classType as? Symbol.ClassSymbol)?.run {
            (interfaces + superclass).forEach {
                val element = it.asElement() ?: return@forEach
                visitClass(parentChain, seenTypesByPolicy, element, element, newPolicyExceptions)
            }
        }

        if (isRegularClass && !anyError && allowFinalClassesFinalFields &&
            !classType.modifiers.contains(Modifier.FINAL)
        ) {
@@ -301,16 +323,14 @@ class ImmutabilityProcessor : AbstractProcessor() {
        parentPolicyExceptions: Set<Immutable.Policy.Exception>,
        nonInterfaceClassFailure: () -> String = { MessageUtils.nonInterfaceReturnFailure() },
    ): Boolean {
        if (isIgnored(symbol)) return false
        if (isIgnored(type)) return false
        if (type.isPrimitive) return false
        if (type.isPrimitiveOrVoid) {
            printError(parentChain, symbol, MessageUtils.voidReturnFailure())
            return true
        }

        if (ignoredTypes.any { processingEnv.typeUtils.isAssignable(type, it) }) {
            return false
        }

        val policyAnnotation = symbol.getAnnotation(Immutable.Policy::class.java)
        val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty()

@@ -357,16 +377,38 @@ class ImmutabilityProcessor : AbstractProcessor() {
        message: String,
    ) = processingEnv.messager.printMessage(
        Diagnostic.Kind.ERROR,
        // Drop one from the parent chain so that the directly enclosing class isn't logged.
        // It exists in the list at this point in the traversal so that further children can
        // include the right reference.
        parentChain.dropLast(1).joinToString() + "\n\t" + message,
        parentChain.plus(element.simpleName).joinToString() + "\n\t " + message,
        element,
    )

    private fun ProcessingEnvironment.erasedType(typeName: String) =
        elementUtils.getTypeElement(typeName)?.asType()?.let(typeUtils::erasure)

    private fun isIgnored(symbol: Symbol) =
        symbol.getAnnotation(Immutable.Ignore::class.java) != null
    private fun isIgnored(type: Type) =
        (type.getAnnotation(Immutable.Ignore::class.java) != null)
                || (ignoredSuperTypes.any { type.isAssignable(it) })
                || (ignoredExactTypes.any { type.isSameType(it) })

    private fun isIgnored(symbol: Symbol) = when {
        // Anything annotated as @Ignore is always ignored
        symbol.getAnnotation(Immutable.Ignore::class.java) != null -> true
        // Then ignore exact types, regardless of what kind they are
        ignoredExactTypes.any { symbol.type.isSameType(it) } -> true
        // Then only allow methods through, since other types (fields) are usually a failure
        symbol.getKind() != ElementKind.METHOD -> false
        // Finally, check for any ignored super types
        else -> ignoredSuperTypes.any { symbol.type.isAssignable(it) }
    }

    private fun TypeMirror.isAssignable(type: TypeMirror) = try {
        processingEnv.typeUtils.isAssignable(this, type)
    } catch (ignored: Exception) {
        false
    }

    private fun TypeMirror.isSameType(type: TypeMirror) = try {
        processingEnv.typeUtils.isSameType(this, type)
    } catch (ignored: Exception) {
        false
    }
}
+92 −27
Original line number Diff line number Diff line
@@ -90,7 +90,7 @@ class ImmutabilityProcessorTest {

    @Test
    fun validInterface() = test(
        JavaFileObjects.forSourceString(
        source = JavaFileObjects.forSourceString(
            "$PACKAGE_PREFIX.$DATA_CLASS_NAME",
            /* language=JAVA */ """
                package $PACKAGE_PREFIX;
@@ -227,22 +227,85 @@ class ImmutabilityProcessorTest {
            nonInterfaceReturnFailure(line = 9),
            nonInterfaceReturnFailure(line = 10, index = 0),
            classNotFinalFailure(line = 13, "NonFinalClassFinalFields"),
        ), otherErrors = listOf(
            memberNotMethodFailure(line = 4) to FINAL_CLASSES[1],
            memberNotMethodFailure(line = 4) to FINAL_CLASSES[3],
        ), otherErrors = mapOf(
            FINAL_CLASSES[1] to listOf(
                memberNotMethodFailure(line = 4),
            ),
            FINAL_CLASSES[3] to listOf(
                memberNotMethodFailure(line = 4),
            ),
        )
    )

    @Test
    fun superClass() {
        val superClass = JavaFileObjects.forSourceString(
            "$PACKAGE_PREFIX.SuperClass",
            /* language=JAVA */ """
            package $PACKAGE_PREFIX;

            import java.util.List;

            public interface SuperClass {
                InnerClass getInnerClassOne();

                final class InnerClass {
                    public String innerField;
                }
            }
            """.trimIndent()
        )

        val dataClass = JavaFileObjects.forSourceString(
            "$PACKAGE_PREFIX.$DATA_CLASS_NAME",
            /* language=JAVA */ """
            package $PACKAGE_PREFIX;

            import java.util.List;

            @Immutable
            public interface $DATA_CLASS_NAME extends SuperClass {
                String[] getArray();
            }
            """.trimIndent()
        )

        test(
            sources = arrayOf(superClass, dataClass),
            fileToErrors = mapOf(
                superClass to listOf(
                    classNotImmutableFailure(line = 5, className = "SuperClass"),
                    nonInterfaceReturnFailure(line = 6),
                    nonInterfaceClassFailure(8),
                    classNotImmutableFailure(line = 8, className = "InnerClass"),
                    memberNotMethodFailure(line = 9),
                ),
                dataClass to listOf(
                    arrayFailure(line = 7),
                )
            )
        )
    }

    private fun test(
        source: JavaFileObject,
        errors: List<CompilationError>,
        otherErrors: List<Pair<CompilationError, JavaFileObject>> = emptyList(),
        otherErrors: Map<JavaFileObject, List<CompilationError>> = emptyMap(),
    ) = test(
        sources = arrayOf(source),
        fileToErrors = otherErrors + (source to errors),
    )

    private fun test(
        vararg sources: JavaFileObject,
        fileToErrors: Map<JavaFileObject, List<CompilationError>> = emptyMap(),
    ) {
        val compilation = javac()
            .withProcessors(ImmutabilityProcessor())
            .compile(FINAL_CLASSES + ANNOTATION + listOf(source))
        val allErrors = otherErrors + errors.map { it to source }
        allErrors.forEach { (error, file) ->
            .compile(FINAL_CLASSES + ANNOTATION + sources)

        fileToErrors.forEach { (file, errors) ->
            errors.forEach { error ->
                try {
                    assertThat(compilation)
                        .hadErrorContaining(error.message)
@@ -260,16 +323,18 @@ class ImmutabilityProcessorTest {
                    expect.that(wrapped).isNull()
                }
            }
        }

        try {
            assertThat(compilation).hadErrorCount(allErrors.size)
        } catch (e: AssertionError) {
        expect.that(compilation.errors().size).isEqualTo(fileToErrors.values.sumOf { it.size })

        if (expect.hasFailures()) {
            expect.withMessage(
                compilation.errors()
                    .sortedBy { it.lineNumber }
                    .joinToString(separator = "\n") {
                        "${it.lineNumber}: ${it.getMessage(Locale.ENGLISH)?.trim()}"
                    }
            ).that(e).isNull()
            ).fail()
        }
    }