diff --git a/examples/ex7_fermions.ipynb b/examples/ex7_fermions.ipynb new file mode 100644 index 0000000..ee4100a --- /dev/null +++ b/examples/ex7_fermions.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "#########################\n", + "# Example for using fermionic operators\n", + "# in the jVMC framework\n", + "#########################\n", + "\n", + "# jVMC\n", + "import jVMC\n", + "import jVMC.nets as nets\n", + "import jVMC.operator as op\n", + "from jVMC.operator import number, creation, annihilation\n", + "import jVMC.sampler\n", + "from jVMC.util import ground_state_search, measure\n", + "from jVMC.vqs import NQS\n", + "from jVMC.stats import SampledObs\n", + "from jVMC import global_defs\n", + "\n", + "# python stuff\n", + "import functools\n", + "\n", + "# jax\n", + "import jax\n", + "from jax.config import config\n", + "config.update(\"jax_enable_x64\", True)\n", + "import jax.numpy as jnp\n", + "import flax.linen as nn\n", + "\n", + "import jax.random as random\n", + "\n", + "# numpy\n", + "import numpy as np\n", + "\n", + "# plotting\n", + "import matplotlib.pyplot as plt\n", + "\n", + "#########################\n", + "# check against openfermion\n", + "#########################\n", + "import openfermion as of\n", + "from openfermion.ops import FermionOperator as fop\n", + "from openfermion.linalg import get_sparse_operator" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "##########################\n", + "# custon tarbget wave function\n", + "# specific to openfermion compatibility\n", + "class Target(nn.Module):\n", + " \"\"\"Target wave function, returns a vector with the same dimension as the Hilbert space\n", + "\n", + " Initialization arguments:\n", + " * ``L``: System size\n", + " * ``d``: local Hilbert space dimension\n", + " * ``delta``: small number to avoid log(0)\n", + "\n", + " \"\"\"\n", + " L: int\n", + " d: float = 2.00\n", + " delta: float = 1e-15\n", + "\n", + " @nn.compact\n", + " def __call__(self, s):\n", + " kernel = self.param('kernel',\n", + " nn.initializers.constant(1),\n", + " (int(self.d**self.L)))\n", + " # return amplitude for state s\n", + " idx = ((self.d**jnp.arange(self.L)).dot(s[::-1])).astype(int) # NOTE that the state is reversed to account for different bit conventions used in openfermion\n", + " return jnp.log(abs(kernel[idx]+self.delta)) + 1.j*jnp.angle(kernel[idx]) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fermionic operators\n", + "\n", + "Fermionic operators have to satisfy the following condition\n", + "$$\n", + "\\lbrace \\hat{c}^\\dagger_i, \\hat{c}_j\\rbrace = \\delta_{ij} \\; ,\n", + "$$\n", + "where $i,j$ are so-called *flavours*.As is done in 'openfermoin' we do not allow for a spin quantum number. In other words, all our fermionic operators can only carry a single flavour. For higher flavour indeces one has combine several distinct fermionis.\n", + "\n", + "The key to realizing fermions is the Jordan-Wigner factor. Every fermionic state is constructed using a filling order, then we have to count how many craetion operatros a given operator has to commute thorugh to arrive at his filling order position.\n", + "We can achieve this as follows.\n", + "$$\n", + "\\hat{c}^\\dagger\\vert 1, 0 \\rangle = (-1)^\\Omega \\vert 1,1\\rangle\n", + "$$\n", + "with \n", + "$$\n", + "\\Omega = \\sum^{j-1}_{i=0} s_i \\; .\n", + "$$\n", + "In the following we construct the repulsive Hubbard Model on a chain as an example and compare it to openfermion\n", + "$$\n", + "H = U \\sum^N_{i=1} \\hat{n}_{\\uparrow i}\\hat{n}_{\\downarrow i} + t \\sum^{N-1}_{i=1} \\hat{c}^\\dagger_{\\sigma i} \\hat{c}_{\\sigma i+1} + h.c. \\;.\n", + "$$ " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "#########################\n", + "# jVMC hamiltonian\n", + "#########################\n", + "t = - 1.0 # hopping\n", + "mu = -2.0 # chemical potential\n", + "V = 4.0 # interaction\n", + "L = 4 # number of sites\n", + "flavour = 2 # number of flavours\n", + "flavourL = flavour*L # number of spins times sites\n", + "\n", + "# initalize the Hamitonian\n", + "hamiltonian = op.BranchFreeOperator()\n", + "# impurity definitions\n", + "site1UP = 0\n", + "site1DO = flavourL-1#//flavour\n", + "# loop over the 1d lattice\n", + "for i in range(0,flavourL//flavour):\n", + " # interaction\n", + " hamiltonian.add(op.scal_opstr( V, ( number(site1UP + i) , number(site1DO - i) ) ) )\n", + " # chemical potential\n", + " hamiltonian.add(op.scal_opstr(mu , ( number(site1UP + i) ,) ) )\n", + " hamiltonian.add(op.scal_opstr(mu , ( number(site1DO - i) ,) ) )\n", + " if i == flavourL//flavour-1:\n", + " continue\n", + " # up chain hopping\n", + " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i) ,creation(site1UP + i + 1) ) ) )\n", + " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1UP + i + 1) ,creation(site1UP + i) ) ) )\n", + " # down chain hopping\n", + " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i) ,creation(site1DO - i - 1) ) ) )\n", + " hamiltonian.add(op.scal_opstr( t, ( annihilation(site1DO - i - 1) ,creation(site1DO - i) ) ) )" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "#########################\n", + "# openfermion\n", + "#########################\n", + "\n", + "H = 0.0*fop()\n", + "# loop over the 1d lattice\n", + "for i in range(0,flavourL//flavour):\n", + " H += fop(((site1UP + i,1),(site1UP + i,0),(site1DO - i,1),(site1DO - i,0)),V) \n", + " H += fop(((site1UP + i,1),(site1UP + i,0)),mu) + fop(((site1DO - i,1),(site1DO - i,0)),mu)\n", + " if i == flavourL//flavour-1:\n", + " continue\n", + " # up chain\n", + " H += (fop(((site1UP + i,1),(site1UP + i + 1,0)),t) + fop(((site1UP + i + 1,1),(site1UP + i,0)),t))\n", + " # down chain\n", + " H += (fop(((site1DO - i,1),(site1DO - i - 1,0)),t) + fop(((site1DO - i - 1,1),(site1DO - i,0)),t))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "#########################\n", + "# diagonalize the Openfermion Hamiltonain\n", + "#########################\n", + "\n", + "ham = get_sparse_operator(H)\n", + "a, b = np.linalg.eigh(ham.toarray())\n", + "\n", + "chi_model = Target(L=flavourL, d=2)\n", + "chi = NQS(chi_model)\n", + "chi(jnp.array(jnp.ones((1, 1, flavourL))))\n", + "chi.set_parameters(b[:,0]+1e-14)\n", + "chiSampler = jVMC.sampler.ExactSampler(chi, (flavourL,))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s, logPsi, p = chiSampler.sample()\n", + "sPrime, _ = hamiltonian.get_s_primes(s)\n", + "Oloc = hamiltonian.get_O_loc(s, chi, logPsi)\n", + "Omean = jVMC.mpi_wrapper.global_mean(Oloc,p)\n", + "\n", + "print(\"Ground state energy: \\njVMC: %.8f, Openfermion: %.8f\"%(Omean.real,a[0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Finding the ground state brute force" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up variational wave function\n", + "all_states = Target(L=flavourL, d=2)\n", + "psi = NQS(all_states)\n", + "# initialize NQS\n", + "print(\"Net init: \",psi(jnp.array(jnp.ones((1, 1, flavourL)))))\n", + "# Set up exact sampler\n", + "exactSampler = jVMC.sampler.ExactSampler(psi, flavourL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up sampler\n", + "sampler = jVMC.sampler.ExactSampler(psi, flavourL)\n", + "\n", + "# Set up TDVP\n", + "tdvpEquation = jVMC.util.tdvp.TDVP(sampler, rhsPrefactor=1.,diagonalShift=10, makeReal='real')\n", + "\n", + "stepper = jVMC.util.stepper.Euler(timeStep=5e-1) # ODE integrator\n", + "\n", + "n_steps = 500\n", + "res = []\n", + "for n in range(n_steps):\n", + "\n", + " dp, _ = stepper.step(0, tdvpEquation, psi.get_parameters(), hamiltonian=hamiltonian, psi=psi, numSamples=None)\n", + " psi.set_parameters(dp)\n", + "\n", + " print(n, jax.numpy.real(tdvpEquation.ElocMean0), tdvpEquation.ElocVar0)\n", + "\n", + " res.append([n, jax.numpy.real(tdvpEquation.ElocMean0), tdvpEquation.ElocVar0])\n", + "\n", + "res = np.array(res)\n", + "\n", + "fig, ax = plt.subplots(2, 1, sharex=True, figsize=[4.8, 4.8])\n", + "ax[0].semilogy(res[:, 0], res[:, 1] - a[0], '-', label=r\"$L=\" + str(L) + \"$\")\n", + "ax[0].set_ylabel(r'$(E-E_0)/L$')\n", + "\n", + "ax[1].semilogy(res[:, 0], res[:, 2], '-')\n", + "ax[1].set_ylabel(r'Var$(E)/L$')\n", + "ax[0].legend()\n", + "plt.xlabel('iteration')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s, logPsi, _ = exactSampler.sample()\n", + "var_wf = np.real(np.exp(logPsi))[0]\n", + "# normalizing the wave function\n", + "var_wf /= var_wf.dot(var_wf)**0.5\n", + "\n", + "figure = plt.figure(dpi=100)\n", + "plt.xlabel('state')\n", + "plt.ylabel('amplitude')\n", + "plt.plot(var_wf,label='jVMC')\n", + "plt.plot(np.exp(chi(s)).real[0],'--',label='openfermion')\n", + "plt.legend()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jvmc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/jVMC/operator/branch_free.py b/jVMC/operator/branch_free.py index 860c5ad..555ba6d 100644 --- a/jVMC/operator/branch_free.py +++ b/jVMC/operator/branch_free.py @@ -107,6 +107,73 @@ def Sm(idx): return {'idx': idx, 'map': jnp.array([0, 0], dtype=np.int32), 'matEls': jnp.array([0.0, 1.0], dtype=opDtype), 'diag': False} +###################### +# fermionic number operator +def number(idx): + """Returns a :math:`c^\dagger c` femrioic number operator + + Args: + + * ``idx``: Index of the local Hilbert space. + + Returns: + Dictionary defining :math:`c^\dagger c` femrioic number operator + + """ + + return { + 'idx': idx, + 'map': jax.numpy.array([0,1],dtype=np.int32), + 'matEls': jax.numpy.array([0.,1.],dtype=opDtype), + 'diag': True, + 'fermionic': False + } + +###################### +# fermionic creation operator +def creation(idx): + """Returns a :math:`c^\dagger` femrioic creation operator + + Args: + + * ``idx``: Index of the local Hilbert space. + + Returns: + Dictionary defining :math:`c^\dagger` femrioic creation operator + + """ + + return { + 'idx': idx, + 'map': jax.numpy.array([1,0],dtype=np.int32), + 'matEls': jax.numpy.array([1.,0.],dtype=opDtype), + 'diag': False, + "fermionic": True + } + +###################### +# fermionic annihilation operator +def annihilation(idx): + """Returns a :math:`c` femrioic creation operator + + Args: + + * ``idx``: Index of the local Hilbert space. + + Returns: + Dictionary defining :math:`c` femrioic creation operator + + """ + + return { + 'idx': idx, + 'map': jax.numpy.array([1,0],dtype=np.int32), + 'matEls': jax.numpy.array([0.,1.],dtype=opDtype), + 'diag': False, + "fermionic": True + } + + import copy @jax.jit @@ -179,6 +246,9 @@ def compile(self): self.matEls = [] self.diag = [] self.prefactor = [] + ######## fermions ######## + self.fermionic = [] + ########################## self.maxOpStrLength = 0 for op in self.ops: tmpLen = len(op) @@ -192,6 +262,7 @@ def compile(self): self.idx.append([]) self.map.append([]) self.matEls.append([]) + self.fermionic.append([]) # check whether string contains prefactor k0=0 if callable(op[0]): @@ -207,10 +278,20 @@ def compile(self): self.idx[o].append(op[k]['idx']) self.map[o].append(op[k]['map']) self.matEls[o].append(op[k]['matEls']) + ######## fermions ######## + fermi_check = True + if "fermionic" in op[k]: + if op[k]["fermionic"]: + fermi_check = False + self.fermionic[o].append(1.) + if fermi_check: + self.fermionic[o].append(0.) + ########################## else: self.idx[o].append(IdOp['idx']) self.map[o].append(IdOp['map']) self.matEls[o].append(IdOp['matEls']) + self.fermionic[o].append(0.) if isDiagonal: self.diag.append(o) @@ -219,6 +300,9 @@ def compile(self): self.idxC = jnp.array(self.idx, dtype=np.int32) self.mapC = jnp.array(self.map, dtype=np.int32) self.matElsC = jnp.array(self.matEls, dtype=opDtype) + ######## fermions ######## + self.fermionicC = jnp.array(self.fermionic, dtype=np.int32) + ########################## self.diag = jnp.array(self.diag, dtype=np.int32) def arg_fun(*args, prefactor, init): @@ -243,10 +327,10 @@ def arg_fun(*args, prefactor, init): return (jnp.array(res), ) - return functools.partial(self._get_s_primes, idxC=self.idxC, mapC=self.mapC, matElsC=self.matElsC, diag=self.diag, prefactor=self.prefactor),\ + return functools.partial(self._get_s_primes, idxC=self.idxC, mapC=self.mapC, matElsC=self.matElsC, diag=self.diag, fermiC=self.fermionicC, prefactor=self.prefactor),\ functools.partial(arg_fun, prefactor=self.prefactor, init=np.ones(self.idxC.shape[0], dtype=self.matElsC.dtype)) - def _get_s_primes(self, s, *args, idxC, mapC, matElsC, diag, prefactor): + def _get_s_primes(self, s, *args, idxC, mapC, matElsC, diag, fermiC, prefactor): numOps = idxC.shape[0] #matEl = jnp.ones(numOps, dtype=matElsC.dtype) @@ -254,28 +338,37 @@ def _get_s_primes(self, s, *args, idxC, mapC, matElsC, diag, prefactor): sp = jnp.array([s] * numOps) + ######## fermions ######## + mask = jnp.tril(jnp.ones((s.shape[-1],s.shape[-1]),dtype=int),-1).T + ########################## + def apply_fun(c, x): config, configMatEl = c - idx, sMap, matEls = x + idx, sMap, matEls, fermi = x configShape = config.shape config = config.ravel() - configMatEl = configMatEl * matEls[config[idx]] + ######## fermions ######## + configMatEl = configMatEl * matEls[config[idx]] * jnp.prod((1 - 2 * fermi) * \ + (2 * fermi * mask[idx] +\ + (1 - 2 * fermi)) * config + \ + (1 - abs(config))) + ########################## config = config.at[idx].set(sMap[config[idx]]) return (config.reshape(configShape), configMatEl), None #def apply_multi(config, configMatEl, opIdx, opMap, opMatEls, prefactor): - def apply_multi(config, configMatEl, opIdx, opMap, opMatEls): + def apply_multi(config, configMatEl, opIdx, opMap, opMatEls, opFermi): - (config, configMatEl), _ = jax.lax.scan(apply_fun, (config, configMatEl), (opIdx, opMap, opMatEls)) + (config, configMatEl), _ = jax.lax.scan(apply_fun, (config, configMatEl), (opIdx, opMap, opMatEls, opFermi)) #return config, prefactor*configMatEl return config, configMatEl # vmap over operators #sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC, jnp.array([f(*args) for f in prefactor])) - sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC) + sp, matEl = vmap(apply_multi, in_axes=(0, 0, 0, 0, 0, 0))(sp, matEl, idxC, mapC, matElsC, fermiC) if len(diag) > 1: matEl = matEl.at[diag[0]].set(jnp.sum(matEl[diag], axis=0)) matEl = matEl.at[diag[1:]].set(jnp.zeros((diag.shape[0] - 1,), dtype=matElsC.dtype)) diff --git a/setup.py b/setup.py index f0c0b40..ec61dc1 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,7 @@ with open("README.md", "r") as fh: long_description = fh.read() - -DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax>=0.4.1,<=0.4.20", "jaxlib>=0.4.1,<=0.4.20", "flax>=0.6.4,<=0.6.11", "mpi4py", "h5py", "PyYAML", "matplotlib", "scipy<1.13"] # Scipy version restricted, because jax is currently incompatible with new function namespace scipy.sparse.tril +DEFAULT_DEPENDENCIES = ["setuptools", "wheel", "numpy", "openfermion", "jax>=0.4.1,<=0.4.20", "jaxlib>=0.4.1,<=0.4.20", "flax>=0.6.4,<=0.6.11", "mpi4py", "h5py", "PyYAML", "matplotlib", "scipy<1.13"] # Scipy version restricted, because jax is currently incompatible with new function namespace scipy.sparse.tril #CUDA_DEPENDENCIES = ["setuptools", "wheel", "numpy", "jax[cuda]>=0.2.11,<=0.2.25", "flax>=0.3.6,<=0.3.6", "mpi4py", "h5py"] DEV_DEPENDENCIES = DEFAULT_DEPENDENCIES + ["sphinx", "mock", "sphinx_rtd_theme", "pytest", "pytest-mpi"] diff --git a/tests/operator_test.py b/tests/operator_test.py index da19ca2..8757790 100644 --- a/tests/operator_test.py +++ b/tests/operator_test.py @@ -132,6 +132,44 @@ def test_td_prefactor(self): hamiltonian.compile() + def test_fermionic_operators(self): + L = 2 + + rbm = nets.CpxRBM(numHidden=2, bias=True) + psi = NQS(rbm) + + sampler = jVMC.sampler.ExactSampler(psi, (L,)) + + def commutator(i,j): + Comm = op.BranchFreeOperator() + Comm.add(op.scal_opstr( 1., (op.annihilation(i), op.creation(j), ) ) ) + Comm.add(op.scal_opstr( 1., (op.creation(j), op.annihilation(i), ) ) ) + return Comm + + observalbes_dict = { + "same_site": [commutator(0,0),commutator(1,1)], + "distinct_site": [commutator(0,1),commutator(1,0)] + } + out_dict = jVMC.util.util.measure(observalbes_dict, psi, sampler) + + self.assertTrue( + jnp.allclose( + jnp.concatenate( + (out_dict["same_site"]['mean'], + out_dict["distinct_site"]['mean'])), + jnp.array([1.,1.,0.,0.]), + rtol=1e-15) + ) + + self.assertTrue( + jnp.allclose( + jnp.concatenate( + (out_dict["same_site"]['variance'], + out_dict["distinct_site"]['variance'])), + jnp.array([0.,0.,0.,0.]), + rtol=1e-15) + ) + if __name__ == "__main__": unittest.main()