Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename / remove dense and sparse variables and refactor #103

Merged
merged 16 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ jobs:
- name: Run mypy
shell: bash {0}
run: pixi run mypy
run-explanation-notebooks:
name: Run explanation notebooks on Python 3.12
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected]
with:
pixi-version: v0.40.3
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: test-cpu
frozen: true
- name: Run explanation notebooks
shell: bash {0}
run: pixi run -e test-cpu explanation-notebooks
# run-explanation-notebooks:
# name: Run explanation notebooks on Python 3.12
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - uses: prefix-dev/[email protected]
# with:
# pixi-version: v0.40.3
# cache: true
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
# environments: test-cpu
# frozen: true
# - name: Run explanation notebooks
# shell: bash {0}
# run: pixi run -e test-cpu explanation-notebooks
4 changes: 2 additions & 2 deletions src/lcm/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _move_axes_to_back(a: Array, axes: tuple[int, ...]) -> Array:
axes (tuple): Axes to move to the back.

Returns:
jax.numpy.ndarray: Array a with shifted axes.
jax.Array: Array a with shifted axes.

"""
front_axes = sorted(set(range(a.ndim)) - set(axes))
Expand All @@ -88,7 +88,7 @@ def _flatten_last_n_axes(a: Array, n: int) -> Array:
n (int): Number of axes to flatten.

Returns:
jax.numpy.ndarray: Array a with flattened last n axes.
jax.Array: Array a with flattened last n axes.

"""
return a.reshape(*a.shape[:-n], -1)
40 changes: 16 additions & 24 deletions src/lcm/discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_solve_discrete_problem(
if is_last_period:
variable_info = variable_info.query("~is_auxiliary")

choice_axes = _determine_dense_discrete_choice_axes(variable_info)
choice_axes = _determine_discrete_choice_axes(variable_info)

if random_utility_shock_type == ShockType.NONE:
func = _solve_discrete_problem_no_shocks
Expand All @@ -71,7 +71,7 @@ def get_solve_discrete_problem(

def _solve_discrete_problem_no_shocks(
cc_values: Array,
choice_axes: tuple[int, ...] | None,
choice_axes: tuple[int, ...],
params: ParamsDict, # noqa: ARG001
) -> Array:
"""Reduce conditional continuation values over discrete choices.
Expand All @@ -90,11 +90,7 @@ def _solve_discrete_problem_no_shocks(
if choice_segments is not None.

"""
out = cc_values
if choice_axes is not None:
out = out.max(axis=choice_axes)

return out
return cc_values.max(axis=choice_axes)


# ======================================================================================
Expand All @@ -108,18 +104,18 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params):
"""Aggregate conditional continuation values over discrete choices.

Args:
values (jax.numpy.ndarray): Multidimensional jax array with conditional
values (jax.Array): Multidimensional jax array with conditional
continuation values.
choice_axes (int or tuple): Int or tuple of int, specifying which axes in
values correspond to dense choice variables.
values correspond to the discrete choice variables.
choice_segments (dict): Dictionary with the entries "segment_ids"
and "num_segments". segment_ids are a 1d integer array that partitions the
first dimension of values into choice sets over which we need to aggregate.
"num_segments" is the number of choice sets.
params (dict): Params dict that contains the schock_scale if necessary.

Returns:
jax.numpy.ndarray: Multidimensional jax array with aggregated continuation
jax.Array: Multidimensional jax array with aggregated continuation
values. Has less dimensions than values if choice_axes is not None and
is shorter in the first dimension if choice_segments is not None.

Expand All @@ -137,26 +133,22 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params):
# ======================================================================================


def _determine_dense_discrete_choice_axes(
def _determine_discrete_choice_axes(
variable_info: pd.DataFrame,
) -> tuple[int, ...] | None:
) -> tuple[int, ...]:
"""Get axes of a state-choice-space that correspond to discrete choices.

Args:
variable_info: DataFrame with information about the variables.

Returns:
tuple[int, ...] | None: A tuple of indices representing the axes' positions in
the value function that correspond to discrete choices. Returns None if
there are no discrete choice axes.
A tuple of indices representing the axes' positions in the value function that
correspond to discrete choices.

"""
# List of dense variables excluding continuous choice variables.
axes = variable_info.query("is_state | is_discrete").index.tolist()

choice_vars = set(variable_info.query("is_choice").index.tolist())

choice_indices = tuple(i for i, ax in enumerate(axes) if ax in choice_vars)

# Return None if there are no discrete choice axes, otherwise return the indices.
return choice_indices if choice_indices else None
discrete_choice_vars = set(
variable_info.query("is_choice & is_discrete").index.tolist()
)
return tuple(
i for i, ax in enumerate(variable_info.index) if ax in discrete_choice_vars
)
143 changes: 64 additions & 79 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,76 @@
from jax import Array, vmap

from lcm.functools import allow_args, allow_only_kwargs
from lcm.utils import find_duplicates

F = TypeVar("F", bound=Callable[..., Array])


def spacemap(
func: F,
dense_vars: list[str],
sparse_vars: list[str] | None = None,
product_vars: tuple[str, ...],
combination_vars: tuple[str, ...],
) -> F:
"""Apply vmap such that func is evaluated on a space of dense and sparse variables.

This is achieved by applying _base_productmap for all dense variables and vmap_1d
for the sparse variables.
"""Apply vmap such that func can be evaluated on product and combination variables.

Product variables are used to create a Cartesian product of possible values. I.e.,
for each product variable, we create a new leading dimension in the output object,
with the size of the dimension being the number of possible values in the grid. The
i-th entries of the combination variables, correspond to one valid combination. For
the combination variables, a single dimension is thus added to the output object,
with the size of the dimension being the number of possible combinations. This means
that all combination variables must have the same size (e.g., in the simulation the
states act as combination variables, and their size equals the number of
simulations).

spacemap preserves the function signature and allows the function to be called with
keyword arguments.

Args:
func: The function to be dispatched.
dense_vars: Names of the dense variables, i.e. those that are stored as arrays
of possible values in the grid.
sparse_vars: Names of the sparse variables, i.e. those that are stored as arrays
of possible combinations of variables in the grid.
put_dense_first: Whether the dense or sparse dimensions should come first in the
output of the dispatched function.

product_vars: Names of the product variables, i.e. those that are stored as
arrays of possible values in the grid, over which we create a Cartesian
product.
combination_vars: Names of the combination variables, i.e. those that are
stored as arrays of possible combinations.

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
jax.numpy.ndarray with k + 1 dimensions, where k is the length of ``dense_vars``
and the additional dimension corresponds to the ``sparse_vars``. The order of
the dimensions is determined by the order of ``dense_vars`` as well as the
``put_dense_first`` argument. If the output of ``func`` is a jax pytree, the
usual jax behavior applies, i.e. the leading dimensions of all arrays in the
pytree are as described above but there might be additional dimensions.
dimension) that returns a jax.Array or pytree of arrays. If `func` returns a
scalar, the dispatched function returns a jax.Array with k + 1 dimensions, where
k is the length of `product_vars` and the additional dimension corresponds to
the `combination_vars`. The order of the dimensions is determined by the order
of `product_vars`. If the output of `func` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

"""
# Check inputs and prepare function
# ==================================================================================
duplicates = {v for v in dense_vars if dense_vars.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in dense variables: {duplicates}",
if duplicates := find_duplicates(product_vars, combination_vars):
msg = (
"Same argument provided more than once in product variables or combination "
f"variables, or is present in both: {duplicates}"
)
raise ValueError(msg)

if sparse_vars:
overlap = set(dense_vars).intersection(sparse_vars)
if overlap:
raise ValueError(
f"Dense and sparse variables must be disjoint. Overlap: {overlap}",
)

duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1}
if duplicates:
raise ValueError(
"Same argument provided more than once in sparse variables: "
f"{duplicates}",
)

# jax.vmap cannot deal with keyword-only arguments
func = allow_args(func)

# Apply vmap_1d for sparse and _base_productmap for dense variables
# ==================================================================================
if not sparse_vars:
vmapped = _base_productmap(func, dense_vars)
else:
vmapped = _base_productmap(func, dense_vars)
vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args")
func_callable_with_args = allow_args(func)

vmapped = _base_productmap(func_callable_with_args, product_vars)

if combination_vars:
vmapped = vmap_1d(
vmapped, variables=combination_vars, callable_with="only_args"
)

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = inspect.signature(func) # type: ignore[attr-defined]
vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined]

