diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index 021a687..3bdd524 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -9,7 +9,7 @@ import flowjax.distributions import flowjax.flows import numpy as np -from paramax import Parameterize, unwrap +from paramax import Parameterize def _generate_sequences(k, r_vals): @@ -35,6 +35,7 @@ def _generate_sequences(k, r_vals): all_sequences.append(sequences) return np.concatenate(all_sequences, axis=0) + def _max_run_length(seq): """ Given a 1D boolean NumPy array 'seq', compute the maximum run length of consecutive @@ -68,6 +69,7 @@ def _max_run_length(seq): run_lengths = np.diff(boundaries) return int(run_lengths.max()) + def _filter_sequences(sequences, m): """ Filter a 2D NumPy boolean array 'sequences' (each row a binary sequence) so that @@ -103,7 +105,9 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3): all_sequences = _generate_sequences(n_layers, r) valid_sequences = _filter_sequences(all_sequences, max_run) - valid_sequences = np.repeat(valid_sequences, n_dim // len(valid_sequences) + 1, axis=0) + valid_sequences = np.repeat( + valid_sequences, n_dim // len(valid_sequences) + 1, axis=0 + ) rng.shuffle(valid_sequences, axis=0) is_in_first = valid_sequences[:n_dim] rng = np.random.default_rng(42) @@ -149,7 +153,6 @@ def __init__( Likewise `out_features` can also be a string `"scalar"`, in which case the output from the layer will have shape `()`. """ - #dtype = default_floating_dtype() if dtype is None else dtype dtype = np.float32 if dtype is None else dtype wkey, bkey = jax.random.split(key, 2) in_features_ = 1 if in_features == "scalar" else in_features @@ -161,7 +164,9 @@ def __init__( wshape = (out_features_, in_features_) self.weight = eqx.nn._misc.default_init(wkey, wshape, dtype, lim) bshape = (out_features_,) - self.bias = eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None + self.bias = ( + eqx.nn._misc.default_init(bkey, bshape, dtype, lim) if use_bias else None + ) self.in_features = in_features self.out_features = out_features @@ -205,6 +210,7 @@ def __call__(self, x: jax.Array, *, key=None) -> jax.Array: x = jnp.squeeze(x) return x + class FactoredMLP(eqx.Module, strict=True): """Standard Multi-Layer Perceptron; also known as a feed-forward network. @@ -268,7 +274,6 @@ def __init__( Likewise `out_size` can also be a string `"scalar"`, in which case the output from the module will have shape `()`. """ - #dtype = default_floating_dtype() if dtype is None else dtype keys = jax.random.split(key, depth + 1) layers = [] if isinstance(width_size, int): @@ -290,9 +295,7 @@ def __init__( layers.append((U, K)) else: k = width_size[0] - layers.append( - Linear(in_size, k, use_bias, dtype=dtype, key=keys[0]) - ) + layers.append(Linear(in_size, k, use_bias, dtype=dtype, key=keys[0])) activations.append(eqx.filter_vmap(lambda: activation, axis_size=k)()) for i in range(depth - 1): @@ -331,9 +334,6 @@ def __init__( # In case `activation` or `final_activation` are learnt, then make a separate # copy of their weights for every neuron. self.activation = tuple(activations) - #self.activation = eqx.filter_vmap( - # eqx.filter_vmap(lambda: activation), axis_size=depth - #)() if out_size == "scalar": self.final_activation = final_activation else: @@ -344,7 +344,7 @@ def __init__( self.use_final_bias = use_final_bias @jax.named_scope("eqx.nn.MLP") - def __call__(self, x: jax.Array, *, key = None) -> jax.Array: + def __call__(self, x: jax.Array, *, key=None) -> jax.Array: """**Arguments:** - `x`: A JAX array with shape `(in_size,)`. (Or shape `()` if @@ -382,7 +382,6 @@ def __call__(self, x: jax.Array, *, key = None) -> jax.Array: return x - def make_mvscale(key, n_dim, size, randomize_base=False): def make_single_hh(key, idx): key1, key2 = jax.random.split(key) @@ -399,7 +398,10 @@ def make_single_hh(key, idx): else: indices = [val % n_dim for val in range(size)] - return bijections.Chain([make_single_hh(key, idx) for key, idx in zip(keys, indices)]) + return bijections.Chain( + [make_single_hh(key, idx) for key, idx in zip(keys, indices)] + ) + def make_hh(key, n_dim, size, randomize_base=False): def make_single_hh(key, idx): @@ -415,19 +417,16 @@ def make_single_hh(key, idx): else: indices = [val % n_dim for val in range(size)] - return bijections.Chain([make_single_hh(key, idx) for key, idx in zip(keys, indices)]) + return bijections.Chain( + [make_single_hh(key, idx) for key, idx in zip(keys, indices)] + ) + def make_elemwise_trafo(key, n_dim, *, count=1): def make_elemwise(key, loc): key1, key2 = jax.random.split(key) - scale = Parameterize( - lambda x: x + jnp.sqrt(1 + x**2), - jnp.zeros(()) - ) - theta = Parameterize( - lambda x: x + jnp.sqrt(1 + x**2), - jnp.zeros(()) - ) + scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) affine = bijections.AsymmetricAffine( loc, @@ -459,6 +458,7 @@ def make(key): make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys) return bijections.Vmap(make_affine, in_axes=eqx.if_array(0)) + def make_elemwise_trafo_(key, n_dim, *, count=1): def make_elemwise(key): scale = Parameterize( @@ -497,6 +497,7 @@ def make(key): make_affine = eqx.filter_vmap(make)(keys) return bijections.Vmap(make_affine()) + def make_coupling(key, dim, n_untransformed, **kwargs): n_transformed = dim - n_untransformed @@ -510,10 +511,12 @@ def make_coupling(key, dim, n_untransformed, **kwargs): else: nn_width = 2 * dim - transformer = bijections.Chain([ - make_elemwise_trafo(key, n_transformed, count=3), - mvscale, - ]) + transformer = bijections.Chain( + [ + make_elemwise_trafo(key, n_transformed, count=3), + mvscale, + ] + ) def make_mlp(out_size): if isinstance(nn_width, tuple): @@ -541,6 +544,7 @@ def make_mlp(out_size): **kwargs, ) + def make_flow( seed, positions, @@ -601,16 +605,6 @@ def make_flow( if n_layers == 0: return bijections.Chain(flows) - scale = Parameterize( - lambda x: x + jnp.sqrt(1 + x**2), - jnp.zeros(n_dim), - ) - affine = eqx.tree_at( - where=lambda aff: aff.scale, - pytree=bijections.Affine(jnp.zeros(n_dim), jnp.ones(n_dim)), - replace=scale, - ) - def make_layer(key, untransformed_dim: int | None, permutation=None): key, key_couple, key_permute, key_hh = jax.random.split(key, 4) @@ -625,7 +619,7 @@ def make_layer(key, untransformed_dim: int | None, permutation=None): n_dim, untransformed_dim, nn_activation=jax.nn.gelu, - nn_width=nn_width + nn_width=nn_width, ) if zero_init: @@ -646,9 +640,7 @@ def add_default_permute(bijection, dim, key): if dim == 2: outer = bijections.Flip((dim,)) else: - outer = bijections.Permute( - jax.random.permutation(key, jnp.arange(dim)) - ) + outer = bijections.Permute(jax.random.permutation(key, jnp.arange(dim))) return bijections.Sandwich(outer, bijection) @@ -698,6 +690,7 @@ def add_default_permute(bijection, dim, key): return bijections.Chain([bijection, *flows]) + def extend_flow( key, base, @@ -714,6 +707,8 @@ def extend_flow( dct: bool = False, extension_var_trafo_count=2, verbose: bool = False, + nn_width=None, + nn_depth=None, ): n_draws, n_dim = positions.shape @@ -871,9 +866,7 @@ def extend_flow( inner.outer, bijections.Chain( [ - bijections.Sandwich( - bijections.Flip(shape=(n_dim,)), coupling - ), + bijections.Sandwich(bijections.Flip(shape=(n_dim,)), coupling), inner.inner, ] ), diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 9721d7c..02eef36 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -1,5 +1,3 @@ -from typing import Callable, Literal, Union, cast -import math from functools import partial import numpy as np @@ -10,9 +8,8 @@ import flowjax import flowjax.flows import flowjax.train -from flowjax import bijections import optax -from paramax import Parameterize, unwrap +from paramax import unwrap from nutpie.normalizing_flow import extend_flow, make_flow @@ -66,9 +63,7 @@ def compute_loss(bijection, draw, grad, logp): if self._gamma is None: def compute_loss(bijection, draw, grad, logp): - draw, grad, logp = bijection.inverse_gradient_and_val_( - draw, grad, logp - ) + draw, grad, logp = bijection.inverse_gradient_and_val_(draw, grad, logp) cost = ((draw + grad) ** 2).sum() return cost @@ -138,6 +133,7 @@ def _init_from_transformed_position(logp_fn, bijection, transformed_position): transformed_gradient, ) + @eqx.filter_jit def _init_from_transformed_position_part1(logp_fn, bijection, transformed_position): bijection = unwrap(bijection) @@ -147,6 +143,7 @@ def _init_from_transformed_position_part1(logp_fn, bijection, transformed_positi return (logdet, untransformed_position) + @eqx.filter_jit def _init_from_transformed_position_part2( bijection, @@ -162,6 +159,7 @@ def _init_from_transformed_position_part2( transformed_gradient, ) + @eqx.filter_jit def _init_from_untransformed_position(logp_fn, bijection, untransformed_position): logp, untransformed_gradient = jax.value_and_grad(lambda x: logp_fn(x)[0])( @@ -178,6 +176,7 @@ def _init_from_untransformed_position(logp_fn, bijection, untransformed_position transformed_gradient, ) + @eqx.filter_jit def _inv_transform(bijection, untransformed_position, untransformed_gradient): bijection = unwrap(bijection) @@ -188,6 +187,7 @@ def _inv_transform(bijection, untransformed_position, untransformed_gradient): ) return logdet, transformed_position, transformed_gradient + class TransformAdapter: def __init__( self, @@ -228,9 +228,7 @@ def __init__( self._num_layers = num_layers if make_optimizer is None: self._make_optimizer = lambda: optax.apply_if_finite( - #optax.adamw(learning_rate), 50 - optax.adabelief(learning_rate), 50 - #optax.adam(learning_rate), 50 + optax.adamw(learning_rate), 50 ) else: self._make_optimizer = make_optimizer @@ -303,9 +301,7 @@ def update(self, seed, positions, gradients, logps): flowjax.distributions.StandardNormal(fit.shape), fit ) params, static = eqx.partition(flow, eqx.is_inexact_array) - new_loss = self._loss_fn( - params, static, positions, gradients, logps - ) + new_loss = self._loss_fn(params, static, positions, gradients, logps) if self._verbose: print("loss from diag:", new_loss) @@ -316,12 +312,8 @@ def update(self, seed, positions, gradients, logps): return - positions = np.array( - positions[self._initial_skip :][-self._window_size :] - ) - gradients = np.array( - gradients[self._initial_skip :][-self._window_size :] - ) + positions = np.array(positions[self._initial_skip :][-self._window_size :]) + gradients = np.array(gradients[self._initial_skip :][-self._window_size :]) logps = np.array(logps[self._initial_skip :][-self._window_size :]) if len(positions) < 10: @@ -431,9 +423,7 @@ def update(self, seed, positions, gradients, logps): ) if self._verbose: - print( - f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}" - ) + print(f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}") if not np.isfinite(old_loss): flow = flowjax.flows.Transformed( @@ -473,14 +463,17 @@ def update(self, seed, positions, gradients, logps): (self.index, fit, (positions, gradients, logps)) ) - def valid_new_logp(): logdet, pos, grad = _inv_transform( fit, jnp.array(positions[-1]), jnp.array(gradients[-1]), ) - return np.isfinite(logdet) and np.isfinite(pos[0]).all() and np.isfinite(grad[0]).all() + return ( + np.isfinite(logdet) + and np.isfinite(pos[0]).all() + and np.isfinite(grad[0]).all() + ) if (not np.isfinite(old_loss)) and (not np.isfinite(new_loss)): self._bijection = self._make_flow_fn( @@ -555,9 +548,7 @@ def init_from_transformed_position_part2( part1, untransformed_gradient, ) - return float(logdet), *[ - np.array(val, dtype="float64") for val in arrays - ] + return float(logdet), *[np.array(val, dtype="float64") for val in arrays] except Exception as e: print(e) print(traceback.format_exc()) @@ -588,6 +579,7 @@ def inv_transform(self, position, gradient): print(traceback.format_exc()) raise + def make_transform_adapter( *, verbose=False,