Skip to content

Commit

Permalink
Merge pull request #44 from alan-turing-institute/pokefactory
Browse files Browse the repository at this point in the history
Pokefactory - generates Pokemon
  • Loading branch information
phinate authored Jun 23, 2023
2 parents 811778e + e442412 commit fd87a31
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/p2lab/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import numpy as np

from p2lab.genetic.genetic import genetic_algorithm
from p2lab.pokemon.poke_factory import PokeFactory
from p2lab.pokemon.premade import gen_1_pokemon
from p2lab.pokemon.teams import generate_teams, import_pool


async def main_loop(num_teams, team_size, num_generations, unique):
# generate the pool
PokeFactory()
pool = import_pool(gen_1_pokemon())
seed_teams = generate_teams(pool, num_teams, team_size, unique=unique)
# crossover_fn = build_crossover_fn(locus_swap, locus=0)
Expand Down
77 changes: 77 additions & 0 deletions src/p2lab/pokemon/poke_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
This class is used to generate Pokemon teams using the pokedex provided from
Pokemon Showdown (via poke-env)
This is directly inspired by poke-env's diagostic_tools folder
"""
from __future__ import annotations

import numpy as np
from poke_env.data import GenData as POKEDEX
from poke_env.teambuilder import TeambuilderPokemon


class PokeFactory:
def __init__(self, gen=1, drop_forms=True):
self.gen = (
gen if gen > 4 else 3
) # because this seems to be the minimum gen for the pokedex
self.dex = POKEDEX(self.gen)
if drop_forms:
self.dex2mon = {
int(self.dex.pokedex[m]["num"]): m
for m in self.dex.pokedex
if "forme" not in self.dex.pokedex[m].keys()
}
else:
self.dex2mon = {
int(self.dex.pokedex[m]["num"]): m for m in self.dex.pokedex
}

def get_pokemon_by_dexnum(self, dexnum):
return self.dex.pokedex[self.dex2mon[dexnum]]

def get_allowed_moves(self, dexnum, level=100):
pot_moves = self.dex.learnset[self.dex2mon[dexnum]]["learnset"]
allowed_moves = []
for move, lims in pot_moves.items():
gens = [int(lim[0]) for lim in lims]
# TODO: write logic here to check if level is allowed
if level != 100:
msg = "Level checking not implemented yet"
raise NotImplementedError(msg)
# lvls = [int(l[1:]) for l in lims if l[1:].isdigit()]
if self.gen in gens:
allowed_moves.append(move)
return allowed_moves

def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs):
"""
kwargs are passed to the TeambuilderPokemon constructor and can include:
- nickname
- item
- ability
- moves
- nature
- evs
- ivs
- level
- happiness
- hiddenpowertype
- gmax
"""
if dexnum < 1:
msg = "Dex number must be greater than 0"
raise ValueError(msg)
if dexnum is None:
dexnum = np.random.choice(list(self.dex2mon.keys()))
if generate_moveset or "moves" not in kwargs.keys():
poss_moves = self.get_allowed_moves(dexnum)
moves = (
np.random.choice(poss_moves, 4, replace=False)
if len(poss_moves) > 3
else poss_moves
)
elif "moves" in kwargs:
return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs)
return TeambuilderPokemon(species=self.dex2mon[dexnum], moves=moves, **kwargs)
84 changes: 84 additions & 0 deletions tests/test_pokedex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import numpy as np

from p2lab.pokemon import poke_factory


def test_pokedex():
p = poke_factory.PokeFactory()
assert p is not None


def test_eevee_fetch():
p = poke_factory.PokeFactory()
eevee = p.get_pokemon_by_dexnum(133)
assert eevee["baseSpecies"].lower() == "eevee"


def test_bulbasaur_fetch():
p = poke_factory.PokeFactory()
bulb = p.get_pokemon_by_dexnum(1)
assert bulb["baseSpecies"].lower() == "bulbasaur"


def test_eevee_moves():
p = poke_factory.PokeFactory()
eevee_moves = p.get_allowed_moves(133)
assert len(eevee_moves) > 0


def test_bulbasaur_moves():
p = poke_factory.PokeFactory()
bulb_moves = p.get_allowed_moves(1)
assert len(bulb_moves) > 0


def test_eevee_is_created():
p = poke_factory.PokeFactory()
eevee = p.make_pokemon(133)
assert eevee is not None


def test_eevee_is_created_with_moves():
p = poke_factory.PokeFactory()
eevee = p.make_pokemon(133, moves=["tackle", "growl"])
assert eevee is not None


def test_random_pokemon_is_created_with_moves():
p = poke_factory.PokeFactory()
dexnum = np.random.randint(1, 151)
while dexnum == 132:
dexnum = np.random.randint(1, 151)
poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True)
assert len(poke.moves) == 4


def test_ditto_is_created_with_moves():
p = poke_factory.PokeFactory()
ditto = p.make_pokemon(132)
assert len(ditto.moves) == 1


def test_all_gen1_pokemon_can_be_created():
p = poke_factory.PokeFactory()
for dexnum in range(1, 152):
poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True)
assert len(poke.moves) > 0


def test_invalid_dex_raised():
p = poke_factory.PokeFactory()
try:
p.make_pokemon(dexnum=0)
except ValueError:
assert True
else:
raise AssertionError()


def test_adding_item_to_pokemon():
p = poke_factory.PokeFactory()
ditto = p.make_pokemon(132, item="choice scarf")
assert ditto.item == "choice scarf"

0 comments on commit fd87a31

Please sign in to comment.