Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 5630a835 authored by Chalard Jean's avatar Chalard Jean Committed by android-build-merger
Browse files

Move the concurrent interpreter out of TrackRecordTest.

am: 58992a59

Change-Id: I4b48f4de33201a9456daf86a5b248d2ffc95b64c
parents 0a287fcc 58992a59
Loading
Loading
Loading
Loading
+162 −0
Original line number Diff line number Diff line
package com.android.testutils

import android.os.SystemClock
import java.util.concurrent.CyclicBarrier
import kotlin.system.measureTimeMillis
import kotlin.test.assertEquals
import kotlin.test.assertNull
import kotlin.test.assertTrue

// The table contains pairs associating a regexp with the code to run. The statement is matched
// against each matcher in sequence and when a match is found the associated code is run, passing
// it the TrackRecord under test and the result of the regexp match.
typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentIntepreter<T>, T, MatchResult) -> Any?>

// The default unit of time for interpreted tests
val INTERPRET_TIME_UNIT = 40L // ms

/**
 * A small interpreter for testing parallel code. The interpreter will read a list of lines
 * consisting of "|"-separated statements. Each column runs in a different concurrent thread
 * and all threads wait for each other in between lines. Each statement is split on ";" then
 * matched with regular expressions in the instructionTable constant, which contains the
 * code associated with each statement. The interpreter supports an object being passed to
 * the interpretTestSpec() method to be passed in each lambda (think about the object under
 * test), and an optional transform function to be executed on the object at the start of
 * every thread.
 *
 * The time unit is defined in milliseconds by the interpretTimeUnit member, which has a default
 * value but can be passed to the constructor. Whitespace is ignored.
 *
 * The interpretation table has to be passed as an argument. It's a table associating a regexp
 * with the code that should execute, as a function taking three arguments : the interpreter,
 * the regexp match, and the object. See the individual tests for the DSL of that test.
 * Implementors for new interpreting languages are encouraged to look at the defaultInterpretTable
 * constant below for an example of how to write an interpreting table.
 * Some expressions already exist by default and can be used by all interpreters. They include :
 * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
 * EXPR = VALUE : asserts that EXPR equals VALUE. EXPR is interpreted. VALUE can either be the
 *   string "null" or an int. Returns Unit.
 * EXPR time x..y : measures the time taken by EXPR and asserts it took at least x and at most
 *   y time units.
 * EXPR // any text : comments are ignored.
 */
open class ConcurrentIntepreter<T>(
    localInterpretTable: List<InterpretMatcher<T>>,
    val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
) {
    private val interpretTable: List<InterpretMatcher<T>> =
            localInterpretTable + getDefaultInstructions()

    // Split the line into multiple statements separated by ";" and execute them. Return whatever
    // the last statement returned.
    fun interpretMultiple(instr: String, r: T): Any? {
        return instr.split(";").map { interpret(it.trim(), r) }.last()
    }

    // Match the statement to a regex and interpret it.
    fun interpret(instr: String, r: T): Any? {
        val (matcher, code) =
                interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
        val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
        return code(this, r, match)
    }

    // Spins as many threads as needed by the test spec and interpret each program concurrently,
    // having all threads waiting on a CyclicBarrier after each line.
    fun interpretTestSpec(spec: String, initial: T, threadTransform: (T) -> T = { it }) {
        // For nice stack traces
        val callSite = getCallingMethod()
        val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
        // |threads| contains arrays of strings that make up the statements of a thread : in other
        // words, it's an array that contains a list of statements for each column in the spec.
        val threadCount = lines[0].size
        assertTrue(lines.all { it.size == threadCount })
        val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
        val barrier = CyclicBarrier(threadCount)
        var crash: InterpretException? = null
        threadInstructions.mapIndexed { threadIndex, instructions ->
            Thread {
                val threadLocal = threadTransform(initial)
                barrier.await()
                var lineNum = 0
                instructions.forEach {
                    if (null != crash) return@Thread
                    lineNum += 1
                    try {
                        interpretMultiple(it, threadLocal)
                    } catch (e: Throwable) {
                        // If fail() or some exception was called, the thread will come here ; if
                        // the exception isn't caught the process will crash, which is not nice for
                        // testing. Instead, catch the exception, cancel other threads, and report
                        // nicely. Catch throwable because fail() is AssertionError, which inherits
                        // from Error.
                        crash = InterpretException(threadIndex, callSite.lineNumber + lineNum,
                                callSite.className, callSite.methodName, callSite.fileName, e)
                    }
                    barrier.await()
                }
            }.also { it.start() }
        }.forEach { it.join() }
        // If the test failed, crash with line number
        crash?.let { throw it }
    }

    // Helper to get the stack trace for a calling method
    protected fun getCallingMethod(depth: Int): StackTraceElement {
        try {
            throw RuntimeException()
        } catch (e: RuntimeException) {
            return e.stackTrace[depth]
        }
    }

    // Override this if you don't call interpretTestSpec directly to get the correct file
    // and line for failure in the error message.
    // 0 is this method here, 1 is getCallingMethod(int), 2 is interpretTestSpec, 3 the lambda
    open fun getCallingMethod() = getCallingMethod(4)
}

