Skip to content

Commit

Permalink
revert benchmark code to include all agents/envs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Mar 6, 2024
1 parent 67d6904 commit 83ce920
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion all/environments/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ class PybulletEnvironment(GymEnvironment):
short_names = {
"ant": "AntBulletEnv-v0",
"cheetah": "HalfCheetahBulletEnv-v0",
"humanoid": "HumanoidBulletEnv-v0",
"hopper": "HopperBulletEnv-v0",
"humanoid": "HumanoidBulletEnv-v0",
"walker": "Walker2DBulletEnv-v0",
}

Expand Down
12 changes: 5 additions & 7 deletions benchmarks/mujoco_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ def main():

agents = [ddpg, ppo, sac]

agents = [sac]

envs = [
MujocoEnvironment(env, device="cuda")
for env in [
# "Ant-v4",
"Ant-v4",
"HalfCheetah-v4",
# "Hopper-v4",
# "Humanoid-v4",
# "Walker2d-v4",
"Hopper-v4",
"Humanoid-v4",
"Walker2d-v4",
]
]

Expand All @@ -27,7 +25,7 @@ def main():
frames,
logdir="benchmarks/mujoco_v4",
sbatch_args={
"partition": "gypsum-2080ti",
"partition": "gpu-long",
},
)

Expand Down
12 changes: 10 additions & 2 deletions benchmarks/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,23 @@ def main():

envs = [
PybulletEnvironment(env, device="cuda")
for env in PybulletEnvironment.short_names
for env in [
"AntBulletEnv-v0",
"HalfCheetahBulletEnv-v0",
"HopperBulletEnv-v0",
"HumanoidBulletEnv-v0",
"Walker2DBulletEnv-v0",
]
]

SlurmExperiment(
agents,
envs,
frames,
logdir="benchmarks/pybullet",
sbatch_args={"partition": "gpu-long"},
sbatch_args={
"partition": "gpu-long",
},
)


Expand Down

0 comments on commit 83ce920

Please sign in to comment.