Skip to content

Commit

Permalink
Switch to enums for distributions instead of strings
Browse files Browse the repository at this point in the history
  • Loading branch information
Sosnowsky committed Jan 29, 2025
1 parent 6683351 commit f91d773
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 86 deletions.
1 change: 1 addition & 0 deletions blobmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .stochasticality import BlobFactory, DefaultBlobFactory
from .geometry import Geometry
from .blob_shape import AbstractBlobShape, BlobShapeImpl
from .distributions import Distribution
77 changes: 77 additions & 0 deletions blobmodel/distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from enum import Enum
from abc import ABC, abstractmethod
import numpy as np


class Distribution(Enum):
"""Enum class used to identify distribution functions."""

deg = 1
zeros = 2
exp = 3
gamma = 4
normal = 5
uniform = 6
rayleigh = 7


class AbstractDistribution(ABC):
"""Abstract class used to represent and implement a distribution function."""

@abstractmethod
def sample(
self,
num_blobs: int,
**kwargs,
) -> np.ndarray:
raise NotImplementedError


def _sample_deg(num_blobs, **kwargs):
free_param = kwargs["free_param"]
return free_param * np.ones(num_blobs).astype(np.float64)


def _sample_zeros(num_blobs, **kwargs):
return np.zeros(num_blobs).astype(np.float64)


def _sample_exp(num_blobs, **kwargs):
free_param = kwargs["free_param"]
return np.random.exponential(scale=free_param, size=num_blobs).astype(np.float64)


def _sample_gamma(num_blobs, **kwargs):
free_param = kwargs["free_param"]
return np.random.gamma(
shape=free_param, scale=1 / free_param, size=num_blobs
).astype(np.float64)


def _sample_normal(num_blobs, **kwargs):
free_param = kwargs["free_param"]
return np.random.normal(loc=0, scale=free_param, size=num_blobs).astype(np.float64)


def _sample_uniform(num_blobs, **kwargs):
free_param = kwargs["free_param"]
return np.random.uniform(
low=1 - free_param / 2, high=1 + free_param / 2, size=num_blobs
).astype(np.float64)


def _sample_rayleigh(num_blobs, **kwargs):
return np.random.rayleigh(scale=np.sqrt(2.0 / np.pi), size=num_blobs).astype(
np.float64
)


DISTRIBUTIONS = {
Distribution.deg: _sample_deg,
Distribution.zeros: _sample_zeros,
Distribution.exp: _sample_exp,
Distribution.gamma: _sample_gamma,
Distribution.normal: _sample_normal,
Distribution.uniform: _sample_uniform,
Distribution.rayleigh: _sample_rayleigh,
}
99 changes: 28 additions & 71 deletions blobmodel/stochasticality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, List, Union, Dict
from .blobs import Blob
from .blob_shape import AbstractBlobShape
from .distributions import *


class BlobFactory(ABC):
Expand Down Expand Up @@ -40,13 +41,13 @@ class DefaultBlobFactory(BlobFactory):