return allow_only_kwargs(vmapped)


def vmap_1d(
func: F,
variables: list[str],
variables: tuple[str, ...],
*,
callable_with: Literal["only_args", "only_kwargs"] = "only_kwargs",
) -> F:
Expand All @@ -98,26 +86,25 @@ def vmap_1d(

Args:
func: The function to be dispatched.
variables: List with names of arguments that over which we map.
variables: Tuple with names of arguments that over which we map.
callable_with: Whether to apply the allow_kwargs decorator to the dispatched
function. If "only_args", the returned function can only be called with
positional arguments. If "only_kwargs", the returned function can only be
called with keyword arguments.

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
jax.numpy.ndarray with 1 dimension and length k, where k is the length of one of
the mapped inputs in ``variables``. The order of the dimensions is determined by
the order of ``variables`` which can be different to the order of ``funcs``
arguments. If the output of ``func`` is a jax pytree, the usual jax behavior
dimension) that returns a jax.Array or pytree of arrays. If `func`
returns a scalar, the dispatched function returns a jax.Array with 1
jax.Array with 1 dimension and length k, where k is the length of one of
the mapped inputs in `variables`. The order of the dimensions is determined by
the order of `variables` which can be different to the order of `funcs`
arguments. If the output of `func` is a jax pytree, the usual jax behavior
applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

