Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 58992a59 authored by Chalard Jean's avatar Chalard Jean
Browse files

Move the concurrent interpreter out of TrackRecordTest.

This is to use it in TestableCallback tests.

Test: atest TrackRecordTest
Change-Id: I47e85725b39fe7d7cbc2508837968ae3775c6be9
parent e300aca6
Loading
Loading
Loading
Loading
+162 −0
Original line number Diff line number Diff line
package com.android.testutils

import android.os.SystemClock
import java.util.concurrent.CyclicBarrier
import kotlin.system.measureTimeMillis
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertTrue

// The table contains pairs associating a regexp with the code to run. The statement is matched
// against each matcher in sequence and when a match is found the associated code is run, passing
// it the TrackRecord under test and the result of the regexp match.
typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentIntepreter<T>, T, MatchResult) -> Any?>

// The default unit of time for interpreted tests
val INTERPRET_TIME_UNIT = 40L // ms

/**
 * A small interpreter for testing parallel code. The interpreter will read a list of lines
 * consisting of "|"-separated statements. Each column runs in a different concurrent thread
 * and all threads wait for each other in between lines. Each statement is split on ";" then
 * matched with regular expressions in the instructionTable constant, which contains the
 * code associated with each statement. The interpreter supports an object being passed to
 * the interpretTestSpec() method to be passed in each lambda (think about the object under
 * test), and an optional transform function to be executed on the object at the start of
 * every thread.
 *
 * The time unit is defined in milliseconds by the interpretTimeUnit member, which has a default
 * value but can be passed to the constructor. Whitespace is ignored.
 *
 * The interpretation table has to be passed as an argument. It's a table associating a regexp
 * with the code that should execute, as a function taking three arguments : the interpreter,
 * the regexp match, and the object. See the individual tests for the DSL of that test.
 * Implementors for new interpreting languages are encouraged to look at the defaultInterpretTable
 * constant below for an example of how to write an interpreting table.
 * Some expressions already exist by default and can be used by all interpreters. They include :
 * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
 * EXPR = VALUE : asserts that EXPR equals VALUE. EXPR is interpreted. VALUE can either be the
 *   string "null" or an int. Returns Unit.
 * EXPR time x..y : measures the time taken by EXPR and asserts it took at least x and at most
 *   y time units.
 * EXPR // any text : comments are ignored.
 */
open class ConcurrentIntepreter<T>(
    localInterpretTable: List<InterpretMatcher<T>>,
    val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
) {
    private val interpretTable: List<InterpretMatcher<T>> =
            localInterpretTable + getDefaultInstructions()

    // Split the line into multiple statements separated by ";" and execute them. Return whatever
    // the last statement returned.
    fun interpretMultiple(instr: String, r: T): Any? {
        return instr.split(";").map { interpret(it.trim(), r) }.last()
    }

    // Match the statement to a regex and interpret it.
    fun interpret(instr: String, r: T): Any? {
        val (matcher, code) =
                interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
        val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
        return code(this, r, match)
    }

    // Spins as many threads as needed by the test spec and interpret each program concurrently,
    // having all threads waiting on a CyclicBarrier after each line.
    fun interpretTestSpec(spec: String, initial: T, threadTransform: (T) -> T = { it }) {
        // For nice stack traces
        val callSite = getCallingMethod()
        val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
        // |threads| contains arrays of strings that make up the statements of a thread : in other
        // words, it's an array that contains a list of statements for each column in the spec.
        val threadCount = lines[0].size
        assertTrue(lines.all { it.size == threadCount })
        val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
        val barrier = CyclicBarrier(threadCount)
        var crash: InterpretException? = null
        threadInstructions.mapIndexed { threadIndex, instructions ->
            Thread {
                val threadLocal = threadTransform(initial)
                barrier.await()
                var lineNum = 0
                instructions.forEach {
                    if (null != crash) return@Thread
                    lineNum += 1
                    try {
                        interpretMultiple(it, threadLocal)
                    } catch (e: Throwable) {
                        // If fail() or some exception was called, the thread will come here ; if
                        // the exception isn't caught the process will crash, which is not nice for
                        // testing. Instead, catch the exception, cancel other threads, and report
                        // nicely. Catch throwable because fail() is AssertionError, which inherits
                        // from Error.
                        crash = InterpretException(threadIndex, callSite.lineNumber + lineNum,
                                callSite.className, callSite.methodName, callSite.fileName, e)
                    }
                    barrier.await()
                }
            }.also { it.start() }
        }.forEach { it.join() }
        // If the test failed, crash with line number
        crash?.let { throw it }
    }

    // Helper to get the stack trace for a calling method
    protected fun getCallingMethod(depth: Int): StackTraceElement {
        try {
            throw RuntimeException()
        } catch (e: RuntimeException) {
            return e.stackTrace[depth]
        }
    }

    // Override this if you don't call interpretTestSpec directly to get the correct file
    // and line for failure in the error message.
    // 0 is this method here, 1 is getCallingMethod(int), 2 is interpretTestSpec, 3 the lambda
    open fun getCallingMethod() = getCallingMethod(4)
}

