Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename / remove dense and sparse variables and refactor #103

Merged
merged 16 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ jobs:
- name: Run mypy
shell: bash {0}
run: pixi run mypy
run-explanation-notebooks:
name: Run explanation notebooks on Python 3.12
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: prefix-dev/[email protected]
with:
pixi-version: v0.40.3
cache: true
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
environments: test-cpu
frozen: true
- name: Run explanation notebooks
shell: bash {0}
run: pixi run -e test-cpu explanation-notebooks
# run-explanation-notebooks:
# name: Run explanation notebooks on Python 3.12
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - uses: prefix-dev/[email protected]
# with:
# pixi-version: v0.40.3
# cache: true
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
# environments: test-cpu
# frozen: true
# - name: Run explanation notebooks
# shell: bash {0}
# run: pixi run -e test-cpu explanation-notebooks
4 changes: 2 additions & 2 deletions src/lcm/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _move_axes_to_back(a: Array, axes: tuple[int, ...]) -> Array:
axes (tuple): Axes to move to the back.

Returns:
jax.numpy.ndarray: Array a with shifted axes.
jax.Array: Array a with shifted axes.

"""
front_axes = sorted(set(range(a.ndim)) - set(axes))
Expand All @@ -88,7 +88,7 @@ def _flatten_last_n_axes(a: Array, n: int) -> Array:
n (int): Number of axes to flatten.

Returns:
jax.numpy.ndarray: Array a with flattened last n axes.
jax.Array: Array a with flattened last n axes.

"""
return a.reshape(*a.shape[:-n], -1)
40 changes: 16 additions & 24 deletions src/lcm/discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_solve_discrete_problem(
if is_last_period:
variable_info = variable_info.query("~is_auxiliary")

choice_axes = _determine_dense_discrete_choice_axes(variable_info)
choice_axes = _determine_discrete_choice_axes(variable_info)

if random_utility_shock_type == ShockType.NONE:
func = _solve_discrete_problem_no_shocks
Expand All @@ -71,7 +71,7 @@ def get_solve_discrete_problem(

def _solve_discrete_problem_no_shocks(
cc_values: Array,
choice_axes: tuple[int, ...] | None,
choice_axes: tuple[int, ...],
params: ParamsDict, # noqa: ARG001
) -> Array:
"""Reduce conditional continuation values over discrete choices.
Expand All @@ -90,11 +90,7 @@ def _solve_discrete_problem_no_shocks(
if choice_segments is not None.

"""
out = cc_values
if choice_axes is not None:
out = out.max(axis=choice_axes)

return out
return cc_values.max(axis=choice_axes)


# ======================================================================================
Expand All @@ -108,18 +104,18 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params):
"""Aggregate conditional continuation values over discrete choices.

Args:
values (jax.numpy.ndarray): Multidimensional jax array with conditional
values (jax.Array): Multidimensional jax array with conditional
continuation values.
choice_axes (int or tuple): Int or tuple of int, specifying which axes in
values correspond to dense choice variables.
values correspond to the discrete choice variables.
choice_segments (dict): Dictionary with the entries "segment_ids"
and "num_segments". segment_ids are a 1d integer array that partitions the
first dimension of values into choice sets over which we need to aggregate.
"num_segments" is the number of choice sets.
params (dict): Params dict that contains the schock_scale if necessary.

Returns:
jax.numpy.ndarray: Multidimensional jax array with aggregated continuation
jax.Array: Multidimensional jax array with aggregated continuation
values. Has less dimensions than values if choice_axes is not None and
is shorter in the first dimension if choice_segments is not None.

Expand All @@ -137,26 +133,22 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, params):
# ======================================================================================


def _determine_dense_discrete_choice_axes(
def _determine_discrete_choice_axes(
variable_info: pd.DataFrame,
) -> tuple[int, ...] | None:
) -> tuple[int, ...]:
"""Get axes of a state-choice-space that correspond to discrete choices.

Args:
variable_info: DataFrame with information about the variables.

Returns:
tuple[int, ...] | None: A tuple of indices representing the axes' positions in
the value function that correspond to discrete choices. Returns None if
there are no discrete choice axes.
A tuple of indices representing the axes' positions in the value function that
correspond to discrete choices.

"""
# List of dense variables excluding continuous choice variables.
axes = variable_info.query("is_state | is_discrete").index.tolist()

choice_vars = set(variable_info.query("is_choice").index.tolist())

choice_indices = tuple(i for i, ax in enumerate(axes) if ax in choice_vars)

# Return None if there are no discrete choice axes, otherwise return the indices.
return choice_indices if choice_indices else None
discrete_choice_vars = set(
variable_info.query("is_choice & is_discrete").index.tolist()
)
return tuple(
i for i, ax in enumerate(variable_info.index) if ax in discrete_choice_vars
)
97 changes: 52 additions & 45 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,71 +11,78 @@

def spacemap(
func: F,
dense_vars: list[str],
sparse_vars: list[str] | None = None,
product_vars: list[str],
combination_vars: list[str] | None = None,
) -> F:
"""Apply vmap such that func is evaluated on a space of dense and sparse variables.

This is achieved by applying _base_productmap for all dense variables and vmap_1d
for the sparse variables.
"""Apply vmap such that func can be evaluated on product and combination variables.

Product variables are used to create a Cartesian product of possible values. I.e.,
for each product variable, we create a new leading dimension in the output object,
with the size of the dimension being the number of possible values in the grid. The
i-th entries of the combination variables, correspond to one valid combination. For
the combination variables, a single dimension is thus added to the output object,
with the size of the dimension being the number of possible combinations. This means
that all combination variables must have the same size (e.g., in the simulation the
states act as combination variables, and their size equals the number of
simulations).

spacemap preserves the function signature and allows the function to be called with
keyword arguments.

Args:
func: The function to be dispatched.
dense_vars: Names of the dense variables, i.e. those that are stored as arrays
of possible values in the grid.
sparse_vars: Names of the sparse variables, i.e. those that are stored as arrays
of possible combinations of variables in the grid.
put_dense_first: Whether the dense or sparse dimensions should come first in the
output of the dispatched function.

product_vars: Names of the product variables, i.e. those that are stored as
arrays of possible values in the grid, over which we create a Cartesian
product.
combination_vars: Names of the combination variables, i.e. those that are
stored as arrays of possible combinations.

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
jax.numpy.ndarray with k + 1 dimensions, where k is the length of ``dense_vars``
and the additional dimension corresponds to the ``sparse_vars``. The order of
the dimensions is determined by the order of ``dense_vars`` as well as the
``put_dense_first`` argument. If the output of ``func`` is a jax pytree, the
usual jax behavior applies, i.e. the leading dimensions of all arrays in the
pytree are as described above but there might be additional dimensions.
dimension) that returns a jax.Array or pytree of arrays. If `func` returns a
scalar, the dispatched function returns a jax.Array with k + 1 dimensions, where
k is the length of `product_vars` and the additional dimension corresponds to
the `combination_vars`. The order of the dimensions is determined by the order
of `product_vars`. If the output of `func` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

"""
# Check inputs and prepare function
# ==================================================================================
duplicates = {v for v in dense_vars if dense_vars.count(v) > 1}
duplicates = {v for v in product_vars if product_vars.count(v) > 1}
if duplicates:
raise ValueError(
f"Same argument provided more than once in dense variables: {duplicates}",
f"Same argument provided more than once in product variables: {duplicates}",
)

if sparse_vars:
overlap = set(dense_vars).intersection(sparse_vars)
if combination_vars:
overlap = set(product_vars).intersection(combination_vars)
if overlap:
raise ValueError(
f"Dense and sparse variables must be disjoint. Overlap: {overlap}",
"Product and combination variables must be disjoint. Overlap: "
f"{overlap}",
)

duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1}
duplicates = {v for v in combination_vars if combination_vars.count(v) > 1}
if duplicates:
raise ValueError(
"Same argument provided more than once in sparse variables: "
"Same argument provided more than once in combination variables: "
f"{duplicates}",
)

