Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit c03e3f5a authored by Issei Suzuki's avatar Issei Suzuki
Browse files

Hierarchical state machine.

Test: atest com.android.server.wm.utils.StateMachineTest
Bug: 242545520

Change-Id: Ib3d7d744ff034026e8b117b6bad0b1e0922ec1be
parent c259e3a7
Loading
Loading
Loading
Loading
+280 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.wm.utils;

import android.annotation.IntRange;
import android.annotation.Nullable;
import android.util.IntArray;
import android.util.Slog;
import android.util.SparseArray;

import com.android.internal.util.AnnotationValidations;

import java.util.ArrayDeque;
import java.util.Queue;

/**
 * Simple hierarchical state machine.
 *
 * The state is represented by an integer value. The root state has a value {@code 0x0}, and top
 * level state has a value in range {@code 0x1} to {@code 0xF}. To indicate a state B is a sub state
 * of a state A, assign an integer state_value(B) = state_value(A) << 4 + (0x0 .. 0xF).
 */
public class StateMachine {
    private static final String TAG = "StateMachine";

    /**
     * Interface for implementing state specific actions.
     */
    public interface Handler {
        /**
         * Called when state machine changes its state to this state.
         */
        default void enter() {}

        /**
         * Called when state machine changes its state from this state to other state.
         */
        default void exit() {}

        /**
         * @param event type of this event.
         * @param param parameter passed to {@link StateMachine#handle(int, Object)}
         * @return {@code true} if the event was handled in this handler, so we don't need to
         *          check the parent state. Otherwise, handle() of the parent state is triggered.
         */
        default boolean handle(int event, @Nullable Object param) {
            return false;
        }
    }

    /**
     * The most recent state requested by transit() call.
     *
     * @note When transit() is called recursively, this might not be same value as mState until
     *       transit() finishes.
     */
    private int mLastRequestedState;

    /**
     * The current state of this state machine.
     */
    private int mState;

    private final IntArray mTmp = new IntArray();
    private final SparseArray<Handler> mStateHandlers = new SparseArray<>();

    /**
     * Actions which need to execute to finish requested transition.
     */
    private final Queue<Command> mCommands = new ArrayDeque<>();

    protected static class Command {
        static final int COMMIT = 1;
        static final int ENTER = 2;
        static final int EXIT = 3;

        final int mType;
        final int mState;

        private Command(int type, @IntRange(from = 0) int state) {
            mType = type;
            AnnotationValidations.validate(IntRange.class, null, state, "from", 0);
            mState = state;
        }

        static Command newCommit(int state) {
            return new Command(COMMIT, state);
        }

        static Command newEnter(int state) {
            return new Command(ENTER, state);
        }

        static Command newExit(int state) {
            return new Command(EXIT, state);
        }