def __init__(
self,
A_dist: str = "exp",
wx_dist: str = "deg",
wy_dist: str = "deg",
vx_dist: str = "deg",
vy_dist: str = "deg",
spx_dist: str = "deg",
spy_dist: str = "deg",
A_dist: Distribution = Distribution.exp,
wx_dist: Distribution = Distribution.deg,
wy_dist: Distribution = Distribution.deg,
vx_dist: Distribution = Distribution.deg,
vy_dist: Distribution = Distribution.deg,
spx_dist: Distribution = Distribution.deg,
spy_dist: Distribution = Distribution.deg,
A_parameter: float = 1.0,
wx_parameter: float = 1.0,
wy_parameter: float = 1.0,
Expand All @@ -64,20 +65,20 @@ def __init__(
Parameters
----------
A_dist : str, optional
Distribution type for amplitude, by default "exp"
wx_dist : str, optional
Distribution type for width in the x-direction, by default "deg"
wy_dist : str, optional
Distribution type for width in the y-direction, by default "deg"
vx_dist : str, optional
Distribution type for velocity in the x-direction, by default "deg"
vy_dist : str, optional
Distribution type for velocity in the y-direction, by default "deg"
spx_dist : str, optional
Distribution type for shape parameter in the x-direction, by default "deg"
spy_dist : str, optional
Distribution type for shape parameter in the y-direction, by default "deg"
A_dist : Distribution, optional
Distribution type for amplitude, by default "Distribution.exp"
wx_dist : Distribution, optional
Distribution type for width in the x-direction, by default "Distribution.deg"
wy_dist : Distribution, optional
Distribution type for width in the y-direction, by default "Distribution.deg"
vx_dist : Distribution, optional
Distribution type for velocity in the x-direction, by default "Distribution.deg"
vy_dist : Distribution, optional
Distribution type for velocity in the y-direction, by default "Distribution.deg"
spx_dist : Distribution, optional
Distribution type for shape parameter in the x-direction, by default "Distribution.deg"
spy_dist : Distribution, optional
Distribution type for shape parameter in the y-direction, by default "Distribution.deg"
A_parameter : float, optional
Free parameter for the amplitude distribution, by default 1.0
wx_parameter : float, optional
Expand Down Expand Up @@ -126,54 +127,10 @@ def __init__(
self.theta_setter = lambda: 0

def _draw_random_variables(
self,
dist_type: str,
free_parameter: float,
num_blobs: int,
self, dist: Distribution, free_parameter: float, num_blobs: int
) -> np.ndarray:
"""
Draws random variables from a specified distribution.
Parameters
----------
dist_type : str
Type of distribution.
free_parameter : float
Free parameter for the distribution.
num_blobs : int
Number of random variables to draw.
Returns
-------
NDArray[Any, Float[64]]
Array of random variables drawn from the specified distribution.
"""
if dist_type == "exp":
return np.random.exponential(scale=1, size=num_blobs).astype(np.float64)
elif dist_type == "gamma":
return np.random.gamma(
shape=free_parameter, scale=1 / free_parameter, size=num_blobs
).astype(np.float64)
elif dist_type == "normal":
return np.random.normal(loc=0, scale=free_parameter, size=num_blobs).astype(
np.float64
)
elif dist_type == "uniform":
return np.random.uniform(
low=1 - free_parameter / 2, high=1 + free_parameter / 2, size=num_blobs
).astype(np.float64)
elif dist_type == "ray":
return np.random.rayleigh(
scale=np.sqrt(2.0 / np.pi), size=num_blobs
).astype(np.float64)
elif dist_type == "deg":
return free_parameter * np.ones(num_blobs).astype(np.float64)
elif dist_type == "zeros":
return np.zeros(num_blobs).astype(np.float64)
else:
raise NotImplementedError(
self.__class__.__name__ + ".distribution function not implemented"
)
"""Draws random variables from a specified distribution."""
return DISTRIBUTIONS[dist](num_blobs, free_param=free_parameter)

def sample_blobs(
self,
Expand Down Expand Up @@ -205,9 +162,9 @@ def sample_blobs(
List of Blob objects generated for the Model.
"""
amps = self._draw_random_variables(
dist_type=self.amplitude_dist,
free_parameter=self.amplitude_parameter,
num_blobs=num_blobs,
self.amplitude_dist,
self.amplitude_parameter,
num_blobs,
)
wxs = self._draw_random_variables(
self.width_x_dist, self.width_x_parameter, num_blobs
Expand Down
4 changes: 2 additions & 2 deletions tests/test_analytical.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from blobmodel import Model, DefaultBlobFactory
from blobmodel import Model, DefaultBlobFactory, Distribution
import xarray as xr
import numpy as np


# use DefaultBlobFactory to define distribution functions fo random variables
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="zeros")
bf = DefaultBlobFactory(A_dist=Distribution.deg, vy_dist=Distribution.zeros)

tmp = Model(
Nx=100,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_blob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from blobmodel import Model, DefaultBlobFactory, Blob, BlobShapeImpl
from blobmodel import Model, DefaultBlobFactory, Blob, BlobShapeImpl, Distribution
import numpy as np
from unittest.mock import MagicMock

Expand Down Expand Up @@ -204,7 +204,7 @@ def test_kwargs():


def test_get_blobs():
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="deg")
bf = DefaultBlobFactory(A_dist=Distribution.deg)
one_blob = Model(
Nx=100,
Ny=100,
Expand All @@ -218,6 +218,6 @@ def test_get_blobs():
num_blobs=3,
blob_factory=bf,
)
ds = one_blob.make_realization()
one_blob.make_realization()
blob_list = one_blob.get_blobs()
assert len(blob_list) == 3
4 changes: 2 additions & 2 deletions tests/test_changing_t_drain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from blobmodel import Model, DefaultBlobFactory
from blobmodel import Model, DefaultBlobFactory, Distribution
import xarray as xr
import numpy as np


# use DefaultBlobFactory to define distribution functions fo random variables
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="zeros")
bf = DefaultBlobFactory(A_dist=Distribution.deg, vy_dist=Distribution.zeros)

t_drain = np.linspace(2, 1, 10)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_one_dim.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from blobmodel import Model, DefaultBlobFactory
from blobmodel import Model, DefaultBlobFactory, Distribution
import numpy as np


# use DefaultBlobFactory to define distribution functions fo random variables
bf = DefaultBlobFactory(A_dist="deg", wx_dist="deg", vx_dist="deg", vy_dist="zeros")
bf = DefaultBlobFactory(A_dist=Distribution.deg, vy_dist=Distribution.zeros)

one_dim_model = Model(
Nx=100,
Expand Down
18 changes: 12 additions & 6 deletions tests/test_stochasticality.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
import pytest
from blobmodel import DefaultBlobFactory, BlobShapeImpl, BlobFactory
from blobmodel import DefaultBlobFactory, BlobShapeImpl, BlobFactory, Distribution


def test_mean_of_distribution():
bf = DefaultBlobFactory()
distributions_mean_1 = ["exp", "gamma", "uniform", "ray", "deg"]
distributions_mean_0 = ["normal", "zeros"]
distributions_mean_1 = [
Distribution.exp,
Distribution.gamma,
Distribution.uniform,
Distribution.rayleigh,
Distribution.deg,
]
distributions_mean_0 = [Distribution.normal, Distribution.zeros]

for dist in distributions_mean_1:
tmp = bf._draw_random_variables(
dist_type=dist,
dist=dist,
free_parameter=1,
num_blobs=10000,
)
assert 0.95 <= tmp.mean() <= 1.05

for dist in distributions_mean_0:
tmp = bf._draw_random_variables(
dist_type=dist,
dist=dist,
free_parameter=1,
num_blobs=10000,
)
assert -0.05 <= tmp.mean() <= 0.05


def test_not_implemented_distribution():
with pytest.raises(NotImplementedError):
with pytest.raises(KeyError):
bf = DefaultBlobFactory(A_dist="something_different")
bf.sample_blobs(1, 1, 1, BlobShapeImpl("gauss"), 1)

Expand Down

0 comments on commit f91d773

Please sign in to comment.