Skip to content

Commit

Permalink
Rename _positive_ints_like and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Feb 19, 2025
1 parent 35ba838 commit e9d3001
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
14 changes: 7 additions & 7 deletions pyrenew/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from pyrenew.distutil import validate_discrete_dist_vector


def _arange_like(vec: ArrayLike) -> jnp.ndarray:
def _positive_ints_like(vec: ArrayLike) -> jnp.ndarray:
"""
Given an array of size n, return the array
[1, ... n].
Given an array of size n, return the 1D Jax array
``[1, ... n]``.
Parameters
----------
Expand All @@ -25,7 +25,7 @@ def _arange_like(vec: ArrayLike) -> jnp.ndarray:
Returns
-------
jnp.ndarray
The resulting array.
The resulting array ``[1, ..., n]``.
"""
return jnp.arange(1, jnp.size(vec) + 1)

Expand All @@ -49,7 +49,7 @@ def _neg_MGF(r: float, w: ArrayLike) -> float:
The value of the negative MGF evaluated at ``r``
and ``w``.
"""
return jnp.sum(w * jnp.exp(-r * _arange_like(w)))
return jnp.sum(w * jnp.exp(-r * _positive_ints_like(w)))


def _neg_MGF_del_r(r: float, w: ArrayLike) -> float:
Expand All @@ -72,7 +72,7 @@ def _neg_MGF_del_r(r: float, w: ArrayLike) -> float:
The value of the partial derivative evaluated at ``r``
and ``w``.
"""
t_vec = _arange_like(w)
t_vec = _positive_ints_like(w)
return -jnp.sum(w * t_vec * jnp.exp(-r * t_vec))


Expand Down Expand Up @@ -101,7 +101,7 @@ def r_approx_from_R(R: float, G: ArrayLike, n_newton_steps: int) -> ArrayLike:
float
The approximate value of ``r``.
"""
mean_gi = jnp.dot(G, _arange_like(G))
mean_gi = jnp.dot(G, _positive_ints_like(G))
init_r = (R - 1) / (R * mean_gi)

def _r_next(_iter, r): # numpydoc ignore=GL08
Expand Down
24 changes: 23 additions & 1 deletion test/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,38 @@
Unit tests for the pyrenew.math module.
"""

import jax.numpy as jnp
import numpy as np
import pytest
from numpy.random import RandomState
from numpy.testing import assert_almost_equal, assert_array_almost_equal
from numpy.testing import (
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
)

import pyrenew.math as pmath

rng = RandomState(5)


@pytest.mark.parametrize(
"arr, arr_len",
[
([3, 1, 2], 3),
(np.ones(50), 50),
((jnp.nan * jnp.ones(250)).reshape((50, -1)), 250),
],
)
def test_positive_ints_like(arr, arr_len):
"""
Test the _positive_ints_like helper function.
"""
result = pmath._positive_ints_like(arr)
expected = jnp.arange(1, arr_len + 1)
assert_array_equal(result, expected)


@pytest.mark.parametrize(
"R, G",
[
Expand Down

0 comments on commit e9d3001

Please sign in to comment.