private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>(
    // Interpret an empty line as doing nothing.
    Regex("") to { _, _, _ -> null },
    // Ignore comments.
    Regex("(.*)//.*") to { i, t, r -> i.interpret(r.strArg(1), t) },
    // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
    Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { i, t, r ->
        assertTrue(measureTimeMillis { i.interpret(r.strArg(1), t) } in r.timeArg(2)..r.timeArg(3))
    },
    // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
    Regex("""(.*)\s*=\s*(null|\d+)""") to { i, t, r ->
        i.interpret(r.strArg(1), t).also {
            if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
        }
    },
    // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
    Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
        SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
    }
)

class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
class InterpretException(
    threadIndex: Int,
    lineNum: Int,
    className: String,
    methodName: String,
    fileName: String,
    cause: Throwable
) : RuntimeException(cause) {
    init {
        stackTrace = arrayOf(StackTraceElement(
                className,
                "$methodName:thread$threadIndex",
                fileName,
                lineNum)) + super.getStackTrace()
    }
}

// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
fun MatchResult.intArg(index: Int) = strArg(index).toInt()
fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)
+50 −168
Original line number Diff line number Diff line
@@ -16,11 +16,17 @@

package android.net.testutils

import android.os.SystemClock
import com.android.testutils.ArrayTrackRecord
import com.android.testutils.ConcurrentIntepreter
import com.android.testutils.InterpretException
import com.android.testutils.InterpretMatcher
import com.android.testutils.SyntaxException
import com.android.testutils.TrackRecord
import com.android.testutils.__FILE__
import com.android.testutils.__LINE__
import com.android.testutils.intArg
import com.android.testutils.strArg
import com.android.testutils.timeArg
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
@@ -43,9 +49,6 @@ const val SHORT_TIMEOUT = 40L // ms
const val TEST_TIMEOUT = 200L // ms
const val LONG_TIMEOUT = 5000L // ms

// The unit of time for interpreted tests
const val INTERPRET_TIME_UNIT = SHORT_TIMEOUT

