CLOVER🍀

That was when it all began.

Spring Statemachineのガードを試してみる

これは、なにをしたくて書いたもの?

Spring Statemachineの、ガードというものを試してみたいなということで。

ガード

ガードについては、用語集に以下のように記載されています。

A boolean expression evaluated dynamically based on the value of extended state variables and event parameters. Guard conditions affect the behavior of a state machine by enabling actions or transitions only when they evaluate to TRUE and disabling them when they evaluate to FALSE.

Spring Statemachine / Appendices / Appendix B: State Machine Concepts / Glossary

拡張ステート変数とイベントパラメーターに基づいて動的に評価される、boolean式のことがガードだそうです。

ガード条件は、アクションまたは遷移に対してTRUEと評価された場合のみ有効となり、FALSEと評価された場合は無効となることで
ステートマシンの振る舞いに影響を与えます。

クラッシュコースでも、似たような説明が書かれています。

Spring Statemachine / Appendices / Appendix B: State Machine Concepts / A State Machine Crash Course / Guard Conditions

説明からは、ステートマシンにガードを関連付け、アクションや遷移の発生時にガード式がTRUEを返すと状態遷移が有効になり、
FALSEを返すと無効化(=状態遷移しない)というように読める気がします。

ガードの実装方法や設定については、このあたりに記載されているのですが。

Spring Statemachine / Using Spring Statemachine / Statemachine Configuration / Configuring Guards

Spring Statemachine / Using Spring Statemachine / Using Guards

どうもガードそのものに関する説明は、用語集やクラッシュコース以上には書かれていなさそうです。

というわけで、実際に動かして試してみようかなと思います。

環境

今回の環境は、こちら。

$ java --version
openjdk 17.0.4 2022-07-19
OpenJDK Runtime Environment (build 17.0.4+8-Ubuntu-120.04)
OpenJDK 64-Bit Server VM (build 17.0.4+8-Ubuntu-120.04, mixed mode, sharing)


$ mvn --version
Apache Maven 3.8.6 (84538c9988a25aec085021c365c560670ad80f63)
Maven home: $HOME/.sdkman/candidates/maven/current
Java version: 17.0.4, vendor: Private Build, runtime: /usr/lib/jvm/java-17-openjdk-amd64
Default locale: ja_JP, platform encoding: UTF-8
OS name: "linux", version: "5.4.0-125-generic", arch: "amd64", family: "unix"

Spring Bootプロジェクトを作成する

では、まずはSpring Bootプロジェクトを作成します。

$ curl -s https://start.spring.io/starter.tgz \
  -d bootVersion=2.6.7 \
  -d javaVersion=17 \
  -d name=statemachine-guard \
  -d groupId=org.littlewings \
  -d artifactId=statemachine-guard \
  -d version=0.0.1-SNAPSHOT \
  -d packageName=org.littlewings.spring.statemachine \
  -d baseDir=statemachine-guard | tar zxvf -

Spring Bootが2.6.7なのは、ドキュメントに記載のバージョンと合わせているからです。

Spring Statemachine / Getting started / Using Maven

プロジェクト内に移動。

$ cd statemachine-guard

自動生成されたソースコードは、今回は削除しておきます。

$ rm src/main/java/org/littlewings/spring/statemachine/StatemachineGuardApplication.java src/test/java/org/littlewings/spring/statemachine/StatemachineGuardApplicationTests.java

Maven依存関係など。

        <properties>
                <java.version>17</java.version>
        </properties>
        <dependencies>
                <dependency>
                        <groupId>org.springframework.boot</groupId>
                        <artifactId>spring-boot-starter</artifactId>
                </dependency>

                <dependency>
                        <groupId>org.springframework.boot</groupId>
                        <artifactId>spring-boot-starter-test</artifactId>
                        <scope>test</scope>
                </dependency>
        </dependencies>

        <build>
                <plugins>
                        <plugin>
                                <groupId>org.springframework.boot</groupId>
                                <artifactId>spring-boot-maven-plugin</artifactId>
                        </plugin>
                </plugins>
        </build>

このうち、spring-boot-starterspring-statemachine-starterに変更します。

 <dependencies>
        <dependency>
            <groupId>org.springframework.statemachine</groupId>
            <artifactId>spring-statemachine-starter</artifactId>
            <version>3.2.0</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
    </dependencies>

