diff --git a/gomoku_rl/env.py b/gomoku_rl/env.py index 0c5ce73..a79a5fa 100644 --- a/gomoku_rl/env.py +++ b/gomoku_rl/env.py @@ -27,7 +27,8 @@ def make_transition( ) -> TensorDict: # if a player wins at time t, its opponent cannot win immediately after reset reward: torch.Tensor = ( - tensordict_t.get("win").float() - tensordict_t_plus_1.get("win").float() + tensordict_t.get("win").float() - + tensordict_t_plus_1.get("win").float() ).unsqueeze(-1) transition: TensorDict = tensordict_t_minus_1.select( "observation", @@ -57,7 +58,8 @@ def __init__( board_size: int, device=None, ): - self.gomoku = Gomoku(num_envs=num_envs, board_size=board_size, device=device) + self.gomoku = Gomoku( + num_envs=num_envs, board_size=board_size, device=device) self.observation_spec = CompositeSpec( { @@ -202,7 +204,8 @@ def _round( return_white_transitions: bool = True, is_last: bool = False, ) -> tuple[TensorDict | None, TensorDict | None, TensorDict, TensorDict]: - tensordict_t_plus_1 = self._step_and_maybe_reset(tensordict=tensordict_t) + tensordict_t_plus_1 = self._step_and_maybe_reset( + tensordict=tensordict_t) with set_interaction_type(type=InteractionType.RANDOM): tensordict_t_plus_1 = player_white(tensordict_t_plus_1) @@ -249,7 +252,8 @@ def _round( ) transition_black.set( "invalid", - torch.zeros(self.num_envs, device=self.device, dtype=torch.bool), + torch.zeros(self.num_envs, device=self.device, + dtype=torch.bool), ) else: transition_black = None @@ -466,6 +470,10 @@ def rollout_fixed_opponent( out_device=None, augment: bool = False, ) -> tuple[TensorDict, dict[str, float]]: + if not hasattr(self, "_t") and not hasattr(self, "_t_minus_1"): + self._t = None + self._t_minus_1 = None + info: defaultdict[str, float] = defaultdict(float) tensordicts: list[TensorDict] = [] @@ -473,15 +481,19 @@ def rollout_fixed_opponent( start = time.perf_counter() info_buffer = defaultdict(float) self._post_step = get_log_func(info_buffer) + + _tds, self._t_minus_1, self._t = self._rollout_fixed_opponent( + rounds=rounds, + player_black=player, + player_white=opponent, + return_black_transitions=True, + out_device=out_device, + augment=augment, + tensordict_t_minus_1=self._t_minus_1, + tensordict_t=self._t, + ) tensordicts.extend( - self._rollout_fixed_opponent( - rounds=rounds, - player_black=player, - player_white=opponent, - return_black_transitions=True, - out_device=out_device, - augment=augment, - ) + _tds ) info.update( @@ -493,16 +505,21 @@ def rollout_fixed_opponent( info_buffer.clear() self._post_step = get_log_func(info_buffer) + _tds, self._t_minus_1, self._t = self._rollout_fixed_opponent( + rounds=rounds, + player_black=opponent, + player_white=player, + return_black_transitions=False, + out_device=out_device, + augment=augment, + tensordict_t_minus_1=self._t_minus_1, + tensordict_t=self._t, + ) + tensordicts.extend( - self._rollout_fixed_opponent( - rounds=rounds, - player_black=opponent, - player_white=player, - return_black_transitions=False, - out_device=out_device, - augment=augment, - ) + _tds ) + end = time.perf_counter() self._fps = (2 * rounds * 2 * self.num_envs) / (end - start) info.update( diff --git a/tests/test_env.py b/tests/test_env.py index 9d7918e..cc81006 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -47,7 +47,8 @@ def assert_layer_transition( same = (layer > EPS) & (next_layer > EPS) assert_tensor_1d_all((same == (layer > EPS)).all(-1).all(-1) | done) assert_tensor_1d_all( - (layer.long().sum(-1).sum(-1) + 1 == next_layer.long().sum(-1).sum(-1)) | done + (layer.long().sum(-1).sum(-1) + 1 == + next_layer.long().sum(-1).sum(-1)) | done ) @@ -70,7 +71,8 @@ def assert_transition(tensordict: TensorDict, type: Type): layer1 = observation[:, 0] x = action // board_size y = action % board_size - assert_tensor_1d_all((layer1[torch.arange(num_envs, device=device), x, y] < EPS)) + assert_tensor_1d_all( + (layer1[torch.arange(num_envs, device=device), x, y] < EPS)) layer1 = layer1.clone() layer1[torch.arange(num_envs, device=device), x, y] = 1.0 assert_tensor_1d_all( @@ -78,17 +80,7 @@ def assert_transition(tensordict: TensorDict, type: Type): ) -def _debug_print(transition: TensorDict): - observation = transition["observation"] - next_observation = transition["next", "observation"] - done = transition["next", "done"] - print(observation) - print(next_observation) - print(done.item()) - exit() - - -def main(): +def test_rollout(): device = "cuda:0" num_envs = 256 board_size = 10 @@ -115,6 +107,15 @@ def main(): ): assert_transition(transition, type=Type.black) + +def test_rollout_fixed_opponent(): + device = "cuda:0" + num_envs = 256 + board_size = 10 + seed = 1234 + set_seed(seed) + env = GomokuEnv(num_envs=num_envs, board_size=board_size, device=device) + transitions, info = env.rollout_fixed_opponent( 50, player=uniform_policy, @@ -128,4 +129,5 @@ def main(): if __name__ == "__main__": - main() + # test_rollout() + test_rollout_fixed_opponent()