Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] XLA Segmentation Fault #283

Closed
3 tasks done
ethanluoyc opened this issue Oct 25, 2023 · 8 comments · Fixed by #284
Closed
3 tasks done

[BUG] XLA Segmentation Fault #283

ethanluoyc opened this issue Oct 25, 2023 · 8 comments · Fixed by #284
Assignees

Comments

@ethanluoyc
Copy link
Contributor

ethanluoyc commented Oct 25, 2023

Describe the bug

A clear and concise description of what the bug is.

To Reproduce

The following code using the XLA interface crashes when running on the GPU.

from typing import Any, NamedTuple
from absl import app
import dataclasses
from absl import logging
from typing import Optional
import os
import time

# envpool only accept double type action input
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
# see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.7"
import envpool
import jax
import flax
import jax.numpy as jnp


class RolloutCarry(NamedTuple):
    handle: Any
    state: Any
    key: Any


@flax.struct.dataclass
class RolloutOutput:
    actions: jnp.array
    timestep: Any


def rollout(env_step_fn, policy, agent_state, rollout_carry, max_steps):
    def _step(carry, timestep):
        del timestep

        rollout_carry = carry
        action_key, key = jax.random.split(rollout_carry.key)
        action = policy(agent_state, rollout_carry.state.observation, action_key)
        handle, next_state = env_step_fn(rollout_carry.handle, action)
        output = RolloutOutput(
            actions=action,
            timestep=rollout_carry.state,
        )
        new_rollout_carry = RolloutCarry(handle=handle, state=next_state, key=key)
        return (new_rollout_carry, output)

    new_rollout_carry, output = jax.lax.scan(_step, rollout_carry, (), length=max_steps)
    return (new_rollout_carry, output)


def main(_):
    num_envs = 64
    num_steps = 32
    total_timesteps = int(3e6)
    num_updates = total_timesteps // (num_envs * num_steps)

    envs = envpool.make("HalfCheetah-v3", env_type="dm", num_envs=num_envs, seed=1)
    # envs = envpool.make("CheetahRun-v1", env_type="dm", num_envs=num_envs, seed=1)
    action_spec = envs.action_spec()

    handle, _, _, step_env = envs.xla()
    state = envs.reset()

    def process_states(states):
        # I am converting the observation to single precision here but that seems to be the line that causes the crash.
        return states._replace(
            observation=jnp.array(states.observation.obs, dtype=jnp.float32, copy=True)
        )

    params = ()
    carry = RolloutCarry(
        handle=handle,
        state=process_states(state),
        key=jax.random.PRNGKey(0),
    )

    def wrapped_step_env(handle, action):
        handle, state = step_env(handle, action)
        return handle, process_states(state)

    def policy(params, obs, key):
        I do use float64 for actions.
        return jax.random.uniform(
            key, (num_envs, action_spec.shape[0]), dtype=jnp.float64
        )

    @jax.jit
    def rollout_fn(agent_state, rollout_carry):
        return rollout(wrapped_step_env, policy, agent_state, rollout_carry, num_steps)

    global_step = 0
    for _ in range(1, num_updates + 1):
        update_time_start = time.time()
        carry, experience = rollout_fn(params, carry)
        global_step += num_steps * num_envs
        sps_update = int(num_envs * num_steps / (time.time() - update_time_start))

        jax.block_until_ready(experience)
        # logging.info("global_step=%d, SPS_update=%d", global_step, sps_update)

    envs.close()


if __name__ == "__main__":
    jax.config.update("jax_default_dtype_bits", "32")
    jax.config.update("jax_enable_x64", True)
    jax.config.config_with_absl()
    app.run(main)
I1025 11:09:56.088881 139954427008832 xla_bridge.py:455] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
I1025 11:09:56.089233 139954427008832 xla_bridge.py:455] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I1025 11:09:56.089317 139954427008832 xla_bridge.py:455] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
Fatal Python error: Segmentation fault

Current thread 0x00007f49ade81740 (most recent call first):
  File "/home/yicheng/projects/corax-mjx/ppo_jax/debug.py", line 93 in main
  File "/home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/absl/app.py", line 254 in _run_main
  File "/home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/absl/app.py", line 308 in run
  File "/home/yicheng/projects/corax-mjx/ppo_jax/debug.py", line 107 in <module>

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
0.8.3 1.26.1 3.10.5 (main, Jun 19 2023, 14:30:29) [GCC 9.4.0] linux

JAX 0.4.10.

import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)

Additional context

I ran under gdb, this is the backtrace