続いて、ソースコードを作成していきます。

ステートを定義したenum

src/main/java/org/littlewings/spring/statemachine/States.java

package org.littlewings.spring.statemachine;

public enum States {
    INITIAL_STATE,
    STATE1,
    STATE2,
    END_STATE
}

イベントを定義したenum

src/main/java/org/littlewings/spring/statemachine/Events.java

package org.littlewings.spring.statemachine;

public enum Events {
    EVENT1,
    EVENT2,
    EVENT3
}

ステートマシンの定義。こちらの詳細は、また後で説明します。

src/main/java/org/littlewings/spring/statemachine/StateMachineConfig.java

package org.littlewings.spring.statemachine;

import java.time.LocalDateTime;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.statemachine.StateContext;
import org.springframework.statemachine.action.Action;
import org.springframework.statemachine.config.EnableStateMachine;
import org.springframework.statemachine.config.EnumStateMachineConfigurerAdapter;
import org.springframework.statemachine.config.builders.StateMachineConfigurationConfigurer;
import org.springframework.statemachine.config.builders.StateMachineStateConfigurer;
import org.springframework.statemachine.config.builders.StateMachineTransitionConfigurer;
import org.springframework.statemachine.guard.Guard;

@Configuration
@EnableStateMachine
public class StateMachineConfig extends EnumStateMachineConfigurerAdapter<States, Events> {
    @Override
    public void configure(StateMachineConfigurationConfigurer<States, Events> config)
            throws Exception {
        config
                .withConfiguration()
                .autoStartup(true)
                .machineId("my-statemachine");
    }

    @Override
    public void configure(StateMachineStateConfigurer<States, Events> states)
            throws Exception {
        states
                .withStates()
                .initial(States.INITIAL_STATE)
                .state(States.STATE1)
                .state(States.STATE2)
                .end(States.END_STATE);
    }

    @Override
    public void configure(StateMachineTransitionConfigurer<States, Events> transitions)
            throws Exception {
        transitions
                .withExternal()
                .source(States.INITIAL_STATE).target(States.STATE1)
                .event(Events.EVENT1)
                .guard(guard1())
                .action(loggingAction())
                .and()
                .withExternal()
                .source(States.STATE1).target(States.STATE2)
                .event(Events.EVENT2)
                .guard(guard2())
                .action(loggingAction())
                .and()
                .withExternal()
                .source(States.STATE2).target(States.END_STATE)
                .timer(TimeUnit.SECONDS.toMillis(2L))
                .guard(guard3())
                .action(loggingAction());
    }

    @Bean
    public Guard<States, Events> guard1() {
        return new Guard<>() {
            AtomicInteger counter = new AtomicInteger(0);

            @Override
            public boolean evaluate(StateContext<States, Events> context) {
                boolean evaluated = counter.incrementAndGet() >= 2;

                System.out.printf(
                        "[%s] guard1, stage = %s, state = %s, trigger type = %s, event = %s, guard evaluated = %b%n",
                        LocalDateTime.now(),
                        context.getStage(),
                        context.getTarget().getId(),
                        context.getTransition().getTrigger() != null ? context.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                        context.getMessage() != null ? context.getMessage().getPayload() : "[none]",
                        evaluated
                );

                return evaluated;
            }
        };
    }

    @Bean
    public Guard<States, Events> guard2() {
        AtomicInteger counter = new AtomicInteger(0);

        return context -> {
            boolean evaluated = counter.incrementAndGet() >= 2;

            System.out.printf(
                    "[%s] guard2, stage = %s, state = %s, trigger type = %s, event = %s, guard evaluated = %b%n",
                    LocalDateTime.now(),
                    context.getStage(),
                    context.getTarget().getId(),
                    context.getTransition().getTrigger() != null ? context.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                    context.getMessage() != null ? context.getMessage().getPayload() : "[none]",
                    evaluated
            );

            return evaluated;
        };
    }

    @Bean
    public Guard<States, Events> guard3() {
        AtomicInteger counter = new AtomicInteger(0);

        return context -> {
            boolean evaluated = counter.incrementAndGet() >= 2;

            System.out.printf(
                    "[%s] guard3, stage = %s, state = %s, trigger type = %s, event = %s, guard evaluated = %b%n",
                    LocalDateTime.now(),
                    context.getStage(),
                    context.getTarget().getId(),
                    context.getTransition().getTrigger() != null ? context.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                    context.getMessage() != null ? context.getMessage().getPayload() : "[none]",
                    evaluated
            );

            return evaluated;
        };
    }

