Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit be74a9fa authored by Paul Duffin's avatar Paul Duffin Committed by Gerrit Code Review
Browse files

Merge "Allow traversal over the trie structure"

parents f8ccd165 92532e72
Loading
Loading
Loading
Loading
+79 −9
Original line number Diff line number Diff line
@@ -22,6 +22,19 @@ from itertools import chain

@dataclasses.dataclass()
class Node:
    """A node in the signature trie."""

    # The type of the node.
    #
    # Leaf nodes are of type "member".
    # Interior nodes can be either "package", or "class".
    type: str

    # The selector of the node.
    #
    # That is a string that can be used to select the node, e.g. in a pattern
    # that is passed to InteriorNode.get_matching_rows().
    selector: str

    def values(self, selector):
        """Get the values from a set of selected nodes.
@@ -48,6 +61,10 @@ class Node:
        """
        raise NotImplementedError("Please Implement this method")

    def child_nodes(self):
        """Get an iterable of the child nodes of this node."""
        raise NotImplementedError("Please Implement this method")


# pylint: disable=line-too-long
@dataclasses.dataclass()
@@ -173,22 +190,68 @@ class InteriorNode(Node):
        element_type, _ = InteriorNode.split_element(element)
        return element_type

    def add(self, signature, value):
    @staticmethod
    def elements_to_selector(elements):
        """Compute a selector for a set of elements.

        A selector uniquely identifies a specific Node in the trie. It is
        essentially a prefix of a signature (without the leading L).

        e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
        would contain nodes with the following selectors:
        * "java"
        * "java/lang"
        * "java/lang/Object"
        * "java/lang/Object;->String()Ljava/lang/String;"
        """
        signature = ""
        preceding_type = ""
        for element in elements:
            element_type, element_value = InteriorNode.split_element(element)
            separator = ""
            if element_type == "package":
                separator = "/"
            elif element_type == "class":
                if preceding_type == "class":
                    separator = "$"
                else:
                    separator = "/"
            elif element_type == "wildcard":
                separator = "/"
            elif element_type == "member":
                separator += ";->"

            if signature:
                signature += separator

            signature += element_value

            preceding_type = element_type

        return signature

    def add(self, signature, value, only_if_matches=False):
        """Associate the value with the specific signature.

        :param signature: the member signature
        :param value: the value to associated with the signature
        :param only_if_matches: True if the value is added only if the signature
             matches at least one of the existing top level packages.
        :return: n/a
        """
        # Split the signature into elements.
        elements = self.signature_to_elements(signature)
        # Find the Node associated with the deepest class.
        node = self
        for element in elements[:-1]:
        for index, element in enumerate(elements[:-1]):
            if element in node.nodes:
                node = node.nodes[element]
            elif only_if_matches and index == 0:
                return
            else:
                next_node = InteriorNode()
                selector = self.elements_to_selector(elements[0:index + 1])
                next_node = InteriorNode(
                    type=InteriorNode.element_type(element), selector=selector)
                node.nodes[element] = next_node
                node = next_node
        # Add a Leaf containing the value and associate it with the member
@@ -201,7 +264,12 @@ class InteriorNode(Node):
                "specific member")
        if last_element in node.nodes:
            raise Exception(f"Duplicate signature: {signature}")
        node.nodes[last_element] = Leaf(value)
        leaf = Leaf(
            type=last_element_type,
            selector=signature,
            value=value,
        )
        node.nodes[last_element] = leaf

    def get_matching_rows(self, pattern):
        """Get the values (plural) associated with the pattern.
@@ -212,10 +280,6 @@ class InteriorNode(Node):
        If the pattern is a class then this will return a list containing the
        values associated with all members of that class.

        If the pattern is a package then this will return a list containing the
        values associated with all the members of all the classes in that
        package and sub-packages.

        If the pattern ends with "*" then the preceding part is treated as a
        package and this will return a list containing the values associated
        with all the members of all the classes in that package.
@@ -261,6 +325,9 @@ class InteriorNode(Node):
            if selector(key):
                node.append_values(values, lambda x: True)

    def child_nodes(self):
        return self.nodes.values()


@dataclasses.dataclass()
class Leaf(Node):
@@ -275,6 +342,9 @@ class Leaf(Node):
    def append_values(self, values, selector):
        values.append([self.value])

    def child_nodes(self):
        return []


def signature_trie():
    return InteriorNode()
    return InteriorNode(type="root", selector="")
+25 −0
Original line number Diff line number Diff line
@@ -27,6 +27,10 @@ class TestSignatureToElements(unittest.TestCase):
    def signature_to_elements(signature):
        return InteriorNode.signature_to_elements(signature)

    @staticmethod
    def elements_to_signature(elements):
        return InteriorNode.elements_to_selector(elements)

    def test_nested_inner_classes(self):
        elements = [
            ("package", "java"),
@@ -38,6 +42,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_basic_member(self):
        elements = [
@@ -48,6 +53,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljava/lang/Object;->hashCode()I"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_double_dollar_class(self):
        elements = [
@@ -61,6 +67,7 @@ class TestSignatureToElements(unittest.TestCase):
        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
                    "-><init>(Ljava/lang/CharSequence;)V"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_no_member(self):
        elements = [
@@ -72,6 +79,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_wildcard(self):
        elements = [
@@ -81,6 +89,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "java/lang/*"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_recursive_wildcard(self):
        elements = [
@@ -90,6 +99,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "java/lang/**"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_no_packages_wildcard(self):
        elements = [
@@ -97,6 +107,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "*"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_no_packages_recursive_wildcard(self):
        elements = [
@@ -104,6 +115,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "**"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, self.elements_to_signature(elements))

    def test_invalid_no_class_or_wildcard(self):
        signature = "java/lang"
@@ -121,6 +133,7 @@ class TestSignatureToElements(unittest.TestCase):
        ]
        signature = "Ljavax/crypto/extObjectInputStream"
        self.assertEqual(elements, self.signature_to_elements(signature))
        self.assertEqual(signature, "L" + self.elements_to_signature(elements))

    def test_invalid_pattern_wildcard(self):
        pattern = "Ljava/lang/Class*"
@@ -200,6 +213,18 @@ Ljava/util/zip/ZipFile;-><clinit>()V
            "Ljava/util/zip/ZipFile;-><clinit>()V",
        ])

    def test_node_wildcard(self):
        trie = self.read_trie()
        node = list(trie.child_nodes())[0]
        self.check_node_patterns(node, "**", [
            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
            "Ljava/lang/Character;->serialVersionUID:J",
            "Ljava/lang/Object;->hashCode()I",
            "Ljava/lang/Object;->toString()Ljava/lang/String;",
            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
            "Ljava/util/zip/ZipFile;-><clinit>()V",
        ])

    # pylint: enable=line-too-long