From 1e5a016509aad5ac0e397ec91de3852999baed52 Mon Sep 17 00:00:00 2001 From: zichunxx <1019856685@qq.com> Date: Mon, 22 Apr 2024 21:46:28 +0800 Subject: [PATCH] collect mocap info --- robosuite/utils/binding_utils.py | 78 +++++++++++++++++++++++++++----- 1 file changed, 66 insertions(+), 12 deletions(-) diff --git a/robosuite/utils/binding_utils.py b/robosuite/utils/binding_utils.py index 563477eebb..a8e13fb157 100644 --- a/robosuite/utils/binding_utils.py +++ b/robosuite/utils/binding_utils.py @@ -241,6 +241,43 @@ def flatten(self): return np.concatenate([[self.time], self.qpos, self.qvel], axis=0) +class MjSimStateMocap: + def __init__(self, time, qpos, qvel, mocap_pos, mocap_quat): + self.time = time + self.qpos = qpos + self.qvel = qvel + self.mocap_pos = mocap_pos + self.mocap_quat = mocap_quat + + @classmethod + def from_flattened(cls, array, sim): + """ + Takes flat mjstate array and MjSim instance and + returns MjSimState. + """ + idx_time = 0 + idx_qpos = idx_time + 1 + idx_qvel = idx_qpos + sim.model.nq + idx_mocap_pos = idx_qvel + sim.model.nv + idx_mocap_quat = idx_mocap_pos + sim.model.nmocap * 3 + + time = array[idx_time] + qpos = array[idx_qpos : idx_qpos + sim.model.nq] + qvel = array[idx_qvel : idx_qvel + sim.model.nv] + mocap_pos = array[idx_mocap_pos : idx_mocap_pos + sim.model.nmocap * 3] + mocap_quat = array[idx_mocap_quat : idx_mocap_quat + sim.model.nmocap * 4] + assert sim.model.na == 0 and sim.model.nmocap != 0 + + return cls( + time=time, qpos=qpos, qvel=qvel, mocap_pos=mocap_pos.reshape(-1, 3), mocap_quat=mocap_quat.reshape(-1, 4) + ) + + def flatten(self): + return np.concatenate( + [[self.time], self.qpos, self.qvel, self.mocap_pos.flatten(), self.mocap_quat.flatten()], axis=0 + ) + + class _MjModelMeta(type): """ Metaclass which allows MjModel below to delegate to mujoco.MjModel. @@ -1137,13 +1174,22 @@ def add_render_context(self, render_context): del self._render_context_offscreen self._render_context_offscreen = render_context - def get_state(self): + def get_state(self, mocap=False): """Return MjSimState instance for current state.""" - return MjSimState( - time=self.data.time, - qpos=np.copy(self.data.qpos), - qvel=np.copy(self.data.qvel), - ) + if mocap: + return MjSimStateMocap( + time=self.data.time, + qpos=np.copy(self.data.qpos), + qvel=np.copy(self.data.qvel), + mocap_pos=np.copy(self.data.mocap_pos), + mocap_quat=np.copy(self.data.mocap_quat), + ) + else: + return MjSimState( + time=self.data.time, + qpos=np.copy(self.data.qpos), + qvel=np.copy(self.data.qvel), + ) def set_state(self, value): """ @@ -1154,19 +1200,27 @@ def set_state(self, value): self.data.qpos[:] = np.copy(value.qpos) self.data.qvel[:] = np.copy(value.qvel) - def set_state_from_flattened(self, value): + def set_state_from_flattened(self, value, mocap=False): """ Set internal mujoco state using flat mjstate array. Should call @forward afterwards to synchronize derived quantities. See https://github.com/openai/mujoco-py/blob/4830435a169c1f3e3b5f9b58a7c3d9c39bdf4acb/mujoco_py/mjsimstate.pyx#L54 """ - state = MjSimState.from_flattened(value, self) + if mocap: + state = MjSimStateMocap.from_flattened(value, self) + self.data.time = state.time + self.data.qpos[:] = state.qpos + self.data.qvel[:] = state.qvel + self.data.mocap_pos[:] = state.mocap_pos + self.data.mocap_quat[:] = state.mocap_quat + else: + state = MjSimState.from_flattened(value, self) - # do this instead of @set_state to avoid extra copy of qpos and qvel - self.data.time = state.time - self.data.qpos[:] = state.qpos - self.data.qvel[:] = state.qvel + # do this instead of @set_state to avoid extra copy of qpos and qvel + self.data.time = state.time + self.data.qpos[:] = state.qpos + self.data.qvel[:] = state.qvel def free(self): # clean up here to prevent memory leaks