# jax.vmap cannot deal with keyword-only arguments
func = allow_args(func)

# Apply vmap_1d for sparse and _base_productmap for dense variables
# Apply vmap_1d for combination variables and _base_productmap for product variables
# ==================================================================================
if not sparse_vars:
vmapped = _base_productmap(func, dense_vars)
if not combination_vars:
vmapped = _base_productmap(func, product_vars)
else:
vmapped = _base_productmap(func, dense_vars)
vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args")
vmapped = _base_productmap(func, product_vars)
vmapped = vmap_1d(
vmapped, variables=combination_vars, callable_with="only_args"
)

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
Expand Down Expand Up @@ -106,12 +113,12 @@ def vmap_1d(

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with 1
jax.numpy.ndarray with 1 dimension and length k, where k is the length of one of
the mapped inputs in ``variables``. The order of the dimensions is determined by
the order of ``variables`` which can be different to the order of ``funcs``
arguments. If the output of ``func`` is a jax pytree, the usual jax behavior
dimension) that returns a jax.Array or pytree of arrays. If `func`
returns a scalar, the dispatched function returns a jax.Array with 1
jax.Array with 1 dimension and length k, where k is the length of one of
the mapped inputs in `variables`. The order of the dimensions is determined by
the order of `variables` which can be different to the order of `funcs`
arguments. If the output of `func` is a jax pytree, the usual jax behavior
applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

