Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit af7b36de authored by Dan Albert's avatar Dan Albert
Browse files

Add type information to symbolfile and ndkstubgen.

Test: mypy symbolfile
Test: pytest
Bug: None
Change-Id: I6b1045d315e5a10e699d31de9fafc084d82768b2
parent 8bd50953
Loading
Loading
Loading
Loading
+20 −9
Original line number Diff line number Diff line
@@ -20,13 +20,16 @@ import json
import logging
import os
import sys
from typing import Iterable, TextIO

import symbolfile
from symbolfile import Arch, Version


class Generator:
    """Output generator that writes stub source files and version scripts."""
    def __init__(self, src_file, version_script, arch, api, llndk, apex):
    def __init__(self, src_file: TextIO, version_script: TextIO, arch: Arch,
                 api: int, llndk: bool, apex: bool) -> None:
        self.src_file = src_file
        self.version_script = version_script
        self.arch = arch
@@ -34,12 +37,12 @@ class Generator:
        self.llndk = llndk
        self.apex = apex

    def write(self, versions):
    def write(self, versions: Iterable[Version]) -> None:
        """Writes all symbol data to the output files."""
        for version in versions:
            self.write_version(version)

    def write_version(self, version):
    def write_version(self, version: Version) -> None:
        """Writes a single version block's data to the output files."""
        if symbolfile.should_omit_version(version, self.arch, self.api,
                                          self.llndk, self.apex):
@@ -84,7 +87,7 @@ class Generator:
                self.version_script.write('}' + base + ';\n')


def parse_args():
def parse_args() -> argparse.Namespace:
    """Parses and returns command line arguments."""
    parser = argparse.ArgumentParser()

@@ -100,23 +103,31 @@ def parse_args():
    parser.add_argument(
        '--apex', action='store_true', help='Use the APEX variant.')

    # https://github.com/python/mypy/issues/1317
    # mypy has issues with using os.path.realpath as an argument here.
    parser.add_argument(
        '--api-map', type=os.path.realpath, required=True,
        '--api-map',
        type=os.path.realpath,  # type: ignore
        required=True,
        help='Path to the API level map JSON file.')

    parser.add_argument(
        'symbol_file', type=os.path.realpath, help='Path to symbol file.')
        'symbol_file',
        type=os.path.realpath,  # type: ignore
        help='Path to symbol file.')
    parser.add_argument(
        'stub_src', type=os.path.realpath,
        'stub_src',
        type=os.path.realpath,  # type: ignore
        help='Path to output stub source file.')
    parser.add_argument(
        'version_script', type=os.path.realpath,
        'version_script',
        type=os.path.realpath,  # type: ignore
        help='Path to output version script.')

    return parser.parse_args()


def main():
def main() -> None:
    """Program entry point."""
    args = parse_args()

cc/ndkstubgen/mypy.ini

0 → 100644
+2 −0
Original line number Diff line number Diff line
[mypy]
disallow_untyped_defs = True
+41 −38
Original line number Diff line number Diff line
@@ -21,19 +21,20 @@ import unittest

import ndkstubgen
import symbolfile
from symbolfile import Arch, Tag


# pylint: disable=missing-docstring


