From 8462ad276ccd1762eda360fe0049420e86cad3f0 Mon Sep 17 00:00:00 2001 From: hesic73 Date: Sun, 14 Jan 2024 03:20:30 +0800 Subject: [PATCH] some minor changes --- README.md | 2 +- {images => assets/images}/screenshot_0.gif | Bin examples/demo_jit_load.py | 10 ++++++++++ test_jit.py => examples/demo_jit_save.py | 22 --------------------- pyproject.toml | 6 ++++++ test_jit_load.py | 11 ----------- {examples => tests}/test_env.py | 0 7 files changed, 17 insertions(+), 34 deletions(-) rename {images => assets/images}/screenshot_0.gif (100%) create mode 100644 examples/demo_jit_load.py rename test_jit.py => examples/demo_jit_save.py (63%) delete mode 100644 test_jit_load.py rename {examples => tests}/test_env.py (100%) diff --git a/README.md b/README.md index 82a2683..0fd3c26 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ **Empirically, Independent RL is enough (and in fact much better than PSRO).** As mentioned in [[1]](#refer-anchor-1), due to Gomoku's asymmetry, it's hard to train a network to play both black and white. -![](/images/screenshot_0.gif) +![](/assets//images/screenshot_0.gif) ## Introduction diff --git a/images/screenshot_0.gif b/assets/images/screenshot_0.gif similarity index 100% rename from images/screenshot_0.gif rename to assets/images/screenshot_0.gif diff --git a/examples/demo_jit_load.py b/examples/demo_jit_load.py new file mode 100644 index 0000000..386256c --- /dev/null +++ b/examples/demo_jit_load.py @@ -0,0 +1,10 @@ +import torch +from gomoku_rl import GomokuEnv + +m: torch.jit.ScriptModule = torch.jit.load("tsmodule.pt") +env = GomokuEnv(num_envs=16, board_size=15) +td = env.reset()[[0]] +with torch.no_grad(): + td = m(td["observation"], td["action_mask"]) + +print(td) diff --git a/test_jit.py b/examples/demo_jit_save.py similarity index 63% rename from test_jit.py rename to examples/demo_jit_save.py index 9eeb957..1424652 100644 --- a/test_jit.py +++ b/examples/demo_jit_save.py @@ -1,20 +1,11 @@ -from gomoku_rl.utils.eval import get_payoff_matrix -from gomoku_rl.utils.visual import annotate_heatmap, heatmap import hydra from omegaconf import DictConfig, OmegaConf -import os from gomoku_rl import CONFIG_PATH import torch - -import matplotlib.pyplot as plt - from gomoku_rl.env import GomokuEnv -from gomoku_rl.utils.policy import uniform_policy -from gomoku_rl.utils.module import ActorNet from gomoku_rl.policy import get_pretrained_policy import numpy as np import functools -import copy @hydra.main(version_base=None, config_path=CONFIG_PATH, config_name="eval") @@ -43,22 +34,9 @@ def main(cfg: DictConfig): checkpoint = cfg.checkpoints[0] player = make_player(checkpoint_path=checkpoint) - # encoder = player.actor.module[0].module - # policy_head = player.actor.module[1].module - - # actor = ActorNet( - # encoder, - # out_features=env.action_spec.space.n, - # num_channels=cfg.algo.num_channels, - # ) - # actor.policy_head = policy_head - # actor.eval().cpu() actor=player.actor.eval().cpu() tensordict = env.reset()[[0]].cpu() - # s = torch.jit.script( - # actor, example_inputs=[tensordict["observation"], tensordict["action_mask"]] - # ) s = torch.jit.trace(actor, (tensordict["observation"], tensordict["action_mask"])) print(s) torch.jit.save(s, "tsmodule.pt") diff --git a/pyproject.toml b/pyproject.toml index aa16a07..213e4a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,14 @@ dependencies = [ ] requires-python = ">=3.10" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", +] + [project.optional-dependencies] gui = ["PyQt5"] +test = ["pytest"] [project.urls] Repository = "https://github.com/hesic73/gomoku_rl.git" diff --git a/test_jit_load.py b/test_jit_load.py deleted file mode 100644 index bc60873..0000000 --- a/test_jit_load.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -from gomoku_rl import GomokuEnv - -m = torch.jit.load("tsmodule.pt") -print(isinstance(m,torch.jit.ScriptModule)) -env=GomokuEnv(num_envs=16,board_size=15) -td=env.reset()[[0]] -with torch.no_grad(): - td=m(td["observation"], td["action_mask"]) - -print(td) \ No newline at end of file diff --git a/examples/test_env.py b/tests/test_env.py similarity index 100% rename from examples/test_env.py rename to tests/test_env.py