Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 9dd2c4d5 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "Add type information to symbolfile and ndkstubgen."

parents e0510d7a af7b36de
Loading
Loading
Loading
Loading
+20 −9
Original line number Diff line number Diff line
@@ -20,13 +20,16 @@ import json
import logging
import os
import sys
from typing import Iterable, TextIO

import symbolfile
from symbolfile import Arch, Version


class Generator:
    """Output generator that writes stub source files and version scripts."""
    def __init__(self, src_file, version_script, arch, api, llndk, apex):
    def __init__(self, src_file: TextIO, version_script: TextIO, arch: Arch,
                 api: int, llndk: bool, apex: bool) -> None:
        self.src_file = src_file
        self.version_script = version_script
        self.arch = arch
@@ -34,12 +37,12 @@ class Generator:
        self.llndk = llndk
        self.apex = apex

    def write(self, versions):
    def write(self, versions: Iterable[Version]) -> None:
        """Writes all symbol data to the output files."""
        for version in versions:
            self.write_version(version)

    def write_version(self, version):
    def write_version(self, version: Version) -> None:
        """Writes a single version block's data to the output files."""
        if symbolfile.should_omit_version(version, self.arch, self.api,
                                          self.llndk, self.apex):
@@ -84,7 +87,7 @@ class Generator:
                self.version_script.write('}' + base + ';\n')


def parse_args():
def parse_args() -> argparse.Namespace:
    """Parses and returns command line arguments."""
    parser = argparse.ArgumentParser()

@@ -100,23 +103,31 @@ def parse_args():
    parser.add_argument(
        '--apex', action='store_true', help='Use the APEX variant.')

    # https://github.com/python/mypy/issues/1317
    # mypy has issues with using os.path.realpath as an argument here.
    parser.add_argument(
        '--api-map', type=os.path.realpath, required=True,
        '--api-map',
        type=os.path.realpath,  # type: ignore
        required=True,
        help='Path to the API level map JSON file.')

    parser.add_argument(
        'symbol_file', type=os.path.realpath, help='Path to symbol file.')
        'symbol_file',
        type=os.path.realpath,  # type: ignore
        help='Path to symbol file.')
    parser.add_argument(
        'stub_src', type=os.path.realpath,
        'stub_src',
        type=os.path.realpath,  # type: ignore
        help='Path to output stub source file.')
    parser.add_argument(
        'version_script', type=os.path.realpath,
        'version_script',
        type=os.path.realpath,  # type: ignore
        help='Path to output version script.')

    return parser.parse_args()


def main():
def main() -> None:
    """Program entry point."""
    args = parse_args()

cc/ndkstubgen/mypy.ini

0 → 100644
+2 −0
Original line number Diff line number Diff line
[mypy]
disallow_untyped_defs = True
+41 −38
Original line number Diff line number Diff line
@@ -21,19 +21,20 @@ import unittest

import ndkstubgen
import symbolfile
from symbolfile import Arch, Tag


# pylint: disable=missing-docstring


