From 5d81a122f97b038ae16d0a1686f0bc72eef0fa56 Mon Sep 17 00:00:00 2001 From: ktaaaki Date: Tue, 16 Mar 2021 18:10:25 +0900 Subject: [PATCH 1/3] support old usage of bernoulli distribution --- .travis.yml | 2 +- .../exponential_distributions.py | 22 +++++++++++++++---- .../test_expornential_distributions.py | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index 927c8826..b34f30c7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,6 +2,6 @@ language: python python: - 3.6.5 install: - - pip install -e "."[test] + - pip install -e ".[test]" script: - pytest \ No newline at end of file diff --git a/pixyz/distributions/exponential_distributions.py b/pixyz/distributions/exponential_distributions.py index 9e9e4857..58b2136d 100644 --- a/pixyz/distributions/exponential_distributions.py +++ b/pixyz/distributions/exponential_distributions.py @@ -9,6 +9,8 @@ from torch.distributions import Beta as BetaTorch from torch.distributions import Laplace as LaplaceTorch from torch.distributions import Gamma as GammaTorch +from torch.distributions.utils import broadcast_all +from torch.nn.functional import binary_cross_entropy_with_logits from ..utils import get_dict_values, sum_samples from .distributions import DistributionBase @@ -40,6 +42,12 @@ def has_reparam(self): return True +class BernoulliTorchOld(BernoulliTorch): + def log_prob(self, value): + logits, value = broadcast_all(self.logits, value) + return -binary_cross_entropy_with_logits(logits, value, reduction='none') + + class Bernoulli(DistributionBase): """Bernoulli distribution parameterized by :attr:`probs`.""" def __init__(self, var=['x'], cond_var=[], name='p', features_shape=torch.Size(), probs=None): @@ -51,7 +59,7 @@ def params_keys(self): @property def distribution_torch_class(self): - return BernoulliTorch + return BernoulliTorchOld @property def distribution_name(self): @@ -111,7 +119,7 @@ def set_dist(self, x_dict={}, batch_n=None, sampling=False, **kwargs): self._dist = self.distribution_torch_class(**params) else: hard_params_keys = ["probs"] - self._dist = BernoulliTorch(**get_dict_values(params, hard_params_keys, return_dict=True)) + self._dist = BernoulliTorchOld(**get_dict_values(params, hard_params_keys, return_dict=True)) # expand batch_n if batch_n: @@ -176,6 +184,12 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs): return log_prob +class CategoricalTorchOld(CategoricalTorch): + def log_prob(self, value): + indices = value.max(-1)[1] + return self._categorical.log_prob(indices) + + class Categorical(DistributionBase): """Categorical distribution parameterized by :attr:`probs`.""" def __init__(self, var=['x'], cond_var=[], name='p', features_shape=torch.Size(), probs=None): @@ -188,7 +202,7 @@ def params_keys(self): @property def distribution_torch_class(self): - return CategoricalTorch + return CategoricalTorchOld @property def distribution_name(self): @@ -251,7 +265,7 @@ def set_dist(self, x_dict={}, batch_n=None, sampling=False, **kwargs): self._dist = self.distribution_torch_class(**params) else: hard_params_keys = ["probs"] - self._dist = BernoulliTorch(**get_dict_values(params, hard_params_keys, return_dict=True)) + self._dist = BernoulliTorchOld(**get_dict_values(params, hard_params_keys, return_dict=True)) # expand batch_n if batch_n: diff --git a/tests/distributions/test_expornential_distributions.py b/tests/distributions/test_expornential_distributions.py index 09d68838..1190e90d 100644 --- a/tests/distributions/test_expornential_distributions.py +++ b/tests/distributions/test_expornential_distributions.py @@ -20,6 +20,6 @@ def nearly_eq(self, tensor1, tensor2): return abs(tensor1.item() - tensor2.item()) < 0.001 def test_sample_mean(self): - rb = RelaxedBernoulli(var=['x'], temperature=torch.tensor(0.5), probs=torch.tensor([1, 2])) + rb = RelaxedBernoulli(var=['x'], temperature=torch.tensor(0.5), probs=torch.tensor([1 / 3., 2 / 3.])) with pytest.raises(NotImplementedError): rb.sample(sample_mean=True) From 0cf2b897d15d2eeeb2039923a1c0ff1738e90a77 Mon Sep 17 00:00:00 2001 From: ktaaaki Date: Tue, 16 Mar 2021 18:34:38 +0900 Subject: [PATCH 2/3] fix type of tensor --- tests/distributions/test_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributions/test_distribution.py b/tests/distributions/test_distribution.py index e3e50a47..732ee1c6 100644 --- a/tests/distributions/test_distribution.py +++ b/tests/distributions/test_distribution.py @@ -120,7 +120,7 @@ def test_unknown_option(self, dist): class TestMixtureDistribution: def test_sample_mean(self): - dist = MixtureModel([Normal(loc=0, scale=1), Normal(loc=1, scale=1)], Categorical(probs=torch.tensor([1, 2]))) + dist = MixtureModel([Normal(loc=0, scale=1), Normal(loc=1, scale=1)], Categorical(probs=torch.tensor([1., 2.]))) assert dist.sample(sample_mean=True)['x'] == torch.ones(1) From 7c444199e36d1556a7766e489813b60596d3be0a Mon Sep 17 00:00:00 2001 From: ktaaaki Date: Fri, 19 Mar 2021 12:51:59 +0900 Subject: [PATCH 3/3] Fixed confusing initial values such as probs in Categorical distribution. --- tests/distributions/test_expornential_distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributions/test_expornential_distributions.py b/tests/distributions/test_expornential_distributions.py index 1190e90d..cd0725a2 100644 --- a/tests/distributions/test_expornential_distributions.py +++ b/tests/distributions/test_expornential_distributions.py @@ -20,6 +20,6 @@ def nearly_eq(self, tensor1, tensor2): return abs(tensor1.item() - tensor2.item()) < 0.001 def test_sample_mean(self): - rb = RelaxedBernoulli(var=['x'], temperature=torch.tensor(0.5), probs=torch.tensor([1 / 3., 2 / 3.])) + rb = RelaxedBernoulli(var=['x'], temperature=torch.tensor(0.5), probs=torch.tensor([0.5, 0.8])) with pytest.raises(NotImplementedError): rb.sample(sample_mean=True)