private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>(
    // Interpret an empty line as doing nothing.
    Regex("") to { _, _, _ -> null },
    // Ignore comments.
    Regex("(.*)//.*") to { i, t, r -> i.interpret(r.strArg(1), t) },
    // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
    Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { i, t, r ->
        assertTrue(measureTimeMillis { i.interpret(r.strArg(1), t) } in r.timeArg(2)..r.timeArg(3))
    },
    // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
    Regex("""(.*)\s*=\s*(null|\d+)""") to { i, t, r ->
        i.interpret(r.strArg(1), t).also {
            if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
        }
    },
    // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
    Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
        SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
    }
)

class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
class InterpretException(
    threadIndex: Int,
    lineNum: Int,
    className: String,
    methodName: String,
    fileName: String,
    cause: Throwable
) : RuntimeException(cause) {
    init {
        stackTrace = arrayOf(StackTraceElement(
                className,
                "$methodName:thread$threadIndex",
                fileName,
                lineNum)) + super.getStackTrace()
    }
}

// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
fun MatchResult.intArg(index: Int) = strArg(index).toInt()
fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)
+50 −168
Original line number Diff line number Diff line
@@ -16,11 +16,17 @@

package android.net.testutils

import android.os.SystemClock
import com.android.testutils.ArrayTrackRecord
import com.android.testutils.ConcurrentIntepreter
import com.android.testutils.InterpretException
import com.android.testutils.InterpretMatcher
import com.android.testutils.SyntaxException
import com.android.testutils.TrackRecord
import com.android.testutils.__FILE__
import com.android.testutils.__LINE__
import com.android.testutils.intArg
import com.android.testutils.strArg
import com.android.testutils.timeArg
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@@ -43,9 +49,6 @@ const val SHORT_TIMEOUT = 40L // ms
const val TEST_TIMEOUT = 200L // ms
const val LONG_TIMEOUT = 5000L // ms

// The unit of time for interpreted tests
const val INTERPRET_TIME_UNIT = SHORT_TIMEOUT

