Skip to content

Commit

Permalink
update to gymnasium 1.0 seeding
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Oct 30, 2024
1 parent 361059b commit 5241ffc
Show file tree
Hide file tree
Showing 19 changed files with 63 additions and 52 deletions.
10 changes: 5 additions & 5 deletions PyFlyt/core/abstractions/base_drone.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DroneClass(ABC):
physics_hz (int): an integer representing the physics looprate of the `Aviary`.
drone_model (str): name of the drone itself, must be the same name as the folder where the URDF and YAML files are located.
model_dir (None | str = None): directory where the drone model folder is located, if none is provided, defaults to the directory of the default drones.
np_random (None | np.random.RandomState = None): random number generator of the simulation.
np_random (None | np.random.Generator = None): random number generator of the simulation.
Example Implementation:
>>> def __init__(
Expand All @@ -38,7 +38,7 @@ class DroneClass(ABC):
>>> physics_hz: int = 240,
>>> drone_model: str = "rocket_brick",
>>> model_dir: None | str = os.path.dirname(os.path.realpath(__file__)),
>>> np_random: None | np.random.RandomState = None,
>>> np_random: None | np.random.Generator = None,
>>> use_camera: bool = False,
>>> use_gimbal: bool = False,
>>> camera_angle_degrees: int = 0,
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
physics_hz: int,
drone_model: str,
model_dir: None | str = None,
np_random: None | np.random.RandomState = None,
np_random: None | np.random.Generator = None,
):
"""Defines the default configuration for UAVs, to be used in conjunction with the Aviary class.
Expand All @@ -89,7 +89,7 @@ def __init__(
physics_hz (int): an integer representing the physics looprate of the `Aviary`.
drone_model (str): name of the drone itself, must be the same name as the folder where the URDF and YAML files are located.
model_dir (None | str = None): directory where the drone model folder is located, if none is provided, defaults to the directory of the default drones.
np_random (None | np.random.RandomState = None): random number generator of the simulation.
np_random (None | np.random.Generator = None): random number generator of the simulation.
"""
if physics_hz % control_hz != 0:
Expand All @@ -98,7 +98,7 @@ def __init__(
)

