Skip to content

Commit

Permalink
style: Reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Feb 13, 2025
1 parent db139d6 commit 2e347ee
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 71 deletions.
81 changes: 37 additions & 44 deletions python/nutpie/normalizing_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import flowjax.distributions
import flowjax.flows
import numpy as np
from paramax import Parameterize, unwrap
from paramax import Parameterize


def _generate_sequences(k, r_vals):
Expand All @@ -35,6 +35,7 @@ def _generate_sequences(k, r_vals):
all_sequences.append(sequences)
return np.concatenate(all_sequences, axis=0)


def _max_run_length(seq):
"""
Given a 1D boolean NumPy array 'seq', compute the maximum run length of consecutive
Expand Down Expand Up @@ -68,6 +69,7 @@ def _max_run_length(seq):
run_lengths = np.diff(boundaries)
return int(run_lengths.max())


def _filter_sequences(sequences, m):
"""
Filter a 2D NumPy boolean array 'sequences' (each row a binary sequence) so that
Expand Down Expand Up @@ -103,7 +105,9 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3):
all_sequences = _generate_sequences(n_layers, r)
valid_sequences = _filter_sequences(all_sequences, max_run)

valid_sequences = np.repeat(valid_sequences, n_dim // len(valid_sequences) + 1, axis=0)
valid_sequences = np.repeat(
valid_sequences, n_dim // len(valid_sequences) + 1, axis=0
)
rng.shuffle(valid_sequences, axis=0)
is_in_first = valid_sequences[:n_dim]
rng = np.random.default_rng(42)
Expand Down Expand Up @@ -149,7 +153,6 @@ def __init__(
Likewise `out_features` can also be a string `"scalar"`, in which case the
output from the layer will have shape `()`.
"""
#dtype = default_floating_dtype() if dtype is None else dtype
dtype = np.float32 if dtype is None else dtype
wkey, bkey = jax.random.split(key, 2)
in_features_ = 1 if in_features == "scalar" else in_features
Expand All @@ -161,7 +164,9 @@ def __init__(
wshape = (out_features_, in_features_)
self.weight = eqx.nn._misc.default_init(wkey, wshape, dtype, lim)
bshape = (out_features_,)
self.bias = eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None
self.bias = (
eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None
)

self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -205,6 +210,7 @@ def __call__(self, x: jax.Array, *, key=None) -> jax.Array:
x = jnp.squeeze(x)
return x


class FactoredMLP(eqx.Module, strict=True):
"""Standard Multi-Layer Perceptron; also known as a feed-forward network.
Expand Down Expand Up @@ -268,7 +274,6 @@ def __init__(
Likewise `out_size` can also be a string `"scalar"`, in which case the
output from the module will have shape `()`.
"""
#dtype = default_floating_dtype() if dtype is None else dtype
keys = jax.random.split(key, depth + 1)
layers = []
if isinstance(width_size, int):
Expand All @@ -290,9 +295,7 @@ def __init__(
layers.append((U, K))
else:
k = width_size[0]
layers.append(
Linear(in_size, k, use_bias, dtype=dtype, key=keys[0])
)
layers.append(Linear(in_size, k, use_bias, dtype=dtype, key=keys[0]))
activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)())

for i in range(depth - 1):
Expand Down Expand Up @@ -331,9 +334,6 @@ def __init__(
# In case `activation` or `final_activation` are learnt, then make a separate
# copy of their weights for every neuron.
self.activation = tuple(activations)
#self.activation = eqx.filter_vmap(
# eqx.filter_vmap(lambda: activation), axis_size=depth
#)()
if out_size == "scalar":
self.final_activation = final_activation
else:
Expand All @@ -344,7 +344,7 @@ def __init__(
self.use_final_bias = use_final_bias

@jax.named_scope("eqx.nn.MLP")
def __call__(self, x: jax.Array, *, key = None) -> jax.Array:
def __call__(self, x: jax.Array, *, key=None) -> jax.Array:
"""**Arguments:**
- `x`: A JAX array with shape `(in_size,)`. (Or shape `()` if
Expand Down Expand Up @@ -382,7 +382,6 @@ def __call__(self, x: jax.Array, *, key = None) -> jax.Array:
return x



def make_mvscale(key, n_dim, size, randomize_base=False):
def make_single_hh(key, idx):
key1, key2 = jax.random.split(key)
Expand All @@ -399,7 +398,10 @@ def make_single_hh(key, idx):
else:
indices = [val % n_dim for val in range(size)]

return bijections.Chain([make_single_hh(key, idx) for key, idx in zip(keys, indices)])
return bijections.Chain(
[make_single_hh(key, idx) for key, idx in zip(keys, indices)]
)


def make_hh(key, n_dim, size, randomize_base=False):
def make_single_hh(key, idx):
Expand All @@ -415,19 +417,16 @@ def make_single_hh(key, idx):
else:
indices = [val % n_dim for val in range(size)]

return bijections.Chain([make_single_hh(key, idx) for key, idx in zip(keys, indices)])
return bijections.Chain(
[make_single_hh(key, idx) for key, idx in zip(keys, indices)]
)


def make_elemwise_trafo(key, n_dim, *, count=1):
def make_elemwise(key, loc):
key1, key2 = jax.random.split(key)
scale = Parameterize(
lambda x: x + jnp.sqrt(1 + x**2),
jnp.zeros(())
)
theta = Parameterize(
lambda x: x + jnp.sqrt(1 + x**2),
jnp.zeros(())
)
scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))
theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(()))

affine = bijections.AsymmetricAffine(
loc,
Expand Down Expand Up @@ -459,6 +458,7 @@ def make(key):
make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys)
return bijections.Vmap(make_affine, in_axes=eqx.if_array(0))


def make_elemwise_trafo_(key, n_dim, *, count=1):
def make_elemwise(key):
scale = Parameterize(
Expand Down Expand Up @@ -497,6 +497,7 @@ def make(key):
make_affine = eqx.filter_vmap(make)(keys)
return bijections.Vmap(make_affine())


def make_coupling(key, dim, n_untransformed, **kwargs):
n_transformed = dim - n_untransformed

Expand All @@ -510,10 +511,12 @@ def make_coupling(key, dim, n_untransformed, **kwargs):
else:
nn_width = 2 * dim

transformer = bijections.Chain([
make_elemwise_trafo(key, n_transformed, count=3),
mvscale,
])
transformer = bijections.Chain(
[
make_elemwise_trafo(key, n_transformed, count=3),
mvscale,
]
)

def make_mlp(out_size):
if isinstance(nn_width, tuple):
Expand Down Expand Up @@ -541,6 +544,7 @@ def make_mlp(out_size):
**kwargs,
)


def make_flow(
seed,
positions,
Expand Down Expand Up @@ -601,16 +605,6 @@ def make_flow(
if n_layers == 0:
return bijections.Chain(flows)

scale = Parameterize(
lambda x: x + jnp.sqrt(1 + x**2),
jnp.zeros(n_dim),
)
affine = eqx.tree_at(
where=lambda aff: aff.scale,
pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)),
replace=scale,
)

def make_layer(key, untransformed_dim: int | None, permutation=None):
key, key_couple, key_permute, key_hh = jax.random.split(key, 4)

Expand All @@ -625,7 +619,7 @@ def make_layer(key, untransformed_dim: int | None, permutation=None):
n_dim,
untransformed_dim,
nn_activation=jax.nn.gelu,
nn_width=nn_width
nn_width=nn_width,
)

if zero_init:
Expand All @@ -646,9 +640,7 @@ def add_default_permute(bijection, dim, key):
if dim == 2:
outer = bijections.Flip((dim,))
else:
outer = bijections.Permute(
jax.random.permutation(key, jnp.arange(dim))
)
outer = bijections.Permute(jax.random.permutation(key, jnp.arange(dim)))

return bijections.Sandwich(outer, bijection)

Expand Down Expand Up @@ -698,6 +690,7 @@ def add_default_permute(bijection, dim, key):

return bijections.Chain([bijection, *flows])


def extend_flow(
key,
base,
Expand All @@ -714,6 +707,8 @@ def extend_flow(
dct: bool = False,
extension_var_trafo_count=2,
verbose: bool = False,
nn_width=None,
nn_depth=None,
):
n_draws, n_dim = positions.shape

Expand Down Expand Up @@ -871,9 +866,7 @@ def extend_flow(
inner.outer,
bijections.Chain(
[
bijections.Sandwich(
bijections.Flip(shape=(n_dim,)), coupling
),
bijections.Sandwich(bijections.Flip(shape=(n_dim,)), coupling),
inner.inner,
]
),
Expand Down
Loading

0 comments on commit 2e347ee

Please sign in to comment.