From 10ff00620ef3423f968a8f67fdedfa606d7ef528 Mon Sep 17 00:00:00 2001 From: Markus Schmitt Date: Mon, 3 Jul 2023 08:35:29 +0000 Subject: [PATCH 1/2] Fixed bug. --- jVMC/stats.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/jVMC/stats.py b/jVMC/stats.py index 78909cc..494e03b 100644 --- a/jVMC/stats.py +++ b/jVMC/stats.py @@ -42,16 +42,16 @@ def jit_my_stuff(): statsPmapDevices = global_defs.myPmapDevices - _mean_helper = pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0)) - _data_prep = pmap_for_my_devices(lambda data, w, mean: jax.vmap(lambda d, w, m: jnp.sqrt(w) * (d - m), in_axes=(0,0,None))(data, w, mean), in_axes=(0, 0, None)) - _covar_helper = pmap_for_my_devices( + _mean_helper = jVMC.global_defs.pmap_for_my_devices(lambda data, w: jnp.tensordot(w, data, axes=(0,0)), in_axes=(0, 0)) + _data_prep = jVMC.global_defs.pmap_for_my_devices(lambda data, w, mean: jax.vmap(lambda d, w, m: jnp.sqrt(w) * (d - m), in_axes=(0,0,None))(data, w, mean), in_axes=(0, 0, None)) + _covar_helper = jVMC.global_defs.pmap_for_my_devices( lambda data1, data2: jnp.tensordot( jnp.conj(data1), data2, axes=(0,0)), in_axes=(0, 0) ) - _covar_var_helper = pmap_for_my_devices( + _covar_var_helper = jVMC.global_defs.pmap_for_my_devices( lambda data1, data2, w: jnp.sum( jnp.abs( @@ -60,14 +60,14 @@ def jit_my_stuff(): axis=0), in_axes=(0, 0, 0) ) - _covar_data_helper = pmap_for_my_devices(lambda data1, data2, w: jax.vmap(lambda a,b,w: jnp.outer(a,b) / w)(jnp.conj(data1), data2, w), in_axes=(0, 0, 0)) - _trafo_helper_1 = pmap_for_my_devices( + _covar_data_helper = jVMC.global_defs.pmap_for_my_devices(lambda data1, data2, w: jax.vmap(lambda a,b,w: jnp.outer(a,b) / w)(jnp.conj(data1), data2, w), in_axes=(0, 0, 0)) + _trafo_helper_1 = jVMC.global_defs.pmap_for_my_devices( lambda data, w, mean, f: f( jax.vmap(lambda x,y: x/jnp.sqrt(y), in_axes=(0,0))(data, w) + mean ), in_axes=(0, 0, None), static_broadcasted_argnums=(3,)) - _trafo_helper_2 = pmap_for_my_devices( + _trafo_helper_2 = jVMC.global_defs.pmap_for_my_devices( lambda data, w, mean, v, f: jnp.matmul(v, f( @@ -76,10 +76,10 @@ def jit_my_stuff(): ) ), in_axes=(0, 0, None, None), static_broadcasted_argnums=(4,)) - _select_helper = pmap_for_my_devices( lambda ix,g: jax.vmap(lambda ix,g: g[ix], in_axes=(None, 0))(ix,g), in_axes=(None, 0) ) - _get_subset_helper = pmap_for_my_devices(lambda x, ixs: x[slice(*ixs)], in_axes=(0,), static_broadcasted_argnums=(1,)) - _subset_mean_helper = pmap_for_my_devices(lambda d, w, m: jnp.tensordot(jnp.sqrt(w), d, axes=(0,0)) + m, in_axes=(0,0,None)) - _subset_data_prep = pmap_for_my_devices(jax.vmap(lambda d, w, m1, m2: d+jnp.sqrt(w)*(m1-m2), in_axes=(0,0,None,None)), in_axes=(0,0,None,None)) + _select_helper = jVMC.global_defs.pmap_for_my_devices( lambda ix,g: jax.vmap(lambda ix,g: g[ix], in_axes=(None, 0))(ix,g), in_axes=(None, 0) ) + _get_subset_helper = jVMC.global_defs.pmap_for_my_devices(lambda x, ixs: x[slice(*ixs)], in_axes=(0,), static_broadcasted_argnums=(1,)) + _subset_mean_helper = jVMC.global_defs.pmap_for_my_devices(lambda d, w, m: jnp.tensordot(jnp.sqrt(w), d, axes=(0,0)) + m, in_axes=(0,0,None)) + _subset_data_prep = jVMC.global_defs.pmap_for_my_devices(jax.vmap(lambda d, w, m1, m2: d+jnp.sqrt(w)*(m1-m2), in_axes=(0,0,None,None)), in_axes=(0,0,None,None)) class SampledObs(): From e2b5b0e1f6790a79e163adf2635a014b81b68fd6 Mon Sep 17 00:00:00 2001 From: Markus Schmitt Date: Mon, 3 Jul 2023 08:45:05 +0000 Subject: [PATCH 2/2] version number. --- jVMC/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jVMC/version.py b/jVMC/version.py index 6be37dd..3311566 100644 --- a/jVMC/version.py +++ b/jVMC/version.py @@ -1,2 +1,2 @@ """Current jVMC version at head on Github.""" -__version__ = "1.2.2" +__version__ = "1.2.3"