Skip to content

Commit

Permalink
Merge pull request #65 from markusschmitt/general_mc_estimator
Browse files Browse the repository at this point in the history
Add general gradient estimator to SampledObs class.
  • Loading branch information
markusschmitt authored Jun 3, 2024
2 parents 67cef3d + 8099c5b commit d17d618
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 15 deletions.
31 changes: 31 additions & 0 deletions jVMC/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,34 @@ def compile():
corresponding matrix elements.
"""

def get_estimator_function(self, psi, *args):
"""Get a function that computes :math:`O_{loc}(\\theta, s)`.
Returns a function that computes :math:`O_{loc}(\\theta, s)=\sum_{s'} O_{s,s'}\\frac{\psi_\\theta(s')}{\psi_\\theta(s)}`
for a given configuration :math:`s` and parameters :math:`\\theta` of a given ansatz :math:`\psi_\\theta(s)`.
Arguments:
* ``psi``: Neural quantum state.
* ``*args``: Further positional arguments for the operator.
Returns:
A function :math:`O_{loc}(\\theta, s)`.
"""

op_fun = self.compile()
if type(op_fun) is tuple:
op_fun_args = op_fun[1](*args)
op_fun = op_fun[0]
net_fun = psi.net.apply

def op_estimator(params, config):

sp, matEls = op_fun(config, *op_fun_args)

log_psi_s = net_fun(params, config)
log_psi_sp = jax.vmap(lambda s: net_fun(params,s))(sp)

#return jnp.dot(matEls, jnp.exp(log_psi_sp - log_psi_s))
return jnp.sum(matEls * jnp.exp(log_psi_sp - log_psi_s))

return op_estimator
134 changes: 122 additions & 12 deletions jVMC/stats.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten, tree_map

import jVMC
import jVMC.mpi_wrapper as mpi
import jVMC.global_defs as global_defs
from jVMC.global_defs import pmap_for_my_devices

import numpy as np

from functools import partial

_mean_helper = None
_data_prep = None
_covar_helper = None
Expand Down Expand Up @@ -82,46 +87,151 @@ def jit_my_stuff():
_subset_data_prep = jVMC.global_defs.pmap_for_my_devices(jax.vmap(lambda d, w, m1, m2: d+jnp.sqrt(w)*(m1-m2), in_axes=(0,0,None,None)), in_axes=(0,0,None,None))


# def get_op_estimator(psi, operator, *args):

# op_fun = operator.compile()
# if type(op_fun) is tuple:
# op_fun_args = op_fun[1](*args)
# op_fun = op_fun[0]
# net_fun = psi.net.apply

# def op_estimator(params, config):

# sp, matEls = op_fun(config, *op_fun_args)

# log_psi_s = net_fun(params, config)
# log_psi_sp = jax.vmap(lambda s: net_fun(params,s))(sp)

# return jnp.dot(matEls, jnp.exp(log_psi_sp - log_psi_s))

# return op_estimator


def flat_grad(fun):

def grad_fun(*args):
grad_tree = jax.grad(fun)(*args)

dtypes = [a.dtype for a in tree_flatten(args[0])[0]]
if dtypes[0] == np.single or dtypes[0] == np.double:
grad_vec = tree_flatten(
tree_map(
lambda x: x.ravel(),
grad_tree
)
)[0]
else:
grad_vec = tree_flatten(
tree_map(
lambda x: [jnp.real(x.ravel()), -jnp.imag(x.ravel())],
grad_tree
)
)[0]

return jnp.concatenate(grad_vec)

return grad_fun


class SampledObs():
"""This class implements the computation of statistics from Monte Carlo or exact samples.
Initializer arguments:
* ``observations``: Observations :math:`O_n` in the sample. The array must have a leading device \
dimension plus a batch dimension.
* ``observations``: Observations :math:`O_n` in the sample. This can be the value of an observable `O(s_n)` or the \
plain configuration `s_n`. The array must have a leading device dimension plus a batch dimension.
* ``weights``: Weights :math:`w_n` associated with observation :math:`O_n`.
* ``estimator``: [optional] Function :math:`O(\\theta, s)` that computes an estimator parametrized by :math:`\\theta`
* ``params``: [optional] A set of parameters for the estimator function.
"""

def __init__(self, observations=None, weights=None):
def __init__(
self,
observations=None,
weights=None,
estimator=None,
params=None
):
"""Initializes SampledObs class.
Args:
* ``observations``: Observations :math:`O_n` in the sample. The array must have a leading device \
dimension plus a batch dimension.
* ``observations``: Observations :math:`O_n` in the sample. This can be the value of an observable `O(s_n)` or the \
plain configuration `s_n`. The array must have a leading device dimension plus a batch dimension.
* ``weights``: Weights :math:`w_n` associated with observation :math:`O_n`.
* ``estimator``: [optional] Function :math:`O(\\theta, s)` that computes an estimator parametrized by :math:`\\theta`
* ``params``: [optional] A set of parameters for the estimator function.
"""

jit_my_stuff()

if (observations is not None) and (weights is not None):
self._weights = weights
self._data = observations
self._mean = None
self._configs = None
def estimator_not_implemented(p,s):
raise Exception("No estimator function given.")

self._estimator = estimator_not_implemented
self._estimator_grad = estimator_not_implemented

if estimator is not None:

self._configs = observations

self._estimator = jVMC.global_defs.pmap_for_my_devices(
jax.vmap(lambda p, s: estimator(p,s), in_axes=(None, 0)),
in_axes=(None, 0)
)

self._estimator_grad = jVMC.global_defs.pmap_for_my_devices(
jax.vmap(lambda p, s: flat_grad(lambda a, b: jnp.real(estimator(a,b)))(p,s) + 1.j*flat_grad(lambda a, b: jnp.imag(estimator(a,b)))(p,s), in_axes=(None, 0)),
in_axes=(None, 0)
)

if params is not None:

observations = self._estimator(params, self._configs)

self._compute_data_and_mean(observations)


def _compute_data_and_mean(self, observations):

if (observations is not None) and (self._weights is not None):
if len(observations.shape) == 2:
observations = observations[...,None]

self._weights = weights
#self._weights = weights
self._mean = mpi.global_sum( _mean_helper(observations,self._weights)[:, None,...] )
self._data = _data_prep(observations, self._weights, self._mean)
else:
self._weights = weights
self._data = observations
self._mean = None


def mean(self):
def mean(self, params=None):
"""Returns the mean.
"""

if params is not None:
observations = self._estimator(params, self._configs)
self._compute_data_and_mean(observations)

return self._mean


def mean_and_grad(self, psi, params):
"""Returns the mean and gradient of the given estimator.
"""

obs = self._estimator(params, self._configs)
self._compute_data_and_mean(obs)

obsGrad = self._estimator_grad(params, self._configs)
obsGradMean = mpi.global_sum( _mean_helper(obsGrad,self._weights)[None,...] )

psiGrad = SampledObs( 2.0*jnp.real( psi.gradients(self._configs) ), self._weights )

return self._mean, psiGrad.covar(self).ravel() + obsGradMean



def covar(self, other=None):
"""Returns the covariance.
Expand Down
1 change: 0 additions & 1 deletion jVMC/vqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def scan_fun(c, x):

