Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

[WIP] Pytest: Split & parametrise model/inference (intra-module) #1809

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
"src/beanmachine/ppl/compiler/runtime.py",
"src/beanmachine/ppl/inference/base_inference.py",
"src/beanmachine/ppl/inference/bmg_inference.py",
"src/beanmachine/ppl/testlib/abstract_conjugate.py",
"src/beanmachine/ppl/testlib/hypothesis_testing.py",
"src/beanmachine/ppl/experimental/torch_jit_backend.py",
"src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py"
],
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ If you would like to run the builtin unit tests:

```bash
python -m pip install "beanmachine[test]"
pytest .
pytest
```

## License
Expand Down
146 changes: 146 additions & 0 deletions src/beanmachine/ppl/examples/conjugate_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod

import beanmachine.ppl as bm
import torch
import torch.distributions as dist
from torch import Tensor


class ConjugateModel(ABC):
"""
The Bean Machine models in this module are examples of conjugacy. Conjugacy
means the posterior will also be in the same family as the prior. The random
variable names theta and x follow the typical presentation of the conjugate
prior relation in the form of p(theta|x) = p(x|theta) * p(theta)/p(x).

See:
https://en.wikipedia.org/wiki/Conjugate_prior
"""

x_dim = 0
"""
Number of indices in likelihood.
"""

@abstractmethod
def theta(self) -> dist.Distribution:
"""
Prior of a conjugate model.
"""
pass

@abstractmethod
def x(self, *args) -> dist.Distribution:
"""
Likelihood of a conjugate model.
"""
pass


class BetaBernoulliModel(ConjugateModel):
x_dim = 1

def __init__(self, alpha: Tensor, beta: Tensor) -> None:
self.alpha_ = alpha
self.beta_ = beta

@bm.random_variable
def theta(self) -> dist.Distribution:
return dist.Beta(self.alpha_, self.beta_)

@bm.random_variable
def x(self, i: int) -> dist.Distribution:
return dist.Bernoulli(self.theta())


class BetaBinomialModel(ConjugateModel):
def __init__(self, alpha: Tensor, beta: Tensor, n: Tensor) -> None:
self.alpha_ = alpha
self.beta_ = beta
self.n_ = n

@bm.random_variable
def theta(self) -> dist.Distribution:
return dist.Beta(self.alpha_, self.beta_)

@bm.random_variable
def x(self) -> dist.Distribution:
return dist.Binomial(self.n_, self.theta())


class CategoricalDirichletModel(ConjugateModel):
def __init__(self, alpha: Tensor) -> None:
self.alpha_ = alpha

@bm.random_variable
def theta(self) -> dist.Distribution:
return dist.Dirichlet(self.alpha_)

@bm.random_variable
def x(self) -> dist.Distribution:
return dist.Categorical(self.theta())


class GammaGammaModel(ConjugateModel):
def __init__(self, shape: Tensor, rate: Tensor, alpha: Tensor) -> None:
self.shape_ = shape
self.rate_ = rate
self.alpha_ = alpha

@bm.random_variable
def theta(self) -> dist.Distribution:
return dist.Gamma(self.shape_, self.rate_)

@bm.random_variable
def x(self) -> dist.Distribution:
return dist.Gamma(self.alpha_, self.theta())


class GammaNormalModel(ConjugateModel):
def __init__(self, shape: Tensor, rate: Tensor, mu: Tensor) -> None:
self.shape_ = shape
self.rate_ = rate
self.mu_ = mu

@bm.random_variable
def theta(self) -> dist.Distribution:
return dist.Gamma(self.shape_, self.rate_)

@bm.random_variable
def x(self) -> dist.Distribution:
return dist.Normal(self.mu_, torch.tensor(1) / torch.sqrt(self.theta()))


class NormalNormalModel(ConjugateModel):
def __init__(self, mu: Tensor, sigma: Tensor, std: Tensor) -> None:
self.mu = mu
self.sigma = sigma
self.std = std

@bm.random_variable
def theta(self) -> dist.Distribution:
return dist.Normal(self.mu, self.sigma)

@bm.random_variable
def x(self) -> dist.Distribution:
return dist.Normal(self.theta(), self.std)


class MV_NormalNormalModel(ConjugateModel):
def __init__(self, mu, sigma, std) -> None:
self.mu = mu
self.sigma = sigma
self.std = std

@bm.random_variable
def theta(self):
return dist.MultivariateNormal(self.mu, self.sigma)

@bm.random_variable
def x(self):
return dist.MultivariateNormal(self.theta(), self.std)
21 changes: 0 additions & 21 deletions src/beanmachine/ppl/examples/conjugate_models/__init__.py

This file was deleted.

22 changes: 0 additions & 22 deletions src/beanmachine/ppl/examples/conjugate_models/beta_bernoulli.py

This file was deleted.

34 changes: 0 additions & 34 deletions src/beanmachine/ppl/examples/conjugate_models/beta_binomial.py

This file was deleted.

This file was deleted.

23 changes: 0 additions & 23 deletions src/beanmachine/ppl/examples/conjugate_models/gamma_gamma.py

This file was deleted.

25 changes: 0 additions & 25 deletions src/beanmachine/ppl/examples/conjugate_models/gamma_normal.py

This file was deleted.

23 changes: 0 additions & 23 deletions src/beanmachine/ppl/examples/conjugate_models/normal_normal.py

This file was deleted.

Loading