        @Override
        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("Command{ type: ");
            switch (mType) {
                case COMMIT:
                    sb.append("commit");
                    break;
                case ENTER:
                    sb.append("enter");
                    break;
                case EXIT:
                    sb.append("exit");
                    break;
                default:
                    sb.append("UNKNOWN(");
                    sb.append(mType);
                    sb.append(")");
                    break;
            }
            sb.append(" state: ");
            sb.append(Integer.toHexString(mState));
            sb.append(" }");
            return sb.toString();
        }
    }

    public StateMachine() {
        this(0);
    }

    public StateMachine(@IntRange(from = 0) int initialState) {
        mState = initialState;
        AnnotationValidations.validate(IntRange.class, null, initialState, "from", 0);
        mLastRequestedState = initialState;
    }

    /**
     * @see #mLastRequestedState
     */
    public int getState() {
        return mLastRequestedState;
    }

    protected int getCurrentState() {
        return mState;
    }

    protected Command[] getCommands() {
        final Command[] commands = new Command[mCommands.size()];
        mCommands.toArray(commands);
        return commands;
    }

    /**
     * Add a handler for a specific state.
     *
     * @param state State which the given handler processes.
     * @param handler A handler which runs entry, exit actions and processes events.
     * @return Previous state handler if it's already registered, or {@code null}.
     */
    @Nullable public Handler addStateHandler(int state, @Nullable Handler handler) {
        final Handler handlerOld = mStateHandlers.get(state);
        mStateHandlers.put(state, handler);
        return handlerOld;
    }

    /**
     * Process an event. Search handler for a given event and {@link Handler#handle(int)}. If the
     * handler cannot handle the event, delegate it to a handler for a parent of the given state.
     *
     * @param event Type of an event.
     */
    public void handle(int event, @Nullable Object param) {
        int state = mState;
        while (state != 0) {
            final Handler h = mStateHandlers.get(state);
            if (h != null && h.handle(event, param)) {
                return;
            }
            state >>= 4;
        }
    }

    protected void enter(@IntRange(from = 0) int state) {
        AnnotationValidations.validate(IntRange.class, null, state, "from", 0);
        final Handler h = mStateHandlers.get(state);
        if (h != null) {
            h.enter();
        }
    }

    protected void exit(@IntRange(from = 0) int state) {
        AnnotationValidations.validate(IntRange.class, null, state, "from", 0);
        final Handler h = mStateHandlers.get(state);
        if (h != null) {
            h.exit();
        }
    }

    /**
     * @return {@code true} if a given sub state is a descendant of a given super state.
     */
    public static boolean isIn(int subState, int superState) {
        while (subState > superState) {
            subState >>= 4;
        }
        return subState == superState;
    }

    /**
     * Check if the last requested state is a sub state of a given state.
     *
     * @return {@code true} if the last requested state (via {@link #transit(int)}) is a sub state
     *         of a given state.
     */
    public boolean isIn(int state) {
        return isIn(mLastRequestedState, state);
    }

    /**
     * Change state to the requested state.
     *
     * @param newState The new state that the state machine should be changed.
     */
    public void transit(@IntRange(from = 0) int newState) {
        AnnotationValidations.validate(IntRange.class, null, newState, "from", 0);

        // entry and exit action might start another transition, so this transit() function can be
        // called recursively. In order to guarantee entry and exit actions in expected order,
        // we first compute the sequence and push them into a queue, then process them later.
        mCommands.add(Command.newCommit(newState));
        if (mLastRequestedState == newState) {
            mCommands.add(Command.newExit(newState));
            mCommands.add(Command.newEnter(newState));
        } else {
            // mLastRequestedState to least common ancestor
            for (int s = mLastRequestedState; !isIn(newState, s); s >>= 4) {
                mCommands.add(Command.newExit(s));
            }
            // least common ancestor to newState
            mTmp.clear();
            for (int s = newState; !isIn(mLastRequestedState, s); s >>= 4) {
                mTmp.add(s);
            }
            for (int i = mTmp.size() - 1; i >= 0; --i) {
                mCommands.add(Command.newEnter(mTmp.get(i)));
            }
        }
        mLastRequestedState = newState;
        while (!mCommands.isEmpty()) {
            final Command cmd = mCommands.remove();
            switch (cmd.mType) {
                case Command.EXIT:
                    exit(cmd.mState);
                    break;
                case Command.ENTER:
                    enter(cmd.mState);
                    break;
                case Command.COMMIT:
                    mState = cmd.mState;
                    break;
                default:
                    Slog.e(TAG, "Unknown command type: " + cmd.mType);
                    break;
            }
        }
    }
}
+18 −0
Original line number Diff line number Diff line
{
  "presubmit": [
    {
      "name": "WmTests",
      "options": [
        {
          "include-filter": "com.android.server.wm.utils"
        },
        {
          "include-annotation": "android.platform.test.annotations.Presubmit"
        },
        {
          "exclude-annotation": "androidx.test.filters.FlakyTest"
        }
      ]
    }
  ]
}
+237 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.wm.utils;

import static com.android.server.wm.utils.StateMachine.isIn;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import android.platform.test.annotations.Presubmit;

import androidx.test.filters.SmallTest;

import org.junit.Test;

/**
 * Build/Install/Run:
 *  atest WmTests:StateMachineTest
 */
@SmallTest
@Presubmit
public class StateMachineTest {
    static class LoggingHandler implements StateMachine.Handler {
        final int mState;
        final StringBuffer mStringBuffer;
        // True if process #handle
        final boolean mHandleSelf;

        LoggingHandler(int state, StringBuffer sb, boolean handleSelf) {
            mHandleSelf = handleSelf;
            mState = state;
            mStringBuffer = sb;
        }

        LoggingHandler(int state, StringBuffer sb) {
            this(state, sb, true /* handleSelf */);
        }

        @Override
        public void enter() {
            mStringBuffer.append('i');
            mStringBuffer.append(Integer.toHexString(mState));
            mStringBuffer.append(';');
        }

        @Override
        public void exit() {
            mStringBuffer.append('o');
            mStringBuffer.append(Integer.toHexString(mState));
            mStringBuffer.append(';');
        }

        @Override
        public boolean handle(int event, Object param) {
            if (mHandleSelf) {
                mStringBuffer.append('h');
                mStringBuffer.append(Integer.toHexString(mState));
                mStringBuffer.append(';');
            }
            return mHandleSelf;
        }
    }

    static class LoggingHandlerTransferInExit extends LoggingHandler {
        final StateMachine mStateMachine;
        final int mStateToTransit;

        LoggingHandlerTransferInExit(int state, StringBuffer sb, StateMachine stateMachine,
                int stateToTransit) {
            super(state, sb);
            mStateMachine = stateMachine;
            mStateToTransit = stateToTransit;
        }

        @Override
        public void exit() {
            super.exit();
            mStateMachine.transit(mStateToTransit);
        }
    }

    @Test
    public void testStateMachineIsIn() {
        assertTrue(isIn(0x112, 0x1));
        assertTrue(isIn(0x112, 0x11));
        assertTrue(isIn(0x112, 0x112));

        assertFalse(isIn(0x1, 0x112));
        assertFalse(isIn(0x12, 0x2));
    }

    @Test
    public void testStateMachineInitialState() {
        StateMachine stateMachine = new StateMachine();
        assertEquals(0, stateMachine.getState());

        stateMachine = new StateMachine(0x23);
        assertEquals(0x23, stateMachine.getState());
    }

    @Test
    public void testStateMachineTransitToChild() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine();
        stateMachine.addStateHandler(0x1, new LoggingHandler(0x1, log));
        stateMachine.addStateHandler(0x12, new LoggingHandler(0x12, log));
        stateMachine.addStateHandler(0x123, new LoggingHandler(0x123, log));
        stateMachine.addStateHandler(0x1233, new LoggingHandler(0x1233, log));

        // 0x0 -> 0x12
        stateMachine.transit(0x12);
        assertEquals("i1;i12;", log.toString());
        assertEquals(0x12, stateMachine.getState());

        // 0x12 -> 0x1233
        log.setLength(0);
        stateMachine.transit(0x1233);
        assertEquals(0x1233, stateMachine.getState());
        assertEquals("i123;i1233;", log.toString());
    }

    @Test
    public void testStateMachineTransitToParent() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine(0x253);
        stateMachine.addStateHandler(0x2, new LoggingHandler(0x2, log));
        stateMachine.addStateHandler(0x25, new LoggingHandler(0x25, log));
        stateMachine.addStateHandler(0x253, new LoggingHandler(0x253, log));

        // 0x253 -> 0x2
        stateMachine.transit(0x2);
        assertEquals(0x2, stateMachine.getState());
        assertEquals("o253;o25;", log.toString());
    }

    @Test
    public void testStateMachineTransitSelf() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine(0x253);
        stateMachine.addStateHandler(0x2, new LoggingHandler(0x2, log));
        stateMachine.addStateHandler(0x25, new LoggingHandler(0x25, log));
        stateMachine.addStateHandler(0x253, new LoggingHandler(0x253, log));

        // 0x253 -> 0x253
        stateMachine.transit(0x253);
        assertEquals(0x253, stateMachine.getState());
        assertEquals("o253;i253;", log.toString());
    }

    @Test
    public void testStateMachineTransitGeneral() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine(0x1351);
        stateMachine.addStateHandler(0x1, new LoggingHandler(0x1, log));
        stateMachine.addStateHandler(0x13, new LoggingHandler(0x13, log));
        stateMachine.addStateHandler(0x132, new LoggingHandler(0x132, log));
        stateMachine.addStateHandler(0x1322, new LoggingHandler(0x1322, log));
        stateMachine.addStateHandler(0x1322, new LoggingHandler(0x1322, log));
        stateMachine.addStateHandler(0x135, new LoggingHandler(0x135, log));
        stateMachine.addStateHandler(0x1351, new LoggingHandler(0x1351, log));

        // 0x1351 -> 0x1322
        // least common ancestor = 0x13
        stateMachine.transit(0x1322);
        assertEquals(0x1322, stateMachine.getState());
        assertEquals("o1351;o135;i132;i1322;", log.toString());
    }

    @Test
    public void testStateMachineTriggerStateAction() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine(0x253);
        stateMachine.addStateHandler(0x2, new LoggingHandler(0x2, log));
        stateMachine.addStateHandler(0x25, new LoggingHandler(0x25, log));
        stateMachine.addStateHandler(0x253, new LoggingHandler(0x253, log));

        // state 0x253 handles the message itself
        stateMachine.handle(0, null);
        assertEquals("h253;", log.toString());
    }

    @Test
    public void testStateMachineTriggerStateActionDelegate() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine(0x253);
        stateMachine.addStateHandler(0x2, new LoggingHandler(0x2, log));
        stateMachine.addStateHandler(0x25, new LoggingHandler(0x25, log));
        stateMachine.addStateHandler(0x253,
                new LoggingHandler(0x253, log, false /* handleSelf */));

        // state 0x253 delegate the message handling to its parent state
        stateMachine.handle(0, null);
        assertEquals("h25;", log.toString());
    }

    @Test
    public void testStateMachineNestedTransition() {
        final StringBuffer log = new StringBuffer();

        StateMachine stateMachine = new StateMachine(0x25);
        stateMachine.addStateHandler(0x1, new LoggingHandler(0x1, log));

        // Force transit to state 0x3 in exit()
        stateMachine.addStateHandler(0x2,
                new LoggingHandlerTransferInExit(0x2, log, stateMachine, 0x3));
        stateMachine.addStateHandler(0x25, new LoggingHandler(0x25, log));
        stateMachine.addStateHandler(0x3, new LoggingHandler(0x3, log));

        stateMachine.transit(0x1);
        // Start transit to 0x1
        //  0x25 -> 0x2 [transit(0x3) requested] -> 0x1
        //  0x1 -> 0x3
        // Immediately set the status to 0x1, no enter/exit
        assertEquals("o25;o2;i1;o1;i3;", log.toString());
    }
}