Skip to content

Commit

Permalink
sdd
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Dec 12, 2024
1 parent 4d89227 commit a3e251e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 35 deletions.
105 changes: 71 additions & 34 deletions build/rocm/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# JAX on ROCm
This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support in a Docker environment, suitable for both runtime and CI workflows. Explore the following methods to use or build JAX on ROCm!

## Build ROCm JAX Using Prebuilt Docker Images
## 1. Using Prebuilt Docker Images

The simplest way to use ROCm JAX is by leveraging prebuilt Docker images. These images are available on Docker Hub and come with JAX configured for ROCm. You can browse the available tags and pull images directly:

https://hub.docker.com/r/rocm/jax-community/tags
The ROCm JAX team provides prebuilt Docker images, which the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.

To pull the latest ROCm JAX Docker image, run:

Expand All @@ -17,14 +15,18 @@ Once the image is downloaded, launch a container using the following command:

```Bash
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/jax-community:latest /bin/bash

> docker attach rocm_jax
```

### Notes:
1. Ensure the Docker image `rocm/jax-community:latest` exists by checking with `docker images`.
2. The `--shm-size` parameter allocates shared memory for the container. Adjust it based on your system's resources if needed.
3. Replace `$(pwd)` with the absolute path to the directory you want to mount inside the container.
1. The `--shm-size` parameter allocates shared memory for the container. Adjust it based on your system's resources if needed.
2. Replace `$(pwd)` with the absolute path to the directory you want to mount inside the container.

### Testing JAX with ROCm:
***For older versions please review the periodically pushed docker images at:
[ROCm JAX Community DockerHub](https://hub.docker.com/r/rocm/jax-community/tags).***

### Testing your ROCm environment with JAX:

After launching the container, test whether JAX detects ROCm devices as expected:

Expand All @@ -35,7 +37,7 @@ After launching the container, test whether JAX detects ROCm devices as expected

If the setup is successful, the output should list all available ROCm devices.

## Using a ROCm Docker Image and Installing JAX
## 2. Using a ROCm Docker Image and Installing JAX

If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu container, follow these steps to install JAX in the container.

Expand All @@ -53,14 +55,15 @@ After pulling the image, launch a container using this command:

```Bash
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/dev-ubuntu-22.04:6.3-complete /bin/bash
> docker attach rocm_jax
```

### Step 3: Install a Specific Version of JAX
### Step 3: Install the Latest Version of JAX

Inside the running container, install the required version of JAX with ROCm support using pip:

```Bash
> pip3 install jax==0.4.35 jax-rocm60-plugin==0.4.35 jax-rocm60-pjrt==0.4.35
> pip3 install jax[rocm]
```

### Step 4: Verify the Installed JAX Version
Expand All @@ -77,7 +80,7 @@ jaxlib==0.4.35

### Step 5: Set the `LLVM_PATH` Environment Variable

Explicitly set the `LLVM_PATH` environment variable. This helps XLA find `ld.lld` in the PATH during runtime:
Explicitly set the `LLVM_PATH` environment variable (This helps XLA find `ld.lld` in the PATH during runtime):

```Bash
> export LLVM_PATH=/opt/rocm/llvm
Expand All @@ -88,50 +91,86 @@ Explicitly set the `LLVM_PATH` environment variable. This helps XLA find `ld.lld
Run the following command to verify that ROCm JAX is installed correctly:

```Bash
> python3.10 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
> python3 -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]

> python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
[0 1 2 3 4]
```

## Install ROCm Outside the Container or in a Custom Container
## 3. Install JAX On Bare-metal or A Custom Container

Follow these steps if you prefer to install ROCm manually on your host system or in a custom container.

### Step 1: Install Docker
### Installing ROCm Libraries Manually

Ensure Docker is installed on your system. Detailed instructions can be found on the [Docker website](https://docs.docker.com/engine/installation/).
### Step 1: Install ROCm

### Step 2: Build a Custom ROCm JAX Runtime Container
Please follow [ROCm installation guide](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) to install ROCm on your system.

To build a Docker image for ROCm JAX with your custom configuration, use the provided script. For example, to build an image with Python 3.12, run:
Once installed, verify ROCm installation using:

```Bash
> WORKSPACE=/jax BUILD_TAG=latest ./build/rocm/ci_build.sh --use_clang=true --py_version 3.12
> rocm-smi

========================================== ROCm System Management Interface ==========================================
==================================================== Concise Info ====================================================
Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
Name (20 chars) (Junction) (Socket) (Mem, Compute)
======================================================================================================================
0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
======================================================================================================================
================================================ End of ROCm SMI Log =================================================
```

### Step 3: Launch a ROCm JAX Docker Container
### Step 2: Install the Latest Version of JAX

After successfully building the image, you can launch a container using the following command:
Install the required version of JAX with ROCm support using pip:

```Bash
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax --name rocm_jax jax-rocm:latest /bin/bash
> pip3 install jax[rocm]
```

### Installing ROCm Libraries Manually
### Step 3: Verify the Installed JAX Version

Check whether the correct version of JAX and its ROCm plugins are installed:

```Bash
> pip3 freeze | grep jax
jax==0.4.35
jax-rocm60-pjrt==0.4.35
jax-rocm60-plugin==0.4.35
jaxlib==0.4.35
```

### Step 4: Set the `LLVM_PATH` Environment Variable

If you prefer to install ROCm libraries yourself, ensure the necessary dependencies are installed. For Ubuntu systems with [AMD's `apt` repositories configured](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html), you can use the following command:
Explicitly set the `LLVM_PATH` environment variable (This helps XLA find `ld.lld` in the PATH during runtime):

```Bash
> sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
> export LLVM_PATH=/opt/rocm/llvm
```

For environments already set up, including containers, we recommend using the script provided in the repository:
### Step 5: Verify the Installation of ROCm JAX

Run the following command to verify that ROCm JAX is installed correctly:

```Bash
> jax/build/rocm/tools/get_rocm.py
> python3 -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]

> python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
[0 1 2 3 4]
```

## Build ROCm JAX from Source
## 4. Build ROCm JAX from Source

Follow these steps to build JAX with ROCm support from source:

Expand All @@ -150,17 +189,15 @@ Run the following command to build the necessary wheels:

```Bash
> python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
--rocm_version=60 --rocm_path=/opt/rocm-6.3.0
--rocm_version=60 --rocm_path=/opt/rocm-[version]
```

This will generate three wheels:
This will generate three wheels in the `dist/` directory:

* jaxlib (generic, without ROCm support)
* jaxlib (generic, device agnostic library)
* jax-rocm-plugin (ROCm-specific plugin)
* jax-rocm-pjrt (ROCm-specific runtime)

The generated wheels will be located in the `dist/` directory.

### Step 3: Then install custom JAX using:

```Bash
Expand Down
2 changes: 1 addition & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ To build with debug information, add the flag `--bazel_options='--copt=/Z7'`.
### Additional notes for building a ROCM `jaxlib` for AMD GPUs

For detailed instructions on building `jaxlib` with ROCm support, refer to the official guide:
[JAX on ROCm](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md)
[Build ROCm JAX from Source](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md)

## Managing hermetic Python

Expand Down

0 comments on commit a3e251e

Please sign in to comment.