class GeneratorTest(unittest.TestCase):
    def test_omit_version(self):
    def test_omit_version(self) -> None:
        # Thorough testing of the cases involved here is handled by
        # OmitVersionTest, PrivateVersionTest, and SymbolPresenceTest.
        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)

        version = symbolfile.Version('VERSION_PRIVATE', None, [], [
            symbolfile.Symbol('foo', []),
@@ -42,74 +43,75 @@ class GeneratorTest(unittest.TestCase):
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION', None, ['x86'], [
        version = symbolfile.Version('VERSION', None, [Tag('x86')], [
            symbolfile.Symbol('foo', []),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION', None, ['introduced=14'], [
        version = symbolfile.Version('VERSION', None, [Tag('introduced=14')], [
            symbolfile.Symbol('foo', []),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

    def test_omit_symbol(self):
    def test_omit_symbol(self) -> None:
        # Thorough testing of the cases involved here is handled by
        # SymbolPresenceTest.
        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['x86']),
            symbolfile.Symbol('foo', [Tag('x86')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['introduced=14']),
            symbolfile.Symbol('foo', [Tag('introduced=14')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['llndk']),
            symbolfile.Symbol('foo', [Tag('llndk')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

        version = symbolfile.Version('VERSION_1', None, [], [
            symbolfile.Symbol('foo', ['apex']),
            symbolfile.Symbol('foo', [Tag('apex')]),
        ])
        generator.write_version(version)
        self.assertEqual('', src_file.getvalue())
        self.assertEqual('', version_file.getvalue())

    def test_write(self):
    def test_write(self) -> None:
        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)

        versions = [
            symbolfile.Version('VERSION_1', None, [], [
                symbolfile.Symbol('foo', []),
                symbolfile.Symbol('bar', ['var']),
                symbolfile.Symbol('woodly', ['weak']),
                symbolfile.Symbol('doodly', ['weak', 'var']),
                symbolfile.Symbol('bar', [Tag('var')]),
                symbolfile.Symbol('woodly', [Tag('weak')]),
                symbolfile.Symbol('doodly',
                                  [Tag('weak'), Tag('var')]),
            ]),
            symbolfile.Version('VERSION_2', 'VERSION_1', [], [
                symbolfile.Symbol('baz', []),
            ]),
            symbolfile.Version('VERSION_3', 'VERSION_1', [], [
                symbolfile.Symbol('qux', ['versioned=14']),
                symbolfile.Symbol('qux', [Tag('versioned=14')]),
            ]),
        ]

@@ -141,7 +143,7 @@ class GeneratorTest(unittest.TestCase):


class IntegrationTest(unittest.TestCase):
    def test_integration(self):
    def test_integration(self) -> None:
        api_map = {
            'O': 9000,
            'P': 9001,
@@ -178,14 +180,14 @@ class IntegrationTest(unittest.TestCase):
                wobble;
            } VERSION_4;
        """))
        parser = symbolfile.SymbolFileParser(input_file, api_map, 'arm', 9,
                                             False, False)
        parser = symbolfile.SymbolFileParser(input_file, api_map, Arch('arm'),
                                             9, False, False)
        versions = parser.parse()

        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, False)
        generator.write(versions)

        expected_src = textwrap.dedent("""\
@@ -213,7 +215,7 @@ class IntegrationTest(unittest.TestCase):
        """)
        self.assertEqual(expected_version, version_file.getvalue())

    def test_integration_future_api(self):
    def test_integration_future_api(self) -> None:
        api_map = {
            'O': 9000,
            'P': 9001,
@@ -230,14 +232,14 @@ class IntegrationTest(unittest.TestCase):
                    *;
            };
        """))
        parser = symbolfile.SymbolFileParser(input_file, api_map, 'arm', 9001,
                                             False, False)
        parser = symbolfile.SymbolFileParser(input_file, api_map, Arch('arm'),
                                             9001, False, False)
        versions = parser.parse()

        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9001,
                                         False, False)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9001, False, False)
        generator.write(versions)

        expected_src = textwrap.dedent("""\
@@ -255,7 +257,7 @@ class IntegrationTest(unittest.TestCase):
        """)
        self.assertEqual(expected_version, version_file.getvalue())

    def test_multiple_definition(self):
    def test_multiple_definition(self) -> None:
        input_file = io.StringIO(textwrap.dedent("""\
            VERSION_1 {
                global:
@@ -280,8 +282,8 @@ class IntegrationTest(unittest.TestCase):
            } VERSION_2;

        """))
        parser = symbolfile.SymbolFileParser(input_file, {}, 'arm', 16, False,
                                             False)
        parser = symbolfile.SymbolFileParser(input_file, {}, Arch('arm'), 16,
                                             False, False)

        with self.assertRaises(
                symbolfile.MultiplyDefinedSymbolError) as ex_context:
@@ -289,7 +291,7 @@ class IntegrationTest(unittest.TestCase):
        self.assertEqual(['bar', 'foo'],
                         ex_context.exception.multiply_defined_symbols)

    def test_integration_with_apex(self):
    def test_integration_with_apex(self) -> None:
        api_map = {
            'O': 9000,
            'P': 9001,
@@ -328,14 +330,14 @@ class IntegrationTest(unittest.TestCase):
                wobble;
            } VERSION_4;
        """))
        parser = symbolfile.SymbolFileParser(input_file, api_map, 'arm', 9,
                                             False, True)
        parser = symbolfile.SymbolFileParser(input_file, api_map, Arch('arm'),
                                             9, False, True)
        versions = parser.parse()

        src_file = io.StringIO()
        version_file = io.StringIO()
        generator = ndkstubgen.Generator(src_file, version_file, 'arm', 9,
                                         False, True)
        generator = ndkstubgen.Generator(src_file, version_file, Arch('arm'),
                                         9, False, True)
        generator.write(versions)

        expected_src = textwrap.dedent("""\
@@ -369,7 +371,8 @@ class IntegrationTest(unittest.TestCase):
        """)
        self.assertEqual(expected_version, version_file.getvalue())

def main():

def main() -> None:
    suite = unittest.TestLoader().loadTestsFromName(__name__)
    unittest.TextTestRunner(verbosity=3).run(suite)

+74 −65
Original line number Diff line number Diff line
@@ -14,15 +14,31 @@
# limitations under the License.
#
"""Parser for Android's version script information."""
from dataclasses import dataclass
import logging
import re
from typing import (
    Dict,
    Iterable,
    List,
    Mapping,
    NewType,
    Optional,
    TextIO,
    Tuple,
)


ApiMap = Mapping[str, int]
Arch = NewType('Arch', str)
Tag = NewType('Tag', str)


ALL_ARCHITECTURES = (
    'arm',
    'arm64',
    'x86',
    'x86_64',
    Arch('arm'),
    Arch('arm64'),
    Arch('x86'),
    Arch('x86_64'),
)


@@ -30,18 +46,36 @@ ALL_ARCHITECTURES = (
FUTURE_API_LEVEL = 10000


def logger():
def logger() -> logging.Logger:
    """Return the main logger for this module."""
    return logging.getLogger(__name__)


def get_tags(line):
@dataclass
class Symbol:
    """A symbol definition from a symbol file."""

    name: str
    tags: List[Tag]


@dataclass
class Version:
    """A version block of a symbol file."""

    name: str
    base: Optional[str]
    tags: List[Tag]
    symbols: List[Symbol]


def get_tags(line: str) -> List[Tag]:
    """Returns a list of all tags on this line."""
    _, _, all_tags = line.strip().partition('#')
    return [e for e in re.split(r'\s+', all_tags) if e.strip()]
    return [Tag(e) for e in re.split(r'\s+', all_tags) if e.strip()]


def is_api_level_tag(tag):
def is_api_level_tag(tag: Tag) -> bool:
    """Returns true if this tag has an API level that may need decoding."""
    if tag.startswith('introduced='):
        return True
@@ -52,7 +86,7 @@ def is_api_level_tag(tag):
    return False


def decode_api_level(api, api_map):
def decode_api_level(api: str, api_map: ApiMap) -> int:
    """Decodes the API level argument into the API level number.

    For the average case, this just decodes the integer value from the string,
@@ -70,12 +104,13 @@ def decode_api_level(api, api_map):
    return api_map[api]


def decode_api_level_tags(tags, api_map):
def decode_api_level_tags(tags: Iterable[Tag], api_map: ApiMap) -> List[Tag]:
    """Decodes API level code names in a list of tags.

    Raises:
        ParseError: An unknown version name was found in a tag.
    """
    decoded_tags = list(tags)
    for idx, tag in enumerate(tags):
        if not is_api_level_tag(tag):
            continue
@@ -83,13 +118,13 @@ def decode_api_level_tags(tags, api_map):

        try:
            decoded = str(decode_api_level(value, api_map))
            tags[idx] = '='.join([name, decoded])
            decoded_tags[idx] = Tag('='.join([name, decoded]))
        except KeyError:
            raise ParseError('Unknown version name in tag: {}'.format(tag))
    return tags
            raise ParseError(f'Unknown version name in tag: {tag}')
    return decoded_tags


def split_tag(tag):
def split_tag(tag: Tag) -> Tuple[str, str]:
    """Returns a key/value tuple of the tag.

    Raises:
@@ -103,7 +138,7 @@ def split_tag(tag):
    return key, value


def get_tag_value(tag):
def get_tag_value(tag: Tag) -> str:
    """Returns the value of a key/value tag.

    Raises:
@@ -114,12 +149,13 @@ def get_tag_value(tag):
    return split_tag(tag)[1]


def version_is_private(version):
def version_is_private(version: str) -> bool:
    """Returns True if the version name should be treated as private."""
    return version.endswith('_PRIVATE') or version.endswith('_PLATFORM')


def should_omit_version(version, arch, api, llndk, apex):
def should_omit_version(version: Version, arch: Arch, api: int, llndk: bool,
                        apex: bool) -> bool:
    """Returns True if the version section should be ommitted.

    We want to omit any sections that do not have any symbols we'll have in the
@@ -145,7 +181,8 @@ def should_omit_version(version, arch, api, llndk, apex):
    return False


def should_omit_symbol(symbol, arch, api, llndk, apex):
def should_omit_symbol(symbol: Symbol, arch: Arch, api: int, llndk: bool,
                       apex: bool) -> bool:
    """Returns True if the symbol should be omitted."""
    no_llndk_no_apex = 'llndk' not in symbol.tags and 'apex' not in symbol.tags
    keep = no_llndk_no_apex or \
@@ -160,7 +197,7 @@ def should_omit_symbol(symbol, arch, api, llndk, apex):
    return False


def symbol_in_arch(tags, arch):
def symbol_in_arch(tags: Iterable[Tag], arch: Arch) -> bool:
    """Returns true if the symbol is present for the given architecture."""
    has_arch_tags = False
    for tag in tags:
@@ -175,7 +212,7 @@ def symbol_in_arch(tags, arch):
    return not has_arch_tags


def symbol_in_api(tags, arch, api):
def symbol_in_api(tags: Iterable[Tag], arch: Arch, api: int) -> bool:
    """Returns true if the symbol is present for the given API level."""
    introduced_tag = None
    arch_specific = False
@@ -197,7 +234,7 @@ def symbol_in_api(tags, arch, api):
    return api >= int(get_tag_value(introduced_tag))


def symbol_versioned_in_api(tags, api):
def symbol_versioned_in_api(tags: Iterable[Tag], api: int) -> bool:
    """Returns true if the symbol should be versioned for the given API.

    This models the `versioned=API` tag. This should be a very uncommonly
@@ -223,68 +260,40 @@ class ParseError(RuntimeError):

class MultiplyDefinedSymbolError(RuntimeError):
    """A symbol name was multiply defined."""
    def __init__(self, multiply_defined_symbols):
        super(MultiplyDefinedSymbolError, self).__init__(
    def __init__(self, multiply_defined_symbols: Iterable[str]) -> None:
        super().__init__(
            'Version script contains multiple definitions for: {}'.format(
                ', '.join(multiply_defined_symbols)))
        self.multiply_defined_symbols = multiply_defined_symbols


class Version:
    """A version block of a symbol file."""
    def __init__(self, name, base, tags, symbols):
        self.name = name
        self.base = base
        self.tags = tags
        self.symbols = symbols

    def __eq__(self, other):
        if self.name != other.name:
            return False
        if self.base != other.base:
            return False
        if self.tags != other.tags:
            return False
        if self.symbols != other.symbols:
            return False
        return True


class Symbol:
    """A symbol definition from a symbol file."""
    def __init__(self, name, tags):
        self.name = name
        self.tags = tags

    def __eq__(self, other):
        return self.name == other.name and set(self.tags) == set(other.tags)


class SymbolFileParser:
    """Parses NDK symbol files."""
    def __init__(self, input_file, api_map, arch, api, llndk, apex):
    def __init__(self, input_file: TextIO, api_map: ApiMap, arch: Arch,
                 api: int, llndk: bool, apex: bool) -> None:
        self.input_file = input_file
        self.api_map = api_map
        self.arch = arch
        self.api = api
        self.llndk = llndk
        self.apex = apex
        self.current_line = None
        self.current_line: Optional[str] = None

    def parse(self):
    def parse(self) -> List[Version]:
        """Parses the symbol file and returns a list of Version objects."""
        versions = []
        while self.next_line() != '':
            assert self.current_line is not None
            if '{' in self.current_line:
                versions.append(self.parse_version())
            else:
                raise ParseError(
                    'Unexpected contents at top level: ' + self.current_line)
                    f'Unexpected contents at top level: {self.current_line}')

        self.check_no_duplicate_symbols(versions)
        return versions

    def check_no_duplicate_symbols(self, versions):
    def check_no_duplicate_symbols(self, versions: Iterable[Version]) -> None:
        """Raises errors for multiply defined symbols.

        This situation is the normal case when symbol versioning is actually
@@ -312,12 +321,13 @@ class SymbolFileParser:
            raise MultiplyDefinedSymbolError(
                sorted(list(multiply_defined_symbols)))

    def parse_version(self):
    def parse_version(self) -> Version:
        """Parses a single version section and returns a Version object."""
        assert self.current_line is not None
        name = self.current_line.split('{')[0].strip()
        tags = get_tags(self.current_line)
        tags = decode_api_level_tags(tags, self.api_map)
        symbols = []
        symbols: List[Symbol] = []
        global_scope = True
        cpp_symbols = False
        while self.next_line() != '':
@@ -333,9 +343,7 @@ class SymbolFileParser:
                    cpp_symbols = False
                else:
                    base = base.rstrip(';').rstrip()
                    if base == '':
                        base = None
                    return Version(name, base, tags, symbols)
                    return Version(name, base or None, tags, symbols)
            elif 'extern "C++" {' in self.current_line:
                cpp_symbols = True
            elif not cpp_symbols and ':' in self.current_line:
@@ -354,8 +362,9 @@ class SymbolFileParser:
                pass
        raise ParseError('Unexpected EOF in version block.')

    def parse_symbol(self):
    def parse_symbol(self) -> Symbol:
        """Parses a single symbol line and returns a Symbol object."""
        assert self.current_line is not None
        if ';' not in self.current_line:
            raise ParseError(
                'Expected ; to terminate symbol: ' + self.current_line)
@@ -368,7 +377,7 @@ class SymbolFileParser:
        tags = decode_api_level_tags(tags, self.api_map)
        return Symbol(name, tags)

    def next_line(self):
    def next_line(self) -> str:
        """Returns the next non-empty non-comment line.

        A return value of '' indicates EOF.

cc/symbolfile/mypy.ini

0 → 100644
+2 −0
Original line number Diff line number Diff line
[mypy]
disallow_untyped_defs = True
Loading