    @Bean
    public Action<States, Events> loggingAction() {
        return stateContext ->
                System.out.printf(
                        "[%s] state action, stage = %s, state = %s, trigger type = %s, event = %s%n",
                        LocalDateTime.now(),
                        stateContext.getStage(),
                        stateContext.getTarget().getId(),
                        stateContext.getTransition().getTrigger() != null ? stateContext.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                        stateContext.getMessage() != null ? stateContext.getMessage().getPayload() : "[none]"
                );
    }
}

ステートマシンを使うクラス。

src/main/java/org/littlewings/spring/statemachine/Runner.java

package org.littlewings.spring.statemachine;

import java.util.concurrent.TimeUnit;

import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.statemachine.StateMachine;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Mono;

@Component
public class Runner implements ApplicationRunner {
    StateMachine<States, Events> stateMachine;

    public Runner(StateMachine<States, Events> stateMachine) {
        this.stateMachine = stateMachine;
    }

    @Override
    public void run(ApplicationArguments args) throws Exception {
        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT1).build()))
                .blockFirst();

        TimeUnit.MILLISECONDS.sleep(500L);

        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT1).build()))
                .blockFirst();

        TimeUnit.MILLISECONDS.sleep(500L);

        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT2).build()))
                .blockFirst();

        TimeUnit.MILLISECONDS.sleep(500L);

        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT2).build()))
                .blockFirst();

        TimeUnit.SECONDS.sleep(5L);

        System.out.printf("StateMachine complete? %b%n", stateMachine.isComplete());
    }
}

同じイベントを2回ずつ、スリープしながら送り込みます。

最後に、ステートマシンが完了したかどうかを出力して完了ですね。

mainクラス。

src/main/java/org/littlewings/spring/statemachine/App.java

package org.littlewings.spring.statemachine;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class App {
    public static void main(String... args) {
        SpringApplication.run(App.class, args);
    }
}
ガードを作成する

ステートマシンの定義をしているクラスで、ガードに関する部分をもう少し掘り下げて見てみましょう。

参照しているドキュメントは、以下あたりです。

Spring Statemachine / Using Spring Statemachine / Statemachine Configuration / Configuring Guards

Spring Statemachine / Using Spring Statemachine / Using Guards

今回、3つの遷移を定義していますが、それぞれにガードを紐づけています。

    @Override
    public void configure(StateMachineTransitionConfigurer<States, Events> transitions)
            throws Exception {
        transitions
                .withExternal()
                .source(States.INITIAL_STATE).target(States.STATE1)
                .event(Events.EVENT1)
                .guard(guard1())
                .action(loggingAction())
                .and()
                .withExternal()
                .source(States.STATE1).target(States.STATE2)
                .event(Events.EVENT2)
                .guard(guard2())
                .action(loggingAction())
                .and()
                .withExternal()
                .source(States.STATE2).target(States.END_STATE)
                .timer(TimeUnit.SECONDS.toMillis(2L))
                .guard(guard3())
                .action(loggingAction());
    }

ガードは、Guardインターフェースを実装して作成します。

Guard (Spring State Machine 3.2.0 API)

ひとつ目の遷移に紐づけるガードは、こんな感じで作成。

    @Bean
    public Guard<States, Events> guard1() {
        return new Guard<>() {
            AtomicInteger counter = new AtomicInteger(0);

            @Override
            public boolean evaluate(StateContext<States, Events> context) {
                boolean evaluated = counter.incrementAndGet() >= 2;

                System.out.printf(
                        "[%s] guard1, stage = %s, state = %s, trigger type = %s, event = %s, guard evaluated = %b%n",
                        LocalDateTime.now(),
                        context.getStage(),
                        context.getTarget().getId(),
                        context.getTransition().getTrigger() != null ? context.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                        context.getMessage() != null ? context.getMessage().getPayload() : "[none]",
                        evaluated
                );

                return evaluated;
            }
        };
    }

ガードは、Guard#evaluatetrueを返すと遷移が有効になり、falseを返すと無効になります。今回は、ガードが2回以上呼び出されると
trueを返すように作成しました。

