Skip to content

Commit

Permalink
eqxi.while_loop now has better error messages for mismatched structures
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 11, 2023
1 parent 27bc1a4 commit 1e2d8c2
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools as it
from typing import Any, TYPE_CHECKING, Union

import jax
Expand All @@ -12,6 +13,7 @@

from ..._filters import combine, is_array, partition
from ..._module import field, Module
from ..._pretty_print import tree_pformat
from ..._tree import tree_at, tree_equal
from ..._unvmap import unvmap_any
from .._nontraceable import nonbatchable
Expand Down Expand Up @@ -456,10 +458,27 @@ def unwrap_and_select(leaf, leaf2):
# Strip `.named_shape`; c.f. Diffrax issue #246
struct = jax.eval_shape(lambda: buffer_val)
struct2 = jax.eval_shape(lambda: buffer_val2)
struct = jtu.tree_map(lambda x: (x.shape, x.dtype), struct)
struct2 = jtu.tree_map(lambda x: (x.shape, x.dtype), struct2)
struct = jtu.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), struct)
struct2 = jtu.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), struct2
)
if not tree_equal(struct, struct2):
raise ValueError("`body_fun` must have the same input and output structure")
string = tree_pformat(struct, struct_as_array=True)
string2 = tree_pformat(struct2, struct_as_array=True)
out = []
for line, line2 in it.zip_longest(
string.split("\n"), string2.split("\n"), fillvalue=""
):
if line == line2:
out.append(" " + line)
else:
out.append("- " + line)
out.append("+ " + line2)
out = "\n".join(out)
raise ValueError(
"`body_fun` must have the same input and output structure. Difference "
"is:\n" + out
)
val2 = jtu.tree_map(
unwrap_and_select, buffer_val, buffer_val2, is_leaf=is_our_buffer
)
Expand Down

0 comments on commit 1e2d8c2

Please sign in to comment.