diff --git a/alfredo/agents/A1/__init__.py b/alfredo/agents/A1/__init__.py deleted file mode 100644 index 3d6835f..0000000 --- a/alfredo/agents/A1/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .alfredo_1 import * diff --git a/alfredo/agents/A1/a1.xml b/alfredo/agents/A1/a1.xml deleted file mode 100644 index 3614f37..0000000 --- a/alfredo/agents/A1/a1.xml +++ /dev/null @@ -1,102 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/alfredo/agents/A1/alfredo_1.py b/alfredo/agents/A1/alfredo_1.py deleted file mode 100644 index 1d9efb9..0000000 --- a/alfredo/agents/A1/alfredo_1.py +++ /dev/null @@ -1,244 +0,0 @@ -# pylint:disable=g-multiple-import -"""Trains Alfredo to run/walk/move in the +x direction.""" -from typing import Tuple - -import jax -from brax import actuator, base, math -from brax.envs import PipelineEnv, State -from brax.io import mjcf -from etils import epath -from jax import numpy as jp - -from alfredo.tools import compose_scene -from alfredo.rewards import rConstant -from alfredo.rewards import rHealthy_simple_z -from alfredo.rewards import rSpeed_X -from alfredo.rewards import rControl_act_ss - -class Alfredo(PipelineEnv): - # pyformat: disable - """ """ - # pyformat: enable - - def __init__( - self, - forward_reward_weight=1.25, - ctrl_cost_weight=0.1, - healthy_reward=5.0, - terminate_when_unhealthy=True, - healthy_z_range=(1.0, 2.0), - reset_noise_scale=1e-2, - exclude_current_positions_from_observation=True, - backend="generalized", - **kwargs, - ): - - # forcing this model to need an input scene_xml_path or - # the combination of env_xml_path and agent_xml_path - # if none of these options are present, an error will be thrown - path="" - - if "env_xml_path" and "agent_xml_path" in kwargs: - env_xp = kwargs["env_xml_path"] - agent_xp = kwargs["agent_xml_path"] - xml_scene = compose_scene(env_xp, agent_xp) - del kwargs["env_xml_path"] - del kwargs["agent_xml_path"] - - sys = mjcf.loads(xml_scene) - - # this is vestigial - get rid of this someday soon - if "scene_xml_path" in kwargs: - path = kwargs["scene_xml_path"] - del kwargs["scene_xml_path"] - - sys = mjcf.load(path) - - n_frames = 5 - - if backend in ["spring", "positional"]: - sys = sys.replace(dt=0.0015) - n_frames = 10 - gear = jp.array( - [ - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 150.0, - 100.0, - 100.0, - 100.0, - 100.0, - 100.0, - 100.0, - ] - ) # pyformat: disable - sys = sys.replace(actuator=sys.actuator.replace(gear=gear)) - - kwargs["n_frames"] = kwargs.get("n_frames", n_frames) - - super().__init__(sys=sys, backend=backend, **kwargs) - - self._forward_reward_weight = forward_reward_weight - self._ctrl_cost_weight = ctrl_cost_weight - self._healthy_reward = healthy_reward - - self._terminate_when_unhealthy = terminate_when_unhealthy - self._healthy_z_range = healthy_z_range - self._reset_noise_scale = reset_noise_scale - - self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation - ) - - def reset(self, rng: jp.ndarray) -> State: - """Resets the environment to an initial state.""" - rng, rng1, rng2 = jax.random.split(rng, 3) - - low, hi = -self._reset_noise_scale, self._reset_noise_scale - - qpos = self.sys.init_q + jax.random.uniform( - rng1, (self.sys.q_size(),), minval=low, maxval=hi - ) - qvel = jax.random.uniform(rng2, (self.sys.qd_size(),), minval=low, maxval=hi) - - pipeline_state = self.pipeline_init(qpos, qvel) - - obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size())) - - reward, done, zero = jp.zeros(3) - metrics = { - "reward_ctrl": zero, - "reward_alive": zero, - "reward_velocity": zero, - "agent_x_position": zero, - "agent_y_position": zero, - "agent_x_velocity": zero, - "agent_y_velocity": zero, - } - - return State(pipeline_state, obs, reward, done, metrics) - - def step(self, state: State, action: jp.ndarray) -> State: - """Runs one timestep of the environment's dynamics.""" - prev_pipeline_state = state.pipeline_state - pipeline_state = self.pipeline_step(prev_pipeline_state, action) - obs = self._get_obs(pipeline_state, action) - - com_before, *_ = self._com(prev_pipeline_state) - com_after, *_ = self._com(pipeline_state) - - x_speed_reward = rSpeed_X(self.sys, - state.pipeline_state, - CoM_prev=com_before, - CoM_now=com_after, - dt=self.dt, - weight=self._forward_reward_weight) - - ctrl_cost = rControl_act_ss(self.sys, - state.pipeline_state, - action, - weight=-self._ctrl_cost_weight) - - healthy_reward = rHealthy_simple_z(self.sys, - state.pipeline_state, - self._healthy_z_range, - early_terminate=self._terminate_when_unhealthy, - weight=self._healthy_reward, - focus_idx_range=(0, 2)) - - reward = healthy_reward[0] + ctrl_cost + x_speed_reward[0] - - done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0 - - state.metrics.update( - reward_ctrl=ctrl_cost, - reward_alive=healthy_reward[0], - reward_velocity=x_speed_reward[0], - agent_x_position=com_after[0], - agent_y_position=com_after[1], - agent_x_velocity=x_speed_reward[1], - agent_y_velocity=x_speed_reward[2], - ) - - return state.replace( - pipeline_state=pipeline_state, obs=obs, reward=reward, done=done - ) - - def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray: - """Observes Alfredo's body position, velocities, and angles.""" - - a_positions = pipeline_state.q - a_velocities = pipeline_state.qd - #print(f"a_positions = {a_positions}") - #print(f"a_velocities = {a_velocities}") - - if self._exclude_current_positions_from_observation: - a_positions = a_positions[2:] - - com, inertia, mass_sum, x_i = self._com(pipeline_state) - cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia) - com_inertia = jp.hstack( - [cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]] - ) - - xd_i = ( - base.Transform.create(pos=x_i.pos - pipeline_state.x.pos) - .vmap() - .do(pipeline_state.xd) - ) - - com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum - com_ang = xd_i.ang - com_velocity = jp.hstack([com_vel, com_ang]) - - qfrc_actuator = actuator.to_tau( - self.sys, action, pipeline_state.q, pipeline_state.qd - ) - - # external_contact_forces are excluded - return jp.concatenate( - [ - a_positions, - a_velocities, - com_inertia.ravel(), - com_velocity.ravel(), - qfrc_actuator, - ] - ) - - def _com(self, pipeline_state: base.State) -> jp.ndarray: - """Computes Center of Mass of Alfredo""" - - inertia = self.sys.link.inertia - - if self.backend in ["spring", "positional"]: - inertia = inertia.replace( - i=jax.vmap(jp.diag)( - jax.vmap(jp.diagonal)(inertia.i) - ** (1 - self.sys.spring_inertia_scale) - ), - mass=inertia.mass ** (1 - self.sys.spring_mass_scale), - ) - - mass_sum = jp.sum(inertia.mass) - x_i = pipeline_state.x.vmap().do(inertia.transform) - - com = jp.sum(jax.vmap(jp.multiply)(inertia.mass, x_i.pos), axis=0) / mass_sum - - return ( - com, - inertia, - mass_sum, - x_i, - ) # pytype: disable=bad-return-type # jax-ndarray - diff --git a/alfredo/agents/__init__.py b/alfredo/agents/__init__.py index 88a07ac..53cd490 100644 --- a/alfredo/agents/__init__.py +++ b/alfredo/agents/__init__.py @@ -1 +1 @@ -from . import A1 +from . import aant diff --git a/alfredo/agents/aant/__init__.py b/alfredo/agents/aant/__init__.py new file mode 100644 index 0000000..145872e --- /dev/null +++ b/alfredo/agents/aant/__init__.py @@ -0,0 +1 @@ +from .aant import * diff --git a/alfredo/agents/aant/aant.py b/alfredo/agents/aant/aant.py new file mode 100644 index 0000000..83daa0c --- /dev/null +++ b/alfredo/agents/aant/aant.py @@ -0,0 +1,224 @@ +from brax import base +from brax import math +from brax.envs.base import PipelineEnv, State +from brax.io import mjcf +from etils import epath +import jax +from jax import numpy as jp + +from alfredo.tools import compose_scene + +class AAnt(PipelineEnv): + """ """ + + def __init__(self, + rewards = {}, + env_xml_path = "", + agent_xml_path = "", + terminate_when_unhealthy=False, + reset_noise_scale=0.1, + exclude_current_positions_from_observation=False, + backend='generalized', + **kwargs,): + + # env_xml_path and agent_xml_path must be provided + if env_xml_path and agent_xml_path: + self._env_xml_path = env_xml_path + self._agent_xml_path = agent_xml_path + + xml_scene = compose_scene(self._env_xml_path, self._agent_xml_path) + sys = mjcf.loads(xml_scene) + else: + raise Exception("env_xml_path & agent_xml_path both must be provided") + + # reward dictionary must be provided + if rewards: + self._rewards = rewards + else: + raise Exception("reward_Structure must be in kwargs") + + # TODO: clean this up in the future & + # make n_frames a function of input dt + n_frames = 5 + + if backend in ['spring', 'positional']: + sys = sys.replace(dt=0.005) + n_frames = 10 + + if backend == 'positional': + # TODO: does the same actuator strength work as in spring + sys = sys.replace( + actuator=sys.actuator.replace( + gear=200 * jp.ones_like(sys.actuator.gear) + ) + ) + + kwargs['n_frames'] = kwargs.get('n_frames', n_frames) + + # Initialize the superclass "PipelineEnv" + super().__init__(sys=sys, backend=backend, **kwargs) + + # Setting other object parameters based on input params + self._terminate_when_unhealthy = terminate_when_unhealthy + self._reset_noise_scale = reset_noise_scale + self._exclude_current_positions_from_observation = ( + exclude_current_positions_from_observation + ) + + + def reset(self, rng: jax.Array) -> State: + + rng, rng1, rng2, rng3 = jax.random.split(rng, 4) + low, hi = -self._reset_noise_scale, self._reset_noise_scale + + # initialize position vector with minor randomization in pose + q = self.sys.init_q + jax.random.uniform( + rng1, (self.sys.q_size(),), minval=low, maxval=hi + ) + + # initialize velocity vector with minor randomization + qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),)) + + # generate sample commands + jcmd = self._sample_command(rng3) + wcmd = jp.array([0.0, 0.0]) + + # initialize pipeline_state (the physics state) + pipeline_state = self.pipeline_init(q, qd) + + # reset values and metrics + reward, done, zero = jp.zeros(3) + + state_info = { + 'jcmd':jcmd, + 'wcmd':wcmd, + 'rewards': {k: 0.0 for k in self._rewards.keys()}, + 'step': 0, + } + + metrics = {'pos_x_world_abs': zero, + 'pos_y_world_abs': zero, + 'pos_z_world_abs': zero,} + + for rn, r in self._rewards.items(): + metrics[rn] = state_info['rewards'][rn] + + # get initial observation vector + obs = self._get_obs(pipeline_state, state_info) + + return State(pipeline_state, obs, reward, done, metrics, state_info) + + def step(self, state: State, action: jax.Array) -> State: + """Run one timestep of the environment's dynamics.""" + + # Save the previous physics state and step physics forward + pipeline_state0 = state.pipeline_state + pipeline_state = self.pipeline_step(pipeline_state0, action) + + # Add all additional parameters to compute rewards + self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd']) + self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd']) + + # Compute all rewards and accumulate total reward + total_reward = 0.0 + for rn, r in self._rewards.items(): + r.add_param('sys', self.sys) + r.add_param('pipeline_state', pipeline_state) + + reward_value = r.compute() + state.info['rewards'][rn] = reward_value[0] + total_reward += reward_value[0] + # print(f'{rn} reward_val = {reward_value}\n') + + # Computing additional metrics as necessary + pos_world = pipeline_state.x.pos[0] + abs_pos_world = jp.abs(pos_world) + + # Compute observations + obs = self._get_obs(pipeline_state, state.info) + done = 0.0 + + # State management + state.info['step'] += 1 + + state.metrics.update(state.info['rewards']) + + state.metrics.update( + pos_x_world_abs = abs_pos_world[0], + pos_y_world_abs = abs_pos_world[1], + pos_z_world_abs = abs_pos_world[2], + ) + + return state.replace( + pipeline_state=pipeline_state, obs=obs, reward=total_reward, done=done + ) + + def _get_obs(self, pipeline_state, state_info) -> jax.Array: + """Observe ant body position and velocities.""" + qpos = pipeline_state.q + qvel = pipeline_state.qd + + inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0]) + local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot) + torso_pos = pipeline_state.x.pos[0] + + jcmd = state_info['jcmd'] + #wcmd = state_info['wcmd'] + + if self._exclude_current_positions_from_observation: + qpos = pipeline_state.q[2:] + + obs = jp.concatenate([ + jp.array(qpos), + jp.array(qvel), + jp.array(local_rpyrate), + jp.array(jcmd), + ]) + + return obs + + def _sample_waypoint(self, rng: jax.Array) -> jax.Array: + x_range = [-25, 25] + y_range = [-25, 25] + z_range = [0, 2] + + _, key1, key2, key3 = jax.random.split(rng, 4) + + x = jax.random.uniform( + key1, (1,), minval=x_range[0], maxval=x_range[1] + ) + + y = jax.random.uniform( + key2, (1,), minval=y_range[0], maxval=y_range[1] + ) + + z = jax.random.uniform( + key3, (1,), minval=z_range[0], maxval=z_range[1] + ) + + wcmd = jp.array([x[0], y[0]]) + + return wcmd + + def _sample_command(self, rng: jax.Array) -> jax.Array: + lin_vel_x_range = [-3.0, 3.0] #[m/s] + lin_vel_y_range = [-3.0, 3.0] #[m/s] + yaw_vel_range = [-1.0, 1.0] #[rad/s] + + _, key1, key2, key3 = jax.random.split(rng, 4) + + lin_vel_x = jax.random.uniform( + key1, (1,), minval=lin_vel_x_range[0], maxval=lin_vel_x_range[1] + ) + + lin_vel_y = jax.random.uniform( + key2, (1,), minval=lin_vel_y_range[0], maxval=lin_vel_y_range[1] + ) + + yaw_vel = jax.random.uniform( + key3, (1,), minval=yaw_vel_range[0], maxval=yaw_vel_range[1] + ) + + jcmd = jp.array([lin_vel_x[0], lin_vel_y[0], yaw_vel[0]]) + + return jcmd diff --git a/alfredo/agents/aant/aant.xml b/alfredo/agents/aant/aant.xml new file mode 100644 index 0000000..43a5fc3 --- /dev/null +++ b/alfredo/agents/aant/aant.xml @@ -0,0 +1,94 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/alfredo/rewards/__init__.py b/alfredo/rewards/__init__.py index d5908fa..d8e5d93 100644 --- a/alfredo/rewards/__init__.py +++ b/alfredo/rewards/__init__.py @@ -1,4 +1,7 @@ +from .reward import Reward + from .rConstant import * -from .rSpeed import * from .rHealthy import * from .rControl import * +from .rEnergy import * +from .rOrientation import * diff --git a/alfredo/rewards/rConstant.py b/alfredo/rewards/rConstant.py index 18d4f7b..74f6c06 100644 --- a/alfredo/rewards/rConstant.py +++ b/alfredo/rewards/rConstant.py @@ -8,8 +8,7 @@ from jax import numpy as jp def rConstant(sys: base.System, - pipeline_state: base.State, - weight=1.0, - focus_idx_range=(1, -1)) -> jp.ndarray: + pipeline_state: base.State, + focus_idx_range=(0, -1)) -> jax.Array: - return jp.array([weight]) + return jp.array([1.0]) diff --git a/alfredo/rewards/rControl.py b/alfredo/rewards/rControl.py index 7e10532..6e332a1 100644 --- a/alfredo/rewards/rControl.py +++ b/alfredo/rewards/rControl.py @@ -10,9 +10,59 @@ def rControl_act_ss(sys: base.System, pipeline_state: base.State, action: jp.ndarray, - weight=1.0, - focus_idx_range=(1, -1)) -> jp.ndarray: + focus_idx_range=(0, -1)) -> jax.Array: ctrl_cost = weight * jp.sum(jp.square(action)) - return ctrl_cost + return jp.array([ctrl_cost]) + +def rTracking_lin_vel(sys: base.System, + pipeline_state: base.State, + jcmd: jax.Array, + sigma=0.25, + focus_idx_range=(0, 1)) -> jax.Array: + + local_vel = math.rotate(pipeline_state.xd.vel[0], + math.quat_inv(pipeline_state.x.rot[0])) + + lv_error = jp.sum(jp.square(jcmd[:2] - local_vel[:2])) # just taking a look at x, y velocities + lv_reward = jp.exp(-lv_error/sigma) + + return jp.array([lv_reward]) + +def rTracking_yaw_vel(sys: base.System, + pipeline_state: base.State, + jcmd: jax.Array, + sigma=0.25, + focus_idx_range=(0, 1)) -> jax.Array: + + local_ang_vel = math.rotate(pipeline_state.xd.ang[0], + math.quat_inv(pipeline_state.x.rot[0])) + + yaw_vel_error = jp.square(jcmd[2] - local_ang_vel[2]) + yv_reward = jp.exp(-yaw_vel_error/sigma) + + return jp.array([yv_reward]) + + +def rTracking_Waypoint(sys: base.System, + pipeline_state: base.State, + wcmd: jax.Array, + focus_idx_range=0) -> jax.Array: + + torso_pos = pipeline_state.x.pos[focus_idx_range] + pos_goal_diff = torso_pos[0:2] - waypoint[0:2] + pos_sum_abs_diff = -jp.sum(jp.abs(pos_goal_diff)) + + return jp.array([pos_sum_abs_diff]) + +def rStand_still(sys: base.System, + pipeline_state: base.State, + jcmd: jax.Array, + joint_angles: jax.Array, + default_pose: jax.Array, + focus_idx_range=0) -> jax.Array: + + close_to_still = jp.sum(jp.abs(joint_angles - default_pose)) * math.normalize(jcmd[:2])[1] < 0.1 + + return jp.array([close_to_still]) diff --git a/alfredo/rewards/rEnergy.py b/alfredo/rewards/rEnergy.py new file mode 100644 index 0000000..760d2d8 --- /dev/null +++ b/alfredo/rewards/rEnergy.py @@ -0,0 +1,27 @@ +from typing import Tuple + +import jax +from brax import actuator, base, math +from brax.envs import PipelineEnv, State +from brax.io import mjcf +from etils import epath +from jax import numpy as jp + + +def rTorques(sys: base.System, + pipeline_state: base.State, + action: jp.ndarray, + focus_idx_range=(0, -1)) -> jax.Array: + + s_idx = focus_idx_range[0] + e_idx = focus_idx_range[1] + + torque = actuator.to_tau(sys, + action, + pipeline_state.q[s_idx:e_idx], + pipeline_state.qd[s_idx:e_idx]) + + + tr = jp.sqrt(jp.sum(jp.square(torque))) + jp.sum(jp.abs(torque)) + + return jp.array([tr]) diff --git a/alfredo/rewards/rHealthy.py b/alfredo/rewards/rHealthy.py index fc1f0be..2a696ca 100644 --- a/alfredo/rewards/rHealthy.py +++ b/alfredo/rewards/rHealthy.py @@ -11,9 +11,8 @@ def rHealthy_simple_z(sys: base.System, pipeline_state: base.State, z_range: Tuple, early_terminate: True, - weight=1.0, - focus_idx_range=(1, -1)) -> jp.ndarray: - + focus_idx_range=(0, -1)) -> jax.Array: + min_z, max_z = z_range focus_s = focus_idx_range[0] focus_e = focus_idx_range[-1] @@ -24,8 +23,8 @@ def rHealthy_simple_z(sys: base.System, is_healthy = jp.where(focus_x_pos > max_z, x=0.0, y=is_healthy) if early_terminate: - hr = weight + hr = 1.0 else: - hr = weight * is_healthy + hr = 1.0 * is_healthy return jp.array([hr, is_healthy]) diff --git a/alfredo/rewards/rOrientation.py b/alfredo/rewards/rOrientation.py new file mode 100644 index 0000000..340c171 --- /dev/null +++ b/alfredo/rewards/rOrientation.py @@ -0,0 +1,17 @@ +from typing import Tuple + +import jax +from brax import actuator, base, math +from brax.envs import PipelineEnv, State +from brax.io import mjcf +from etils import epath +from jax import numpy as jp + +def rUpright(sys: base.System, + pipeline_state: base.State, + focus_idx_range = (0,0)) -> jax.Array: + + up = jp.array([0.0, 0.0, 1.0]) + rot_up = math.rotate(up, pipeline_state.x.rot[0]) + + return jp.dot(up, rot_up) diff --git a/alfredo/rewards/rSpeed.py b/alfredo/rewards/rSpeed.py deleted file mode 100644 index f500c52..0000000 --- a/alfredo/rewards/rSpeed.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Tuple - -import jax -from brax import actuator, base, math -from brax.envs import PipelineEnv, State -from brax.io import mjcf -from etils import epath -from jax import numpy as jp - -def rSpeed_X(sys: base.System, - pipeline_state: base.State, - CoM_prev: jp.ndarray, - CoM_now: jp.ndarray, - dt, - weight=1.0, - focus_idx_range=(1, -1)) -> jp.ndarray: - - - velocity = (CoM_now - CoM_prev) / dt - - focus_s = focus_idx_range[0] - focus_e = focus_idx_range[-1] - - sxr = weight * velocity[0] - - return jp.array([sxr, velocity[0], velocity[1]]) - -def rSpeed_Y(sys: base.System, - pipeline_state: base.State, - CoM_prev: jp.ndarray, - CoM_now: jp.ndarray, - dt, - weight=1.0, - focus_idx_range=(1, -1)) -> jp.ndarray: - - - velocity = (CoM_now - CoM_prev) / dt - - focus_s = focus_idx_range[0] - focus_e = focus_idx_range[-1] - - syr = weight * velocity[1] - - return jp.array([syr, velocity[0], velocity[1]]) diff --git a/alfredo/rewards/reward.py b/alfredo/rewards/reward.py new file mode 100644 index 0000000..384dadb --- /dev/null +++ b/alfredo/rewards/reward.py @@ -0,0 +1,41 @@ +from brax import base +from jax import numpy as jp + +class Reward: + + def __init__(self, f, sc, ps): + """ + :param f: A function handle (ie. function that computes this reward) + :param sc: A float that gets multiplied to base computation provided by f + :param ps: A dictionary of parameters required for the reward computation + """ + + self.f = f + self.scale = sc + self.params = ps + + def add_param(self, p_name, p_value): + """ + Updates self.params dictionary with provided key and value + """ + + self.params[p_name] = p_value + + def compute(self): + """ + computes reward as specified by self.f given + scale and general parameters are set. + Otherwise, this errors out quite spectacularly + """ + + res = self.f(**self.params) + res = res.at[0].multiply(self.scale) #may not be the best way to do this + + return res + + def __str__(self): + """ + provides a standard string output + """ + + return f'reward: {self.f}, scale: {self.scale}' diff --git a/alfredo/scenes/flatworld/flatworld_A1_env.xml b/alfredo/scenes/flatworld/flatworld_A1_env.xml index a30eb64..ef9dd55 100644 --- a/alfredo/scenes/flatworld/flatworld_A1_env.xml +++ b/alfredo/scenes/flatworld/flatworld_A1_env.xml @@ -8,21 +8,7 @@ -