Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update binding_utils.py to collect mocap info #475

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 66 additions & 12 deletions robosuite/utils/binding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,43 @@ def flatten(self):
return np.concatenate([[self.time], self.qpos, self.qvel], axis=0)


class MjSimStateMocap:
def __init__(self, time, qpos, qvel, mocap_pos, mocap_quat):
self.time = time
self.qpos = qpos
self.qvel = qvel
self.mocap_pos = mocap_pos
self.mocap_quat = mocap_quat

@classmethod
def from_flattened(cls, array, sim):
"""
Takes flat mjstate array and MjSim instance and
returns MjSimState.
"""
idx_time = 0
idx_qpos = idx_time + 1
idx_qvel = idx_qpos + sim.model.nq
idx_mocap_pos = idx_qvel + sim.model.nv
idx_mocap_quat = idx_mocap_pos + sim.model.nmocap * 3

time = array[idx_time]
qpos = array[idx_qpos : idx_qpos + sim.model.nq]
qvel = array[idx_qvel : idx_qvel + sim.model.nv]
mocap_pos = array[idx_mocap_pos : idx_mocap_pos + sim.model.nmocap * 3]
mocap_quat = array[idx_mocap_quat : idx_mocap_quat + sim.model.nmocap * 4]
assert sim.model.na == 0 and sim.model.nmocap != 0

return cls(
time=time, qpos=qpos, qvel=qvel, mocap_pos=mocap_pos.reshape(-1, 3), mocap_quat=mocap_quat.reshape(-1, 4)
)

def flatten(self):
return np.concatenate(
[[self.time], self.qpos, self.qvel, self.mocap_pos.flatten(), self.mocap_quat.flatten()], axis=0
)


class _MjModelMeta(type):
"""
Metaclass which allows MjModel below to delegate to mujoco.MjModel.
Expand Down Expand Up @@ -1137,13 +1174,22 @@ def add_render_context(self, render_context):
del self._render_context_offscreen
self._render_context_offscreen = render_context

def get_state(self):
def get_state(self, mocap=False):
"""Return MjSimState instance for current state."""
return MjSimState(
time=self.data.time,
qpos=np.copy(self.data.qpos),
qvel=np.copy(self.data.qvel),
)
if mocap:
return MjSimStateMocap(
time=self.data.time,
qpos=np.copy(self.data.qpos),
qvel=np.copy(self.data.qvel),
mocap_pos=np.copy(self.data.mocap_pos),
mocap_quat=np.copy(self.data.mocap_quat),
)
else:
return MjSimState(
time=self.data.time,
qpos=np.copy(self.data.qpos),
qvel=np.copy(self.data.qvel),
)

def set_state(self, value):
"""
Expand All @@ -1154,19 +1200,27 @@ def set_state(self, value):
self.data.qpos[:] = np.copy(value.qpos)
self.data.qvel[:] = np.copy(value.qvel)

def set_state_from_flattened(self, value):
def set_state_from_flattened(self, value, mocap=False):
"""
Set internal mujoco state using flat mjstate array. Should
call @forward afterwards to synchronize derived quantities.

See https://github.com/openai/mujoco-py/blob/4830435a169c1f3e3b5f9b58a7c3d9c39bdf4acb/mujoco_py/mjsimstate.pyx#L54
"""
state = MjSimState.from_flattened(value, self)
if mocap:
state = MjSimStateMocap.from_flattened(value, self)
self.data.time = state.time
self.data.qpos[:] = state.qpos
self.data.qvel[:] = state.qvel
self.data.mocap_pos[:] = state.mocap_pos
self.data.mocap_quat[:] = state.mocap_quat
else:
state = MjSimState.from_flattened(value, self)

# do this instead of @set_state to avoid extra copy of qpos and qvel
self.data.time = state.time
self.data.qpos[:] = state.qpos
self.data.qvel[:] = state.qvel
# do this instead of @set_state to avoid extra copy of qpos and qvel
self.data.time = state.time
self.data.qpos[:] = state.qpos
self.data.qvel[:] = state.qvel

def free(self):
# clean up here to prevent memory leaks
Expand Down