return res[:s.shape[0]]


def flat_gradient(fun, params, arg):
gr = grad(lambda p, y: jnp.real(fun.apply(p, y)))(params, arg)["params"]
gr = tree_flatten(tree_map(lambda x: x.ravel(), gr))[0]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
long_description = fh.read()


DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax>=0.4.1,<=0.4.20", "jaxlib>=0.4.1,<=0.4.20", "flax>=0.6.4,<=0.6.11", "mpi4py", "h5py", "PyYAML", "matplotlib"]
DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax>=0.4.1,<=0.4.20", "jaxlib>=0.4.1,<=0.4.20", "flax>=0.6.4,<=0.6.11", "mpi4py", "h5py", "PyYAML", "matplotlib", "scipy<1.13"] # Scipy version restricted, because jax is currently incompatible with new function namespace scipy.sparse.tril
#CUDA_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax[cuda]>=0.2.11,<=0.2.25", "flax>=0.3.6,<=0.3.6", "mpi4py", "h5py"]
DEV_DEPENDENCIES = DEFAULT_DEPENDENCIES + ["sphinx", "mock", "sphinx_rtd_theme", "pytest", "pytest-mpi"]

Expand Down
59 changes: 58 additions & 1 deletion tests/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import jax
import jax.numpy as jnp