self.p = p
self.np_random = np.random.RandomState() if np_random is None else np_random
self.np_random = np.random.default_rng() if np_random is None else np_random
self.physics_control_ratio = int(physics_hz / control_hz)
self.physics_period = 1.0 / physics_hz
self.control_period = 1.0 / control_hz
Expand Down
6 changes: 3 additions & 3 deletions PyFlyt/core/abstractions/base_wind_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class WindFieldClass(ABC):
>>>
>>> # define the wind field
>>> class MyWindField(WindFieldClass):
>>> def __init__(self, my_parameter=1.0, np_random: None | np.random.RandomState = None):
>>> def __init__(self, my_parameter=1.0, np_random: None | np.random.Generator = None):
>>> super().__init__(np_random)
>>> self.strength = my_parameter
>>>
Expand All @@ -39,9 +39,9 @@ class WindFieldClass(ABC):
>>> ...
"""

def __init__(self, np_random: None | np.random.RandomState = None):
def __init__(self, np_random: None | np.random.Generator = None):
"""Initializes the wind_field."""
self.np_random = np.random.RandomState() if np_random is None else np_random
self.np_random = np.random.default_rng() if np_random is None else np_random

@abstractmethod
def __call__(self, time: float, position: np.ndarray) -> np.ndarray:
Expand Down
8 changes: 4 additions & 4 deletions PyFlyt/core/abstractions/boosters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Boosters:
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
booster_ids (np.ndarray | list[int]): list of integers representing the link index that each booster should be attached to.
fueltank_ids (np.ndarray | list[None | int]): list of integers representing the link index for the fuel tank that each booster is attached to.
Expand All @@ -39,7 +39,7 @@ def __init__(
self,
p: bullet_client.BulletClient,
physics_period: float,
np_random: np.random.RandomState,
np_random: np.random.Generator,
uav_id: int,
booster_ids: np.ndarray | list[int],
fueltank_ids: np.ndarray | list[None | int],
Expand All @@ -58,7 +58,7 @@ def __init__(
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
booster_ids (np.ndarray | list[int]): list of integers representing the link index that each booster should be attached to.
fueltank_ids (np.ndarray | list[None | int]): list of integers representing the link index for the fuel tank that each booster is attached to.
Expand Down Expand Up @@ -239,7 +239,7 @@ def _compute_thrust_mass_inertia(

# noise in the motor
self.throttle += (
self.np_random.randn(*self.throttle.shape)
self.np_random.normal(*self.throttle.shape)
* self.throttle
* self.noise_ratio
)
Expand Down
6 changes: 3 additions & 3 deletions PyFlyt/core/abstractions/boring_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BoringBodies:
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
body_ids (np.ndarray | Sequence[int]): (n,) array of IDs for the links representing the bodies.
drag_coefs (np.ndarray): (n, 3) array of drag coefficients for each body in the link-referenced XYZ directions.
Expand All @@ -28,7 +28,7 @@ def __init__(
self,
p: bullet_client.BulletClient,
physics_period: float,
np_random: np.random.RandomState,
np_random: np.random.Generator,
uav_id: int,
body_ids: np.ndarray | Sequence[int],
drag_coefs: np.ndarray,
Expand All @@ -39,7 +39,7 @@ def __init__(
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
body_ids (np.ndarray | Sequence[int]): (n,) array of IDs for the links representing the bodies.
drag_coefs (np.ndarray): (n, 3) array of drag coefficients for each body in the link-referenced XYZ directions.
Expand Down
6 changes: 3 additions & 3 deletions PyFlyt/core/abstractions/gimbals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Gimbals:
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
gimbal_unit_1 (np.ndarray): first unit vector that the gimbal rotates around.
gimbal_unit_2 (np.ndarray): second unit vector that the gimbal rotates around.
gimbal_tau (np.ndarray): gimbal actuation time constant.
Expand All @@ -31,7 +31,7 @@ def __init__(
self,
p: bullet_client.BulletClient,
physics_period: float,
np_random: np.random.RandomState,
np_random: np.random.Generator,
gimbal_unit_1: np.ndarray,
gimbal_unit_2: np.ndarray,
gimbal_tau: np.ndarray,
Expand All @@ -42,7 +42,7 @@ def __init__(
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
gimbal_unit_1 (np.ndarray): first unit vector that the gimbal rotates around.
gimbal_unit_2 (np.ndarray): second unit vector that the gimbal rotates around.
gimbal_tau (np.ndarray): gimbal actuation time constant.
Expand Down
6 changes: 3 additions & 3 deletions PyFlyt/core/abstractions/lifting_surfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class LiftingSurface:
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
surface_id (int): an integer for the link ID for this lifting surface.
lifting_unit (np.ndarray): (3,) unit vector representing the direction of lift.
Expand All @@ -141,7 +141,7 @@ def __init__(
self,
p: bullet_client.BulletClient,
physics_period: float,
np_random: np.random.RandomState,
np_random: np.random.Generator,
uav_id: int,
surface_id: int,
lifting_unit: np.ndarray,
Expand All @@ -163,7 +163,7 @@ def __init__(
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
surface_id (int): an integer for the link ID for this lifting surface.
lifting_unit (np.ndarray): (3,) unit vector representing the direction of lift.
Expand Down
8 changes: 4 additions & 4 deletions PyFlyt/core/abstractions/motors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Motors:
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
motor_ids (list[int]): a (n,) list of integers representing the link IDs for n motors.
tau (np.ndarray): an (n,) of floats array representing the ramp time constant of each motor.
Expand All @@ -36,7 +36,7 @@ def __init__(
self,
p: bullet_client.BulletClient,
physics_period: float,
np_random: np.random.RandomState,
np_random: np.random.Generator,
uav_id: np.ndarray | int,
motor_ids: np.ndarray | list[int],
tau: np.ndarray,
Expand All @@ -51,7 +51,7 @@ def __init__(
Args:
p (bullet_client.BulletClient): PyBullet physics client ID.
physics_period (float): physics period of the simulation.
np_random (np.random.RandomState): random number generator of the simulation.
np_random (np.random.Generator): random number generator of the simulation.
uav_id (int): ID of the drone.
motor_ids (list[int]): a (n,) list of integers representing the link IDs for n motors.
tau (np.ndarray): an (n,) of floats array representing the ramp time constant of each motor.
Expand Down Expand Up @@ -132,7 +132,7 @@ def physics_update(

# noise in the motor
self.throttle += (
self.np_random.randn(*self.throttle.shape)
self.np_random.normal(*self.throttle.shape)
* self.throttle
* self.noise_ratio
)
Expand Down
24 changes: 18 additions & 6 deletions PyFlyt/core/aviary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Aviary(bullet_client.BulletClient):
physics_hz (int): physics looprate (not recommended to be changed).
world_scale (float): how big to spawn the floor.
seed (None | int): optional int for seeding the simulation RNG.
np_random (None | np.random.Generator): a numpy random number generator to be used for RNG.
"""

Expand All @@ -78,6 +79,7 @@ def __init__(
physics_hz: int = 240,
world_scale: float = 1.0,
seed: None | int = None,
np_random: None | np.random.Generator = None,
):
"""Initializes a PyBullet environment that hosts UAVs and other entities.
Expand All @@ -96,11 +98,24 @@ def __init__(
physics_hz (int): physics looprate (not recommended to be changed).
world_scale (float): how big to spawn the floor.
seed (None | int): optional int for seeding the simulation RNG.
np_random (None | np.random.Generator): a numpy random number generator to be used for RNG.
"""
super().__init__(p.GUI if render else p.DIRECT)
print("\033[A \033[A")

# set random state
if seed and np_random:
raise AviaryInitException(
"Cannot set both `seed` and `np_random` arguments together."
)
elif seed and not np_random:
self.np_random = np.random.default_rng(seed)
elif not seed and np_random:
self.np_random = np_random
else:
self.np_random = np.random.default_rng()