こちらを遷移の定義に紐づけます。

                .withExternal()
                .source(States.INITIAL_STATE).target(States.STATE1)
                .event(Events.EVENT1)
                .guard(guard1())
                .action(loggingAction())

2つ目、3つ目は、ガードをLambda式で作成したくらいで、やっていることは同じです。

    @Bean
    public Guard<States, Events> guard2() {
        AtomicInteger counter = new AtomicInteger(0);

        return context -> {
            boolean evaluated = counter.incrementAndGet() >= 2;

            System.out.printf(
                    "[%s] guard2, stage = %s, state = %s, trigger type = %s, event = %s, guard evaluated = %b%n",
                    LocalDateTime.now(),
                    context.getStage(),
                    context.getTarget().getId(),
                    context.getTransition().getTrigger() != null ? context.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                    context.getMessage() != null ? context.getMessage().getPayload() : "[none]",
                    evaluated
            );

            return evaluated;
        };
    }

    @Bean
    public Guard<States, Events> guard3() {
        AtomicInteger counter = new AtomicInteger(0);

        return context -> {
            boolean evaluated = counter.incrementAndGet() >= 2;

            System.out.printf(
                    "[%s] guard3, stage = %s, state = %s, trigger type = %s, event = %s, guard evaluated = %b%n",
                    LocalDateTime.now(),
                    context.getStage(),
                    context.getTarget().getId(),
                    context.getTransition().getTrigger() != null ? context.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                    context.getMessage() != null ? context.getMessage().getPayload() : "[none]",
                    evaluated
            );

            return evaluated;
        };
    }

どの遷移に紐づけたガードなのかは、出力するメッセージでわかるようにしています。

最終的に、遷移の定義はこんな感じになりました、と。

        transitions
                .withExternal()
                .source(States.INITIAL_STATE).target(States.STATE1)
                .event(Events.EVENT1)
                .guard(guard1())
                .action(loggingAction())
                .and()
                .withExternal()
                .source(States.STATE1).target(States.STATE2)
                .event(Events.EVENT2)
                .guard(guard2())
                .action(loggingAction())
                .and()
                .withExternal()
                .source(States.STATE2).target(States.END_STATE)
                .timer(TimeUnit.SECONDS.toMillis(2L))
                .guard(guard3())
                .action(loggingAction());

なお、最初の2つは遷移のトリガーがイベントになっていますが、最後のひとつはタイマーとしています。どちらの種類のトリガーに対しても
ガードが効くことを確認します。

あと、ログ出力用にアクションもつけています。

    @Bean
    public Action<States, Events> loggingAction() {
        return stateContext ->
                System.out.printf(
                        "[%s] state action, stage = %s, state = %s, trigger type = %s, event = %s%n",
                        LocalDateTime.now(),
                        stateContext.getStage(),
                        stateContext.getTarget().getId(),
                        stateContext.getTransition().getTrigger() != null ? stateContext.getTransition().getTrigger().getClass().getSimpleName() : "[no trigger]",
                        stateContext.getMessage() != null ? stateContext.getMessage().getPayload() : "[none]"
                );
    }

動かしてみる

では、動かしてみましょう。

$ mvn spring-boot:run

結果。

2022-09-24 16:42:22.942  INFO 22500 --- [           main] org.littlewings.spring.statemachine.App  : Started App in 2.606 seconds (JVM running for 3.389)
[2022-09-24T16:42:22.959521198] guard1, stage = TRANSITION, state = STATE1, trigger type = EventTrigger, event = EVENT1, guard evaluated = false
[2022-09-24T16:42:23.479315697] guard1, stage = TRANSITION, state = STATE1, trigger type = EventTrigger, event = EVENT1, guard evaluated = true
[2022-09-24T16:42:23.483381085] state action, stage = TRANSITION, state = STATE1, trigger type = EventTrigger, event = EVENT1
[2022-09-24T16:42:23.990134943] guard2, stage = TRANSITION, state = STATE2, trigger type = EventTrigger, event = EVENT2, guard evaluated = false
[2022-09-24T16:42:24.492337074] guard2, stage = TRANSITION, state = STATE2, trigger type = EventTrigger, event = EVENT2, guard evaluated = true
[2022-09-24T16:42:24.492937108] state action, stage = TRANSITION, state = STATE2, trigger type = EventTrigger, event = EVENT2
[2022-09-24T16:42:24.670700080] guard3, stage = TRANSITION, state = END_STATE, trigger type = TimerTrigger, event = [none], guard evaluated = false
[2022-09-24T16:42:26.670818707] guard3, stage = TRANSITION, state = END_STATE, trigger type = TimerTrigger, event = [none], guard evaluated = true
[2022-09-24T16:42:26.671592307] state action, stage = TRANSITION, state = END_STATE, trigger type = TimerTrigger, event = [none]

