Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds tests for battles #50

Merged
merged 10 commits into from
Jun 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ jobs:

- name: Upload coverage report
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}

dist:
needs: [pre-commit]
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ pokemon-showdown/

# Outputs
outputs/*
simple_match.py
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ pip install -e .

- Have node installed and install showdown
- See: https://nodejs.dev/en/learn/how-to-install-nodejs/

```bash
npm install pokemon-showdown # -g flag might be needed on some systems
```

## Running

To run locally start the pokemon showdown server:
Expand Down
18 changes: 10 additions & 8 deletions src/p2lab/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
import numpy as np

from p2lab.genetic.genetic import genetic_algorithm
from p2lab.pokemon.poke_factory import PokeFactory
from p2lab.pokemon.premade import gen_1_pokemon
from p2lab.pokemon.teams import generate_teams, import_pool


async def main_loop(num_teams, team_size, num_generations, unique):
# generate the pool
PokeFactory()
pool = import_pool(gen_1_pokemon())
seed_teams = generate_teams(pool, num_teams, team_size, unique=unique)
# crossover_fn = build_crossover_fn(locus_swap, locus=0)
Expand Down Expand Up @@ -56,10 +54,10 @@ def parse_args():
default=10,
)
parser.add_argument(
"--teamsize", help="Number of pokemon per team (max 6)", type=int, default=2
"--team-size", help="Number of pokemon per team (max 6)", type=int, default=2
)
parser.add_argument(
"--numteams",
"--teams",
help="Number of teams i.e., individuals per generation",
type=int,
default=30,
Expand All @@ -75,12 +73,16 @@ def parse_args():

def main():
args = parse_args()

if args["s"] is not None:
np.random.seed(args["s"])
if args["seed"] is not None:
np.random.seed(args["seed"])

asyncio.get_event_loop().run_until_complete(
main_loop(args["n"], args["t"], args["g"], args["u"])
main_loop(
num_teams=args["teams"],
team_size=args["team_size"],
num_generations=args["generations"],
unique=args["unique"],
)
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def get_allowed_moves(self, dexnum, level=100):
allowed_moves.append(move)
return allowed_moves

def get_allowed_abilities(self, dexnum):
return self.dex.pokedex[self.dex2mon[dexnum]]["abilities"]

def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs):
"""
kwargs are passed to the TeambuilderPokemon constructor and can include:
Expand Down Expand Up @@ -72,6 +75,19 @@ def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs):
if len(poss_moves) > 3
else poss_moves
)
elif "moves" in kwargs:
return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs)
return TeambuilderPokemon(species=self.dex2mon[dexnum], moves=moves, **kwargs)
kwargs["moves"] = moves
if "ivs" not in kwargs.keys():
ivs = [31] * 6
kwargs["ivs"] = ivs
if "evs" not in kwargs.keys():
# TODO: implement EV generation better
evs = [510 // 6] * 6
kwargs["evs"] = evs
if "level" not in kwargs.keys():
kwargs["level"] = 100
if "ability" not in kwargs.keys():
kwargs["ability"] = np.random.choice(
list(self.get_allowed_abilities(dexnum).values())
)

return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs)
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

import pytest

from p2lab.pokemon import pokefactory


@pytest.fixture()
def default_factory():
return pokefactory.PokeFactory()
101 changes: 101 additions & 0 deletions tests/test_battles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import inspect

import pytest
from poke_env import PlayerConfiguration
from poke_env.player import SimpleHeuristicsPlayer

from p2lab.pokemon.battles import run_battles
from p2lab.pokemon.teams import Team

pytest_plugins = ("pytest_asyncio",)


def test_battle_eevee_pikachu_pokes(event_loop, default_factory):
eevee = default_factory.make_pokemon(133, moves=["tackle", "growl"], level=5)
pikachu = default_factory.make_pokemon(25, moves=["thundershock", "growl"], level=5)
team1 = Team([eevee])
team2 = Team([pikachu])
teams = [team1, team2]
matches = [[0, 1]]

player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P1", None),
battle_format="gen7anythinggoes",
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P2", None),
battle_format="gen7anythinggoes",
)
res = event_loop.run_until_complete(
run_battles(matches, teams, player_1, player_2, battles_per_match=1)
)

assert res is not None


def test_battle_mewtwo_obliterates_eevee(event_loop, default_factory):
eevee = default_factory.make_pokemon(133, moves=["tackle", "growl"], level=5)
mewtwo = default_factory.make_pokemon(150, moves=["psychic"], level=100)
team1 = Team([eevee])
team2 = Team([mewtwo])
teams = [team1, team2]
matches = [[0, 1]]
player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P1", None),
battle_format="gen7anythinggoes",
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(inspect.stack()[0][3].split("test_")[1][:15] + " P2", None),
battle_format="gen7anythinggoes",
)
res = event_loop.run_until_complete(
run_battles(matches, teams, player_1, player_2, battles_per_match=1)
)
mewtwo_wins = res[0][1]
eevee_wins = res[0][0]
assert mewtwo_wins > eevee_wins


@pytest.mark.parametrize(
"battle_format",
[
("gen4anythinggoes"),
("gen6anythinggoes"),
("gen7anythinggoes"),
("gen8anythinggoes"),
],
)
def test_battle_eevee_pikachu_formats(event_loop, battle_format, default_factory):
eevee = default_factory.make_pokemon(133, moves=["tackle", "growl"], level=5)
pikachu = default_factory.make_pokemon(25, moves=["thundershock", "growl"], level=5)
team1 = Team([eevee])
team2 = Team([pikachu])
teams = [team1, team2]
matches = [[0, 1]]

player_1 = SimpleHeuristicsPlayer(
PlayerConfiguration(
inspect.stack()[0][3].split("test_")[1][:10]
+ "-"
+ battle_format[:4]
+ " P1",
None,
),
battle_format=battle_format,
)
player_2 = SimpleHeuristicsPlayer(
PlayerConfiguration(
inspect.stack()[0][3].split("test_")[1][:10]
+ "-"
+ battle_format[:4]
+ " P2",
None,
),
battle_format=battle_format,
)
res = event_loop.run_until_complete(
run_battles(matches, teams, player_1, player_2, battles_per_match=1)
)
assert res is not None
26 changes: 13 additions & 13 deletions tests/test_pokedex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,52 @@

import numpy as np

from p2lab.pokemon import poke_factory
from p2lab.pokemon import pokefactory


def test_pokedex():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
assert p is not None


def test_eevee_fetch():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee = p.get_pokemon_by_dexnum(133)
assert eevee["baseSpecies"].lower() == "eevee"


def test_bulbasaur_fetch():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
bulb = p.get_pokemon_by_dexnum(1)
assert bulb["baseSpecies"].lower() == "bulbasaur"


def test_eevee_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee_moves = p.get_allowed_moves(133)
assert len(eevee_moves) > 0


def test_bulbasaur_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
bulb_moves = p.get_allowed_moves(1)
assert len(bulb_moves) > 0


def test_eevee_is_created():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133)
assert eevee is not None


def test_eevee_is_created_with_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
eevee = p.make_pokemon(133, moves=["tackle", "growl"])
assert eevee is not None


def test_random_pokemon_is_created_with_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
dexnum = np.random.randint(1, 151)
while dexnum == 132:
dexnum = np.random.randint(1, 151)
Expand All @@ -56,20 +56,20 @@ def test_random_pokemon_is_created_with_moves():


def test_ditto_is_created_with_moves():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
ditto = p.make_pokemon(132)
assert len(ditto.moves) == 1


def test_all_gen1_pokemon_can_be_created():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
for dexnum in range(1, 152):
poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True)
assert len(poke.moves) > 0


def test_invalid_dex_raised():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
try:
p.make_pokemon(dexnum=0)
except ValueError:
Expand All @@ -79,6 +79,6 @@ def test_invalid_dex_raised():


def test_adding_item_to_pokemon():
p = poke_factory.PokeFactory()
p = pokefactory.PokeFactory()
ditto = p.make_pokemon(132, item="choice scarf")
assert ditto.item == "choice scarf"