From a5d462000175d36dd6599e6f4e18c991a28bddaf Mon Sep 17 00:00:00 2001 From: Will Handley Date: Thu, 29 Jun 2023 09:17:05 +0100 Subject: [PATCH] dlogX -> logdX (#306) * dlogX -> logdX * bumped version * Actually added the error this time * Added logdX tests --- README.rst | 2 +- anesthetic/_version.py | 2 +- anesthetic/samples.py | 33 ++++++++++++++++++++------------- tests/test_samples.py | 27 +++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/README.rst b/README.rst index d9825869..15021611 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ anesthetic: nested sampling post-processing =========================================== :Authors: Will Handley and Lukas Hergt -:Version: 2.0.0-beta.40 +:Version: 2.0.0-beta.41 :Homepage: https://github.com/handley-lab/anesthetic :Documentation: http://anesthetic.readthedocs.io/ diff --git a/anesthetic/_version.py b/anesthetic/_version.py index 0376b8da..a9c286ee 100644 --- a/anesthetic/_version.py +++ b/anesthetic/_version.py @@ -1 +1 @@ -__version__ = '2.0.0b40' +__version__ = '2.0.0b41' diff --git a/anesthetic/samples.py b/anesthetic/samples.py index c3bdd216..6cf00467 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -872,7 +872,14 @@ def logX(self, nsamples=None): logX.name = 'logX' return logX + # TODO: remove this in version >= 2.1 def dlogX(self, nsamples=None): + # noqa: disable=D102 + raise NotImplementedError( + "This is anesthetic 1.0 syntax. You should instead use logdX." + ) + + def logdX(self, nsamples=None): """Compute volume of shell of loglikelihood. Parameters @@ -892,10 +899,10 @@ def dlogX(self, nsamples=None): logX = self.logX(nsamples) logXp = logX.shift(1, fill_value=0) logXm = logX.shift(-1, fill_value=-np.inf) - dlogX = np.log(1 - np.exp(logXm-logXp)) + logXp - np.log(2) - dlogX.name = 'dlogX' + logdX = np.log(1 - np.exp(logXm-logXp)) + logXp - np.log(2) + logdX.name = 'logdX' - return dlogX + return logdX def _betalogL(self, beta=None): """Log(L**beta) convenience function. @@ -958,20 +965,20 @@ def logw(self, nsamples=None, beta=None): if np.ndim(nsamples) > 0: return nsamples - dlogX = self.dlogX(nsamples) + logdX = self.logdX(nsamples) betalogL = self._betalogL(beta) - if dlogX.ndim == 1 and betalogL.ndim == 1: - logw = dlogX + betalogL - elif dlogX.ndim > 1 and betalogL.ndim == 1: - logw = dlogX.add(betalogL, axis=0) - elif dlogX.ndim == 1 and betalogL.ndim > 1: - logw = betalogL.add(dlogX, axis=0) + if logdX.ndim == 1 and betalogL.ndim == 1: + logw = logdX + betalogL + elif logdX.ndim > 1 and betalogL.ndim == 1: + logw = logdX.add(betalogL, axis=0) + elif logdX.ndim == 1 and betalogL.ndim > 1: + logw = betalogL.add(logdX, axis=0) else: - cols = MultiIndex.from_product([betalogL.columns, dlogX.columns]) - dlogX = dlogX.reindex(columns=cols, level='samples') + cols = MultiIndex.from_product([betalogL.columns, logdX.columns]) + logdX = logdX.reindex(columns=cols, level='samples') betalogL = betalogL.reindex(columns=cols, level='beta') - logw = betalogL+dlogX + logw = betalogL+logdX return logw def logZ(self, nsamples=None, beta=None): diff --git a/tests/test_samples.py b/tests/test_samples.py index 07d9bfc9..42e4f308 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -506,6 +506,30 @@ def test_logX(): assert (abs(logX.mean(axis=1) - pc.logX()) < logX.std(axis=1) * 3).all() +def test_logdX(): + np.random.seed(3) + pc = read_chains('./tests/example_data/pc') + + logdX = pc.logdX() + assert isinstance(logdX, WeightedSeries) + assert_array_equal(logdX.index, pc.index) + + nsamples = 10 + + logdX = pc.logdX(nsamples=nsamples) + assert isinstance(logdX, WeightedDataFrame) + assert_array_equal(logdX.index, pc.index) + assert_array_equal(logdX.columns, np.arange(nsamples)) + assert logdX.columns.name == 'samples' + + assert not (logdX > 0).to_numpy().any() + + n = 1000 + logdX = pc.logdX(n) + + assert (abs(logdX.mean(axis=1) - pc.logdX()) < logdX.std(axis=1) * 3).all() + + def test_logbetaL(): np.random.seed(3) pc = read_chains('./tests/example_data/pc') @@ -1347,6 +1371,9 @@ def test_old_gui(): with pytest.raises(NotImplementedError): make_1d_axes(['x0', 'y0'], tex={'x0': '$x_0$', 'y0': '$y_0$'}) + with pytest.raises(NotImplementedError): + samples.dlogX(1000) + def test_groupby_stats(): mcmc = read_chains('./tests/example_data/cb')