Skip to content

Commit

Permalink
dlogX -> logdX (#306)
Browse files Browse the repository at this point in the history
* dlogX -> logdX

* bumped version

* Actually added the error this time

* Added logdX tests
  • Loading branch information
williamjameshandley authored Jun 29, 2023
1 parent 98a4eb2 commit a5d4620
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.0.0-beta.40
:Version: 2.0.0-beta.41
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.0b40'
__version__ = '2.0.0b41'
33 changes: 20 additions & 13 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,14 @@ def logX(self, nsamples=None):
logX.name = 'logX'
return logX

# TODO: remove this in version >= 2.1
def dlogX(self, nsamples=None):
# noqa: disable=D102
raise NotImplementedError(
"This is anesthetic 1.0 syntax. You should instead use logdX."
)

def logdX(self, nsamples=None):
"""Compute volume of shell of loglikelihood.
Parameters
Expand All @@ -892,10 +899,10 @@ def dlogX(self, nsamples=None):
logX = self.logX(nsamples)
logXp = logX.shift(1, fill_value=0)
logXm = logX.shift(-1, fill_value=-np.inf)
dlogX = np.log(1 - np.exp(logXm-logXp)) + logXp - np.log(2)
dlogX.name = 'dlogX'
logdX = np.log(1 - np.exp(logXm-logXp)) + logXp - np.log(2)
logdX.name = 'logdX'

return dlogX
return logdX

def _betalogL(self, beta=None):
"""Log(L**beta) convenience function.
Expand Down Expand Up @@ -958,20 +965,20 @@ def logw(self, nsamples=None, beta=None):
if np.ndim(nsamples) > 0:
return nsamples

dlogX = self.dlogX(nsamples)
logdX = self.logdX(nsamples)
betalogL = self._betalogL(beta)

if dlogX.ndim == 1 and betalogL.ndim == 1:
logw = dlogX + betalogL
elif dlogX.ndim > 1 and betalogL.ndim == 1:
logw = dlogX.add(betalogL, axis=0)
elif dlogX.ndim == 1 and betalogL.ndim > 1:
logw = betalogL.add(dlogX, axis=0)
if logdX.ndim == 1 and betalogL.ndim == 1:
logw = logdX + betalogL
elif logdX.ndim > 1 and betalogL.ndim == 1:
logw = logdX.add(betalogL, axis=0)
elif logdX.ndim == 1 and betalogL.ndim > 1:
logw = betalogL.add(logdX, axis=0)
else:
cols = MultiIndex.from_product([betalogL.columns, dlogX.columns])
dlogX = dlogX.reindex(columns=cols, level='samples')
cols = MultiIndex.from_product([betalogL.columns, logdX.columns])
logdX = logdX.reindex(columns=cols, level='samples')
betalogL = betalogL.reindex(columns=cols, level='beta')
logw = betalogL+dlogX
logw = betalogL+logdX
return logw

def logZ(self, nsamples=None, beta=None):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,30 @@ def test_logX():
assert (abs(logX.mean(axis=1) - pc.logX()) < logX.std(axis=1) * 3).all()


def test_logdX():
np.random.seed(3)
pc = read_chains('./tests/example_data/pc')

logdX = pc.logdX()
assert isinstance(logdX, WeightedSeries)
assert_array_equal(logdX.index, pc.index)

nsamples = 10

logdX = pc.logdX(nsamples=nsamples)
assert isinstance(logdX, WeightedDataFrame)
assert_array_equal(logdX.index, pc.index)
assert_array_equal(logdX.columns, np.arange(nsamples))
assert logdX.columns.name == 'samples'

assert not (logdX > 0).to_numpy().any()

n = 1000
logdX = pc.logdX(n)

assert (abs(logdX.mean(axis=1) - pc.logdX()) < logdX.std(axis=1) * 3).all()


def test_logbetaL():
np.random.seed(3)
pc = read_chains('./tests/example_data/pc')
Expand Down Expand Up @@ -1347,6 +1371,9 @@ def test_old_gui():
with pytest.raises(NotImplementedError):
make_1d_axes(['x0', 'y0'], tex={'x0': '$x_0$', 'y0': '$y_0$'})

with pytest.raises(NotImplementedError):
samples.dlogX(1000)


def test_groupby_stats():
mcmc = read_chains('./tests/example_data/cb')
Expand Down

0 comments on commit a5d4620

Please sign in to comment.