Skip to content

Commit

Permalink
Merge pull request #172 from masa-su/fix/new_bernoulli
Browse files Browse the repository at this point in the history
Support old usage of bernoulli distribution
  • Loading branch information
masa-su authored Mar 19, 2021
2 parents 57a1fd9 + 7c44419 commit a0d43a1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ language: python
python:
- 3.6.5
install:
- pip install -e "."[test]
- pip install -e ".[test]"
script:
- pytest
22 changes: 18 additions & 4 deletions pixyz/distributions/exponential_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from torch.distributions import Beta as BetaTorch
from torch.distributions import Laplace as LaplaceTorch
from torch.distributions import Gamma as GammaTorch
from torch.distributions.utils import broadcast_all
from torch.nn.functional import binary_cross_entropy_with_logits

from ..utils import get_dict_values, sum_samples
from .distributions import DistributionBase
Expand Down Expand Up @@ -40,6 +42,12 @@ def has_reparam(self):
return True


class BernoulliTorchOld(BernoulliTorch):
def log_prob(self, value):
logits, value = broadcast_all(self.logits, value)
return -binary_cross_entropy_with_logits(logits, value, reduction='none')


class Bernoulli(DistributionBase):
"""Bernoulli distribution parameterized by :attr:`probs`."""
def __init__(self, var=['x'], cond_var=[], name='p', features_shape=torch.Size(), probs=None):
Expand All @@ -51,7 +59,7 @@ def params_keys(self):

@property
def distribution_torch_class(self):
return BernoulliTorch
return BernoulliTorchOld

@property
def distribution_name(self):
Expand Down Expand Up @@ -111,7 +119,7 @@ def set_dist(self, x_dict={}, batch_n=None, sampling=False, **kwargs):
self._dist = self.distribution_torch_class(**params)
else:
hard_params_keys = ["probs"]
self._dist = BernoulliTorch(**get_dict_values(params, hard_params_keys, return_dict=True))
self._dist = BernoulliTorchOld(**get_dict_values(params, hard_params_keys, return_dict=True))

# expand batch_n
if batch_n:
Expand Down Expand Up @@ -176,6 +184,12 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None, **kwargs):
return log_prob


class CategoricalTorchOld(CategoricalTorch):
def log_prob(self, value):
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)


class Categorical(DistributionBase):
"""Categorical distribution parameterized by :attr:`probs`."""
def __init__(self, var=['x'], cond_var=[], name='p', features_shape=torch.Size(), probs=None):
Expand All @@ -188,7 +202,7 @@ def params_keys(self):

@property
def distribution_torch_class(self):
return CategoricalTorch
return CategoricalTorchOld

@property
def distribution_name(self):
Expand Down Expand Up @@ -251,7 +265,7 @@ def set_dist(self, x_dict={}, batch_n=None, sampling=False, **kwargs):
self._dist = self.distribution_torch_class(**params)
else:
hard_params_keys = ["probs"]
self._dist = BernoulliTorch(**get_dict_values(params, hard_params_keys, return_dict=True))
self._dist = BernoulliTorchOld(**get_dict_values(params, hard_params_keys, return_dict=True))

# expand batch_n
if batch_n:
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_unknown_option(self, dist):

class TestMixtureDistribution:
def test_sample_mean(self):
dist = MixtureModel([Normal(loc=0, scale=1), Normal(loc=1, scale=1)], Categorical(probs=torch.tensor([1, 2])))
dist = MixtureModel([Normal(loc=0, scale=1), Normal(loc=1, scale=1)], Categorical(probs=torch.tensor([1., 2.])))
assert dist.sample(sample_mean=True)['x'] == torch.ones(1)


Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_expornential_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ def nearly_eq(self, tensor1, tensor2):
return abs(tensor1.item() - tensor2.item()) < 0.001

def test_sample_mean(self):
rb = RelaxedBernoulli(var=['x'], temperature=torch.tensor(0.5), probs=torch.tensor([1, 2]))
rb = RelaxedBernoulli(var=['x'], temperature=torch.tensor(0.5), probs=torch.tensor([0.5, 0.8]))
with pytest.raises(NotImplementedError):
rb.sample(sample_mean=True)

0 comments on commit a0d43a1

Please sign in to comment.