Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 92532e72 authored by Paul Duffin's avatar Paul Duffin
Browse files

Allow traversal over the trie structure

Previously, there was no way to traverse the trie structure and no way
to identify specific nodes in the trie. That made it impossible to
analyze the trie structure resulting from loading a set of flags. This
change adds type and selector properties to nodes as well as access to
the child nodes of a node to allow for the structure to be analyzed.

Bug: 202154151
Test: m out/soong/hiddenapi/hiddenapi-flags.csv
      atest --host signature_trie_test verify_overlaps_test
      pyformat -s 4 --force_quote_type double -i scripts/hiddenapi/signature_trie*
      /usr/bin/pylint --rcfile $ANDROID_BUILD_TOP/tools/repohooks/tools/pylintrc scripts/hiddenapi/signature_trie*
Change-Id: Ia4714dbf59f6fd143aa3bf3ad1a59cd073d2175b
parent ea93542e
Loading
Loading
Loading
Loading
+79 −9
Original line number Diff line number Diff line
@@ -22,6 +22,19 @@ from itertools import chain

@dataclasses.dataclass()
class Node:
    """A node in the signature trie."""

    # The type of the node.
    #
    # Leaf nodes are of type "member".
    # Interior nodes can be either "package", or "class".
    type: str

    # The selector of the node.
    #
    # That is a string that can be used to select the node, e.g. in a pattern
    # that is passed to InteriorNode.get_matching_rows().
    selector: str

    def values(self, selector):
        """Get the values from a set of selected nodes.
@@ -48,6 +61,10 @@ class Node:
        """
        raise NotImplementedError("Please Implement this method")

    def child_nodes(self):
        """Get an iterable of the child nodes of this node."""
        raise NotImplementedError("Please Implement this method")


# pylint: disable=line-too-long
@dataclasses.dataclass()
@@ -173,22 +190,68 @@ class InteriorNode(Node):
        element_type, _ = InteriorNode.split_element(element)
        return element_type

    def add(self, signature, value):
    @staticmethod
    def elements_to_selector(elements):
        """Compute a selector for a set of elements.

        A selector uniquely identifies a specific Node in the trie. It is
        essentially a prefix of a signature (without the leading L).

        e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
        would contain nodes with the following selectors:
        * "java"
        * "java/lang"
        * "java/lang/Object"
        * "java/lang/Object;->String()Ljava/lang/String;"
        """
        signature = ""
        preceding_type = ""
        for element in elements:
            element_type, element_value = InteriorNode.split_element(element)
            separator = ""
            if element_type == "package":
                separator = "/"
            elif element_type == "class":
                if preceding_type == "class":
                    separator = "$"
                else:
                    separator = "/"
            elif element_type == "wildcard":
                separator = "/"
            elif element_type == "member":
                separator += ";->"

            if signature:
                signature += separator

            signature += element_value

            preceding_type = element_type

        return signature

    def add(self, signature, value, only_if_matches=False):
        """Associate the value with the specific signature.

        :param signature: the member signature
        :param value: the value to associated with the signature
        :param only_if_matches: True if the value is added only if the signature
             matches at least one of the existing top level packages.
        :return: n/a
        """
        # Split the signature into elements.
        elements = self.signature_to_elements(signature)
        # Find the Node associated with the deepest class.
        node = self
        for element in elements[:-1]:
        for index, element in enumerate(elements[:-1]):
            if element in node.nodes:
                node = node.nodes[element]
            elif only_if_matches and index == 0:
                return
            else:
                next_node = InteriorNode()
                selector = self.elements_to_selector(elements[0:index + 1])
                next_node = InteriorNode(
                    type=InteriorNode.element_type(element), selector=selector)
                node.nodes[element] = next_node
                node = next_node
        # Add a Leaf containing the value and associate it with the member
@@ -201,7 +264,12 @@ class InteriorNode(Node):
                "specific member")
        if last_element in node.nodes:
            raise Exception(f"Duplicate signature: {signature}")
        node.nodes[last_element] = Leaf(value)
        leaf = Leaf(
            type=last_element_type,
            selector=signature,
            value=value,
        )
        node.nodes[last_element] = leaf

    def get_matching_rows(self, pattern):
        """Get the values (plural) associated with the pattern.
@@ -212,10 +280,6 @@ class InteriorNode(Node):
        If the pattern is a class then this will return a list containing the
        values associated with all members of that class.

        If the pattern is a package then this will return a list containing the
        values associated with all the members of all the classes in that
        package and sub-packages.

        If the pattern ends with "*" then the preceding part is treated as a
        package and this will return a list containing the values associated
        with all the members of all the classes in that package.
@@ -261,6 +325,9 @@ class InteriorNode(Node):
            if selector(key):
                node.append_values(values, lambda x: True)

    def child_nodes(self):
        return self.nodes.values()


@dataclasses.dataclass()
class Leaf(Node):
@@ -275,6 +342,9 @@ class Leaf(Node):
    def append_values(self, values, selector):
        values.append([self.value])

    def child_nodes(self):
        return []


def signature_trie():
    return InteriorNode()
    return InteriorNode(type="root", selector="")
+25 −0
Original line number Diff line number Diff line
@@ -27,6 +27,10 @@ class TestSignatureToElements(unittest.TestCase):
    def signature_to_elements(signature):
        return InteriorNode.signature_to_elements(signature)

    @staticmethod
    def elements_to_signature(elements):
        return InteriorNode.elements_to_selector(elements)

    def test_nested_inner_classes(self):
        elements = [
            ("package", "java"),
@@ -38,6 +42,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_basic_member(self):
        elements = [
@@ -48,6 +53,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljava/lang/Object;->hashCode()I"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_double_dollar_class(self):
        elements = [
@@ -61,6 +67,7 @@ class TestSignatureToElements(unittest.TestCase):
        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
                    "-><init>(Ljava/lang/CharSequence;)V"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_no_member(self):
        elements = [
@@ -72,6 +79,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_wildcard(self):
        elements = [
@@ -81,6 +89,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "java/lang/*"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_recursive_wildcard(self):
        elements = [
@@ -90,6 +99,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "java/lang/**"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_no_packages_wildcard(self):
        elements = [
@@ -97,6 +107,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "*"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_no_packages_recursive_wildcard(self):
        elements = [
@@ -104,6 +115,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "**"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_invalid_no_class_or_wildcard(self):
        signature = "java/lang"
@@ -121,6 +133,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljavax/crypto/extObjectInputStream"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_invalid_pattern_wildcard(self):
        pattern = "Ljava/lang/Class*"
@@ -200,6 +213,18 @@ Ljava/util/zip/ZipFile;-><clinit>()V
            "Ljava/util/zip/ZipFile;-><clinit>()V",
        ])

    def test_node_wildcard(self):
        trie = self.read_trie()
        node = list(trie.child_nodes())[0]
        self.check_node_patterns(node, "**", [
            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
            "Ljava/lang/Character;->serialVersionUID:J",
            "Ljava/lang/Object;->hashCode()I",
            "Ljava/lang/Object;->toString()Ljava/lang/String;",
            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
            "Ljava/util/zip/ZipFile;-><clinit>()V",
        ])

    # pylint: enable=line-too-long