import jVMC
from jVMC.stats import SampledObs
import jVMC.mpi_wrapper as mpi
import jVMC.operator as op
from jVMC.global_defs import device_count


Expand Down Expand Up @@ -34,6 +35,62 @@ def test_sampled_obs(self):
O = obs2._data.reshape((-1,2))
self.assertTrue(jnp.allclose(obs2.tangent_kernel(), jnp.matmul(O, jnp.conj(jnp.transpose(O)))))


def test_estimator(self):

L = 4

for rbm in [jVMC.nets.CpxRBM(numHidden=1, bias=False), jVMC.nets.RBM(numHidden=1, bias=False)]:

# Set up variational wave function
orbit = jVMC.util.symmetries.get_orbit_1D(L, "translation", "reflection")
net = jVMC.nets.sym_wrapper.SymNet(net=rbm, orbit=orbit)
psi = jVMC.vqs.NQS(net)

# Set up MCMC sampler
# mcSampler = jVMC.sampler.MCSampler(psi, (L,), jax.random.PRNGKey(0), updateProposer=jVMC.sampler.propose_spin_flip, sweepSteps=L+1, numChains=777)

exactSampler = jVMC.sampler.ExactSampler(psi, (L,))

p0 = psi.parameters

# configs, configsLogPsi, p = mcSampler.sample(numSamples=40000)

configs, configsLogPsi, p = exactSampler.sample()

h = op.BranchFreeOperator()
for i in range(L):
h.add(op.scal_opstr(2., (op.Sx(i),)))
h.add(op.scal_opstr(2., (op.Sy(i), op.Sz((i + 1) % L))))

op_estimator = h.get_estimator_function(psi)

obs1 = SampledObs(configs, p, estimator=op_estimator)

Oloc = h.get_O_loc(configs, psi)
obs2 = SampledObs(Oloc, p)
self.assertTrue( jnp.allclose( obs1.mean(p0), obs2.mean() ) )

psiGrads = SampledObs(psi.gradients(configs), p)
Eloc = h.get_O_loc(configs, psi, configsLogPsi)
Eloc = SampledObs( Eloc, p )

Egrad2 = 2*jnp.real( psiGrads.covar(Eloc) )

self.assertTrue(jnp.allclose( jnp.real(obs1.mean_and_grad(psi, p0)[1]), Egrad2.ravel() ))


# obs1 = SampledObs(weights=pEx, configs=configsEx, estimator=op_estimator)
# E0 = obs1.mean(psi.parameters)
# p0 = psi.get_parameters()
# dp = 1e-6
# p0 = p0.at[0].add(dp)
# psi.set_parameters(p0)
# configsEx, configsLogPsiEx, pEx = exactSampler.sample(parameters=psi.params)
# obs1 = SampledObs(weights=pEx, configs=configsEx, estimator=op_estimator)
# E1 = obs1.mean(psi.parameters)
# print((E1-E0)/dp)

def test_subset_function(self):

N = 10
Expand Down

0 comments on commit d17d618

Please sign in to comment.