Skip to content

Commit

Permalink
Correctly check for nested tuple in map_func_over_tuple_of_tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Feb 7, 2024
1 parent 13ef0a8 commit eebda28
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
6 changes: 3 additions & 3 deletions scico/numpy/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import jax.numpy as jnp

import scico.numpy as snp

from ._blockarray import BlockArray


Expand Down Expand Up @@ -83,9 +85,7 @@ def mapped(*args, **kwargs):

map_arg_val = bound_args.arguments.pop(map_arg_name)

if not isinstance(map_arg_val, tuple) or not all(
isinstance(x, tuple) for x in map_arg_val
): # not nested tuple
if not snp.util.is_nested(map_arg_val): # not nested tuple
return func(*args, **kwargs) # no mapping

# map
Expand Down
10 changes: 9 additions & 1 deletion scico/test/numpy/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,17 @@ def test_ufunc_conj():
def test_create_zeros():
A = snp.zeros(2)
assert np.all(A == 0)
assert isinstance(A, jax.Array)

A = snp.zeros((2,))
assert isinstance(A, jax.Array)

A = snp.zeros(((2,), (2,)))
assert all(snp.all(A == 0))
assert isinstance(A, snp.BlockArray)

A = snp.zeros(())
assert isinstance(A, jax.Array) # from issue 499


def test_create_ones():
Expand All @@ -261,7 +269,7 @@ def test_create_ones():
assert all(snp.all(A == 1))


def test_create_zeros():
def test_create_empty():
A = snp.empty(2)
assert np.all(A == 0)

Expand Down

0 comments on commit eebda28

Please sign in to comment.