Skip to content

Commit

Permalink
Merge pull request #50 from alan-turing-institute/testing/battles
Browse files Browse the repository at this point in the history
Adds tests for battles
  • Loading branch information
phinate authored Jun 24, 2023
2 parents fd87a31 + d5f8cc5 commit 4125104
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ jobs:
- name: Upload coverage report
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}

dist:
needs: [pre-commit]
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ pokemon-showdown/

# Outputs
outputs/*
simple_match.py
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ pip install -e .

- Have node installed and install showdown
- See: https://nodejs.dev/en/learn/how-to-install-nodejs/

```bash
npm install pokemon-showdown # -g flag might be needed on some systems
```

## Running

To run locally start the pokemon showdown server:
Expand Down
18 changes: 10 additions & 8 deletions src/p2lab/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import numpy as np

from p2lab.genetic.genetic import genetic_algorithm
from p2lab.pokemon.poke_factory import PokeFactory
from p2lab.pokemon.premade import gen_1_pokemon
from p2lab.pokemon.teams import generate_teams, import_pool


async def main_loop(num_teams, team_size, num_generations, unique):
# generate the pool
PokeFactory()
pool = import_pool(gen_1_pokemon())
seed_teams = generate_teams(pool, num_teams, team_size, unique=unique)
# crossover_fn = build_crossover_fn(locus_swap, locus=0)
Expand Down Expand Up @@ -56,10 +54,10 @@ def parse_args():
default=10,
)
parser.add_argument(
"--teamsize", help="Number of pokemon per team (max 6)", type=int, default=2
"--team-size", help="Number of pokemon per team (max 6)", type=int, default=2
)
parser.add_argument(
"--numteams",
"--teams",
help="Number of teams i.e., individuals per generation",
type=int,
default=30,
Expand All @@ -75,12 +73,16 @@ def parse_args():

def main():
args = parse_args()

if args["s"] is not None:
np.random.seed(args["s"])
if args["seed"] is not None:
np.random.seed(args["seed"])

asyncio.get_event_loop().run_until_complete(
main_loop(args["n"], args["t"], args["g"], args["u"])
main_loop(
num_teams=args["teams"],
team_size=args["team_size"],
num_generations=args["generations"],
unique=args["unique"],
)
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def get_allowed_moves(self, dexnum, level=100):
allowed_moves.append(move)
return allowed_moves

def get_allowed_abilities(self, dexnum):
return self.dex.pokedex[self.dex2mon[dexnum]]["abilities"]

def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs):
"""
kwargs are passed to the TeambuilderPokemon constructor and can include:
Expand Down Expand Up @@ -72,6 +75,19 @@ def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs):
if len(poss_moves) > 3
else poss_moves
)
elif "moves" in kwargs:
return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs)
return TeambuilderPokemon(species=self.dex2mon[dexnum], moves=moves, **kwargs)
kwargs["moves"] = moves
if "ivs" not in kwargs.keys():
ivs = [31] * 6
kwargs["ivs"] = ivs
if "evs" not in kwargs.keys():
# TODO: implement EV generation better
evs = [510 // 6] * 6
kwargs["evs"] = evs
if "level" not in kwargs.keys():
kwargs["level"] = 100
if "ability" not in kwargs.keys():
kwargs["ability"] = np.random.choice(
list(self.get_allowed_abilities(dexnum).values())
)

return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

import pytest

from p2lab.pokemon import pokefactory


@pytest.fixture()
def default_factory():
return pokefactory.PokeFactory()
101 changes: 101 additions & 0 deletions tests/test_battles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import inspect

import pytest
from poke_env import PlayerConfiguration
from poke_env.player import SimpleHeuristicsPlayer

from p2lab.pokemon.battles import run_battles
from p2lab.pokemon.teams import Team

pytest_plugins = ("pytest_asyncio",)


def test_battle_eevee_pikachu_pokes(event_loop, default_factory):
eevee = default_factory.make_pokemon(133, moves=["tackle", "growl"], level=5)
pikachu = default_factory.make_pokemon(25, moves=["thundershock", "growl"], level=5)
team1 = Team([eevee])
team2 = Team([pikachu])
teams = [team1, team2]
matches = [[0, 1]]

player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P1", None),
battle_format="gen7anythinggoes",
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P2", None),
battle_format="gen7anythinggoes",
)
res = event_loop.run_until_complete(
run_battles(matches, teams, player_1, player_2, battles_per_match=1)
)

assert res is not None


def test_battle_mewtwo_obliterates_eevee(event_loop, default_factory):
eevee = default_factory.make_pokemon(133, moves=["tackle", "growl"], level=5)
mewtwo = default_factory.make_pokemon(150, moves=["psychic"], level=100)
team1 = Team([eevee])
team2 = Team([mewtwo])
teams = [team1, team2]
matches = [[0, 1]]
player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P1", None),
battle_format="gen7anythinggoes",
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P2", None),
battle_format="gen7anythinggoes",
)
res = event_loop.run_until_complete(
run_battles(matches, teams, player_1, player_2, battles_per_match=1)
)
mewtwo_wins = res[0][1]
eevee_wins = res[0][0]
assert mewtwo_wins > eevee_wins


@pytest.mark.parametrize(
"battle_format",
[
("gen4anythinggoes"),
("gen6anythinggoes"),
("gen7anythinggoes"),
("gen8anythinggoes"),
],
)
def test_battle_eevee_pikachu_formats(event_loop, battle_format, default_factory):
eevee = default_factory.make_pokemon(133, moves=["tackle", "growl"], level=5)
pikachu = default_factory.make_pokemon(25, moves=["thundershock", "growl"], level=5)
team1 = Team([eevee])
team2 = Team([pikachu])
teams = [team1, team2]
matches = [[0, 1]]

player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(
inspect.stack()[0][3].split("test_")[1][:10]
+ "-"
+ battle_format[:4]
+ " P1",
None,
),
battle_format=battle_format,
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(
inspect.stack()[0][3].split("test_")[1][:10]
+ "-"
+ battle_format[:4]
+ " P2",
None,
),
battle_format=battle_format,
)
res = event_loop.run_until_complete(
run_battles(matches, teams, player_1, player_2, battles_per_match=1)
)
assert res is not None
26 changes: 13 additions & 13 deletions tests/test_pokedex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,52 @@

import numpy as np

from p2lab.pokemon import poke_factory
from p2lab.pokemon import pokefactory


def test_pokedex():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
assert p is not None


def test_eevee_fetch():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee = p.get_pokemon_by_dexnum(133)
assert eevee["baseSpecies"].lower() == "eevee"


def test_bulbasaur_fetch():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
bulb = p.get_pokemon_by_dexnum(1)
assert bulb["baseSpecies"].lower() == "bulbasaur"


def test_eevee_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee_moves = p.get_allowed_moves(133)
assert len(eevee_moves) > 0


def test_bulbasaur_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
bulb_moves = p.get_allowed_moves(1)
assert len(bulb_moves) > 0


def test_eevee_is_created():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133)
assert eevee is not None


def test_eevee_is_created_with_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133, moves=["tackle", "growl"])
assert eevee is not None


def test_random_pokemon_is_created_with_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
dexnum = np.random.randint(1, 151)
while dexnum == 132:
dexnum = np.random.randint(1, 151)
Expand All @@ -56,20 +56,20 @@ def test_random_pokemon_is_created_with_moves():


def test_ditto_is_created_with_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
ditto = p.make_pokemon(132)
assert len(ditto.moves) == 1


def test_all_gen1_pokemon_can_be_created():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
for dexnum in range(1, 152):
poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True)
assert len(poke.moves) > 0


def test_invalid_dex_raised():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
try:
p.make_pokemon(dexnum=0)
except ValueError:
Expand All @@ -79,6 +79,6 @@ def test_invalid_dex_raised():


def test_adding_item_to_pokemon():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
ditto = p.make_pokemon(132, item="choice scarf")
assert ditto.item == "choice scarf"

0 comments on commit 4125104

Please sign in to comment.