@RunWith(JUnit4::class)
class TrackRecordTest {
    @Test
@@ -217,7 +220,7 @@ class TrackRecordTest {
    fun testInterpreter() {
        val interpretLine = __LINE__ + 2
        try {
            interpretTestSpec(useReadHeads = true, spec = """
            TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
                add(4) | poll(1, 0) = 5
            """)
            fail("This spec should have thrown")
@@ -232,7 +235,7 @@ class TrackRecordTest {

    @Test
    fun testMultipleAdds() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
            add(2)         |                |                |
                           | add(4)         |                |
                           |                | add(6)         |
@@ -246,7 +249,7 @@ class TrackRecordTest {

    @Test
    fun testConcurrentAdds() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
            add(2)             | add(4)             | add(6)             | add(8)
            add(1)             | add(3)             | add(5)             | add(7)
            poll(0, 1) is even | poll(0, 0) is even | poll(0, 3) is even | poll(0, 2) is even
@@ -256,7 +259,7 @@ class TrackRecordTest {

    @Test
    fun testMultiplePoll() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
            add(4)         | poll(1, 0) = 4
                           | poll(0, 1) = null time 0..1
                           | poll(1, 1) = null time 1..2
@@ -267,7 +270,7 @@ class TrackRecordTest {

    @Test
    fun testMultiplePollWithPredicate() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
                     | poll(1, 0) = null          | poll(1, 0) = null
            add(6)   | poll(1, 0) = 6             |
            add(11)  | poll(1, 0) { > 20 } = null | poll(1, 0) { = 11 } = 11
@@ -277,7 +280,7 @@ class TrackRecordTest {

    @Test
    fun testMultipleReadHeads() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
                   | poll() = null | poll() = null | poll() = null
            add(5) |               | poll() = 5    |
                   | poll() = 5    |               |
@@ -291,7 +294,7 @@ class TrackRecordTest {

    @Test
    fun testReadHeadPollWithPredicate() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            add(5)  | poll() { < 0 } = null
                    | poll() { > 5 } = null
            add(10) |
@@ -303,7 +306,7 @@ class TrackRecordTest {

    @Test
    fun testPollImmediatelyAdvancesReadhead() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            add(1)                  | add(2)              | add(3)   | add(4)
            mark = 0                | poll(0) { > 3 } = 4 |          |
            poll(0) { > 10 } = null |                     |          |
@@ -314,7 +317,7 @@ class TrackRecordTest {

    @Test
    fun testParallelReadHeads() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            mark = 0   | mark = 0   | mark = 0   | mark = 0
            add(2)     |            |            |
                       | add(4)     |            |
@@ -330,7 +333,7 @@ class TrackRecordTest {

    @Test
    fun testPeek() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            add(2)     |            |               |
                       | add(4)     |               |
                       |            | add(6)        |
@@ -345,29 +348,20 @@ class TrackRecordTest {
            peek() = 6 | peek() = 6 | peek() = null | mark = 3
        """)
    }
    /**
     * // TODO : don't submit without this.
     * Test poll()
     *   - Check that it immediately finds added stuff that matches
     * Test ReadHead#poll()
     *   - All of the above, and:
     *   - Put stuff, check that it timeouts when it doesn't match the predicate, and the read head
     *     has advanced
     *   - Check that it immediately advances the read head
     *   - Check multiple read heads in different threads
     * Test ReadHead#peek()
     */
}

/**
 * A small interpreter for testing parallel code. The interpreter will read a list of lines
 * consisting of "|"-separated statements. Each column runs in a different concurrent thread
 * and all threads wait for each other in between lines. Each statement is split on ";" then
 * matched with regular expressions in the instructionTable constant, which contains the
 * code associated with each statement.
 *
 * The time unit is defined in milliseconds by the INTERPRET_TIME_UNIT constant. Whitespace is
 * ignored. Quick ref of supported expressions :
private object TRTInterpreter : ConcurrentIntepreter<TrackRecord<Int>>(interpretTable) {
    fun interpretTestSpec(spec: String, useReadHeads: Boolean) = if (useReadHeads) {
        interpretTestSpec(spec, ArrayTrackRecord(), { (it as ArrayTrackRecord).newReadHead() })
    } else {
        interpretTestSpec(spec, ArrayTrackRecord())
    }

    override fun getCallingMethod() = getCallingMethod(4)
}

/*
 * Quick ref of supported expressions :
 * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
 * add(x) : calls and returns TrackRecord#add.
 * poll(time, pos) [{ predicate }] : calls and returns TrackRecord#poll(x time units, pos).
@@ -380,160 +374,48 @@ class TrackRecordTest {
 *   y time units.
 * predicate must be one of "= x", "< x" or "> x".
 */
class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
class InterpretException(
    threadIndex: Int,
    lineNum: Int,
    className: String,
    methodName: String,
    fileName: String,
    cause: Throwable
) : RuntimeException(cause) {
    init {
        stackTrace = arrayOf(StackTraceElement(
                className,
                "$methodName:thread$threadIndex",
                fileName,
                lineNum)) + super.getStackTrace()
    }
}

// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
private fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
private fun MatchResult.intArg(index: Int) = strArg(index).toInt()
private fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)

// Parses a { = x } or { < x } or { > x } string and returns the corresponding predicate
// Returns an always-true predicate for empty and null arguments
private fun makePredicate(spec: String?): (Int) -> Boolean {
    if (spec.isNullOrEmpty()) return { true }
    val match = Regex("""\{\s*([<>=])\s*(\d+)\s*\}""").matchEntire(spec)
    if (null == match) throw SyntaxException("Predicate \"${spec}\"")
    val arg = match.intArg(2)
    return when (match.strArg(1)) {
        ">" -> { i -> i > arg }
        "<" -> { i -> i < arg }
        "=" -> { i -> i == arg }
        else -> throw RuntimeException("How did \"${spec}\" match this regexp ?")
    }
}

const val DEBUG_INTERPRETER = true

// The table contains pairs associating a regexp with the code to run. The statement is matched
// against each matcher in sequence and when a match is found the associated code is run, passing
// it the TrackRecord under test and the result of the regexp match.
typealias InterpretMatcher = Pair<Regex, (TrackRecord<Int>, MatchResult) -> Any?>

val interpretTable = listOf<InterpretMatcher>(
    // Interpret an empty line as doing nothing.
    Regex("") to { _, _ -> null },
    Regex("(.*)//.*") to { t, r -> interpret(r.strArg(1), t) },
    // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
    Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { t, r ->
        assertTrue(measureTimeMillis { interpret(r.strArg(1), t) } in r.timeArg(2)..r.timeArg(3))
    },
    // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
    Regex("""(.*)\s*=\s*(null|\d+)""") to { t, r ->
        interpret(r.strArg(1), t).also {
            if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
        }
    },
private val interpretTable = listOf<InterpretMatcher<TrackRecord<Int>>>(
    // Interpret "XXX is odd" : run XXX and assert its return value is odd ("even" works too)
    Regex("(.*)\\s+is\\s+(even|odd)") to { t, r ->
        interpret(r.strArg(1), t).also {
    Regex("(.*)\\s+is\\s+(even|odd)") to { i, t, r ->
        i.interpret(r.strArg(1), t).also {
            assertEquals((it as Int) % 2, if ("even" == r.strArg(2)) 0 else 1)
        }
    },
    // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
    Regex("""sleep(\((\d+)\))?""") to { t, r ->
        SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2))
    },
    // Interpret "add(XXX)" as TrackRecord#add(int)
    Regex("""add\((\d+)\)""") to { t, r ->
    Regex("""add\((\d+)\)""") to { i, t, r ->
        t.add(r.intArg(1))
    },
    // Interpret "poll(x, y)" as TrackRecord#poll(timeout = x * INTERPRET_TIME_UNIT, pos = y)
    // Accepts an optional {} argument for the predicate (see makePredicate for syntax)
    Regex("""poll\((\d+),\s*(\d+)\)\s*(\{.*\})?""") to { t, r ->
    Regex("""poll\((\d+),\s*(\d+)\)\s*(\{.*\})?""") to { i, t, r ->
        t.poll(r.timeArg(1), r.intArg(2), makePredicate(r.strArg(3)))
    },
    // ReadHead#poll. If this throws in the cast, the code is malformed and has passed "poll()"
    // in a test that takes a TrackRecord that is not a ReadHead. It's technically possible to get
    // the test code to not compile instead of throw, but it's vastly more complex and this will
    // fail 100% at runtime any test that would not have compiled.
    Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { t, r ->
        (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time ->
    Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { i, t, r ->
        (if (r.strArg(1).isEmpty()) i.interpretTimeUnit else r.timeArg(1)).let { time ->
            (t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2)))
        }
    },
    // ReadHead#mark. The same remarks apply as with ReadHead#poll.
    Regex("mark") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).mark },
    Regex("mark") to { i, t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).mark },
    // ReadHead#peek. The same remarks apply as with ReadHead#poll.
    Regex("peek\\(\\)") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).peek() }
    Regex("peek\\(\\)") to { i, t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).peek() }
)

// Split the line into multiple statements separated by ";" and execute them. Return whatever
// the last statement returned.
private fun <T : TrackRecord<Int>> interpretMultiple(instruction: String, r: T): Any? {
    return instruction.split(";").map { interpret(it.trim(), r) }.last()
}
// Match the statement to a regex and interpret it.
private fun <T : TrackRecord<Int>> interpret(instr: String, r: T): Any? {
    val (matcher, code) =
            interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
    val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
    return code(r, match)
}

// Create the ArrayTrackRecord<Int> under test, then spins as many threads as needed by the test
// spec and interpret each program concurrently, having all threads waiting on a CyclicBarrier
// after each line. If |useReadHeads| is true, it will create a ReadHead over the ArrayTrackRecord
// in each thread and call the interpreted methods on that ; if it's false, it will call the
// interpreted methods on the ArrayTrackRecord directly. Be careful that some instructions may
// only be supported on ReadHead, and will throw if called when using useReadHeads = false.
private fun interpretTestSpec(useReadHeads: Boolean, spec: String) {
    // For nice stack traces
    val callSite = getCallingMethod()
    val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
    // |threads| contains arrays of strings that make up the statements of a thread : in other
    // words, it's an array that contains a list of statements for each column in the spec.
    val threadCount = lines[0].size
    assertTrue(lines.all { it.size == threadCount })
    val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
    val barrier = CyclicBarrier(threadCount)
    val rec = ArrayTrackRecord<Int>()
    var crash: InterpretException? = null
    threadInstructions.mapIndexed { threadIndex, instructions ->
        Thread {
            val rh = if (useReadHeads) rec.newReadHead() else rec
            barrier.await()
            var lineNum = 0
            instructions.forEach {
                if (null != crash) return@Thread
                lineNum += 1
                try {
                    interpretMultiple(it, rh)
                } catch (e: Throwable) {
                    // If fail() or some exception was called, the thread will come here ; if the
                    // exception isn't caught the process will crash, which is not nice for testing.
                    // Instead, catch the exception, cancel other threads, and report nicely.
                    // Catch throwable because fail() is AssertionError, which inherits from Error.
                    crash = InterpretException(threadIndex, callSite.lineNumber + lineNum,
                            callSite.className, callSite.methodName, callSite.fileName, e)
                }
                barrier.await()
            }
        }.also { it.start() }
    }.forEach { it.join() }
    // If the test failed, crash with line number
    crash?.let { throw it }
}

private fun getCallingMethod(): StackTraceElement {
    try {
        throw RuntimeException()
    } catch (e: RuntimeException) {
        return e.stackTrace[3] // 0 is this method here, 1 is interpretTestSpec, 2 the lambda
// Parses a { = x } or { < x } or { > x } string and returns the corresponding predicate
// Returns an always-true predicate for empty and null arguments
private fun makePredicate(spec: String?): (Int) -> Boolean {
    if (spec.isNullOrEmpty()) return { true }
    val match = Regex("""\{\s*([<>=])\s*(\d+)\s*\}""").matchEntire(spec)
            ?: throw SyntaxException("Predicate \"${spec}\"")
    val arg = match.intArg(2)
    return when (match.strArg(1)) {
        ">" -> { i -> i > arg }
        "<" -> { i -> i < arg }
        "=" -> { i -> i == arg }
        else -> throw RuntimeException("How did \"${spec}\" match this regexp ?")
    }
}