Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 61aaa158 authored by Makoto Onuki's avatar Makoto Onuki Committed by Android (Google) Code Review
Browse files

Merge "Script to help convert AndroidTestCase tests" into main

parents 988b4f6f d1435467
Loading
Loading
Loading
Loading
+184 −0
Original line number Diff line number Diff line
#!/usr/bin/python3
# Copyright (C) 2024 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This script converts a legacy test class (using AndroidTestCase, TestCase or
# InstrumentationTestCase to a modern style test class, in a best-effort manner.
#
# Usage:
#  convert-androidtest.py TARGET-FILE [TARGET-FILE ...]
#
# Caveats:
#   - It adds all the extra imports, even if they're not needed.
#   - It won't sort imports.
#   - It also always adds getContext() and getTestContext().
#

import sys
import fileinput
import re
import subprocess

# Print message on console
def log(msg):
    print(msg, file=sys.stderr)


# Matches `extends AndroidTestCase` (or another similar base class)
re_extends = re.compile(
    r''' \b extends \s+ (AndroidTestCase|TestCase|InstrumentationTestCase) \s* ''',
    re.S + re.X)


# Look into given files and return the files that have `re_extends`.
def find_target_files(files):
    ret = []

    for file in files:
        try:
            with open(file, 'r') as f:
                data = f.read()

                if re_extends.search(data):
                    ret.append(file)

        except FileNotFoundError as e:
            log(f'Failed to open file {file}: {e}')

    return ret


def main(args):
    files = args

    # Find the files that should be processed.
    files = find_target_files(files)

    if len(files) == 0:
        log("No target files found.")
        return 0

    # Process the files.
    with fileinput.input(files=(files), inplace = True, backup = '.bak') as f:
        import_seen = False
        carry_over = ''
        class_body_started = False
        class_seen = False

        def on_file_start():
            nonlocal import_seen, carry_over, class_body_started, class_seen
            import_seen = False
            carry_over = ''
            class_body_started = False
            class_seen = False

        for line in f:
            if (fileinput.filelineno() == 1):
                log(f"Processing: {fileinput.filename()}")
                on_file_start()

            line = line.rstrip('\n')

            # Carry over a certain line to the next line.
            if re.search(r'''@Override\b''', line):
                carry_over = carry_over + line + '\n'
                continue

            if carry_over:
                line = carry_over + line
                carry_over = ''


            # Remove the base class from the class definition.
            line = re_extends.sub('', line)

            # Add a @RunWith.
            if not class_seen and re.search(r'''\b class \b''', line, re.X):
                class_seen = True
                print("@RunWith(AndroidJUnit4.class)")


            # Inject extra imports.
            if not import_seen and re.search(r'''^import\b''', line):
                import_seen = True
                print("""\
import android.content.Context;
import androidx.test.platform.app.InstrumentationRegistry;

import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertSame;
import static junit.framework.TestCase.assertNotSame;
import static junit.framework.TestCase.assertTrue;
import static junit.framework.TestCase.assertFalse;
import static junit.framework.TestCase.assertNull;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.fail;

import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.junit.Test;

import androidx.test.ext.junit.runners.AndroidJUnit4;
""")

            # Add @Test to the test methods.
            if re.search(r'''^ \s* public \s* void \s* test''', line, re.X):
                print("    @Test")

            # Convert setUp/tearDown to @Before/@After.
            if re.search(r''' ^\s+ ( \@Override \s+ ) ? (public|protected) \s+ void \s+ (setUp|tearDown) ''',
                        line, re.X):
                if re.search('setUp', line):
                    print('    @Before')
                else:
                    print('    @After')

                line = re.sub(r''' \s* \@Override \s* \n ''', '', line, 0, re.X)
                line = re.sub(r'''protected''', 'public', line, 0, re.X)

            # Remove the super setUp / tearDown call.
            if re.search(r''' \b super \. (setUp|tearDown) \b ''', line, re.X):
                continue

            # Convert mContext to getContext().
            line = re.sub(r'''\b mContext \b ''', 'getContext()', line, 0, re.X)

            # Print the processed line.
            print(line)

            # Add getContext() / getTestContext() at the beginning of the class.
            if not class_body_started and re.search(r'''\{''', line):
                class_body_started = True
                print("""\
    private Context getContext() {
        return InstrumentationRegistry.getInstrumentation().getTargetContext();
    }

    private Context getTestContext() {
        return InstrumentationRegistry.getInstrumentation().getContext();
    }
""")


    # Run diff
    for file in files:
        subprocess.call(["diff", "-u", "--color=auto", f"{file}.bak", file])

    log(f'{len(files)} file(s) converted.')

    return 0

if __name__ == '__main__':
    sys.exit(main(sys.argv[1:]))