"""
duplicates = {v for v in variables if variables.count(v) > 1}
if duplicates:
if duplicates := find_duplicates(variables):
raise ValueError(
f"Same argument provided more than once in variables: {duplicates}",
)
Expand Down Expand Up @@ -154,7 +141,7 @@ def vmap_1d(
return out


def productmap(func: F, variables: list[str]) -> F:
def productmap(func: F, variables: tuple[str, ...]) -> F:
"""Apply vmap such that func is evaluated on the Cartesian product of variables.

This is achieved by an iterative application of vmap.
Expand All @@ -164,50 +151,48 @@ def productmap(func: F, variables: list[str]) -> F:

Args:
func: The function to be dispatched.
variables: List with names of arguments that over which the Cartesian product
variables: Tuple with names of arguments that over which the Cartesian product
should be formed.

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
dimensions, where k is the length of ``variables``. The order of the dimensions
is determined by the order of ``variables`` which can be different to the order
of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
dimension) that returns a jax.Array or pytree of arrays. If `func`
returns a scalar, the dispatched function returns a jax.Array with k
dimensions, where k is the length of `variables`. The order of the dimensions
is determined by the order of `variables` which can be different to the order
of `funcs` arguments. If the output of `func` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

"""
func = allow_args(func) # jax.vmap cannot deal with keyword-only arguments

duplicates = {v for v in variables if variables.count(v) > 1}
if duplicates:
if duplicates := find_duplicates(variables):
raise ValueError(
f"Same argument provided more than once in variables: {duplicates}",
)

signature = inspect.signature(func)
vmapped = _base_productmap(func, variables)
func_callable_with_args = allow_args(func)

vmapped = _base_productmap(func_callable_with_args, variables)

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = signature # type: ignore[attr-defined]
vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined]

return allow_only_kwargs(vmapped)


def _base_productmap(func: F, product_axes: list[str]) -> F:
def _base_productmap(func: F, product_axes: tuple[str, ...]) -> F:
"""Map func over the Cartesian product of product_axes.

Like vmap, this function does not preserve the function signature and does not allow
the function to be called with keyword arguments.

Args:
func: The function to be dispatched. Cannot have keyword-only arguments.
product_axes: List with names of arguments over which we apply vmap.
product_axes: Tuple with names of arguments over which we apply vmap.

Returns:
A callable with the same arguments as func. See ``product_map`` for details.
A callable with the same arguments as func. See `product_map` for details.

"""
signature = inspect.signature(func)
Expand Down
Loading
Loading