Skip to content

Commit

Permalink
push
Browse files Browse the repository at this point in the history
  • Loading branch information
jayscoder committed Apr 14, 2024
1 parent 28ac3e3 commit 9848566
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pybts/rl/on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.devi
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
dtype = th.float32 if device == 'mps' else th.float64

if isinstance(obs, np.ndarray):
return th.as_tensor(obs, device=device, dtype=th.float32)
return th.as_tensor(obs, device=device, dtype=dtype)
elif isinstance(obs, dict):
return { key: th.as_tensor(_obs, device=device, dtype=th.float32) for (key, _obs) in obs.items() }
return { key: th.as_tensor(_obs, device=device, dtype=dtype) for (key, _obs) in obs.items() }
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python = ">=3.9,<4.0"
py-trees = "^2.2.3"
jinja2 = "^3.1.3"
flask = "^3.0.2"
tqdm = "^4.66.2"
tqdm = "^4.64"
gymnasium = {version = "^0.29.1", optional = true}
torch = {version = "^2.2.2", optional = true}
stable-baselines3 = {version = "^2.3.0", optional = true}
Expand Down

0 comments on commit 9848566

Please sign in to comment.