@RunWith(JUnit4::class)
class TrackRecordTest {
    @Test
@@ -217,7 +220,7 @@ class TrackRecordTest {
    fun testInterpreter() {
        val interpretLine = __LINE__ + 2
        try {
            interpretTestSpec(useReadHeads = true, spec = """
            TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
                add(4) | poll(1, 0) = 5
            """)
            fail("This spec should have thrown")
@@ -232,7 +235,7 @@ class TrackRecordTest {

    @Test
    fun testMultipleAdds() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
            add(2)         |                |                |
                           | add(4)         |                |
                           |                | add(6)         |
@@ -246,7 +249,7 @@ class TrackRecordTest {

    @Test
    fun testConcurrentAdds() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
            add(2)             | add(4)             | add(6)             | add(8)
            add(1)             | add(3)             | add(5)             | add(7)
            poll(0, 1) is even | poll(0, 0) is even | poll(0, 3) is even | poll(0, 2) is even
@@ -256,7 +259,7 @@ class TrackRecordTest {

    @Test
    fun testMultiplePoll() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
            add(4)         | poll(1, 0) = 4
                           | poll(0, 1) = null time 0..1
                           | poll(1, 1) = null time 1..2
@@ -267,7 +270,7 @@ class TrackRecordTest {

    @Test
    fun testMultiplePollWithPredicate() {
        interpretTestSpec(useReadHeads = false, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = false, spec = """
                     | poll(1, 0) = null          | poll(1, 0) = null
            add(6)   | poll(1, 0) = 6             |
            add(11)  | poll(1, 0) { > 20 } = null | poll(1, 0) { = 11 } = 11
@@ -277,7 +280,7 @@ class TrackRecordTest {

    @Test
    fun testMultipleReadHeads() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
                   | poll() = null | poll() = null | poll() = null
            add(5) |               | poll() = 5    |
                   | poll() = 5    |               |
@@ -291,7 +294,7 @@ class TrackRecordTest {

    @Test
    fun testReadHeadPollWithPredicate() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            add(5)  | poll() { < 0 } = null
                    | poll() { > 5 } = null
            add(10) |
@@ -303,7 +306,7 @@ class TrackRecordTest {

    @Test
    fun testPollImmediatelyAdvancesReadhead() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            add(1)                  | add(2)              | add(3)   | add(4)
            mark = 0                | poll(0) { > 3 } = 4 |          |
            poll(0) { > 10 } = null |                     |          |
@@ -314,7 +317,7 @@ class TrackRecordTest {

    @Test
    fun testParallelReadHeads() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            mark = 0   | mark = 0   | mark = 0   | mark = 0
            add(2)     |            |            |
                       | add(4)     |            |
@@ -330,7 +333,7 @@ class TrackRecordTest {

    @Test
    fun testPeek() {
        interpretTestSpec(useReadHeads = true, spec = """
        TRTInterpreter.interpretTestSpec(useReadHeads = true, spec = """
            add(2)     |            |               |
                       | add(4)     |               |
                       |            | add(6)        |
@@ -345,29 +348,20 @@ class TrackRecordTest {
            peek() = 6 | peek() = 6 | peek() = null | mark = 3
        """)
    }
    /**
     * // TODO : don't submit without this.
     * Test poll()
     *   - Check that it immediately finds added stuff that matches
     * Test ReadHead#poll()
     *   - All of the above, and:
     *   - Put stuff, check that it timeouts when it doesn't match the predicate, and the read head
     *     has advanced
     *   - Check that it immediately advances the read head
     *   - Check multiple read heads in different threads
     * Test ReadHead#peek()
     */
}

/**
 * A small interpreter for testing parallel code. The interpreter will read a list of lines
 * consisting of "|"-separated statements. Each column runs in a different concurrent thread
 * and all threads wait for each other in between lines. Each statement is split on ";" then
 * matched with regular expressions in the instructionTable constant, which contains the
 * code associated with each statement.
 *
 * The time unit is defined in milliseconds by the INTERPRET_TIME_UNIT constant. Whitespace is
 * ignored. Quick ref of supported expressions :
private object TRTInterpreter : ConcurrentIntepreter<TrackRecord<Int>>(interpretTable) {
    fun interpretTestSpec(spec: String, useReadHeads: Boolean) = if (useReadHeads) {
        interpretTestSpec(spec, ArrayTrackRecord(), { (it as ArrayTrackRecord).newReadHead() })
    } else {
        interpretTestSpec(spec, ArrayTrackRecord())
    }

    override fun getCallingMethod() = getCallingMethod(4)
}

/*
 * Quick ref of supported expressions :
 * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
 * add(x) : calls and returns TrackRecord#add.
 * poll(time, pos) [{ predicate }] : calls and returns TrackRecord#poll(x time units, pos).
@@ -380,160 +374,48 @@ class TrackRecordTest {
 *   y time units.
 * predicate must be one of "= x", "< x" or "> x".
 */
class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
class InterpretException(
    threadIndex: Int,
    lineNum: Int,
    className: String,
    methodName: String,
    fileName: String,
    cause: Throwable
) : RuntimeException(cause) {
    init {
        stackTrace = arrayOf(StackTraceElement(
                className,
                "$methodName:thread$threadIndex",
                fileName,
                lineNum)) + super.getStackTrace()
    }
}

// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
private fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
private fun MatchResult.intArg(index: Int) = strArg(index).toInt()
private fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)

// Parses a { = x } or { < x } or { > x } string and returns the corresponding predicate
// Returns an always-true predicate for empty and null arguments
private fun makePredicate(spec: String?): (Int) -> Boolean {
    if (spec.isNullOrEmpty()) return { true }
    val match = Regex("""\{\s*([<>=])\s*(\d+)\s*\}""").matchEntire(spec)
    if (null == match) throw SyntaxException("Predicate \"${spec}\"")
    val arg = match.intArg(2)
    return when (match.strArg(1)) {
        ">" -> { i -> i > arg }
        "<" -> { i -> i < arg }
        "=" -> { i -> i == arg }
        else -> throw RuntimeException("How did \"${spec}\" match this regexp ?")
    }
}

const val DEBUG_INTERPRETER = true

// The table contains pairs associating a regexp with the code to run. The statement is matched
// against each matcher in sequence and when a match is found the associated code is run, passing
// it the TrackRecord under test and the result of the regexp match.
typealias InterpretMatcher = Pair<Regex, (TrackRecord<Int>, MatchResult) -> Any?>

val interpretTable = listOf<InterpretMatcher>(
    // Interpret an empty line as doing nothing.
    Regex("") to { _, _ -> null },
    Regex("(.*)//.*") to { t, r -> interpret(r.strArg(1), t) },
    // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
    Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { t, r ->
        assertTrue(measureTimeMillis { interpret(r.strArg(1), t) } in r.timeArg(2)..r.timeArg(3))
    },
    // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
    Regex("""(.*)\s*=\s*(null|\d+)""") to { t, r ->
        interpret(r.strArg(1), t).also {
            if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
        }
    },
private val interpretTable = listOf<InterpretMatcher<TrackRecord<Int>>>(
    // Interpret "XXX is odd" : run XXX and assert its return value is odd ("even" works too)
    Regex("(.*)\\s+is\\s+(even|odd)") to { t, r ->
        interpret(r.strArg(1), t).also {
    Regex("(.*)\\s+is\\s+(even|odd)") to { i, t, r ->
        i.interpret(r.strArg(1), t).also {
            assertEquals((it as Int) % 2, if ("even" == r.strArg(2)) 0 else 1)
        }
    },
    // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
    Regex("""sleep(\((\d+)\))?""") to { t, r ->
        SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2))
    },
    // Interpret "add(XXX)" as TrackRecord#add(int)
    Regex("""add\((\d+)\)""") to { t, r ->
    Regex("""add\((\d+)\)""") to { i, t, r ->
        t.add(r.intArg(1))
    },
    // Interpret "poll(x, y)" as TrackRecord#poll(timeout = x * INTERPRET_TIME_UNIT, pos = y)
    // Accepts an optional {} argument for the predicate (see makePredicate for syntax)
    Regex("""poll\((\d+),\s*(\d+)\)\s*(\{.*\})?""") to { t, r ->
    Regex("""poll\((\d+),\s*(\d+)\)\s*(\{.*\})?""") to { i, t, r ->
        t.poll(r.timeArg(1), r.intArg(2), makePredicate(r.strArg(3)))
    },
    // ReadHead#poll. If this throws in the cast, the code is malformed and has passed "poll()"
    // in a test that takes a TrackRecord that is not a ReadHead. It's technically possible to get
    // the test code to not compile instead of throw, but it's vastly more complex and this will
    // fail 100% at runtime any test that would not have compiled.
    Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { t, r ->
        (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time ->
    Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { i, t, r ->
        (if (r.strArg(1).isEmpty()) i.interpretTimeUnit else r.timeArg(1)).let { time ->
            (t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2)))
        }
    },
    // ReadHead#mark. The same remarks apply as with ReadHead#poll.
    Regex("mark") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).mark },
    Regex("mark") to { i, t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).mark },
    // ReadHead#peek. The same remarks apply as with ReadHead#poll.
    Regex("peek\\(\\)") to { t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).peek() }
    Regex("peek\\(\\)") to { i, t, _ -> (t as ArrayTrackRecord<Int>.ReadHead).peek() }
)

// Split the line into multiple statements separated by ";" and execute them. Return whatever
// the last statement returned.
private fun <T : TrackRecord<Int>> interpretMultiple(instruction: String, r: T): Any? {
    return instruction.split(";").map { interpret(it.trim(), r) }.last()
}
// Match the statement to a regex and interpret it.
private fun <T : TrackRecord<Int>> interpret(instr: String, r: T): Any? {
    val (matcher, code) =
            interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
    val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
    return code(r, match)
}

// Create the ArrayTrackRecord<Int> under test, then spins as many threads as needed by the test
// spec and interpret each program concurrently, having all threads waiting on a CyclicBarrier
// after each line. If |useReadHeads| is true, it will create a ReadHead over the ArrayTrackRecord
// in each thread and call the interpreted methods on that ; if it's false, it will call the
// interpreted methods on the ArrayTrackRecord directly. Be careful that some instructions may
// only be supported on ReadHead, and will throw if called when using useReadHeads = false.
private fun interpretTestSpec(useReadHeads: Boolean, spec: String) {
    // For nice stack traces
    val callSite = getCallingMethod()
    val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
    // |threads| contains arrays of strings that make up the statements of a thread : in other
    // words, it's an array that contains a list of statements for each column in the spec.
    val threadCount = lines[0].size
    assertTrue(lines.all { it.size == threadCount })
    val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
    val barrier = CyclicBarrier(threadCount)
    val rec = ArrayTrackRecord<Int>()
    var crash: InterpretException? = null
    threadInstructions.mapIndexed { threadIndex, instructions ->
        Thread {
            val rh = if (useReadHeads) rec.newReadHead() else rec
            barrier.await()
            var lineNum = 0
            instructions.forEach {
                if (null != crash) return@Thread
                lineNum += 1
                try {
                    interpretMultiple(it, rh)
                } catch (e: Throwable) {
                    // If fail() or some exception was called, the thread will come here ; if the
                    // exception isn't caught the process will crash, which is not nice for testing.
                    // Instead, catch the exception, cancel other threads, and report nicely.
                    // Catch throwable because fail() is AssertionError, which inherits from Error.
                    crash = InterpretException(threadIndex, callSite.lineNumber + lineNum,
                            callSite.className, callSite.methodName, callSite.fileName, e)
                }
                barrier.await()
            }
        }.also { it.start() }
    }.forEach { it.join() }
    // If the test failed, crash with line number
    crash?.let { throw it }
}

private fun getCallingMethod(): StackTraceElement {
    try {
        throw RuntimeException()
    } catch (e: RuntimeException) {
        return e.stackTrace[3] // 0 is this method here, 1 is interpretTestSpec, 2 the lambda
// Parses a { = x } or { < x } or { > x } string and returns the corresponding predicate
// Returns an always-true predicate for empty and null arguments
private fun makePredicate(spec: String?): (Int) -> Boolean {
    if (spec.isNullOrEmpty()) return { true }
    val match = Regex("""\{\s*([<>=])\s*(\d+)\s*\}""").matchEntire(spec)
            ?: throw SyntaxException("Predicate \"${spec}\"")
    val arg = match.intArg(2)
    return when (match.strArg(1)) {
        ">" -> { i -> i > arg }
        "<" -> { i -> i < arg }
        "=" -> { i -> i == arg }
        else -> throw RuntimeException("How did \"${spec}\" match this regexp ?")
    }
}