-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #44 from alan-turing-institute/pokefactory
Pokefactory - generates Pokemon
- Loading branch information
Showing
3 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
This class is used to generate Pokemon teams using the pokedex provided from | ||
Pokemon Showdown (via poke-env) | ||
This is directly inspired by poke-env's diagostic_tools folder | ||
""" | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
from poke_env.data import GenData as POKEDEX | ||
from poke_env.teambuilder import TeambuilderPokemon | ||
|
||
|
||
class PokeFactory: | ||
def __init__(self, gen=1, drop_forms=True): | ||
self.gen = ( | ||
gen if gen > 4 else 3 | ||
) # because this seems to be the minimum gen for the pokedex | ||
self.dex = POKEDEX(self.gen) | ||
if drop_forms: | ||
self.dex2mon = { | ||
int(self.dex.pokedex[m]["num"]): m | ||
for m in self.dex.pokedex | ||
if "forme" not in self.dex.pokedex[m].keys() | ||
} | ||
else: | ||
self.dex2mon = { | ||
int(self.dex.pokedex[m]["num"]): m for m in self.dex.pokedex | ||
} | ||
|
||
def get_pokemon_by_dexnum(self, dexnum): | ||
return self.dex.pokedex[self.dex2mon[dexnum]] | ||
|
||
def get_allowed_moves(self, dexnum, level=100): | ||
pot_moves = self.dex.learnset[self.dex2mon[dexnum]]["learnset"] | ||
allowed_moves = [] | ||
for move, lims in pot_moves.items(): | ||
gens = [int(lim[0]) for lim in lims] | ||
# TODO: write logic here to check if level is allowed | ||
if level != 100: | ||
msg = "Level checking not implemented yet" | ||
raise NotImplementedError(msg) | ||
# lvls = [int(l[1:]) for l in lims if l[1:].isdigit()] | ||
if self.gen in gens: | ||
allowed_moves.append(move) | ||
return allowed_moves | ||
|
||
def make_pokemon(self, dexnum=None, generate_moveset=False, **kwargs): | ||
""" | ||
kwargs are passed to the TeambuilderPokemon constructor and can include: | ||
- nickname | ||
- item | ||
- ability | ||
- moves | ||
- nature | ||
- evs | ||
- ivs | ||
- level | ||
- happiness | ||
- hiddenpowertype | ||
- gmax | ||
""" | ||
if dexnum < 1: | ||
msg = "Dex number must be greater than 0" | ||
raise ValueError(msg) | ||
if dexnum is None: | ||
dexnum = np.random.choice(list(self.dex2mon.keys())) | ||
if generate_moveset or "moves" not in kwargs.keys(): | ||
poss_moves = self.get_allowed_moves(dexnum) | ||
moves = ( | ||
np.random.choice(poss_moves, 4, replace=False) | ||
if len(poss_moves) > 3 | ||
else poss_moves | ||
) | ||
elif "moves" in kwargs: | ||
return TeambuilderPokemon(species=self.dex2mon[dexnum], **kwargs) | ||
return TeambuilderPokemon(species=self.dex2mon[dexnum], moves=moves, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from p2lab.pokemon import poke_factory | ||
|
||
|
||
def test_pokedex(): | ||
p = poke_factory.PokeFactory() | ||
assert p is not None | ||
|
||
|
||
def test_eevee_fetch(): | ||
p = poke_factory.PokeFactory() | ||
eevee = p.get_pokemon_by_dexnum(133) | ||
assert eevee["baseSpecies"].lower() == "eevee" | ||
|
||
|
||
def test_bulbasaur_fetch(): | ||
p = poke_factory.PokeFactory() | ||
bulb = p.get_pokemon_by_dexnum(1) | ||
assert bulb["baseSpecies"].lower() == "bulbasaur" | ||
|
||
|
||
def test_eevee_moves(): | ||
p = poke_factory.PokeFactory() | ||
eevee_moves = p.get_allowed_moves(133) | ||
assert len(eevee_moves) > 0 | ||
|
||
|
||
def test_bulbasaur_moves(): | ||
p = poke_factory.PokeFactory() | ||
bulb_moves = p.get_allowed_moves(1) | ||
assert len(bulb_moves) > 0 | ||
|
||
|
||
def test_eevee_is_created(): | ||
p = poke_factory.PokeFactory() | ||
eevee = p.make_pokemon(133) | ||
assert eevee is not None | ||
|
||
|
||
def test_eevee_is_created_with_moves(): | ||
p = poke_factory.PokeFactory() | ||
eevee = p.make_pokemon(133, moves=["tackle", "growl"]) | ||
assert eevee is not None | ||
|
||
|
||
def test_random_pokemon_is_created_with_moves(): | ||
p = poke_factory.PokeFactory() | ||
dexnum = np.random.randint(1, 151) | ||
while dexnum == 132: | ||
dexnum = np.random.randint(1, 151) | ||
poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True) | ||
assert len(poke.moves) == 4 | ||
|
||
|
||
def test_ditto_is_created_with_moves(): | ||
p = poke_factory.PokeFactory() | ||
ditto = p.make_pokemon(132) | ||
assert len(ditto.moves) == 1 | ||
|
||
|
||
def test_all_gen1_pokemon_can_be_created(): | ||
p = poke_factory.PokeFactory() | ||
for dexnum in range(1, 152): | ||
poke = p.make_pokemon(dexnum=dexnum, generate_moveset=True) | ||
assert len(poke.moves) > 0 | ||
|
||
|
||
def test_invalid_dex_raised(): | ||
p = poke_factory.PokeFactory() | ||
try: | ||
p.make_pokemon(dexnum=0) | ||
except ValueError: | ||
assert True | ||
else: | ||
raise AssertionError() | ||
|
||
|
||
def test_adding_item_to_pokemon(): | ||
p = poke_factory.PokeFactory() | ||
ditto = p.make_pokemon(132, item="choice scarf") | ||
assert ditto.item == "choice scarf" |