Skip to content

Commit

Permalink
added Reward class that will allow us to pass in reward functions via…
Browse files Browse the repository at this point in the history
… experiment scripts
  • Loading branch information
mginoya committed Sep 20, 2024
1 parent 24c3f87 commit 10de7c1
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions alfredo/rewards/reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from brax import base
from jax import numpy as jp

class Reward:

def __init__(self, f, sc, ps):
"""
:param f: A function handle (ie. function that computes this reward)
:param sc: A float that gets multiplied to base computation provided by f
:param ps: A dictionary of parameters required for the reward computation
"""

self.f = f
self.scale = sc
self.params = ps

def add_param(self, p_name, p_value):
"""
Updates self.params dictionary with provided key and value
"""

self.params[p_name] = p_value

def compute(self):
"""
computes reward as specified by self.f given
scale and general parameters are set.
Otherwise, this errors out quite spectacularly
"""

return self.scale*self.f(**self.params)

0 comments on commit 10de7c1

Please sign in to comment.