0x00007fff6856c196 in void AsyncEnvPool<mujoco_gym::HalfCheetahEnv>::SendImpl<std::vector<Array, std::allocator<Array> > const&>(std::vector<Array, std::allocator<Array> > const&) () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/envpool/mujoco/mujoco_gym_envpool.so
(gdb) bt
#0  0x00007fff6856c196 in void AsyncEnvPool<mujoco_gym::HalfCheetahEnv>::SendImpl<std::vector<Array, std::allocator<Array> > const&>(std::vector<Array, std::allocator<Array> > const&) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/envpool/mujoco/mujoco_gym_envpool.so
#1  0x00007fff6856ca18 in CustomCall<AsyncEnvPool<mujoco_gym::HalfCheetahEnv>, XlaSend<AsyncEnvPool<mujoco_gym::HalfCheetahEnv> > >::Gpu(CUstream_st*, void**, char const*, unsigned long) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/envpool/mujoco/mujoco_gym_envpool.so
#2  0x00007fff70a158e0 in xla::runtime::CustomCallHandler<(xla::runtime::CustomCall::RuntimeChecks)1, xla::runtime::CustomCall::FunctionWrapper<&xla::gpu::XlaCustomCallImpl>, xla::runtime::internal::UserData<xla::ServiceExecutableRunOptions const*>, xla::runtime::internal::UserData<xla::DebugOptions const*>, xla::runtime::CustomCall::RemainingArgs, xla::runtime::internal::Attr<std::basic_string_view<char, std::char_traits<char> > >, xla::runtime::internal::Attr<int>, xla::runtime::internal::Attr<std::basic_string_view<char, std::char_traits<char> > > >::call(void**, void**, void**, xla::runtime::PtrMapByType<xla::runtime::CustomCall, 16u> const*, xla::runtime::DiagnosticEngine const*) const ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#3  0x00007fff70a16502 in xla::gpu::XlaCustomCall(xla::runtime::ExecutionContext*, void**, void**, void**) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#4  0x00007fff600f21b0 in __xla__main.192 ()
#5  0x00007fff70c90557 in xla::runtime::Executable::Execute(unsigned int, xla::runtime::Executable::CallFrame&, xla::runtime::Executable::ExecuteOpts const&) const () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#6  0x00007fff7093e5b3 in xla::gpu::GpuRuntimeExecutable::Execute(xla::ServiceExecutableRunOptions const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<unsigned char, std::allocator<unsigned char> > const&, xla::gpu::BufferAllocations const&, xla::gpu::NonAtomicallyUpgradeableRWLock&, xla::BufferAllocation const*) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#7  0x00007fff70924327 in xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime(xla::ServiceExecutableRunOptions const*, xla::gpu::BufferAllocations const&, bool, xla::gpu::NonAtomicallyUpgradeableRWLock&) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#8  0x00007fff7092868a in xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl(xla::ServiceExecutableRunOptions const*, std::variant<absl::lts_20230125::Span<xla::ShapedBuffer const* const>, absl::lts_20230125::Span<xla::ExecutionInput> >) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#9  0x00007fff70929630 in xla::gpu::GpuExecutable::ExecuteAsyncOnStream(xla::ServiceExecutableRunOptions const*, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::HloExecutionProfile*) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#10 0x00007fff7227f2e7 in xla::Executable::ExecuteAsyncOnStreamWrapper(xla::ServiceExecutableRunOptions const*, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >) () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#11 0x00007fff6f677f43 in xla::LocalExecutable::RunAsync(absl::lts_20230125::Span<xla::Shape const* const>, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::ExecutableRunOptions) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#12 0x00007fff6f678ba5 in xla::LocalExecutable::RunAsync(std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::ExecutableRunOptions) () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@mavenlin
Copy link
Member

I am converting the observation to single precision here but that seems to be the line that causes the crash.

Do you mean that if this line is remove then the crash goes away?

mavenlin added a commit that referenced this issue Oct 26, 2023
@mavenlin mavenlin mentioned this issue Oct 26, 2023
14 tasks
@ethanluoyc
Copy link
Contributor Author

I am converting the observation to single precision here but that seems to be the line that causes the crash.

Do you mean that if this line is remove then the crash goes away?

Initially I thought so but then looks like it's flaky. But I guess you have found the issue?

@mavenlin
Copy link
Member

Yep, #284 should fix it.

Trinkle23897 pushed a commit that referenced this issue Oct 26, 2023
This closes #283

The `XlaSend` call requires `envpool` to make a copy of the `action` to
prevent `action` from being recycled by the XLA runtime before `envpool`
finishes using it. Originally, I used `cudaMemcpy` to make sure the copy
was finished synchronously. However, it seems to cause a problem with
issue #283.

Here, I replace the original `cudaMemcpy` call with the async version, and
an explicit `streamSynchronize`.

It is not clear how `cudaMemcpy` in the default stream in a custom call
interacts with the stream managed by pjrt. However, from the code
[here](https://github.com/tensorflow/tensorflow/blob/0d2d79e84c9bdf71c737ad17a7b1dc04d9efc24f/tensorflow/compiler/xla/g3doc/custom_call.md),
I can hypothesize that an explicit stream synchronization in the custom
call is safe.
@ethanluoyc
Copy link
Contributor Author

ethanluoyc commented Oct 27, 2023

@mavenlin Hmm I tried it on my side but that issue seems to persist, I will take a closer look at the setup on my side.

@mavenlin
Copy link
Member

Hmm I tried it on my side but that issue seems to persist, I will take a closer look at the setup on my side.

I tested the wheel from here. I can run your above code without an issue.

@ethanluoyc
Copy link
Contributor Author

Yeah it seems to work. I was experimenting with PDM and that seems to have messed up my pip installation somehow. Many thanks for fixing this! It would be super cool if there is a new release on PyPI.

@Trinkle23897
Copy link
Collaborator

will do this weekend, sorry for the delay

@Trinkle23897
Copy link
Collaborator

done, pip install envpool will now use 0.8.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants