Skip to content

Commit

Permalink
fixed env
Browse files Browse the repository at this point in the history
  • Loading branch information
vonHartz committed Sep 8, 2023
1 parent ea4d3ea commit bb50f7d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/env/franka.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def get_obs(self):
camera_obs[cam] = SingleCamObservation(**{
'rgb': torch.Tensor(img_frames[i]),
'depth': torch.Tensor(depth_frames[i]),
'extrinsics': torch.Tensor(extrinsics[i]),
'intrinsics': self.intrinsics[i],
'extr': torch.Tensor(extrinsics[i]),
'intr': self.intrinsics[i],
}, batch_size=empty_batchsize)

multicam_obs = dict_to_tensordict(
Expand Down
47 changes: 23 additions & 24 deletions src/env/rlbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from loguru import logger
from pyrep.const import RenderMode
from pyrep.errors import IKError
import rlbench
from rlbench.action_modes import ActionMode, ArmActionMode
from rlbench.backend.observation import Observation as RLBenchObservation
from rlbench.environment import Environment as RLBenchEnvironment
from rlbench.observation_config import CameraConfig, ObservationConfig
from rlbench.task_environment import InvalidActionError
from rlbench.tasks import (ArmScan, CloseMicrowave, PhoneBase, PhoneOnBase,
Expand Down Expand Up @@ -101,7 +101,7 @@ def launch_simulation_env(self, config: dict,
action_mode = ActionMode(ArmActionMode.EE_POSE_EE_FRAME)
self.do_postprocess_actions = True

self.env = RLBenchEnvironment(
self.env = rlbench.environment.Environment(
action_mode,
obs_config=obs_config,
static_positions=config["static_env"],
Expand All @@ -118,16 +118,16 @@ def close(self):
def setup_camera_controls(self, config):
self.camera_pose = config["camera_pose"]

self.camera_assoc = {}
if config["shoulders_on"]:
self.camera_assoc["shoulder_left"] = \
self.env._scene._cam_over_shoulder_left
self.camera_assoc["shoulder_right"] = \
self.env._scene._cam_over_shoulder_right
if config["wrist_on"]:
self.camera_assoc["wrist"] = self.env._scene._cam_wrist
if config["overhead_on"]:
self.camera_assoc["overhead"] = self.env._scene._cam_overhead
camera_map = {
"left_shoulder": self.env._scene._cam_over_shoulder_left,
"right_shoulder": self.env._scene._cam_over_shoulder_right,
"wrist": self.env._scene._cam_wrist,
"overhead": self.env._scene._cam_overhead,
"front": self.env._scene._cam_front,
}

self.camera_map = {k: v for k, v in camera_map.items()
if k in self.cameras}

def reset(self):
super().reset()
Expand Down Expand Up @@ -186,7 +186,7 @@ def _step(self, action: np.ndarray, postprocess: bool = True,
if postprocess:
action_delayed = self.postprocess_action(
action, scale_action=scale_action, delay_gripper=delay_gripper,
return_euler=True)
return_euler=False)
else:
action_delayed = action

Expand Down Expand Up @@ -218,13 +218,13 @@ def _step(self, action: np.ndarray, postprocess: bool = True,
def get_camera_pose(self) -> dict[str, np.ndarray]:

return {
name: cam.get_pose() for name, cam in self.camera_assoc.items()
name: cam.get_pose() for name, cam in self.camera_map.items()
}

def set_camera_pose(self, pos_dict: dict[str, np.ndarray]) -> None:
for camera_name, pos in pos_dict.items():
if camera_name in self.camera_assoc:
camera = self.camera_assoc[camera_name]
if camera_name in self.camera_map:
camera = self.camera_map[camera_name]
camera.set_pose(pos)

def process_observation(self, obs: RLBenchObservation) -> SceneObservation:
Expand All @@ -247,29 +247,28 @@ def process_observation(self, obs: RLBenchObservation) -> SceneObservation:
rgb = getattr(obs, cam + "_rgb").transpose((2, 0, 1)) / 255
depth = getattr(obs, cam + "_depth")
mask = getattr(obs, cam + "_mask").astype(int)
ext = getattr(obs.misc, cam + "_camera_extrinsics")
intr = getattr(obs.misc, cam + "_camera_intrinsics").float()
extr = obs.misc[cam + "_camera_extrinsics"]
intr = obs.misc[cam + "_camera_intrinsics"].astype(float)

camera_obs[cam] = SingleCamObservation(**{
"rgb": torch.Tensor(rgb),
"depth": torch.Tensor(depth),
"mask": torch.Tensor(mask).to(torch.uint8),
"ext": torch.Tensor(ext),
"int": intr,
"extr": torch.Tensor(extr),
"intr": torch.Tensor(intr),
}, batch_size=empty_batchsize)

multicam_obs = dict_to_tensordict(
{'_order ': CameraOrder._create(self.cameras)} | camera_obs)


joint_pos = torch.Tensor(obs.joint_positions)
joint_vel = torch.Tensor(obs.joint_velocities)

ee_pose = torch.Tensor(obs.gripper_pose)
gripper_open = torch.Tensor(obs.gripper_open)

gripper_open = torch.Tensor([obs.gripper_open])

object_poses = {'stacked': torch.Tensor(obs.task_low_dim_state)}
object_poses = dict_to_tensordict(
{'stacked': torch.Tensor(obs.task_low_dim_state)})

obs = SceneObservation(cameras=multicam_obs, ee_pose=ee_pose,
object_poses=object_poses,
Expand Down
4 changes: 3 additions & 1 deletion src/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def get_full_task_name(config):
if config["env_config"]["background"]:
task_name += "-" + config["env_config"]["background"]

task_name += "-" + "_".join([m for m in config["env_config"]["model_ids"]])
if config["env_config"]["model_ids"]:
task_name += "-" + "_".join(
[m for m in config["env_config"]["model_ids"]])

return task_name

Expand Down

0 comments on commit bb50f7d

Please sign in to comment.