-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal for Code Structure Improvement Using jax.lax.cond
#245
Comments
Hi, thanks for the suggestion! Agreed it isn't as clean as the others, but the reason it needs to be done this way is because the def transition(
reward: Array,
observation: Observation,
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
) -> TimeStep: def termination(
reward: Array,
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
) -> TimeStep: If we simply passed in un-named arguments through the |
Hi, thank you for your comment and your observation. Yes, we would rather have what you suggested, i.e. a simple |
Thank you @sash-a for replying so fast that I didn't see your comment! Agree with what you said. A solution would be to move the |
Ye that seems like a reasonable solution |
It seems that the cases including the Since all instances calling I’m learning a lot from studying Jumanji’s code and can see the effort put into systematically building the JAX reinforcement learning environment throughout. Thank you very much for writing such excellent code. I hope @clement-bonnet 's solution is implemented (or that |
@helpingstar would you be interested in making a PR to fix this? |
@sash-a I’d be glad to help with this. Is there a preferred timeline? |
No rush honestly, whenever you have time! I think @clement-bonnet suggested fix should work well |
@sash-a Got it, thanks for letting me know! |
Is your feature request related to a problem? Please describe
This is a simple question regarding code style. It is not related to any bugs.
jumanji/jumanji/environments/routing/connector/env.py
Lines 184 to 198 in fd511b4
jumanji/jumanji/environments/logic/game_2048/env.py
Lines 222 to 236 in fd511b4
Rather than repeatedly using
lambda
and duplicating variables as shown in the code above, it seems better to follow the functional style ofjax.lax.cond
and write it in the style of the solution code below.It seems that there is little to no difference in performance.
If this is a minor issue, I will close it.
Describe the solution you'd like
Describe alternatives you've considered
None
Additional context
jumanji/jumanji/environments/logic/minesweeper/env.py
Lines 178 to 185 in fd511b4
jumanji/jumanji/environments/logic/graph_coloring/env.py
Lines 202 to 209 in fd511b4
jumanji/jumanji/environments/logic/rubiks_cube/env.py
Lines 169 to 176 in fd511b4
jumanji/jumanji/environments/logic/sudoku/env.py
Lines 124 to 132 in fd511b4
jumanji/jumanji/environments/packing/tetris/env.py
Lines 235 to 242 in fd511b4
Misc
The text was updated successfully, but these errors were encountered: