From 3f64e42c488df7e6de6ef1dece80c40e22997453 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 14 Nov 2024 17:43:31 +0100 Subject: [PATCH] fix: Move mvscale import to be conditional --- python/nutpie/transform_adapter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 136fbf8..f658ca6 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -18,7 +18,6 @@ def make_transform_adapter( import flowjax import flowjax.train import flowjax.flows - from flowjax.bijections import mvscale import optax import traceback from paramax import Parameterize, unwrap @@ -164,6 +163,8 @@ def make_layer(key, is_last=False): flow = flowjax.flows._add_default_permute(coupling, n_dim, key_permute) if scale_layer: + from flowjax.bijections import mvscale + bijections = list(flow.bijections) bijections.append(mvscale.MvScale4(jnp.ones(n_dim) * 1e-5)) # bijections.append(mvscale.MvScale3(jnp.ones(n_dim) * 1e-5))