Skip to content

Commit

Permalink
add test_env.py;fix a bug in rollout_fixed_opponent
Browse files Browse the repository at this point in the history
  • Loading branch information
hesic73 committed Jan 13, 2024
1 parent 8462ad2 commit 079d599
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 34 deletions.
57 changes: 37 additions & 20 deletions gomoku_rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def make_transition(
) -> TensorDict:
# if a player wins at time t, its opponent cannot win immediately after reset
reward: torch.Tensor = (
tensordict_t.get("win").float() - tensordict_t_plus_1.get("win").float()
tensordict_t.get("win").float() -
tensordict_t_plus_1.get("win").float()
).unsqueeze(-1)
transition: TensorDict = tensordict_t_minus_1.select(
"observation",
Expand Down Expand Up @@ -57,7 +58,8 @@ def __init__(
board_size: int,
device=None,
):
self.gomoku = Gomoku(num_envs=num_envs, board_size=board_size, device=device)
self.gomoku = Gomoku(
num_envs=num_envs, board_size=board_size, device=device)

self.observation_spec = CompositeSpec(
{
Expand Down Expand Up @@ -202,7 +204,8 @@ def _round(
return_white_transitions: bool = True,
is_last: bool = False,
) -> tuple[TensorDict | None, TensorDict | None, TensorDict, TensorDict]:
tensordict_t_plus_1 = self._step_and_maybe_reset(tensordict=tensordict_t)
tensordict_t_plus_1 = self._step_and_maybe_reset(
tensordict=tensordict_t)

with set_interaction_type(type=InteractionType.RANDOM):
tensordict_t_plus_1 = player_white(tensordict_t_plus_1)
Expand Down Expand Up @@ -249,7 +252,8 @@ def _round(
)
transition_black.set(
"invalid",
torch.zeros(self.num_envs, device=self.device, dtype=torch.bool),
torch.zeros(self.num_envs, device=self.device,
dtype=torch.bool),
)
else:
transition_black = None
Expand Down Expand Up @@ -466,22 +470,30 @@ def rollout_fixed_opponent(
out_device=None,
augment: bool = False,
) -> tuple[TensorDict, dict[str, float]]:
if not hasattr(self, "_t") and not hasattr(self, "_t_minus_1"):
self._t = None
self._t_minus_1 = None

info: defaultdict[str, float] = defaultdict(float)

tensordicts: list[TensorDict] = []

start = time.perf_counter()
info_buffer = defaultdict(float)
self._post_step = get_log_func(info_buffer)

_tds, self._t_minus_1, self._t = self._rollout_fixed_opponent(
rounds=rounds,
player_black=player,
player_white=opponent,
return_black_transitions=True,
out_device=out_device,
augment=augment,
tensordict_t_minus_1=self._t_minus_1,
tensordict_t=self._t,
)
tensordicts.extend(
self._rollout_fixed_opponent(
rounds=rounds,
player_black=player,
player_white=opponent,
return_black_transitions=True,
out_device=out_device,
augment=augment,
)
_tds
)

info.update(
Expand All @@ -493,16 +505,21 @@ def rollout_fixed_opponent(
info_buffer.clear()

self._post_step = get_log_func(info_buffer)
_tds, self._t_minus_1, self._t = self._rollout_fixed_opponent(
rounds=rounds,
player_black=opponent,
player_white=player,
return_black_transitions=False,
out_device=out_device,
augment=augment,
tensordict_t_minus_1=self._t_minus_1,
tensordict_t=self._t,
)

tensordicts.extend(
self._rollout_fixed_opponent(
rounds=rounds,
player_black=opponent,
player_white=player,
return_black_transitions=False,
out_device=out_device,
augment=augment,
)
_tds
)

end = time.perf_counter()
self._fps = (2 * rounds * 2 * self.num_envs) / (end - start)
info.update(
Expand Down
30 changes: 16 additions & 14 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def assert_layer_transition(
same = (layer > EPS) & (next_layer > EPS)
assert_tensor_1d_all((same == (layer > EPS)).all(-1).all(-1) | done)
assert_tensor_1d_all(
(layer.long().sum(-1).sum(-1) + 1 == next_layer.long().sum(-1).sum(-1)) | done
(layer.long().sum(-1).sum(-1) + 1 ==
next_layer.long().sum(-1).sum(-1)) | done
)


Expand All @@ -70,25 +71,16 @@ def assert_transition(tensordict: TensorDict, type: Type):
layer1 = observation[:, 0]
x = action // board_size
y = action % board_size
assert_tensor_1d_all((layer1[torch.arange(num_envs, device=device), x, y] < EPS))
assert_tensor_1d_all(
(layer1[torch.arange(num_envs, device=device), x, y] < EPS))
layer1 = layer1.clone()
layer1[torch.arange(num_envs, device=device), x, y] = 1.0
assert_tensor_1d_all(
torch.isclose(layer1, next_observation[:, 0]).all(-1).all(-1) | done
)


def _debug_print(transition: TensorDict):
observation = transition["observation"]
next_observation = transition["next", "observation"]
done = transition["next", "done"]
print(observation)
print(next_observation)
print(done.item())
exit()


def main():
def test_rollout():
device = "cuda:0"
num_envs = 256
board_size = 10
Expand All @@ -115,6 +107,15 @@ def main():
):
assert_transition(transition, type=Type.black)


def test_rollout_fixed_opponent():
device = "cuda:0"
num_envs = 256
board_size = 10
seed = 1234
set_seed(seed)
env = GomokuEnv(num_envs=num_envs, board_size=board_size, device=device)

transitions, info = env.rollout_fixed_opponent(
50,
player=uniform_policy,
Expand All @@ -128,4 +129,5 @@ def main():


if __name__ == "__main__":
main()
# test_rollout()
test_rollout_fixed_opponent()

0 comments on commit 079d599

Please sign in to comment.