Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CWL now compatible with partially nondifferentiable body functions. #548

Merged
merged 5 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions equinox/debug/_announce_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def announce_transform(

def _impl(*x, stack, name, intermediates, announce):
del intermediates
stack = stack + ("impl,")
stack = stack + ("impl",)
announce(name + ":".join(stack))
return x

Expand All @@ -81,9 +81,20 @@ def _jvp(p, t, *, stack, name, intermediates, announce):
p_out = announce_jaxpr_p.bind(
*p, stack=p_stack, name=name, intermediates=intermediates, announce=announce
)
t_out = announce_jaxpr_p.bind(
*t, stack=t_stack, name=name, intermediates=intermediates, announce=announce
)
t_nonzero = [ti for ti in t if type(ti) is not ad.Zero]
if len(t_nonzero) > 0:
t_nonzero_out = announce_jaxpr_p.bind(
*t_nonzero,
stack=t_stack,
name=name,
intermediates=intermediates,
announce=announce,
)
else:
t_nonzero_out = []
t_nonzero_out = iter(t_nonzero_out)
t_out = [ti if type(ti) is ad.Zero else next(t_nonzero_out) for ti in t]
assert next(t_nonzero_out, None) is None
return p_out, t_out


Expand Down
186 changes: 158 additions & 28 deletions equinox/internal/_loop/checkpointed.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,10 @@ def _is_symbolic_zero_or_none(x):
return isinstance(x, jax.custom_derivatives.SymbolicZero) or x is None


def _not_symbolic_zero(x):
return not isinstance(x, jax.custom_derivatives.SymbolicZero)


def _materialise_symbolic_zero(is_zero, ct_x, x):
if is_zero:
assert ct_x is None
Expand Down Expand Up @@ -753,7 +757,7 @@ def _checkpointed_while_loop_bwd(
loading values from checkpoints and using treeverse to toggle between forward and
backward steps.
"""
_, perturb_body_fun = perturbed
perturb_init_val, perturb_body_fun = perturbed
_, body_fun = vjp_arg
grad_final_body_fun = jtu.tree_map(
lambda x, p: jnp.zeros_like(x) if p else None, body_fun, perturb_body_fun
Expand Down Expand Up @@ -915,46 +919,166 @@ def _to_vjp(_diff_body, _diff_val):
init_step_val, init_step_grad_val, init_index, checkpoints
)

# Fixed-point iteration: promote symbolic zeros to materialised zeros if necessary.
# Do it inside the dynamic context of a `jax.eval_shape` so that it's as cheap as
#
# Okay, this next bit is a bit complicated.
# We need to figure out exactly which parts of the carry that we're going to compute
# cotangents for.
# There are three interacting pieces:
# (1) Only some inputs even need their cotangents computed (those that are marked
# as perturbed). We should avoid computing intermediate cotangents that aren't
# needed.
# (2) Some cotangents may be symbolic zeros. We should avoid materialising any that
# don't need to be materialised.
# (3) Some parts of the body function may not be differentiable. (E.g. they call to
# `eqxi.nondifferentiable` or `eqxi.nondifferentiable_backward`.) In particular,
# some inexact arrays may be marked with these. (Anything else is automatically
# nondifferentiable.) As these raise an error eagerly, we can't just take an
# approach of "differentiate everything, then DCE".
#
# Our strategy is as follows. First, consider those inputs that are perturbed, and
# iterate to find all the carries that become perturbed as a result. We obtain a
# fixed point.
# * This mimicks what JAX does in JVP rules for scans/whiles. And in fact we also
# track perturbation information forward via JVPs (which input tangents affect
# which output tangents), rather than backward via VJPs (which output cotangents
# affect which input cotangents), as this is what JAX does here and we'd like to
# be consistent. (Also, it seems like there might be a catch-22 trying to do this
# with VJPs -- how do you decide which inputs to perturb in your VJP?) This is
# basically handling criteria (1).
# * This also means that we're perturbing the minmal amount possible. We don't even
# run the JVP rule for something that isn't differentiated. This is part of
# addressing criteria (3).
#
# Second, we ignore all cotangents (of our forward output) that aren't on a
# perturbed carry. Clearly they're not going to affect anything. So we switch them
# to symbolic zeros.
# * This is part of addressing criteria (3) -- we avoid propagating cotangents
# that we don't need to. Maybe there's an `eqxi.nondifferentiable_backward` in
# there.
#
# Third, we find another fixed point, this time for the symbolic zeros. Track which
# cotangents affect which cotangents, and materialise the minimal amount possible.
# * This addresses criteria (2).
# * Note that symbolic zeros are doing double-duty here. We avoid materialising them
# for simple efficiency reasons, sure -- but we also want to keep as many
# unmaterialised as possible, in case there's an `eqxi.nondifferentiable_backward`
# in there, which only accepts symbolic zero cotangents.
#
# Phew! At that point, job done: we know which cotangents we're propagating and so
# we can run our loop.
#
# The first fixed point basically mimicks what JAX does in the JVP rule for a while
# or scan, figuring out which values are perturbed. The second fixed point basically
# mimicks what JAX does in the transpose rule of a scan, figuring out which
# cotangents are symbolic zeros.
#

# Do this inside the dynamic context of a `jax.eval_shape` so that it's as cheap as
# possible, without extra tracing.
def _resolve_perturb_val():
init_val_buffers = tree_at(
buffers(_is_none), init_val, filled_buffers, is_leaf=_is_none
)
perturb_val = perturb_init_val
assert jtu.tree_structure(init_val_buffers) == jtu.tree_structure(perturb_val)

while True:
# `body_fun` is included so that we also track perturbatations on that.
perturb_pair = (perturb_body_fun, perturb_val)
dynamic, static = partition((body_fun, init_val_buffers), perturb_pair)
new_perturb_val = sentinel = object()

@jax.custom_jvp
def _record_symbolic_zeros(_dynamic_out):
return _dynamic_out

def _record_symbolic_zeros_jvp(primals, tangents):
(primals,) = primals
(tangents,) = tangents
nonlocal new_perturb_val
new_perturb_val = jtu.tree_map(_not_symbolic_zero, tangents)
return primals, tangents

_record_symbolic_zeros.defjvp(
_record_symbolic_zeros_jvp, symbolic_zeros=True
)

def _to_linearize(_dynamic):
_body_fun, _val = combine(_dynamic, static)
_out = _body_fun(_val)
_dynamic_out, _static_out = partition(_out, is_inexact_array)
_dynamic_out = _record_symbolic_zeros(_dynamic_out)
_out = combine(_dynamic_out, _static_out)
return _out

# Not `jax.jvp`, so as not to error if `body_fun` has any `custom_vjp`s.
jax.linearize(_to_linearize, dynamic)
assert new_perturb_val is not sentinel
assert jtu.tree_structure(
perturb_val, is_leaf=_is_none
) == jtu.tree_structure(new_perturb_val, is_leaf=_is_none)
new_perturb_val = jtu.tree_map(
lambda x, y: False if (x is False and y is None) else y,
perturb_val,
new_perturb_val,
)
assert jtu.tree_structure(perturb_val) == jtu.tree_structure(
new_perturb_val
)
if tree_equal(perturb_val, new_perturb_val):
jtu.tree_map(_assert_bool, perturb_val)
if getattr(typing, "TESTING", False):
print("perturb_val", perturb_val)
return Static(perturb_val)
else:
perturb_val = jtu.tree_map(operator.or_, perturb_val, new_perturb_val)

# Find a fixed point of which values we need to perturb. (Not the same as whether
# or not they have symbolic zero cotangents! All unperturbed values must have
# symbolic zero cotangents, but some perturbed values may have these as well.)
perturb_val = jax.eval_shape(_resolve_perturb_val).value

# Do this inside the dynamic context of a `jax.eval_shape` so that it's as cheap as
# possible, without extra tracing.
def _resolve_symbolic_zeros(grad_val):
symbolic_zero_gradient = jtu.tree_map(_is_none, grad_val, is_leaf=_is_none)
new_symbolic_zero_gradient = None
init_val_buffers = tree_at(
buffers(_is_none), init_val, filled_buffers, is_leaf=_is_none
)
dynamic_init_val, static_init_val = partition(
init_val_buffers, is_inexact_array
)
dynamic_init_val, static_init_val = partition(init_val_buffers, perturb_val)

# Some hackery to extract which inputs still have symbolic zero cotangents.
@jax.custom_vjp
def _record_symbolic_zero(_dynamic_val):
return _dynamic_val
while True:
new_symbolic_zero_gradient = sentinel = object()

def _record_symbolic_zero_fwd(_dynamic_val):
return jtu.tree_map(_get_value, _dynamic_val), None
# Some hackery to extract which inputs still have symbolic zero cotangents.
@jax.custom_vjp
def _record_symbolic_zero(_dynamic_val):
return _dynamic_val

def _record_symbolic_zero_bwd(_, grad_dynamic_val):
nonlocal new_symbolic_zero_gradient
new_symbolic_zero_gradient = jtu.tree_map(
_is_symbolic_zero_or_none, grad_dynamic_val, is_leaf=_is_none
)
return (grad_dynamic_val,)
def _record_symbolic_zero_fwd(_dynamic_val):
return jtu.tree_map(_get_value, _dynamic_val), None

_record_symbolic_zero.defvjp(
_record_symbolic_zero_fwd, _record_symbolic_zero_bwd, symbolic_zeros=True
)
def _record_symbolic_zero_bwd(_, grad_dynamic_val):
nonlocal new_symbolic_zero_gradient
new_symbolic_zero_gradient = jtu.tree_map(
_is_symbolic_zero_or_none, grad_dynamic_val, is_leaf=_is_none
)
return (grad_dynamic_val,)

_record_symbolic_zero.defvjp(
_record_symbolic_zero_fwd,
_record_symbolic_zero_bwd,
symbolic_zeros=True,
)

def _to_vjp(_dynamic_val):
_dynamic_val = _record_symbolic_zero(_dynamic_val)
val = combine(_dynamic_val, static_init_val)
return filter(body_fun(val), symbolic_zero_gradient, inverse=True)
def _to_vjp(_dynamic_val):
_dynamic_val = _record_symbolic_zero(_dynamic_val)
val = combine(_dynamic_val, static_init_val)
return filter(body_fun(val), symbolic_zero_gradient, inverse=True)

while True:
_, vjp_fn = jax.vjp(_to_vjp, dynamic_init_val)
vjp_fn(grad_val) # get new_symbolic_zero_gradient via nonlocal
vjp_fn(grad_val)
assert new_symbolic_zero_gradient is not sentinel
if tree_equal(symbolic_zero_gradient, new_symbolic_zero_gradient):
jtu.tree_map(_assert_bool, symbolic_zero_gradient)
if getattr(typing, "TESTING", False):
Expand All @@ -971,6 +1095,12 @@ def _to_vjp(_dynamic_val):
init_val_buffers,
)

grad_final_val = filter(grad_final_val, perturb_val)
# If all provided cotangents end up being irrelevant -- i.e. the perturbed inputs
# depend only on those outputs which have symbolic zero cotangents -- then we can
# skip the whole computation.
if len(jtu.tree_leaves(grad_final_val)) == 0:
return jtu.tree_map(lambda _: None, (grad_final_val, body_fun))
symbolic_zero_gradient = jax.eval_shape(
_resolve_symbolic_zeros, grad_final_val
).value
Expand Down
27 changes: 13 additions & 14 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Any

import jax
import jax.core
Expand All @@ -8,7 +8,7 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, Bool, Shaped
from jaxtyping import Array, Bool

from ..._filters import combine, is_array, partition
from ..._module import field, Module
Expand Down Expand Up @@ -248,7 +248,8 @@ def _maybe_set(pred, xs, x, i, *, kwargs, makes_false_steps):


class _Buffer(Module):
_array: Union[Shaped[Array, "..."], "_Buffer"]
# annotation removed because beartype can't handle the forward reference.
_array: Any # Union[Shaped[Array, "..."], _Buffer]
_pred: Bool[Array, ""]
_tag: object = field(static=True)
_makes_false_steps: bool = field(static=True)
Expand Down Expand Up @@ -321,17 +322,11 @@ def _unwrap_buffers(x):
return x


# Work around JAX issue #15676
@jax.custom_jvp
def fixed_asarray(x):
return jnp.asarray(x)


@fixed_asarray.defjvp
def _fixed_asarray_jvp(x, tx):
(x,) = x
(tx,) = tx
return fixed_asarray(x), fixed_asarray(tx)
# Work around JAX issue #15676.
# This issue arises with both JVP tracing and make_jaxpr tracing. The former can be
# handled with a custom_jvp, but the latter cannot. So we need to just call `jnp.array`
# instead.
fixed_asarray = jnp.array


def common_rewrite(cond_fun, body_fun, init_val, max_steps, buffers, makes_false_steps):
Expand Down Expand Up @@ -416,6 +411,10 @@ def unwrap_and_select(leaf, leaf2):
step, pred, _, val = val
buffer_val = _wrap_buffers(val, pred, tag)
buffer_val2 = body_fun(buffer_val)
# Needed to work with `disable_jit`, as then we lose the automatic
# ArrayLike->Array cast provided by JAX's while loops.
# The input `val` is already cast to Array below, so this matches that.
buffer_val2 = jtu.tree_map(fixed_asarray, buffer_val2)
# Strip `.named_shape`; c.f. Diffrax issue #246
struct = jax.eval_shape(lambda: buffer_val)
struct2 = jax.eval_shape(lambda: buffer_val2)
Expand Down
2 changes: 1 addition & 1 deletion equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class LayerNorm(Module):

"""

shape: tuple[int] = field(static=True)
shape: tuple[int, ...] = field(static=True)
eps: float = field(static=True)
use_weight: bool = field(static=True)
use_bias: bool = field(static=True)
Expand Down
Loading
Loading