イベント、タイマーいずれのトリガーも1回目の呼び出し時はガードにfalseと判定され遷移が進まず(アクションの呼び出されない)、
2回目の呼び出し時にtrueと判定され遷移が進んでいます。

イベント用のトリガーは500ミリ秒ずつ間隔を空けて実行するようにして

    @Override
    public void run(ApplicationArguments args) throws Exception {
        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT1).build()))
                .blockFirst();

        TimeUnit.MILLISECONDS.sleep(500L);

        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT1).build()))
                .blockFirst();

        TimeUnit.MILLISECONDS.sleep(500L);

        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT2).build()))
                .blockFirst();

        TimeUnit.MILLISECONDS.sleep(500L);

        stateMachine
                .sendEvent(Mono.just(MessageBuilder.withPayload(Events.EVENT2).build()))
                .blockFirst();

        TimeUnit.SECONDS.sleep(5L);

        System.out.printf("StateMachine complete? %b%n", stateMachine.isComplete());
    }

タイマー用のトリガーは2秒おきに実行するようにしているので、これらをログの出力時刻と突き合わせるとトリガーが動いたタイミングで
評価されていることがわかります。

                .withExternal()
                .source(States.STATE2).target(States.END_STATE)
                .timer(TimeUnit.SECONDS.toMillis(2L))
                .guard(guard3())
                .action(loggingAction());

簡単ですが、確認はこんな感じでしょう。

上手くいかなかったこと

実は、もうちょっと凝った感じでガードを試してみようと思っていたのですが、こちらはうまくいかなかったので。

最初はイベントトリガー+イベントトリガーの組み合わせで、

  • 最初はイベントでトリガーを起動させガードはfalseを返す
  • 時間が経過したらタイマーでトリガーを起動させ、ガードはtrueを返して遷移を進める

といったシナリオを考えていたのですが、こちらはうまくいかず。

        transitions
                .withExternal()
                .source(States.INITIAL_STATE).target(States.STATE1)
                .event(Events.EVENT1)
                .timer(TimeUnit.SECONDS.toMillis(2L))
                .guard(guard1())
                .action(loggingAction())

最初にガードが反応したトリガーのみしか評価されず、その後別の種類のトリガーが起動してもガードが評価されなかったからです。
どうも、最初に起動したトリガーにガードが紐づいている感じがします。ここは深追いしていませんけど。

よって、今回のサンプルも同じ種類のトリガーを繰り返し実行するものになりました。

あと、そもそもガードの評価結果をtruefalseに切り替えているサンプルがドキュメント上になく(いつもtruefalseを固定で返している)、
どういう使い方がよいのだろうと思ってテストコードを見てみたのですが。

こちらはCountDownLatchを使っていて、今回のサンプルと似たような感じになっていました。

https://github.com/spring-projects/spring-statemachine/blob/v3.2.0/spring-statemachine-core/src/test/java/org/springframework/statemachine/guard/GuardTests.java

   public static class TestGuard implements Guard<TestStates, TestEvents> {

        public CountDownLatch onEvaluateLatch = new CountDownLatch(1);
        boolean evaluationResult = true;

        public TestGuard() {
        }

        public TestGuard(boolean evaluationResult) {
            this.evaluationResult = evaluationResult;
        }

        @Override
        public boolean evaluate(StateContext<TestStates, TestEvents> context) {
            onEvaluateLatch.countDown();
            return evaluationResult;
        }
    }

https://github.com/spring-projects/spring-statemachine/blob/v3.2.0/spring-statemachine-core/src/test/java/org/springframework/statemachine/AbstractStateMachineTests.java#L178-L195

まとめ

Spring Statemachineのガードを試してみました。

ガードを使う自体は簡単だったのですが、複数のトリガーに紐づけて場合分けするみたいな考え方で使うのはちょっと違うみたいですね。
このあたりは、チョイスやジャンクションを使ったりするのかなと思うのですが、どうなんでしょう。

また見ていきましょうか。