Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit ab772a92 authored by Hani Kazmi's avatar Hani Kazmi Committed by Android (Google) Code Review
Browse files

Merge "Extend SaferParcelChecker to Bundle/Intent APIs: Update lint check"

parents 4218c300 0468338b
Loading
Loading
Loading
Loading
+32 −12
Original line number Original line Diff line number Diff line
@@ -19,6 +19,7 @@ package com.google.android.lint.parcel
import com.android.tools.lint.detector.api.JavaContext
import com.android.tools.lint.detector.api.JavaContext
import com.android.tools.lint.detector.api.LintFix
import com.android.tools.lint.detector.api.LintFix
import com.android.tools.lint.detector.api.Location
import com.android.tools.lint.detector.api.Location
import com.intellij.psi.PsiArrayType
import com.intellij.psi.PsiCallExpression
import com.intellij.psi.PsiCallExpression
import com.intellij.psi.PsiClassType
import com.intellij.psi.PsiClassType
import com.intellij.psi.PsiIntersectionType
import com.intellij.psi.PsiIntersectionType
@@ -26,8 +27,8 @@ import com.intellij.psi.PsiMethod
import com.intellij.psi.PsiType
import com.intellij.psi.PsiType
import com.intellij.psi.PsiTypeParameter
import com.intellij.psi.PsiTypeParameter
import com.intellij.psi.PsiWildcardType
import com.intellij.psi.PsiWildcardType
import org.jetbrains.kotlin.utils.addToStdlib.cast
import org.jetbrains.uast.UCallExpression
import org.jetbrains.uast.UCallExpression
import org.jetbrains.uast.UElement
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.UExpression
import org.jetbrains.uast.UVariable
import org.jetbrains.uast.UVariable


@@ -42,11 +43,11 @@ abstract class CallMigrator(
) {
) {
    open fun report(context: JavaContext, call: UCallExpression, method: PsiMethod) {
    open fun report(context: JavaContext, call: UCallExpression, method: PsiMethod) {
        val location = context.getLocation(call)
        val location = context.getLocation(call)
        val itemType = getBoundingClass(context, call, method)
        val itemType = filter(getBoundingClass(context, call, method))
        val fix = (itemType as? PsiClassType)?.let { type ->
        val fix = (itemType as? PsiClassType)?.let { type ->
            getParcelFix(location, this.method.name, getArgumentSuffix(type))
            getParcelFix(location, this.method.name, getArgumentSuffix(type))
        }
        }
        val message = "Unsafe `Parcel.${this.method.name}()` API usage"
        val message = "Unsafe `${this.method.className}.${this.method.name}()` API usage"
        context.report(SaferParcelChecker.ISSUE_UNSAFE_API_USAGE, call, location, message, fix)
        context.report(SaferParcelChecker.ISSUE_UNSAFE_API_USAGE, call, location, message, fix)
    }
    }


@@ -73,14 +74,14 @@ abstract class CallMigrator(
    }
    }


    /**
    /**
     * Tries to obtain the type expected by the "receiving" end given a certain {@link UExpression}.
     * Tries to obtain the type expected by the "receiving" end given a certain {@link UElement}.
     *
     *
     * This could be an assignment, an argument passed to a method call, to a constructor call, a
     * This could be an assignment, an argument passed to a method call, to a constructor call, a
     * type cast, etc. If no receiving end is found, the type of the UExpression itself is returned.
     * type cast, etc. If no receiving end is found, the type of the UExpression itself is returned.
     */
     */
    protected fun getReceivingType(expression: UExpression): PsiType? {
    protected fun getReceivingType(expression: UElement): PsiType? {
        val parent = expression.uastParent
        val parent = expression.uastParent
        val type = when (parent) {
        var type = when (parent) {
            is UCallExpression -> {
            is UCallExpression -> {
                val i = parent.valueArguments.indexOf(expression)
                val i = parent.valueArguments.indexOf(expression)
                val psiCall = parent.sourcePsi as? PsiCallExpression ?: return null
                val psiCall = parent.sourcePsi as? PsiCallExpression ?: return null
@@ -92,10 +93,13 @@ abstract class CallMigrator(
            is UExpression -> parent.getExpressionType()
            is UExpression -> parent.getExpressionType()
            else -> null
            else -> null
        }
        }
        return filter(type ?: expression.getExpressionType())
        if (type == null && expression is UExpression) {
            type = expression.getExpressionType()
        }
        return type
    }
    }


    private fun filter(type: PsiType?): PsiType? {
    protected fun filter(type: PsiType?): PsiType? {
        // It's important that PsiIntersectionType case is above the one that check the type in
        // It's important that PsiIntersectionType case is above the one that check the type in
        // rejects, because for intersect types, the canonicalText is one of the terms.
        // rejects, because for intersect types, the canonicalText is one of the terms.
        if (type is PsiIntersectionType) {
        if (type is PsiIntersectionType) {
@@ -169,7 +173,7 @@ class ContainerReturnMigrator(
    override fun getBoundingClass(
    override fun getBoundingClass(
            context: JavaContext, call: UCallExpression, method: PsiMethod
            context: JavaContext, call: UCallExpression, method: PsiMethod
    ): PsiType? {
    ): PsiType? {
        val type = getReceivingType(call.uastParent as UExpression) ?: return null
        val type = getReceivingType(call.uastParent!!) ?: return null
        return getItemType(type, container)
        return getItemType(type, container)
    }
    }
}
}
@@ -184,7 +188,7 @@ class ReturnMigrator(
    override fun getBoundingClass(
    override fun getBoundingClass(
            context: JavaContext, call: UCallExpression, method: PsiMethod
            context: JavaContext, call: UCallExpression, method: PsiMethod
    ): PsiType? {
    ): PsiType? {
        return getReceivingType(call.uastParent as UExpression)
        return getReceivingType(call.uastParent!!)
    }
    }
}
}


@@ -199,7 +203,7 @@ class ReturnMigratorWithClassLoader(
    override fun getBoundingClass(
    override fun getBoundingClass(
            context: JavaContext, call: UCallExpression, method: PsiMethod
            context: JavaContext, call: UCallExpression, method: PsiMethod
    ): PsiType? {
    ): PsiType? {
        return getReceivingType(call.uastParent as UExpression)
        return getReceivingType(call.uastParent!!)
    }
    }


    override fun getArgumentSuffix(type: PsiClassType): String =
    override fun getArgumentSuffix(type: PsiClassType): String =
@@ -207,3 +211,19 @@ class ReturnMigratorWithClassLoader(
                    "${type.rawType().canonicalText}.class"
                    "${type.rawType().canonicalText}.class"


}
}

/**
 * This class derives the type to be appended by inferring the expected array type
 * for the method result.
 */
class ArrayReturnMigrator(
    method: Method,
    rejects: Set<String> = emptySet(),
) : CallMigrator(method, rejects) {
    override fun getBoundingClass(
           context: JavaContext, call: UCallExpression, method: PsiMethod
    ): PsiType? {
        val type = getReceivingType(call.uastParent!!)
        return (type as? PsiArrayType)?.componentType
    }
}
+5 −1
Original line number Original line Diff line number Diff line
@@ -35,4 +35,8 @@ data class Method(
            val prefix = if (params.isEmpty()) "" else "${params.joinToString(", ", "<", ">")} "
            val prefix = if (params.isEmpty()) "" else "${params.joinToString(", ", "<", ">")} "
            return "$prefix$clazz.$name(${parameters.joinToString()})"
            return "$prefix$clazz.$name(${parameters.joinToString()})"
        }
        }

    val className: String by lazy {
        clazz.split(".").last()
    }
}
}
+46 −18
Original line number Original line Diff line number Diff line
@@ -24,6 +24,7 @@ import com.intellij.psi.PsiTypeParameter
import org.jetbrains.uast.UCallExpression
import org.jetbrains.uast.UCallExpression
import java.util.*
import java.util.*


@Suppress("UnstableApiUsage")
class SaferParcelChecker : Detector(), SourceCodeScanner {
class SaferParcelChecker : Detector(), SourceCodeScanner {
    override fun getApplicableMethodNames(): List<String> =
    override fun getApplicableMethodNames(): List<String> =
            MIGRATORS
            MIGRATORS
@@ -65,9 +66,9 @@ class SaferParcelChecker : Detector(), SourceCodeScanner {
        @JvmField
        @JvmField
        val ISSUE_UNSAFE_API_USAGE: Issue = Issue.create(
        val ISSUE_UNSAFE_API_USAGE: Issue = Issue.create(
                id = "UnsafeParcelApi",
                id = "UnsafeParcelApi",
                briefDescription = "Use of unsafe Parcel API",
                briefDescription = "Use of unsafe deserialization API",
                explanation = """
                explanation = """
                    You are using a deprecated Parcel API that doesn't accept the expected class as\
                    You are using a deprecated deserialization API that doesn't accept the expected class as\
                     a parameter. This means that unexpected classes could be instantiated and\
                     a parameter. This means that unexpected classes could be instantiated and\
                     unexpected code executed.
                     unexpected code executed.


@@ -83,25 +84,52 @@ class SaferParcelChecker : Detector(), SourceCodeScanner {
                )
                )
        )
        )


        private val METHOD_READ_SERIALIZABLE = Method("android.os.Parcel", "readSerializable", listOf())
        // Parcel
        private val METHOD_READ_ARRAY_LIST = Method("android.os.Parcel", "readArrayList", listOf("java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_SERIALIZABLE = Method("android.os.Parcel", "readSerializable", listOf())
        private val METHOD_READ_LIST = Method("android.os.Parcel", "readList", listOf("java.util.List", "java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_ARRAY_LIST = Method("android.os.Parcel", "readArrayList", listOf("java.lang.ClassLoader"))
        private val METHOD_READ_PARCELABLE = Method(listOf("T"), "android.os.Parcel", "readParcelable", listOf("java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_LIST = Method("android.os.Parcel", "readList", listOf("java.util.List", "java.lang.ClassLoader"))
        private val METHOD_READ_PARCELABLE_LIST = Method(listOf("T"), "android.os.Parcel", "readParcelableList", listOf("java.util.List<T>", "java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_PARCELABLE = Method(listOf("T"), "android.os.Parcel", "readParcelable", listOf("java.lang.ClassLoader"))
        private val METHOD_READ_SPARSE_ARRAY = Method(listOf("T"), "android.os.Parcel", "readSparseArray", listOf("java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_PARCELABLE_LIST = Method(listOf("T"), "android.os.Parcel", "readParcelableList", listOf("java.util.List<T>", "java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_SPARSE_ARRAY = Method(listOf("T"), "android.os.Parcel", "readSparseArray", listOf("java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_ARRAY = Method("android.os.Parcel", "readArray", listOf("java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_PARCELABLE_ARRAY = Method("android.os.Parcel", "readParcelableArray", listOf("java.lang.ClassLoader"))

        // Bundle
        private val BUNDLE_METHOD_GET_SERIALIZABLE = Method("android.os.Bundle", "getSerializable", listOf("java.lang.String"))
        private val BUNDLE_METHOD_GET_PARCELABLE = Method(listOf("T"), "android.os.Bundle", "getParcelable", listOf("java.lang.String"))
        private val BUNDLE_METHOD_GET_PARCELABLE_ARRAY_LIST = Method(listOf("T"), "android.os.Bundle", "getParcelableArrayList", listOf("java.lang.String"))
        private val BUNDLE_METHOD_GET_PARCELABLE_ARRAY = Method("android.os.Bundle", "getParcelableArray", listOf("java.lang.String"))
        private val BUNDLE_METHOD_GET_SPARSE_PARCELABLE_ARRAY = Method(listOf("T"), "android.os.Bundle", "getSparseParcelableArray", listOf("java.lang.String"))

        // Intent
        private val INTENT_METHOD_GET_SERIALIZABLE_EXTRA = Method("android.content.Intent", "getSerializableExtra", listOf("java.lang.String"))
        private val INTENT_METHOD_GET_PARCELABLE_EXTRA = Method(listOf("T"), "android.content.Intent", "getParcelableExtra", listOf("java.lang.String"))
        private val INTENT_METHOD_GET_PARCELABLE_ARRAY_EXTRA = Method("android.content.Intent", "getParcelableArrayExtra", listOf("java.lang.String"))
        private val INTENT_METHOD_GET_PARCELABLE_ARRAY_LIST_EXTRA = Method(listOf("T"), "android.content.Intent", "getParcelableArrayListExtra", listOf("java.lang.String"))


        // TODO: Write migrators for methods below
        // TODO: Write migrators for methods below
        private val METHOD_READ_ARRAY = Method("android.os.Parcel", "readArray", listOf("java.lang.ClassLoader"))
        private val PARCEL_METHOD_READ_PARCELABLE_CREATOR = Method("android.os.Parcel", "readParcelableCreator", listOf("java.lang.ClassLoader"))
        private val METHOD_READ_PARCELABLE_ARRAY = Method("android.os.Parcel", "readParcelableArray", listOf("java.lang.ClassLoader"))
        private val METHOD_READ_PARCELABLE_CREATOR = Method("android.os.Parcel", "readParcelableCreator", listOf("java.lang.ClassLoader"))


        private val MIGRATORS = listOf(
        private val MIGRATORS = listOf(
                ReturnMigrator(METHOD_READ_PARCELABLE, setOf("android.os.Parcelable")),
            ReturnMigrator(PARCEL_METHOD_READ_PARCELABLE, setOf("android.os.Parcelable")),
                ContainerArgumentMigrator(METHOD_READ_LIST, 0, "java.util.List"),
            ContainerArgumentMigrator(PARCEL_METHOD_READ_LIST, 0, "java.util.List"),
                ContainerReturnMigrator(METHOD_READ_ARRAY_LIST, "java.util.Collection"),
            ContainerReturnMigrator(PARCEL_METHOD_READ_ARRAY_LIST, "java.util.Collection"),
                ContainerReturnMigrator(METHOD_READ_SPARSE_ARRAY, "android.util.SparseArray"),
            ContainerReturnMigrator(PARCEL_METHOD_READ_SPARSE_ARRAY, "android.util.SparseArray"),
                ContainerArgumentMigrator(METHOD_READ_PARCELABLE_LIST, 0, "java.util.List"),
            ContainerArgumentMigrator(PARCEL_METHOD_READ_PARCELABLE_LIST, 0, "java.util.List"),
                ReturnMigratorWithClassLoader(METHOD_READ_SERIALIZABLE),
            ReturnMigratorWithClassLoader(PARCEL_METHOD_READ_SERIALIZABLE),
            ArrayReturnMigrator(PARCEL_METHOD_READ_ARRAY, setOf("java.lang.Object")),
            ArrayReturnMigrator(PARCEL_METHOD_READ_PARCELABLE_ARRAY, setOf("android.os.Parcelable")),

            ReturnMigrator(BUNDLE_METHOD_GET_PARCELABLE, setOf("android.os.Parcelable")),
            ContainerReturnMigrator(BUNDLE_METHOD_GET_PARCELABLE_ARRAY_LIST, "java.util.Collection", setOf("android.os.Parcelable")),
            ArrayReturnMigrator(BUNDLE_METHOD_GET_PARCELABLE_ARRAY, setOf("android.os.Parcelable")),
            ContainerReturnMigrator(BUNDLE_METHOD_GET_SPARSE_PARCELABLE_ARRAY, "android.util.SparseArray", setOf("android.os.Parcelable")),
            ReturnMigrator(BUNDLE_METHOD_GET_SERIALIZABLE, setOf("java.io.Serializable")),

            ReturnMigrator(INTENT_METHOD_GET_PARCELABLE_EXTRA, setOf("android.os.Parcelable")),
            ContainerReturnMigrator(INTENT_METHOD_GET_PARCELABLE_ARRAY_LIST_EXTRA, "java.util.Collection", setOf("android.os.Parcelable")),
            ArrayReturnMigrator(INTENT_METHOD_GET_PARCELABLE_ARRAY_EXTRA, setOf("android.os.Parcelable")),
            ReturnMigrator(INTENT_METHOD_GET_SERIALIZABLE_EXTRA, setOf("java.io.Serializable")),
        )
        )
    }
    }
}
}
+537 −142

File changed.

Preview size limit exceeded, changes collapsed.