Skip to content

Commit

Permalink
[ROCm] Implement RNN support
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 committed Dec 16, 2024
1 parent 5cda053 commit 33ca0e5
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 125 deletions.
49 changes: 48 additions & 1 deletion jax/experimental/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,50 @@ def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)


def swap_lstm_gates(weights, input_size, hidden_size, num_layers, bidirectional):
"""
Swaps the weights for the input and output gates in an LSTM model's parameters.
This function is specifically designed for compatibility with MIOpen, where the gate ordering
differs from CuDNN. In CuDNN, the gates are ordered as:
- 0: Forget gate (f)
- 1: Input gate (i)
- 2: New memory gate (g)
- 3: Output gate (o)
However, in MIOpen, the ordering of the new memory (g) and output (o) gates is swapped:
- 0: Forget gate (f)
- 1: Input gate (i)
- 2: Output gate (o)
- 3: New memory gate (g)
This function rearranges the weights and biases for the gates to ensure that the model
operates correctly with MIOpen by swapping the third (new memory) and fourth (output) gates.
"""
weights = jnp.asarray(weights)
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional)
num_directions = 2 if bidirectional else 1

w_offsets = 0
for l in range(num_layers):
for direction in range(num_directions):
# Iterate through all weight and bias gate names to swap gates in both weights and biases.
for gate_name in ["W_ih", "W_hh", "b_ih", "b_hh"]:
shape = flat_shapes.pop(0)
num_elems = math.prod(shape)
matrix = weights[w_offsets:w_offsets + num_elems].reshape(shape)

# Swap between the input and output gates (third and fourth gates).
gates = jnp.split(matrix, 4, axis=0)
swapped_matrix = jnp.concatenate([gates[0], gates[1], gates[3], gates[2]], axis=0)

# Update the weights with swapped matrix.
weights = weights.at[w_offsets:w_offsets + num_elems].set(swapped_matrix.flatten())
w_offsets += num_elems

return weights


def unpack_lstm_weights(
weights: Array, input_size: int, hidden_size: int, num_layers: int,
bidirectional: bool
Expand Down Expand Up @@ -437,7 +481,8 @@ def _gpu_lowering_strip_tf32(fn, *args, cudnn_allow_tf32, **kw):
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
if gpu_rnn:
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_fwd_lowering, platform='cuda')
mlir.register_lowering(rnn_fwd_p, gpu_rnn.miopen_rnn_fwd_lowering, platform='rocm')


def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
Expand Down Expand Up @@ -481,5 +526,7 @@ def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
if gpu_rnn:
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.miopen_rnn_bwd_lowering, platform='rocm')

lstm.defvjp(lstm_fwd, lstm_bwd)
Loading

0 comments on commit 33ca0e5

Please sign in to comment.