Loading tools/processors/immutability/Android.bp +9 −0 Original line number Diff line number Diff line Loading @@ -54,5 +54,14 @@ java_test_host { "ImmutabilityAnnotationProcessorHostLibrary", ], // Bundle the source file so it can be loaded into the test compiler java_resources: [":ImmutabilityAnnotationJavaSource"], test_suites: ["general-tests"], } filegroup { name: "ImmutabilityAnnotationJavaSource", srcs: ["src/android/processor/immutability/Immutable.java"], path: "src/android/processor/immutability/", } tools/processors/immutability/src/android/processor/immutability/ImmutabilityProcessor.kt +177 −60 Original line number Diff line number Diff line Loading @@ -51,6 +51,11 @@ class ImmutabilityProcessor : AbstractProcessor() { "java.lang.Short", "java.lang.String", "java.lang.Void", "android.os.Parcelable.Creator", ) private val IGNORED_METHODS = listOf( "writeToParcel", ) } Loading @@ -59,7 +64,7 @@ class ImmutabilityProcessor : AbstractProcessor() { private lateinit var ignoredTypes: List<TypeMirror> private val seenTypes = mutableSetOf<Type>() private val seenTypesByPolicy = mutableMapOf<Set<Immutable.Policy.Exception>, Set<Type>>() override fun getSupportedSourceVersion() = SourceVersion.latest()!! Loading @@ -67,9 +72,9 @@ class ImmutabilityProcessor : AbstractProcessor() { override fun init(processingEnv: ProcessingEnvironment) { super.init(processingEnv) collectionType = processingEnv.erasedType("java.util.Collection") mapType = processingEnv.erasedType("java.util.Map") ignoredTypes = IGNORED_TYPES.map { processingEnv.elementUtils.getTypeElement(it).asType() } collectionType = processingEnv.erasedType("java.util.Collection")!! mapType = processingEnv.erasedType("java.util.Map")!! ignoredTypes = IGNORED_TYPES.mapNotNull { processingEnv.erasedType(it) } } override fun process( Loading @@ -80,72 +85,146 @@ class ImmutabilityProcessor : AbstractProcessor() { it.qualifiedName.toString() == IMMUTABLE_ANNOTATION_NAME } ?: return false roundEnvironment.getElementsAnnotatedWith(Immutable::class.java) .forEach { visitClass(emptyList(), seenTypes, it, it as Symbol.TypeSymbol) } .forEach { visitClass( parentChain = emptyList(), seenTypesByPolicy = seenTypesByPolicy, elementToPrint = it, classType = it as Symbol.TypeSymbol, parentPolicyExceptions = emptySet() ) } return true } /** * @return true if any error was encountered at this level or any child level */ private fun visitClass( parentChain: List<String>, seenTypes: MutableSet<Type>, seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, elementToPrint: Element, classType: Symbol.TypeSymbol, ) { if (!seenTypes.add(classType.asType())) return if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return parentPolicyExceptions: Set<Immutable.Policy.Exception>, ): Boolean { if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return false if (classType.getAnnotation(Immutable::class.java) == null) { printError(parentChain, elementToPrint, MessageUtils.classNotImmutableFailure(classType.simpleName.toString())) } val policyAnnotation = classType.getAnnotation(Immutable.Policy::class.java) val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty() if (classType.getKind() != ElementKind.INTERFACE) { printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure()) } // If already seen this type with the same policies applied, skip it val seenTypes = seenTypesByPolicy[newPolicyExceptions] val type = classType.asType() if (seenTypes?.contains(type) == true) return false seenTypesByPolicy[newPolicyExceptions] = seenTypes.orEmpty() + type val allowFinalClassesFinalFields = newPolicyExceptions.contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS) val filteredElements = classType.enclosedElements .filterNot(::isIgnored) filteredElements val hasFieldError = filteredElements .filter { it.getKind() == ElementKind.FIELD } .forEach { if (it.isStatic) { if (!it.isPrivate) { if (!it.modifiers.contains(Modifier.FINAL)) { printError(parentChain, it, MessageUtils.staticNonFinalFailure()) .fold(false) { anyError, field -> if (field.isStatic) { if (!field.isPrivate) { var finalityError = !field.modifiers.contains(Modifier.FINAL) if (finalityError) { printError(parentChain, field, MessageUtils.staticNonFinalFailure()) } visitType(parentChain, seenTypes, it, it.type) // Must call visitType first so it doesn't get short circuited by the || visitType( parentChain = parentChain, seenTypesByPolicy = seenTypesByPolicy, symbol = field, type = field.type, parentPolicyExceptions = parentPolicyExceptions ) || anyError || finalityError } return@fold anyError } else { printError(parentChain, it, MessageUtils.memberNotMethodFailure()) val isFinal = field.modifiers.contains(Modifier.FINAL) if (!isFinal || !allowFinalClassesFinalFields) { printError(parentChain, field, MessageUtils.memberNotMethodFailure()) return@fold true } return@fold anyError } } // Scan inner classes before methods so that any violations isolated to the file prints // the error on the class declaration rather than on the method that returns the type. // Although it doesn't matter too much either way. filteredElements val hasClassError = filteredElements .filter { it.getKind() == ElementKind.CLASS } .map { it as Symbol.ClassSymbol } .forEach { visitClass(parentChain, seenTypes, it, it) .fold(false) { anyError, innerClass -> // Must call visitClass first so it doesn't get short circuited by the || visitClass( parentChain, seenTypesByPolicy, innerClass, innerClass, newPolicyExceptions ) || anyError } val newChain = parentChain + "$classType" filteredElements val hasMethodError = filteredElements .filter { it.getKind() == ElementKind.METHOD } .map { it as Symbol.MethodSymbol } .forEach { visitMethod(newChain, seenTypes, it) .filterNot { IGNORED_METHODS.contains(it.name.toString()) } .fold(false) { anyError, method -> // Must call visitMethod first so it doesn't get short circuited by the || visitMethod(newChain, seenTypesByPolicy, method, newPolicyExceptions) || anyError } val className = classType.simpleName.toString() val isRegularClass = classType.getKind() == ElementKind.CLASS var anyError = hasFieldError || hasClassError || hasMethodError // If final classes are not considered OR there's a non-field failure, also check for // interface/@Immutable, assuming the class is malformed if ((isRegularClass && !allowFinalClassesFinalFields) || hasMethodError || hasClassError) { if (classType.getAnnotation(Immutable::class.java) == null) { printError( parentChain, elementToPrint, MessageUtils.classNotImmutableFailure(className) ) anyError = true } if (classType.getKind() != ElementKind.INTERFACE) { printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure()) anyError = true } } if (isRegularClass && !anyError && allowFinalClassesFinalFields && !classType.modifiers.contains(Modifier.FINAL) ) { printError(parentChain, elementToPrint, MessageUtils.classNotFinalFailure(className)) return true } return anyError } /** * @return true if any error was encountered at this level or any child level */ private fun visitMethod( parentChain: List<String>, seenTypes: MutableSet<Type>, seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, method: Symbol.MethodSymbol, ) { parentPolicyExceptions: Set<Immutable.Policy.Exception>, ): Boolean { val returnType = method.returnType val typeName = returnType.toString() when (returnType.kind) { Loading @@ -164,13 +243,21 @@ class ImmutabilityProcessor : AbstractProcessor() { TypeKind.VOID -> { if (!method.isConstructor) { printError(parentChain, method, MessageUtils.voidReturnFailure()) return true } } TypeKind.ARRAY -> { printError(parentChain, method, MessageUtils.arrayFailure()) return true } TypeKind.DECLARED -> { visitType(parentChain, seenTypes, method, method.returnType) return visitType( parentChain, seenTypesByPolicy, method, method.returnType, parentPolicyExceptions ) } TypeKind.ERROR, TypeKind.TYPEVAR, Loading @@ -182,54 +269,84 @@ class ImmutabilityProcessor : AbstractProcessor() { TypeKind.INTERSECTION, // Java 9+ // TypeKind.MODULE, null -> printError(parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName)) else -> printError(parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName)) null -> { printError( parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName) ) return true } else -> { printError( parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName) ) return true } } return false } /** * @return true if any error was encountered at this level or any child level */ private fun visitType( parentChain: List<String>, seenTypes: MutableSet<Type>, seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, symbol: Symbol, type: Type, parentPolicyExceptions: Set<Immutable.Policy.Exception>, nonInterfaceClassFailure: () -> String = { MessageUtils.nonInterfaceReturnFailure() }, ) { if (type.isPrimitive) return ): Boolean { if (type.isPrimitive) return false if (type.isPrimitiveOrVoid) { printError(parentChain, symbol, MessageUtils.voidReturnFailure()) return return true } if (ignoredTypes.any { processingEnv.typeUtils.isSameType(it, type) }) { return if (ignoredTypes.any { processingEnv.typeUtils.isAssignable(type, it) }) { return false } val policyAnnotation = symbol.getAnnotation(Immutable.Policy::class.java) val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty() // Collection (and Map) types are ignored for the interface check as they have immutability // enforced through a runtime exception which must be verified in a separate runtime test val isMap = processingEnv.typeUtils.isAssignable(type, mapType) if (!processingEnv.typeUtils.isAssignable(type, collectionType) && !isMap) { if (type.isInterface) { visitClass(parentChain, seenTypes, symbol, processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol) } else { if (!type.isInterface && !newPolicyExceptions .contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS) ) { printError(parentChain, symbol, nonInterfaceClassFailure()) // If the type already isn't an interface, don't scan deeper children // to avoid printing an excess amount of errors for a known bad type. return return true } else { return visitClass( parentChain, seenTypesByPolicy, symbol, processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol, newPolicyExceptions, ) } } var anyError = false type.typeArguments.forEachIndexed { index, typeArg -> visitType(parentChain, seenTypes, symbol, typeArg) { MessageUtils.nonInterfaceReturnFailure(prefix = when { val argError = visitType(parentChain, seenTypesByPolicy, symbol, typeArg, newPolicyExceptions) { MessageUtils.nonInterfaceReturnFailure( prefix = when { !isMap -> "" index == 0 -> "Key " + typeArg.asElement().simpleName else -> "Value " + typeArg.asElement().simpleName }, index = index) }, index = index ) } anyError = anyError || argError } return anyError } private fun printError( Loading @@ -246,7 +363,7 @@ class ImmutabilityProcessor : AbstractProcessor() { ) private fun ProcessingEnvironment.erasedType(typeName: String) = typeUtils.erasure(elementUtils.getTypeElement(typeName).asType()) elementUtils.getTypeElement(typeName)?.asType()?.let(typeUtils::erasure) private fun isIgnored(symbol: Symbol) = symbol.getAnnotation(Immutable.Ignore::class.java) != null Loading tools/processors/immutability/src/android/processor/immutability/Immutable.java +23 −0 Original line number Diff line number Diff line Loading @@ -48,4 +48,27 @@ public @interface Immutable { @interface Ignore { String reason() default ""; } /** * Marks an element and its reachable children with a specific policy. */ @Retention(RetentionPolicy.CLASS) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD}) @interface Policy { Exception[] exceptions() default {}; enum Exception { /** * Allow final classes with only final fields. By default these are not allowed because * direct field access disallows hard removal of APIs (by having their getters return * mocks/stubs) and also prevents field compaction, which can occur with booleans * stuffed into a number as flags. * * This exception is allowed though because several framework classes are built around * the final field access model and it would be unnecessarily difficult to migrate or * wrap each type. */ FINAL_CLASSES_WITH_FINAL_FIELDS, } } } tools/processors/immutability/src/android/processor/immutability/MessageUtils.kt +2 −0 Original line number Diff line number Diff line Loading @@ -20,6 +20,8 @@ object MessageUtils { fun classNotImmutableFailure(className: String) = "$className should be marked @Immutable" fun classNotFinalFailure(className: String) = "$className should be marked final" fun memberNotMethodFailure() = "Member must be a method" fun nonInterfaceClassFailure() = "Class was not an interface" Loading tools/processors/immutability/test/android/processor/ImmutabilityProcessorTest.kt +123 −30 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ import com.google.testing.compile.Compiler.javac import com.google.testing.compile.JavaFileObjects import org.junit.Rule import org.junit.Test import java.util.* import javax.tools.JavaFileObject class ImmutabilityProcessorTest { Loading @@ -32,20 +33,56 @@ class ImmutabilityProcessorTest { companion object { private const val PACKAGE_PREFIX = "android.processor.immutability" private const val DATA_CLASS_NAME = "DataClass" private val ANNOTATION = JavaFileObjects.forSourceString(IMMUTABLE_ANNOTATION_NAME, private val ANNOTATION = JavaFileObjects.forResource("Immutable.java") private val FINAL_CLASSES = listOf( JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.NonFinalClassFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; public class NonFinalClassFinalFields { private final String finalField; public NonFinalClassFinalFields(String value) { this.finalField = value; } } """.trimIndent() ), JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.NonFinalClassNonFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; public class NonFinalClassNonFinalFields { private String nonFinalField; } """.trimIndent() ), JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.FinalClassFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; public final class FinalClassFinalFields { private final String finalField; public FinalClassFinalFields(String value) { this.finalField = value; } } """.trimIndent() ), JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.FinalClassNonFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; @Retention(RetentionPolicy.SOURCE) public @interface Immutable { @Retention(RetentionPolicy.SOURCE) @interface Ignore {} public final class FinalClassNonFinalFields { private String nonFinalField; } """.trimIndent() ) ) } @get:Rule Loading @@ -53,7 +90,8 @@ class ImmutabilityProcessorTest { @Test fun validInterface() = test( JavaFileObjects.forSourceString("$PACKAGE_PREFIX.$DATA_CLASS_NAME", JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.$DATA_CLASS_NAME", /* language=JAVA */ """ package $PACKAGE_PREFIX; Loading Loading @@ -86,11 +124,13 @@ class ImmutabilityProcessorTest { } } """.trimIndent() ), errors = emptyList()) ), errors = emptyList() ) @Test fun abstractClass() = test( JavaFileObjects.forSourceString("$PACKAGE_PREFIX.$DATA_CLASS_NAME", JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.$DATA_CLASS_NAME", /* language=JAVA */ """ package $PACKAGE_PREFIX; Loading Loading @@ -140,30 +180,77 @@ class ImmutabilityProcessorTest { arrayFailure(line = 17), nonInterfaceReturnFailure(line = 18), nonInterfaceReturnFailure(line = 19), classNotImmutableFailure(line = 22, className = "InnerInterface"), nonInterfaceReturnFailure(line = 25, prefix = "Key InnerClass"), nonInterfaceReturnFailure(line = 25, prefix = "Value InnerClass"), classNotImmutableFailure(line = 27, className = "InnerClass"), nonInterfaceClassFailure(line = 27), memberNotMethodFailure(line = 28), arrayFailure(line = 29), classNotImmutableFailure(line = 22, className = "InnerInterface"), arrayFailure(line = 33), nonInterfaceReturnFailure(line = 34), )) ) ) @Test fun finalClasses() = test( JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.$DATA_CLASS_NAME", /* language=JAVA */ """ package $PACKAGE_PREFIX; import java.util.List; @Immutable public interface $DATA_CLASS_NAME { NonFinalClassFinalFields getNonFinalFinal(); List<NonFinalClassNonFinalFields> getNonFinalNonFinal(); FinalClassFinalFields getFinalFinal(); List<FinalClassNonFinalFields> getFinalNonFinal(); @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) NonFinalClassFinalFields getPolicyNonFinalFinal(); private fun test(source: JavaFileObject, errors: List<CompilationError>) { @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) List<NonFinalClassNonFinalFields> getPolicyNonFinalNonFinal(); @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) FinalClassFinalFields getPolicyFinalFinal(); @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) List<FinalClassNonFinalFields> getPolicyFinalNonFinal(); } """.trimIndent() ), errors = listOf( nonInterfaceReturnFailure(line = 7), nonInterfaceReturnFailure(line = 8, index = 0), nonInterfaceReturnFailure(line = 9), nonInterfaceReturnFailure(line = 10, index = 0), classNotFinalFailure(line = 13, "NonFinalClassFinalFields"), ), otherErrors = listOf( memberNotMethodFailure(line = 4) to FINAL_CLASSES[1], memberNotMethodFailure(line = 4) to FINAL_CLASSES[3], ) ) private fun test( source: JavaFileObject, errors: List<CompilationError>, otherErrors: List<Pair<CompilationError, JavaFileObject>> = emptyList(), ) { val compilation = javac() .withProcessors(ImmutabilityProcessor()) .compile(listOf(source) + ANNOTATION) errors.forEach { .compile(FINAL_CLASSES + ANNOTATION + listOf(source)) val allErrors = otherErrors + errors.map { it to source } allErrors.forEach { (error, file) -> try { assertThat(compilation) .hadErrorContaining(it.message) .inFile(source) .onLine(it.line) .hadErrorContaining(error.message) .inFile(file) .onLine(error.line) } catch (e: AssertionError) { // Wrap the exception so that the line number is logged val wrapped = AssertionError("Expected $it, ${e.message}").apply { val wrapped = AssertionError("Expected $error, ${e.message}").apply { stackTrace = e.stackTrace } Loading @@ -175,11 +262,14 @@ class ImmutabilityProcessorTest { } try { assertThat(compilation).hadErrorCount(errors.size) assertThat(compilation).hadErrorCount(allErrors.size) } catch (e: AssertionError) { if (expect.hasFailures()) { expect.that(e).isNull() } else throw e expect.withMessage( compilation.errors() .joinToString(separator = "\n") { "${it.lineNumber}: ${it.getMessage(Locale.ENGLISH)?.trim()}" } ).that(e).isNull() } } Loading @@ -192,7 +282,7 @@ class ImmutabilityProcessorTest { private fun nonInterfaceReturnFailure(line: Long) = CompilationError(line = line, message = MessageUtils.nonInterfaceReturnFailure()) private fun nonInterfaceReturnFailure(line: Long, prefix: String, index: Int = -1) = private fun nonInterfaceReturnFailure(line: Long, prefix: String = "", index: Int = -1) = CompilationError( line = line, message = MessageUtils.nonInterfaceReturnFailure(prefix = prefix, index = index) Loading @@ -210,6 +300,9 @@ class ImmutabilityProcessorTest { private fun arrayFailure(line: Long) = CompilationError(line = line, message = MessageUtils.arrayFailure()) private fun classNotFinalFailure(line: Long, className: String) = CompilationError(line = line, message = MessageUtils.classNotFinalFailure(className)) data class CompilationError( val line: Long, val message: String, Loading Loading
tools/processors/immutability/Android.bp +9 −0 Original line number Diff line number Diff line Loading @@ -54,5 +54,14 @@ java_test_host { "ImmutabilityAnnotationProcessorHostLibrary", ], // Bundle the source file so it can be loaded into the test compiler java_resources: [":ImmutabilityAnnotationJavaSource"], test_suites: ["general-tests"], } filegroup { name: "ImmutabilityAnnotationJavaSource", srcs: ["src/android/processor/immutability/Immutable.java"], path: "src/android/processor/immutability/", }
tools/processors/immutability/src/android/processor/immutability/ImmutabilityProcessor.kt +177 −60 Original line number Diff line number Diff line Loading @@ -51,6 +51,11 @@ class ImmutabilityProcessor : AbstractProcessor() { "java.lang.Short", "java.lang.String", "java.lang.Void", "android.os.Parcelable.Creator", ) private val IGNORED_METHODS = listOf( "writeToParcel", ) } Loading @@ -59,7 +64,7 @@ class ImmutabilityProcessor : AbstractProcessor() { private lateinit var ignoredTypes: List<TypeMirror> private val seenTypes = mutableSetOf<Type>() private val seenTypesByPolicy = mutableMapOf<Set<Immutable.Policy.Exception>, Set<Type>>() override fun getSupportedSourceVersion() = SourceVersion.latest()!! Loading @@ -67,9 +72,9 @@ class ImmutabilityProcessor : AbstractProcessor() { override fun init(processingEnv: ProcessingEnvironment) { super.init(processingEnv) collectionType = processingEnv.erasedType("java.util.Collection") mapType = processingEnv.erasedType("java.util.Map") ignoredTypes = IGNORED_TYPES.map { processingEnv.elementUtils.getTypeElement(it).asType() } collectionType = processingEnv.erasedType("java.util.Collection")!! mapType = processingEnv.erasedType("java.util.Map")!! ignoredTypes = IGNORED_TYPES.mapNotNull { processingEnv.erasedType(it) } } override fun process( Loading @@ -80,72 +85,146 @@ class ImmutabilityProcessor : AbstractProcessor() { it.qualifiedName.toString() == IMMUTABLE_ANNOTATION_NAME } ?: return false roundEnvironment.getElementsAnnotatedWith(Immutable::class.java) .forEach { visitClass(emptyList(), seenTypes, it, it as Symbol.TypeSymbol) } .forEach { visitClass( parentChain = emptyList(), seenTypesByPolicy = seenTypesByPolicy, elementToPrint = it, classType = it as Symbol.TypeSymbol, parentPolicyExceptions = emptySet() ) } return true } /** * @return true if any error was encountered at this level or any child level */ private fun visitClass( parentChain: List<String>, seenTypes: MutableSet<Type>, seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, elementToPrint: Element, classType: Symbol.TypeSymbol, ) { if (!seenTypes.add(classType.asType())) return if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return parentPolicyExceptions: Set<Immutable.Policy.Exception>, ): Boolean { if (classType.getAnnotation(Immutable.Ignore::class.java) != null) return false if (classType.getAnnotation(Immutable::class.java) == null) { printError(parentChain, elementToPrint, MessageUtils.classNotImmutableFailure(classType.simpleName.toString())) } val policyAnnotation = classType.getAnnotation(Immutable.Policy::class.java) val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty() if (classType.getKind() != ElementKind.INTERFACE) { printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure()) } // If already seen this type with the same policies applied, skip it val seenTypes = seenTypesByPolicy[newPolicyExceptions] val type = classType.asType() if (seenTypes?.contains(type) == true) return false seenTypesByPolicy[newPolicyExceptions] = seenTypes.orEmpty() + type val allowFinalClassesFinalFields = newPolicyExceptions.contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS) val filteredElements = classType.enclosedElements .filterNot(::isIgnored) filteredElements val hasFieldError = filteredElements .filter { it.getKind() == ElementKind.FIELD } .forEach { if (it.isStatic) { if (!it.isPrivate) { if (!it.modifiers.contains(Modifier.FINAL)) { printError(parentChain, it, MessageUtils.staticNonFinalFailure()) .fold(false) { anyError, field -> if (field.isStatic) { if (!field.isPrivate) { var finalityError = !field.modifiers.contains(Modifier.FINAL) if (finalityError) { printError(parentChain, field, MessageUtils.staticNonFinalFailure()) } visitType(parentChain, seenTypes, it, it.type) // Must call visitType first so it doesn't get short circuited by the || visitType( parentChain = parentChain, seenTypesByPolicy = seenTypesByPolicy, symbol = field, type = field.type, parentPolicyExceptions = parentPolicyExceptions ) || anyError || finalityError } return@fold anyError } else { printError(parentChain, it, MessageUtils.memberNotMethodFailure()) val isFinal = field.modifiers.contains(Modifier.FINAL) if (!isFinal || !allowFinalClassesFinalFields) { printError(parentChain, field, MessageUtils.memberNotMethodFailure()) return@fold true } return@fold anyError } } // Scan inner classes before methods so that any violations isolated to the file prints // the error on the class declaration rather than on the method that returns the type. // Although it doesn't matter too much either way. filteredElements val hasClassError = filteredElements .filter { it.getKind() == ElementKind.CLASS } .map { it as Symbol.ClassSymbol } .forEach { visitClass(parentChain, seenTypes, it, it) .fold(false) { anyError, innerClass -> // Must call visitClass first so it doesn't get short circuited by the || visitClass( parentChain, seenTypesByPolicy, innerClass, innerClass, newPolicyExceptions ) || anyError } val newChain = parentChain + "$classType" filteredElements val hasMethodError = filteredElements .filter { it.getKind() == ElementKind.METHOD } .map { it as Symbol.MethodSymbol } .forEach { visitMethod(newChain, seenTypes, it) .filterNot { IGNORED_METHODS.contains(it.name.toString()) } .fold(false) { anyError, method -> // Must call visitMethod first so it doesn't get short circuited by the || visitMethod(newChain, seenTypesByPolicy, method, newPolicyExceptions) || anyError } val className = classType.simpleName.toString() val isRegularClass = classType.getKind() == ElementKind.CLASS var anyError = hasFieldError || hasClassError || hasMethodError // If final classes are not considered OR there's a non-field failure, also check for // interface/@Immutable, assuming the class is malformed if ((isRegularClass && !allowFinalClassesFinalFields) || hasMethodError || hasClassError) { if (classType.getAnnotation(Immutable::class.java) == null) { printError( parentChain, elementToPrint, MessageUtils.classNotImmutableFailure(className) ) anyError = true } if (classType.getKind() != ElementKind.INTERFACE) { printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure()) anyError = true } } if (isRegularClass && !anyError && allowFinalClassesFinalFields && !classType.modifiers.contains(Modifier.FINAL) ) { printError(parentChain, elementToPrint, MessageUtils.classNotFinalFailure(className)) return true } return anyError } /** * @return true if any error was encountered at this level or any child level */ private fun visitMethod( parentChain: List<String>, seenTypes: MutableSet<Type>, seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, method: Symbol.MethodSymbol, ) { parentPolicyExceptions: Set<Immutable.Policy.Exception>, ): Boolean { val returnType = method.returnType val typeName = returnType.toString() when (returnType.kind) { Loading @@ -164,13 +243,21 @@ class ImmutabilityProcessor : AbstractProcessor() { TypeKind.VOID -> { if (!method.isConstructor) { printError(parentChain, method, MessageUtils.voidReturnFailure()) return true } } TypeKind.ARRAY -> { printError(parentChain, method, MessageUtils.arrayFailure()) return true } TypeKind.DECLARED -> { visitType(parentChain, seenTypes, method, method.returnType) return visitType( parentChain, seenTypesByPolicy, method, method.returnType, parentPolicyExceptions ) } TypeKind.ERROR, TypeKind.TYPEVAR, Loading @@ -182,54 +269,84 @@ class ImmutabilityProcessor : AbstractProcessor() { TypeKind.INTERSECTION, // Java 9+ // TypeKind.MODULE, null -> printError(parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName)) else -> printError(parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName)) null -> { printError( parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName) ) return true } else -> { printError( parentChain, method, MessageUtils.genericTypeKindFailure(typeName = typeName) ) return true } } return false } /** * @return true if any error was encountered at this level or any child level */ private fun visitType( parentChain: List<String>, seenTypes: MutableSet<Type>, seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, symbol: Symbol, type: Type, parentPolicyExceptions: Set<Immutable.Policy.Exception>, nonInterfaceClassFailure: () -> String = { MessageUtils.nonInterfaceReturnFailure() }, ) { if (type.isPrimitive) return ): Boolean { if (type.isPrimitive) return false if (type.isPrimitiveOrVoid) { printError(parentChain, symbol, MessageUtils.voidReturnFailure()) return return true } if (ignoredTypes.any { processingEnv.typeUtils.isSameType(it, type) }) { return if (ignoredTypes.any { processingEnv.typeUtils.isAssignable(type, it) }) { return false } val policyAnnotation = symbol.getAnnotation(Immutable.Policy::class.java) val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty() // Collection (and Map) types are ignored for the interface check as they have immutability // enforced through a runtime exception which must be verified in a separate runtime test val isMap = processingEnv.typeUtils.isAssignable(type, mapType) if (!processingEnv.typeUtils.isAssignable(type, collectionType) && !isMap) { if (type.isInterface) { visitClass(parentChain, seenTypes, symbol, processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol) } else { if (!type.isInterface && !newPolicyExceptions .contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS) ) { printError(parentChain, symbol, nonInterfaceClassFailure()) // If the type already isn't an interface, don't scan deeper children // to avoid printing an excess amount of errors for a known bad type. return return true } else { return visitClass( parentChain, seenTypesByPolicy, symbol, processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol, newPolicyExceptions, ) } } var anyError = false type.typeArguments.forEachIndexed { index, typeArg -> visitType(parentChain, seenTypes, symbol, typeArg) { MessageUtils.nonInterfaceReturnFailure(prefix = when { val argError = visitType(parentChain, seenTypesByPolicy, symbol, typeArg, newPolicyExceptions) { MessageUtils.nonInterfaceReturnFailure( prefix = when { !isMap -> "" index == 0 -> "Key " + typeArg.asElement().simpleName else -> "Value " + typeArg.asElement().simpleName }, index = index) }, index = index ) } anyError = anyError || argError } return anyError } private fun printError( Loading @@ -246,7 +363,7 @@ class ImmutabilityProcessor : AbstractProcessor() { ) private fun ProcessingEnvironment.erasedType(typeName: String) = typeUtils.erasure(elementUtils.getTypeElement(typeName).asType()) elementUtils.getTypeElement(typeName)?.asType()?.let(typeUtils::erasure) private fun isIgnored(symbol: Symbol) = symbol.getAnnotation(Immutable.Ignore::class.java) != null Loading
tools/processors/immutability/src/android/processor/immutability/Immutable.java +23 −0 Original line number Diff line number Diff line Loading @@ -48,4 +48,27 @@ public @interface Immutable { @interface Ignore { String reason() default ""; } /** * Marks an element and its reachable children with a specific policy. */ @Retention(RetentionPolicy.CLASS) @Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD}) @interface Policy { Exception[] exceptions() default {}; enum Exception { /** * Allow final classes with only final fields. By default these are not allowed because * direct field access disallows hard removal of APIs (by having their getters return * mocks/stubs) and also prevents field compaction, which can occur with booleans * stuffed into a number as flags. * * This exception is allowed though because several framework classes are built around * the final field access model and it would be unnecessarily difficult to migrate or * wrap each type. */ FINAL_CLASSES_WITH_FINAL_FIELDS, } } }
tools/processors/immutability/src/android/processor/immutability/MessageUtils.kt +2 −0 Original line number Diff line number Diff line Loading @@ -20,6 +20,8 @@ object MessageUtils { fun classNotImmutableFailure(className: String) = "$className should be marked @Immutable" fun classNotFinalFailure(className: String) = "$className should be marked final" fun memberNotMethodFailure() = "Member must be a method" fun nonInterfaceClassFailure() = "Class was not an interface" Loading
tools/processors/immutability/test/android/processor/ImmutabilityProcessorTest.kt +123 −30 Original line number Diff line number Diff line Loading @@ -25,6 +25,7 @@ import com.google.testing.compile.Compiler.javac import com.google.testing.compile.JavaFileObjects import org.junit.Rule import org.junit.Test import java.util.* import javax.tools.JavaFileObject class ImmutabilityProcessorTest { Loading @@ -32,20 +33,56 @@ class ImmutabilityProcessorTest { companion object { private const val PACKAGE_PREFIX = "android.processor.immutability" private const val DATA_CLASS_NAME = "DataClass" private val ANNOTATION = JavaFileObjects.forSourceString(IMMUTABLE_ANNOTATION_NAME, private val ANNOTATION = JavaFileObjects.forResource("Immutable.java") private val FINAL_CLASSES = listOf( JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.NonFinalClassFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; public class NonFinalClassFinalFields { private final String finalField; public NonFinalClassFinalFields(String value) { this.finalField = value; } } """.trimIndent() ), JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.NonFinalClassNonFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; public class NonFinalClassNonFinalFields { private String nonFinalField; } """.trimIndent() ), JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.FinalClassFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; public final class FinalClassFinalFields { private final String finalField; public FinalClassFinalFields(String value) { this.finalField = value; } } """.trimIndent() ), JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.FinalClassNonFinalFields", /* language=JAVA */ """ package $PACKAGE_PREFIX; @Retention(RetentionPolicy.SOURCE) public @interface Immutable { @Retention(RetentionPolicy.SOURCE) @interface Ignore {} public final class FinalClassNonFinalFields { private String nonFinalField; } """.trimIndent() ) ) } @get:Rule Loading @@ -53,7 +90,8 @@ class ImmutabilityProcessorTest { @Test fun validInterface() = test( JavaFileObjects.forSourceString("$PACKAGE_PREFIX.$DATA_CLASS_NAME", JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.$DATA_CLASS_NAME", /* language=JAVA */ """ package $PACKAGE_PREFIX; Loading Loading @@ -86,11 +124,13 @@ class ImmutabilityProcessorTest { } } """.trimIndent() ), errors = emptyList()) ), errors = emptyList() ) @Test fun abstractClass() = test( JavaFileObjects.forSourceString("$PACKAGE_PREFIX.$DATA_CLASS_NAME", JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.$DATA_CLASS_NAME", /* language=JAVA */ """ package $PACKAGE_PREFIX; Loading Loading @@ -140,30 +180,77 @@ class ImmutabilityProcessorTest { arrayFailure(line = 17), nonInterfaceReturnFailure(line = 18), nonInterfaceReturnFailure(line = 19), classNotImmutableFailure(line = 22, className = "InnerInterface"), nonInterfaceReturnFailure(line = 25, prefix = "Key InnerClass"), nonInterfaceReturnFailure(line = 25, prefix = "Value InnerClass"), classNotImmutableFailure(line = 27, className = "InnerClass"), nonInterfaceClassFailure(line = 27), memberNotMethodFailure(line = 28), arrayFailure(line = 29), classNotImmutableFailure(line = 22, className = "InnerInterface"), arrayFailure(line = 33), nonInterfaceReturnFailure(line = 34), )) ) ) @Test fun finalClasses() = test( JavaFileObjects.forSourceString( "$PACKAGE_PREFIX.$DATA_CLASS_NAME", /* language=JAVA */ """ package $PACKAGE_PREFIX; import java.util.List; @Immutable public interface $DATA_CLASS_NAME { NonFinalClassFinalFields getNonFinalFinal(); List<NonFinalClassNonFinalFields> getNonFinalNonFinal(); FinalClassFinalFields getFinalFinal(); List<FinalClassNonFinalFields> getFinalNonFinal(); @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) NonFinalClassFinalFields getPolicyNonFinalFinal(); private fun test(source: JavaFileObject, errors: List<CompilationError>) { @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) List<NonFinalClassNonFinalFields> getPolicyNonFinalNonFinal(); @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) FinalClassFinalFields getPolicyFinalFinal(); @Immutable.Policy(exceptions = {Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS}) List<FinalClassNonFinalFields> getPolicyFinalNonFinal(); } """.trimIndent() ), errors = listOf( nonInterfaceReturnFailure(line = 7), nonInterfaceReturnFailure(line = 8, index = 0), nonInterfaceReturnFailure(line = 9), nonInterfaceReturnFailure(line = 10, index = 0), classNotFinalFailure(line = 13, "NonFinalClassFinalFields"), ), otherErrors = listOf( memberNotMethodFailure(line = 4) to FINAL_CLASSES[1], memberNotMethodFailure(line = 4) to FINAL_CLASSES[3], ) ) private fun test( source: JavaFileObject, errors: List<CompilationError>, otherErrors: List<Pair<CompilationError, JavaFileObject>> = emptyList(), ) { val compilation = javac() .withProcessors(ImmutabilityProcessor()) .compile(listOf(source) + ANNOTATION) errors.forEach { .compile(FINAL_CLASSES + ANNOTATION + listOf(source)) val allErrors = otherErrors + errors.map { it to source } allErrors.forEach { (error, file) -> try { assertThat(compilation) .hadErrorContaining(it.message) .inFile(source) .onLine(it.line) .hadErrorContaining(error.message) .inFile(file) .onLine(error.line) } catch (e: AssertionError) { // Wrap the exception so that the line number is logged val wrapped = AssertionError("Expected $it, ${e.message}").apply { val wrapped = AssertionError("Expected $error, ${e.message}").apply { stackTrace = e.stackTrace } Loading @@ -175,11 +262,14 @@ class ImmutabilityProcessorTest { } try { assertThat(compilation).hadErrorCount(errors.size) assertThat(compilation).hadErrorCount(allErrors.size) } catch (e: AssertionError) { if (expect.hasFailures()) { expect.that(e).isNull() } else throw e expect.withMessage( compilation.errors() .joinToString(separator = "\n") { "${it.lineNumber}: ${it.getMessage(Locale.ENGLISH)?.trim()}" } ).that(e).isNull() } } Loading @@ -192,7 +282,7 @@ class ImmutabilityProcessorTest { private fun nonInterfaceReturnFailure(line: Long) = CompilationError(line = line, message = MessageUtils.nonInterfaceReturnFailure()) private fun nonInterfaceReturnFailure(line: Long, prefix: String, index: Int = -1) = private fun nonInterfaceReturnFailure(line: Long, prefix: String = "", index: Int = -1) = CompilationError( line = line, message = MessageUtils.nonInterfaceReturnFailure(prefix = prefix, index = index) Loading @@ -210,6 +300,9 @@ class ImmutabilityProcessorTest { private fun arrayFailure(line: Long) = CompilationError(line = line, message = MessageUtils.arrayFailure()) private fun classNotFinalFailure(line: Long, className: String) = CompilationError(line = line, message = MessageUtils.classNotFinalFailure(className)) data class CompilationError( val line: Long, val message: String, Loading