Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error on training Transformer baseline #5

Open
Sangboom opened this issue Jan 20, 2025 · 0 comments
Open

Error on training Transformer baseline #5

Sangboom opened this issue Jan 20, 2025 · 0 comments

Comments

@Sangboom
Copy link

Sangboom commented Jan 20, 2025

I try to follow the training tutorial of Transformer baseline. But there are lots of errors occured.

First of all, function _to_transition_trajectoreis in Transformer/orca/data/rlds/rlds_dataset.py return error. The dataset's format is not matched at all. So I try to match data with dataset builder's format like below.

def _to_transition_trajectories(trajectory: Dict[str, Any]) -> Dict[str, Any]:
  # return transition dataset in convention of Bridge dataset
  observations = {}
  for key in self._image_obs_key:
      observations[key] = tf.cast(trajectory["observation"][key], tf.float32) / 127.5 - 1.0
  observations["proprio"] = trajectory["observation"]["joint_pos"]
  # observations["proprio"] = trajectory["observation"]["state"]
  observations["ee_ft"] = trajectory["observation"]["eef_force"]
  # observations["ee_ft"] = trajectory["observation"]["ee_ft"]
  observations["primitive_id"] = tf.constant([0], dtype=tf.uint8)
  # observations["primitive_id"] = tf.cast(trajectory["observation"]["primitive_id"], tf.uint8)
  observations["primitive"] = trajectory["observation"]["primitive"]
  observations["peg_id"] = tf.constant([0], dtype=tf.uint8)
  # observations["peg_id"] = trajectory["observation"]["peg_id"]
  observations["ee_pose"] = trajectory["observation"]["eef_pose"]
  # observations["ee_pose"] = trajectory["observation"]["ee_pose"]
  observations["ee_vel"] = trajectory["observation"]["eef_vel"]
  # observations["ee_vel"] = trajectory["observation"]["ee_vel"]
  return {
      "observations": {**observations},
      "next_observations": {**observations}, # FMB doesn't need next obs
      **(
          {"language": trajectory["language_instruction"]}
          if self.load_language
          else {}
      ),
      "actions": trajectory["action"],
      "terminals": trajectory["is_terminal"],
      "truncates": tf.math.logical_and(
          trajectory["is_last"],
          tf.math.logical_not(trajectory["is_terminal"]),
      ),
  }

After matching dataset's format, this error happens.

Traceback (most recent call last):
File "/media/sblee/170d6766-97d9-4917-8fc6-7d6ae84df8961/SSD2/workspaces/fmb/Transformer/orca/data/rlds/rlds_dataset.py", line 281, in
print(next(iterator).numpy())
File "/home/sblee/miniconda3/envs/fmb_transformer_ori/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 814, in next
return self._next_internal()
File "/home/sblee/miniconda3/envs/fmb_transformer_ori/lib/python3.10/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 777, in _next_internal
ret = gen_dataset_ops.iterator_get_next(
File "/home/sblee/miniconda3/envs/fmb_transformer_ori/lib/python3.10/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 3028, in iterator_get_next
_ops.raise_from_not_ok_status(e, name)
File "/home/sblee/miniconda3/envs/fmb_transformer_ori/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 6656, in raise_from_not_ok_status
raise core._status_to_exception(e) from None # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node _wrapped__IteratorGetNext_output_types_43_device/job:localhost/replica:0/task:0/device:CPU:0}} indices[1,0] = 1 is not in [0, 1)
[[{{node GatherV2_16}}]] [Op:IteratorGetNext] name:

How can I deal with this situation? I want to know that training Transformer baseline is available.

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant