Skip to content

Commit

Permalink
cuex: implement custom_dce rule (#88)
Browse files Browse the repository at this point in the history
* wip

* wip

* working

* done

* remove print

* clean

* clean dce code

* add missing import
  • Loading branch information
mariogeiger authored Mar 1, 2025
1 parent bd5c605 commit 9884278
Showing 1 changed file with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax.extend
import jax.lax
import jax.numpy as jnp
from jax.interpreters import ad, batching, mlir, xla
from jax.interpreters import ad, batching, mlir, partial_eval, xla

import cuequivariance as cue
from cuequivariance_jax.primitives.primitives_utils import reshape
Expand Down Expand Up @@ -562,6 +562,49 @@ def flatten_index(x: jax.Array) -> jax.Array:
return outputs, (0,) * len(outputs)


def segmented_polynomial_dce(
used_outputs: list[bool],
eqn: jax.extend.core.JaxprEqn,
) -> tuple[list[bool], jax.extend.core.JaxprEqn | None]:
assert len(used_outputs) == len(eqn.outvars)

polynomial: cue.SegmentedPolynomial = eqn.params["polynomial"]
buffer_index = eqn.params["buffer_index"]
outputs_shape_dtype = eqn.params["outputs_shape_dtype"]

# If no outputs are used, we can eliminate the operation entirely
if not any(used_outputs) and not eqn.effects:
return [False] * len(eqn.invars), None

num_inputs = polynomial.num_inputs

polynomial = polynomial.compute_only(used_outputs)
used_inputs: list[bool] = polynomial.used_inputs()

used_indices, buffer_index = _dce_helper(
used_inputs, used_outputs, buffer_index, len(eqn.invars) - num_inputs
)

new_eqn = jax.extend.core.JaxprEqn(
[v for v, used in zip(eqn.invars, used_inputs + used_indices) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive,
dict(
eqn.params,
polynomial=polynomial.select_buffers(used_inputs + used_outputs),
buffer_index=buffer_index,
outputs_shape_dtype=tuple(
x for x, used in zip(outputs_shape_dtype, used_outputs) if used
),
),
eqn.effects,
eqn.source_info,
eqn.ctx,
)

return used_inputs + used_indices, new_eqn


segmented_polynomial_p.def_abstract_eval(segmented_polynomial_abstract_eval)
segmented_polynomial_p.def_impl(partial(xla.apply_primitive, segmented_polynomial_p))
mlir.register_lowering(
Expand All @@ -583,3 +626,4 @@ def flatten_index(x: jax.Array) -> jax.Array:
ad.primitive_jvps[segmented_polynomial_p] = segmented_polynomial_jvp
ad.primitive_transposes[segmented_polynomial_p] = segmented_polynomial_transpose
batching.primitive_batchers[segmented_polynomial_p] = segmented_polynomial_batching
partial_eval.dce_rules[segmented_polynomial_p] = segmented_polynomial_dce

0 comments on commit 9884278

Please sign in to comment.