Skip to content

Commit

Permalink
some minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hesic73 committed Jan 13, 2024
1 parent 87ddc4e commit 8462ad2
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Empirically, Independent RL is enough (and in fact much better than PSRO).** As mentioned in [[1]](#refer-anchor-1), due to Gomoku's asymmetry, it's hard to train a network to play both black and white.

![](/images/screenshot_0.gif)
![](/assets//images/screenshot_0.gif)

## Introduction

Expand Down
File renamed without changes
10 changes: 10 additions & 0 deletions examples/demo_jit_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
from gomoku_rl import GomokuEnv

m: torch.jit.ScriptModule = torch.jit.load("tsmodule.pt")
env = GomokuEnv(num_envs=16, board_size=15)
td = env.reset()[[0]]
with torch.no_grad():
td = m(td["observation"], td["action_mask"])

print(td)
22 changes: 0 additions & 22 deletions test_jit.py → examples/demo_jit_save.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
from gomoku_rl.utils.eval import get_payoff_matrix
from gomoku_rl.utils.visual import annotate_heatmap, heatmap
import hydra
from omegaconf import DictConfig, OmegaConf
import os
from gomoku_rl import CONFIG_PATH
import torch

import matplotlib.pyplot as plt

from gomoku_rl.env import GomokuEnv
from gomoku_rl.utils.policy import uniform_policy
from gomoku_rl.utils.module import ActorNet
from gomoku_rl.policy import get_pretrained_policy
import numpy as np
import functools
import copy


@hydra.main(version_base=None, config_path=CONFIG_PATH, config_name="eval")
Expand Down Expand Up @@ -43,22 +34,9 @@ def main(cfg: DictConfig):

checkpoint = cfg.checkpoints[0]
player = make_player(checkpoint_path=checkpoint)
# encoder = player.actor.module[0].module
# policy_head = player.actor.module[1].module

# actor = ActorNet(
# encoder,
# out_features=env.action_spec.space.n,
# num_channels=cfg.algo.num_channels,
# )
# actor.policy_head = policy_head
# actor.eval().cpu()
actor=player.actor.eval().cpu()

tensordict = env.reset()[[0]].cpu()
# s = torch.jit.script(
# actor, example_inputs=[tensordict["observation"], tensordict["action_mask"]]
# )
s = torch.jit.trace(actor, (tensordict["observation"], tensordict["action_mask"]))
print(s)
torch.jit.save(s, "tsmodule.pt")
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@ dependencies = [
]
requires-python = ">=3.10"

classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
]

[project.optional-dependencies]
gui = ["PyQt5"]
test = ["pytest"]

[project.urls]
Repository = "https://github.com/hesic73/gomoku_rl.git"
11 changes: 0 additions & 11 deletions test_jit_load.py

This file was deleted.

File renamed without changes.

0 comments on commit 8462ad2

Please sign in to comment.