From c55c72e98b82100d9c2a86f7d5bca6f0f3366543 Mon Sep 17 00:00:00 2001 From: Sotetsu KOYAMADA Date: Mon, 8 Jan 2024 19:36:03 +0900 Subject: [PATCH] extract game state --- pgx/_src/dwg/connect_four.py | 2 +- pgx/connect_four.py | 44 ++++++++++++++++++++---------------- tests/test_connect_four.py | 2 +- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/pgx/_src/dwg/connect_four.py b/pgx/_src/dwg/connect_four.py index 86af72cde..40f293b58 100644 --- a/pgx/_src/dwg/connect_four.py +++ b/pgx/_src/dwg/connect_four.py @@ -77,7 +77,7 @@ def _make_connect_four_dwg(dwg, state: ConnectFourState, config): ) # stones - board = state._board + board = state._x._board for xy, stone in enumerate(board): if stone == -1: continue diff --git a/pgx/connect_four.py b/pgx/connect_four.py index 498845edf..244cb4eb5 100644 --- a/pgx/connect_four.py +++ b/pgx/connect_four.py @@ -24,15 +24,7 @@ @dataclass -class State(core.State): - current_player: Array = jnp.int32(0) - observation: Array = jnp.zeros((6, 7, 2), dtype=jnp.bool_) - rewards: Array = jnp.float32([0.0, 0.0]) - terminated: Array = FALSE - truncated: Array = FALSE - legal_action_mask: Array = jnp.ones(7, dtype=jnp.bool_) - _step_count: Array = jnp.int32(0) - # --- Connect Four specific --- +class GameState: _turn: Array = jnp.int32(0) # 6x7 board # [[ 0, 1, 2, 3, 4, 5, 6], @@ -44,6 +36,18 @@ class State(core.State): _board: Array = -jnp.ones(42, jnp.int32) # -1 (empty), 0, 1 _blank_row: Array = jnp.full(7, 5) + +@dataclass +class State(core.State): + current_player: Array = jnp.int32(0) + observation: Array = jnp.zeros((6, 7, 2), dtype=jnp.bool_) + rewards: Array = jnp.float32([0.0, 0.0]) + terminated: Array = FALSE + truncated: Array = FALSE + legal_action_mask: Array = jnp.ones(7, dtype=jnp.bool_) + _step_count: Array = jnp.int32(0) + _x: GameState = GameState() + @property def env_id(self) -> core.EnvId: return "connect_four" @@ -112,11 +116,11 @@ def _init(rng: PRNGKey) -> State: def _step(state: State, action: Array) -> State: - board = state._board - row = state._blank_row[action] - blank_row = state._blank_row.at[action].set(row - 1) - board = board.at[_to_idx(row, action)].set(state._turn) - won = _win_check(board, state._turn) + board = state._x._board + row = state._x._blank_row[action] + blank_row = state._x._blank_row.at[action].set(row - 1) + board = board.at[_to_idx(row, action)].set(state._x._turn) + won = _win_check(board, state._x._turn) reward = jax.lax.cond( won, lambda: jnp.float32([-1, -1]).at[state.current_player].set(1), @@ -125,11 +129,13 @@ def _step(state: State, action: Array) -> State: return state.replace( # type: ignore current_player=1 - state.current_player, legal_action_mask=blank_row >= 0, - _turn=1 - state._turn, - _board=board, - _blank_row=blank_row, terminated=won | jnp.all(blank_row == -1), rewards=reward, + _x=state._x.replace( # type: ignore + _turn=1 - state._x._turn, + _board=board, + _blank_row=blank_row, + ), ) @@ -142,7 +148,7 @@ def _win_check(board, turn) -> Array: def _observe(state: State, player_id: Array) -> Array: - turns = jnp.int32([state._turn, 1 - state._turn]) + turns = jnp.int32([state._x._turn, 1 - state._x._turn]) turns = jax.lax.cond( player_id == state.current_player, lambda: turns, @@ -150,6 +156,6 @@ def _observe(state: State, player_id: Array) -> Array: ) def make(turn): - return state._board.reshape(6, 7) == turn + return state._x._board.reshape(6, 7) == turn return jnp.stack(jax.vmap(make)(turns), -1) diff --git a/tests/test_connect_four.py b/tests/test_connect_four.py index 642d81e12..23075340b 100644 --- a/tests/test_connect_four.py +++ b/tests/test_connect_four.py @@ -31,7 +31,7 @@ def test_step(): @@..... """ # fmt: off - assert (state._board == jnp.array( + assert (state._x._board == jnp.array( [1, 1, -1, -1, -1, -1, -1, 0, 0, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1,