# check for starting position and orientation shapes
if len(start_pos.shape) != 2:
raise AviaryInitException(
Expand Down Expand Up @@ -197,9 +212,10 @@ def __init__(
text="RTF here", textPosition=[0, 0, 0], textColorRGB=[1, 0, 0]
)

self.reset(seed)
# initialize the environment
self.reset()

def reset(self, seed: None | int = None) -> None:
def reset(self) -> None:
"""Resets the simulation.
Args:
Expand All @@ -220,9 +236,6 @@ def reset(self, seed: None | int = None) -> None:
cameraTargetPosition=[0, 0, 1],
)

# define new RNG
self.np_random = np.random.RandomState(seed=seed)

# construct the world
self.planeId = self.loadURDF(
"plane.urdf", useFixedBase=True, globalScaling=self.world_scale
Expand Down Expand Up @@ -303,7 +316,6 @@ def register_all_new_bodies(self) -> None:
Call this when there is an update in the number of bodies in the environment.
"""
# the collision array is a scipy sparse, upper triangle array
num_bodies = (
np.max([self.getBodyUniqueId(i) for i in range(self.getNumBodies())]) + 1
)
Expand Down
4 changes: 2 additions & 2 deletions PyFlyt/core/drones/fixedwing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
physics_hz: int = 240,
drone_model: str = "fixedwing",
model_dir: None | str = None,
np_random: None | np.random.RandomState = None,
np_random: None | np.random.Generator = None,
use_camera: bool = False,
use_gimbal: bool = False,
camera_angle_degrees: int = 0,
Expand All @@ -44,7 +44,7 @@ def __init__(
physics_hz (int): physics_hz
drone_model (str): drone_model
model_dir (None | str): model_dir
np_random (None | np.random.RandomState): np_random
np_random (None | np.random.Generator): np_random
use_camera (bool): use_camera
use_gimbal (bool): use_gimbal
camera_angle_degrees (int): camera_angle_degrees
Expand Down
4 changes: 2 additions & 2 deletions PyFlyt/core/drones/quadx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
physics_hz: int = 240,
drone_model: str = "cf2x",
model_dir: None | str = None,
np_random: None | np.random.RandomState = None,
np_random: None | np.random.Generator = None,
use_camera: bool = False,
use_gimbal: bool = False,
camera_angle_degrees: int = 20,
Expand All @@ -47,7 +47,7 @@ def __init__(
physics_hz (int): physics_hz
drone_model (str): drone_model
model_dir (None | str): model_dir
np_random (None | np.random.RandomState): np_random
np_random (None | np.random.Generator): np_random
use_camera (bool): use_camera
use_gimbal (bool): use_gimbal
camera_angle_degrees (int): camera_angle_degrees
Expand Down
4 changes: 2 additions & 2 deletions PyFlyt/core/drones/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
physics_hz: int = 240,
drone_model: str = "rocket",
model_dir: None | str = None,
np_random: None | np.random.RandomState = None,
np_random: None | np.random.Generator = None,
use_camera: bool = False,
use_gimbal: bool = False,
camera_angle_degrees: int = 30,
Expand All @@ -56,7 +56,7 @@ def __init__(
control_hz (int): control_hz
drone_model (str): drone_model
model_dir (None | str): model_dir
np_random (None | np.random.RandomState): np_random
np_random (None | np.random.Generator): np_random
use_camera (bool): use_camera
use_gimbal (bool): use_gimbal
camera_angle_degrees (int): camera_angle_degrees
Expand Down
1 change: 1 addition & 0 deletions PyFlyt/core/utils/compile_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def jitter(func: Callable, **kwargs):

def check_numpy():
"""Checks that numpy is installed."""
return
if not p.isNumpyEnabled():
warnings.warn(
colorize(
Expand Down
4 changes: 2 additions & 2 deletions PyFlyt/gym_envs/fixedwing_envs/fixedwing_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def begin_reset(
drone_type="fixedwing",
render=self.render_mode == "human",
drone_options=drone_options,
seed=seed,
np_random=self.np_random,
)

if self.render_mode == "human":
Expand Down Expand Up @@ -292,7 +292,7 @@ def render(self) -> np.ndarray:
projectionMatrix=self.env.drones[0].camera.proj_mat,
)

rgbaImg = np.asarray(rgbaImg).reshape(
rgbaImg = np.asarray(rgbaImg, dtype=np.uint8).reshape(
self.render_resolution[0], self.render_resolution[1], -1
)

Expand Down
4 changes: 2 additions & 2 deletions PyFlyt/gym_envs/quadx_envs/quadx_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def begin_reset(
drone_type="quadx",
render=self.render_mode == "human",
drone_options=drone_options,
seed=seed,
np_random=self.np_random,
)

if self.render_mode == "human":
Expand Down Expand Up @@ -327,7 +327,7 @@ def render(self) -> np.ndarray:
f"Unknown render mode {self.render_mode}, should not have ended up here"
)

rgbaImg = np.asarray(rgbaImg).reshape(
rgbaImg = np.asarray(rgbaImg, dtype=np.uint8).reshape(
self.render_resolution[0], self.render_resolution[1], -1
)

Expand Down
Loading

0 comments on commit 5241ffc

Please sign in to comment.