class GeneratorTest(unittest.TestCase):
    def test_omit_version(self):
    def test_omit_version(self) -> None:
        # Thorough testing of the cases involved here is handled by
        # OmitVersionTest, PrivateVersionTest, and SymbolPresenceTest.
        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)

        version = symbolfile.Version('VERSION_PRIVATE', None, [], [
            symbolfile.Symbol('foo', []),
@@ -42,74 +43,75 @@ class GeneratorTest(unittest.TestCase):
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION', None, ['x86'], [
        version = symbolfile.Version('VERSION', None, [Tag('x86')], [
            symbolfile.Symbol('foo', []),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION', None, ['introduced=14'], [
        version = symbolfile.Version('VERSION', None, [Tag('introduced=14')], [
            symbolfile.Symbol('foo', []),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

    def test_omit_symbol(self):
    def test_omit_symbol(self) -> None:
        # Thorough testing of the cases involved here is handled by
        # SymbolPresenceTest.
        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['x86']),
            symbolfile.Symbol('foo', [Tag('x86')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['introduced=14']),
            symbolfile.Symbol('foo', [Tag('introduced=14')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['llndk']),
            symbolfile.Symbol('foo', [Tag('llndk')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['apex']),
            symbolfile.Symbol('foo', [Tag('apex')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

    def test_write(self):
    def test_write(self) -> None:
        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)

        versions = [
            symbolfile.Version('VERSION_1', None, [], [
                symbolfile.Symbol('foo', []),
                symbolfile.Symbol('bar', ['var']),
                symbolfile.Symbol('woodly', ['weak']),
                symbolfile.Symbol('doodly', ['weak', 'var']),
                symbolfile.Symbol('bar', [Tag('var')]),
                symbolfile.Symbol('woodly', [Tag('weak')]),
                symbolfile.Symbol('doodly',
                                  [Tag('weak'), Tag('var')]),
            ]),
            symbolfile.Version('VERSION_2', 'VERSION_1', [], [
                symbolfile.Symbol('baz', []),
            ]),
            symbolfile.Version('VERSION_3', 'VERSION_1', [], [
                symbolfile.Symbol('qux', ['versioned=14']),
                symbolfile.Symbol('qux', [Tag('versioned=14')]),
            ]),
        ]

@@ -141,7 +143,7 @@ class GeneratorTest(unittest.TestCase):


class IntegrationTest(unittest.TestCase):
    def test_integration(self):
    def test_integration(self) -> None:
        api_map = {
            'O': 9000,
            'P': 9001,
@@ -178,14 +180,14 @@ class IntegrationTest(unittest.TestCase):
                wobble;
            } VERSION_4;
        """))
        parser = symbolfile.SymbolFileParser(input_file, api_map, 'arm', 9,
                                             False, False)
        parser = symbolfile.SymbolFileParser(input_file, api_map, Arch('arm'),
                                             9, False, False)
        versions = parser.parse()

        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)
        generator.write(versions)

        expected_src = textwrap.dedent("""\
@@ -213,7 +215,7 @@ class IntegrationTest(unittest.TestCase):
        """)
        self.assertEqual(expected_version, version_file.getvalue())

    def test_integration_future_api(self):
    def test_integration_future_api(self) -> None:
        api_map = {
            'O': 9000,
            'P': 9001,
@@ -230,14 +232,14 @@ class IntegrationTest(unittest.TestCase):
                    *;
            };
        """))
        parser = symbolfile.SymbolFileParser(input_file, api_map, 'arm', 9001,
                                             False, False)
        parser = symbolfile.SymbolFileParser(input_file, api_map, Arch('arm'),
                                             9001, False, False)
        versions = parser.parse()

        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9001,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9001, False, False)
        generator.write(versions)

        expected_src = textwrap.dedent("""\
@@ -255,7 +257,7 @@ class IntegrationTest(unittest.TestCase):
        """)
        self.assertEqual(expected_version, version_file.getvalue())

    def test_multiple_definition(self):
    def test_multiple_definition(self) -> None:
        input_file = io.StringIO(textwrap.dedent("""\
            VERSION_1 {
                global:
@@ -280,8 +282,8 @@ class IntegrationTest(unittest.TestCase):
            } VERSION_2;

        """))
        parser = symbolfile.SymbolFileParser(input_file, {}, 'arm', 16, False,
                                             False)
        parser = symbolfile.SymbolFileParser(input_file, {}, Arch('arm'), 16,
                                             False, False)

        with self.assertRaises(
                symbolfile.MultiplyDefinedSymbolError) as ex_context:
@@ -289,7 +291,7 @@ class IntegrationTest(unittest.TestCase):
        self.assertEqual(['bar', 'foo'],
                         ex_context.exception.multiply_defined_symbols)

    def test_integration_with_apex(self):
    def test_integration_with_apex(self) -> None:
        api_map = {
            'O': 9000,
            'P': 9001,
@@ -328,14 +330,14 @@ class IntegrationTest(unittest.TestCase):
                wobble;
            } VERSION_4;
        """))
        parser = symbolfile.SymbolFileParser(input_file, api_map, 'arm', 9,
                                             False, True)
        parser = symbolfile.SymbolFileParser(input_file, api_map, Arch('arm'),
                                             9, False, True)
        versions = parser.parse()

        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, True)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, True)
        generator.write(versions)

        expected_src = textwrap.dedent("""\
@@ -369,7 +371,8 @@ class IntegrationTest(unittest.TestCase):
        """)
        self.assertEqual(expected_version, version_file.getvalue())

def main():

def main() -> None:
    suite = unittest.TestLoader().loadTestsFromName(__name__)
    unittest.TextTestRunner(verbosity=3).run(suite)

+74 −65
Original line number Diff line number Diff line
@@ -14,15 +14,31 @@
# limitations under the License.
#
"""Parser for Android's version script information."""
from dataclasses import dataclass
import logging
import re
from typing import (
    Dict,
    Iterable,
    List,
    Mapping,
    NewType,
    Optional,
    TextIO,
    Tuple,
)


ApiMap = Mapping[str, int]
Arch = NewType('Arch', str)
Tag = NewType('Tag', str)


ALL_ARCHITECTURES = (
    'arm',
    'arm64',
    'x86',
    'x86_64',
    Arch('arm'),
    Arch('arm64'),
    Arch('x86'),
    Arch('x86_64'),
)


@@ -30,18 +46,36 @@ ALL_ARCHITECTURES = (
FUTURE_API_LEVEL = 10000


def logger():
def logger() -> logging.Logger:
    """Return the main logger for this module."""
    return logging.getLogger(__name__)


def get_tags(line):
@dataclass
class Symbol:
    """A symbol definition from a symbol file."""

    name: str
    tags: List[Tag]


@dataclass
class Version:
    """A version block of a symbol file."""

    name: str
    base: Optional[str]
    tags: List[Tag]
    symbols: List[Symbol]


def get_tags(line: str) -> List[Tag]:
    """Returns a list of all tags on this line."""
    _, _, all_tags = line.strip().partition('#')
    return [e for e in re.split(r'\s+', all_tags) if e.strip()]
    return [Tag(e) for e in re.split(r'\s+', all_tags) if e.strip()]


def is_api_level_tag(tag):
def is_api_level_tag(tag: Tag) -> bool:
    """Returns true if this tag has an API level that may need decoding."""
    if tag.startswith('introduced='):
        return True
@@ -52,7 +86,7 @@ def is_api_level_tag(tag):
    return False


def decode_api_level(api, api_map):
def decode_api_level(api: str, api_map: ApiMap) -> int:
    """Decodes the API level argument into the API level number.

    For the average case, this just decodes the integer value from the string,
@@ -70,12 +104,13 @@ def decode_api_level(api, api_map):
    return api_map[api]


def decode_api_level_tags(tags, api_map):
def decode_api_level_tags(tags: Iterable[Tag], api_map: ApiMap) -> List[Tag]:
    """Decodes API level code names in a list of tags.

    Raises:
        ParseError: An unknown version name was found in a tag.
    """
    decoded_tags = list(tags)
    for idx, tag in enumerate(tags):
        if not is_api_level_tag(tag):
            continue
@@ -83,13 +118,13 @@ def decode_api_level_tags(tags, api_map):

        try:
            decoded = str(decode_api_level(value, api_map))
            tags[idx] = '='.join([name, decoded])
            decoded_tags[idx] = Tag('='.join([name, decoded]))
        except KeyError:
            raise ParseError('Unknown version name in tag: {}'.format(tag))
    return tags
            raise ParseError(f'Unknown version name in tag: {tag}')
    return decoded_tags


def split_tag(tag):
def split_tag(tag: Tag) -> Tuple[str, str]:
    """Returns a key/value tuple of the tag.

    Raises:
@@ -103,7 +138,7 @@ def split_tag(tag):
    return key, value


def get_tag_value(tag):
def get_tag_value(tag: Tag) -> str:
    """Returns the value of a key/value tag.

    Raises:
@@ -114,12 +149,13 @@ def get_tag_value(tag):
    return split_tag(tag)[1]


def version_is_private(version):
def version_is_private(version: str) -> bool:
    """Returns True if the version name should be treated as private."""
    return version.endswith('_PRIVATE') or version.endswith('_PLATFORM')


def should_omit_version(version, arch, api, llndk, apex):
def should_omit_version(version: Version, arch: Arch, api: int, llndk: bool,
                        apex: bool) -> bool:
    """Returns True if the version section should be ommitted.

    We want to omit any sections that do not have any symbols we'll have in the
@@ -145,7 +181,8 @@ def should_omit_version(version, arch, api, llndk, apex):
    return False


def should_omit_symbol(symbol, arch, api, llndk, apex):
def should_omit_symbol(symbol: Symbol, arch: Arch, api: int, llndk: bool,
                       apex: bool) -> bool:
    """Returns True if the symbol should be omitted."""
    no_llndk_no_apex = 'llndk' not in symbol.tags and 'apex' not in symbol.tags
    keep = no_llndk_no_apex or \
@@ -160,7 +197,7 @@ def should_omit_symbol(symbol, arch, api, llndk, apex):
    return False


def symbol_in_arch(tags, arch):
def symbol_in_arch(tags: Iterable[Tag], arch: Arch) -> bool:
    """Returns true if the symbol is present for the given architecture."""
    has_arch_tags = False
    for tag in tags:
@@ -175,7 +212,7 @@ def symbol_in_arch(tags, arch):
    return not has_arch_tags


def symbol_in_api(tags, arch, api):
def symbol_in_api(tags: Iterable[Tag], arch: Arch, api: int) -> bool:
    """Returns true if the symbol is present for the given API level."""
    introduced_tag = None
    arch_specific = False
@@ -197,7 +234,7 @@ def symbol_in_api(tags, arch, api):
    return api >= int(get_tag_value(introduced_tag))


def symbol_versioned_in_api(tags, api):
def symbol_versioned_in_api(tags: Iterable[Tag], api: int) -> bool:
    """Returns true if the symbol should be versioned for the given API.

    This models the `versioned=API` tag. This should be a very uncommonly
@@ -223,68 +260,40 @@ class ParseError(RuntimeError):

class MultiplyDefinedSymbolError(RuntimeError):
    """A symbol name was multiply defined."""
    def __init__(self, multiply_defined_symbols):
        super(MultiplyDefinedSymbolError, self).__init__(
    def __init__(self, multiply_defined_symbols: Iterable[str]) -> None:
        super().__init__(
            'Version script contains multiple definitions for: {}'.format(
                ', '.join(multiply_defined_symbols)))
        self.multiply_defined_symbols = multiply_defined_symbols


class Version:
    """A version block of a symbol file."""
    def __init__(self, name, base, tags, symbols):
        self.name = name
        self.base = base
        self.tags = tags
        self.symbols = symbols

    def __eq__(self, other):
        if self.name != other.name:
            return False
        if self.base != other.base:
            return False
        if self.tags != other.tags:
            return False
        if self.symbols != other.symbols:
            return False
        return True


class Symbol:
    """A symbol definition from a symbol file."""
    def __init__(self, name, tags):
        self.name = name
        self.tags = tags

    def __eq__(self, other):
        return self.name == other.name and set(self.tags) == set(other.tags)


class SymbolFileParser:
    """Parses NDK symbol files."""
    def __init__(self, input_file, api_map, arch, api, llndk, apex):
    def __init__(self, input_file: TextIO, api_map: ApiMap, arch: Arch,
                 api: int, llndk: bool, apex: bool) -> None:
        self.input_file = input_file
        self.api_map = api_map
        self.arch = arch
        self.api = api
        self.llndk = llndk
        self.apex = apex
        self.current_line = None
        self.current_line: Optional[str] = None

    def parse(self):
    def parse(self) -> List[Version]:
        """Parses the symbol file and returns a list of Version objects."""
        versions = []
        while self.next_line() != '':
            assert self.current_line is not None
            if '{' in self.current_line:
                versions.append(self.parse_version())
            else:
                raise ParseError(
                    'Unexpected contents at top level: ' + self.current_line)
                    f'Unexpected contents at top level: {self.current_line}')

        self.check_no_duplicate_symbols(versions)
        return versions

    def check_no_duplicate_symbols(self, versions):
    def check_no_duplicate_symbols(self, versions: Iterable[Version]) -> None:
        """Raises errors for multiply defined symbols.

        This situation is the normal case when symbol versioning is actually
@@ -312,12 +321,13 @@ class SymbolFileParser:
            raise MultiplyDefinedSymbolError(
                sorted(list(multiply_defined_symbols)))

    def parse_version(self):
    def parse_version(self) -> Version:
        """Parses a single version section and returns a Version object."""
        assert self.current_line is not None
        name = self.current_line.split('{')[0].strip()
        tags = get_tags(self.current_line)
        tags = decode_api_level_tags(tags, self.api_map)
        symbols = []
        symbols: List[Symbol] = []
        global_scope = True
        cpp_symbols = False
        while self.next_line() != '':
@@ -333,9 +343,7 @@ class SymbolFileParser:
                    cpp_symbols = False
                else:
                    base = base.rstrip(';').rstrip()
                    if base == '':
                        base = None
                    return Version(name, base, tags, symbols)
                    return Version(name, base or None, tags, symbols)
            elif 'extern "C++" {' in self.current_line:
                cpp_symbols = True
            elif not cpp_symbols and ':' in self.current_line:
@@ -354,8 +362,9 @@ class SymbolFileParser:
                pass
        raise ParseError('Unexpected EOF in version block.')

    def parse_symbol(self):
    def parse_symbol(self) -> Symbol:
        """Parses a single symbol line and returns a Symbol object."""
        assert self.current_line is not None
        if ';' not in self.current_line:
            raise ParseError(
                'Expected ; to terminate symbol: ' + self.current_line)
@@ -368,7 +377,7 @@ class SymbolFileParser:
        tags = decode_api_level_tags(tags, self.api_map)
        return Symbol(name, tags)

    def next_line(self):
    def next_line(self) -> str:
        """Returns the next non-empty non-comment line.

        A return value of '' indicates EOF.

cc/symbolfile/mypy.ini

0 → 100644
+2 −0
Original line number Diff line number Diff line
[mypy]
disallow_untyped_defs = True
Loading