Expand Down Expand Up @@ -169,11 +176,11 @@ def productmap(func: F, variables: list[str]) -> F:

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.numpy.ndarray or pytree of arrays. If ``func``
returns a scalar, the dispatched function returns a jax.numpy.ndarray with k
dimensions, where k is the length of ``variables``. The order of the dimensions
is determined by the order of ``variables`` which can be different to the order
of ``funcs`` arguments. If the output of ``func`` is a jax pytree, the usual jax
dimension) that returns a jax.Array or pytree of arrays. If `func`
returns a scalar, the dispatched function returns a jax.Array with k
dimensions, where k is the length of `variables`. The order of the dimensions
is determined by the order of `variables` which can be different to the order
of `funcs` arguments. If the output of `func` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

Expand Down Expand Up @@ -207,7 +214,7 @@ def _base_productmap(func: F, product_axes: list[str]) -> F:
product_axes: List with names of arguments over which we apply vmap.

Returns:
A callable with the same arguments as func. See ``product_map`` for details.
A callable with the same arguments as func. See `product_map` for details.

"""
signature = inspect.signature(func)
Expand Down
33 changes: 9 additions & 24 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.simulate import simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.simulation.simulate import simulate
from lcm.solution.solve_brute import solve
from lcm.solution.state_choice_space import create_state_choice_space
from lcm.typing import ParamsDict
from lcm.user_model import Model

Expand All @@ -38,11 +38,6 @@ def get_lcm_function(
source code of this function to see how the lower level components are meant to be
used.

Notes:
-----
- There is a hack to make the state_indexers empty in the last period which needs
to be replaced by a better solution, when we want to allow for bequest motives.

Args:
model: User model specification.
targets: The requested function types. Currently only "solve", "simulate" and
Expand Down Expand Up @@ -79,8 +74,7 @@ def get_lcm_function(
# Initialize other argument lists
# ==================================================================================
state_choice_spaces = []
state_indexers = [] # type:ignore[var-annotated]
space_infos = []
state_space_infos = []
compute_ccv_functions = []
compute_ccv_policy_functions = []
choice_segments = [] # type: ignore[var-annotated]
Expand All @@ -94,25 +88,19 @@ def get_lcm_function(

# call state space creation function, append trivial items to their lists
# ==============================================================================
sc_space, space_info = create_state_choice_space(
state_choice_space, state_space_info = create_state_choice_space(
model=_mod,
is_last_period=is_last_period,
)

state_choice_spaces.append(sc_space)
state_choice_spaces.append(state_choice_space)
choice_segments.append(None)

if is_last_period:
state_indexers.append({})
else:
state_indexers.append({})

space_infos.append(space_info)
state_space_infos.append(state_space_info)

# ==================================================================================
# Shift space info (in period t we require the space info of period t+1)
# ==================================================================================
space_infos = space_infos[1:] + [{}] # type: ignore[list-item]
state_space_infos = state_space_infos[1:] + [{}] # type: ignore[list-item]

# ==================================================================================
# Create model functions
Expand All @@ -124,8 +112,7 @@ def get_lcm_function(
# ==============================================================================
u_and_f = get_utility_and_feasibility_function(
model=_mod,
space_info=space_infos[period],
name_of_values_on_grid="vf_arr",
state_space_info=state_space_infos[period],
period=period,
is_last_period=is_last_period,
)
Expand Down Expand Up @@ -157,7 +144,6 @@ def get_lcm_function(
_solve_model = partial(
solve,
state_choice_spaces=state_choice_spaces,
state_indexers=state_indexers,
continuous_choice_grids=continuous_choice_grids,
compute_ccv_functions=compute_ccv_functions,
emax_calculators=emax_calculators,
Expand All @@ -169,7 +155,6 @@ def get_lcm_function(
_next_state_simulate = get_next_state_function(model=_mod, target="simulate")
simulate_model = partial(
simulate,
state_indexers=state_indexers,
continuous_choice_grids=continuous_choice_grids,
compute_ccv_policy_functions=compute_ccv_policy_functions,
model=_mod,
Expand Down
Loading
Loading