Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor solution and simulation #108

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/lcm/conditional_continuation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import functools
from collections.abc import Callable

import jax.numpy as jnp
from jax import Array

from lcm.argmax import argmax
from lcm.dispatchers import productmap
from lcm.typing import ParamsDict


def get_compute_conditional_continuation_value(
utility_and_feasibility: Callable[..., tuple[Array, Array]],
continuous_choice_variables: tuple[str, ...],
) -> Callable[..., Array]:
"""Get a function that computes the conditional continuation value.

This function solves the continuous choice problem conditional on a state-
(discrete-)choice combination; and is used in the model solution process.

Args:
utility_and_feasibility: A function that takes a state-choice combination and
return the utility of that combination (scalar) and whether the state-choice
combination is feasible (bool).
continuous_choice_variables: Tuple of choice variable names that are continuous.

Returns:
A function that takes a state-choice combination and returns the conditional
continuation value over the continuous choices.

"""
if continuous_choice_variables:
utility_and_feasibility = productmap(
func=utility_and_feasibility,
variables=continuous_choice_variables,
)

@functools.wraps(utility_and_feasibility)
def compute_ccv(params: ParamsDict, **kwargs: Array) -> Array:
u, f = utility_and_feasibility(params=params, **kwargs)
return u.max(where=f, initial=-jnp.inf)

return compute_ccv


def get_compute_conditional_continuation_policy(
utility_and_feasibility: Callable[..., tuple[Array, Array]],
continuous_choice_variables: tuple[str, ...],
) -> Callable[..., tuple[Array, Array]]:
"""Get a function that computes the conditional continuation policy.

This function solves the continuous choice problem conditional on a state-
(discrete-)choice combination; and is used in the model simulation process.

Args:
utility_and_feasibility: A function that takes a state-choice combination and
return the utility of that combination (scalar) and whether the state-choice
combination is feasible (bool).
continuous_choice_variables: Tuple of choice variable names that are
continuous.

Returns:
A function that takes a state-choice combination and returns the conditional
continuation value over the continuous choices, and the index that maximizes the
conditional continuation value.

"""
if continuous_choice_variables:
utility_and_feasibility = productmap(
func=utility_and_feasibility,
variables=continuous_choice_variables,
)

@functools.wraps(utility_and_feasibility)
def compute_ccp(params: ParamsDict, **kwargs: Array) -> tuple[Array, Array]:
u, f = utility_and_feasibility(params=params, **kwargs)
_argmax, _max = argmax(u, where=f, initial=-jnp.inf)
return _argmax, _max

return compute_ccp
70 changes: 64 additions & 6 deletions src/lcm/discrete_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@
import pandas as pd
from jax import Array

from lcm.typing import DiscreteProblemSolverFunction, ParamsDict, ShockType
from lcm.argmax import argmax
from lcm.typing import (
DiscreteProblemPolicySolverFunction,
DiscreteProblemValueSolverFunction,
ParamsDict,
ShockType,
)


def get_solve_discrete_problem(
def get_solve_discrete_problem_value(
*,
random_utility_shock_type: ShockType,
variable_info: pd.DataFrame,
is_last_period: bool,
) -> DiscreteProblemSolverFunction:
) -> DiscreteProblemValueSolverFunction:
"""Get function that computes the expected max. of conditional continuation values.

The maximum is taken over the discrete choice variables in each state.
Expand All @@ -51,7 +57,7 @@ def get_solve_discrete_problem(
if is_last_period:
variable_info = variable_info.query("~is_auxiliary")

choice_axes = _determine_discrete_choice_axes(variable_info)
choice_axes = _determine_discrete_choice_axes_solution(variable_info)

if random_utility_shock_type == ShockType.NONE:
func = _solve_discrete_problem_no_shocks
Expand All @@ -63,6 +69,37 @@ def get_solve_discrete_problem(
return partial(func, choice_axes=choice_axes)


def get_solve_discrete_problem_policy(
*,
variable_info: pd.DataFrame,
) -> DiscreteProblemPolicySolverFunction:
"""Return a function that calculates the argmax and max of continuation values.

The argmax is taken over the discrete choice variables in each state.

Args:
variable_info (pd.DataFrame): DataFrame with information about the model
variables.

Returns:
callable: Function that calculates the argmax of the conditional continuation
values. The function depends on:
- values (jax.Array): Multidimensional jax array with conditional
continuation values.

"""
choice_axes = _determine_discrete_choice_axes_simulation(variable_info)

def _calculate_discrete_argmax(
values: Array,
choice_axes: tuple[int, ...],
params: ParamsDict, # noqa: ARG001
) -> tuple[Array, Array]:
return argmax(values, axis=choice_axes)

return partial(_calculate_discrete_argmax, choice_axes=choice_axes)


# ======================================================================================
# Discrete problem with no shocks
# ======================================================================================
Expand Down Expand Up @@ -129,10 +166,10 @@ def _calculate_emax_extreme_value_shocks(
# ======================================================================================


def _determine_discrete_choice_axes(
def _determine_discrete_choice_axes_solution(
variable_info: pd.DataFrame,
) -> tuple[int, ...]:
"""Get axes of a state-choice-space that correspond to discrete choices.
"""Get axes of state-choice-space that correspond to discrete choices in solution.

Args:
variable_info: DataFrame with information about the variables.
Expand All @@ -148,3 +185,24 @@ def _determine_discrete_choice_axes(
return tuple(
i for i, ax in enumerate(variable_info.index) if ax in discrete_choice_vars
)


def _determine_discrete_choice_axes_simulation(
variable_info: pd.DataFrame,
) -> tuple[int, ...]:
"""Get axes of state-choice-space that correspond to discrete choices in simulation.

Args:
variable_info: DataFrame with information about the variables.

Returns:
A tuple of indices representing the axes' positions in the value function that
correspond to discrete choices.

"""
discrete_choice_vars = set(
variable_info.query("is_choice & is_discrete").index.tolist()
)

# The first dimension corresponds to the simulated states, so add 1.
return tuple(1 + i for i in range(len(discrete_choice_vars)))
65 changes: 30 additions & 35 deletions src/lcm/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,59 @@
)


def spacemap(
def simulation_spacemap(
func: FunctionWithArrayReturn,
product_vars: tuple[str, ...],
combination_vars: tuple[str, ...],
choices_var_names: tuple[str, ...],
states_var_names: tuple[str, ...],
) -> FunctionWithArrayReturn:
"""Apply vmap such that func can be evaluated on product and combination variables.
"""Apply vmap such that func can be evaluated on choices and simulation states.

Product variables are used to create a Cartesian product of possible values. I.e.,
for each product variable, we create a new leading dimension in the output object,
with the size of the dimension being the number of possible values in the grid. The
i-th entries of the combination variables, correspond to one valid combination. For
the combination variables, a single dimension is thus added to the output object,
with the size of the dimension being the number of possible combinations. This means
that all combination variables must have the same size (e.g., in the simulation the
states act as combination variables, and their size equals the number of
simulations).
This function maps the function `func` over the simulation state-choice-space. That
is, it maps `func` over the Cartesian product of the choice variables, and over the
fixed simulation states. For each choice variable, a leading dimension is added to
the output object, with the length of the axis being the number of possible values
in the grid. Importantly, it does not create a Cartesian product over the state
variables, since these are fixed during the simulation. For the state variables,
a single dimension is added to the output object, with the length of the axis
being the number of simulated states.

spacemap preserves the function signature and allows the function to be called with
keyword arguments.
simulation_spacemap preserves the function signature and allows the function to be
called with keyword arguments.

Args:
func: The function to be dispatched.
product_vars: Names of the product variables, i.e. those that are stored as
choices_var_names: Names of the choice variables, i.e. those that are stored as
arrays of possible values in the grid, over which we create a Cartesian
product.
combination_vars: Names of the combination variables, i.e. those that are
stored as arrays of possible combinations.
states_var_names: Names of the state variables, i.e. those that are stored as
arrays of possible states.

Returns:
A callable with the same arguments as func (but with an additional leading
dimension) that returns a jax.Array or pytree of arrays. If `func` returns a
scalar, the dispatched function returns a jax.Array with k + 1 dimensions, where
k is the length of `product_vars` and the additional dimension corresponds to
the `combination_vars`. The order of the dimensions is determined by the order
of `product_vars`. If the output of `func` is a jax pytree, the usual jax
dimension) that returns an Array or pytree of Arrays. If `func` returns a
scalar, the dispatched function returns an Array with k + 1 dimensions, where k
is the length of `choices_var_names` and the additional dimension corresponds to
the `states_var_names`. The order of the dimensions is determined by the order
of `choices_var_names`. If the output of `func` is a jax pytree, the usual jax
behavior applies, i.e. the leading dimensions of all arrays in the pytree are as
described above but there might be additional dimensions.

"""
if duplicates := find_duplicates(product_vars, combination_vars):
if duplicates := find_duplicates(choices_var_names, states_var_names):
msg = (
"Same argument provided more than once in product variables or combination "
f"variables, or is present in both: {duplicates}"
"Same argument provided more than once in choices or states variables, "
f"or is present in both: {duplicates}"
)
raise ValueError(msg)

func_callable_with_args = allow_args(func)

vmapped = _base_productmap(func_callable_with_args, product_vars)
mappable_func = allow_args(func)

if combination_vars:
vmapped = vmap_1d(
vmapped, variables=combination_vars, callable_with="only_args"
)
vmapped = _base_productmap(mappable_func, choices_var_names)
vmapped = vmap_1d(vmapped, variables=states_var_names, callable_with="only_args")

# This raises a mypy error but is perfectly fine to do. See
# https://github.com/python/mypy/issues/12472
vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined]
vmapped.__signature__ = inspect.signature(mappable_func) # type: ignore[attr-defined]

return cast(FunctionWithArrayReturn, allow_only_kwargs(vmapped))

Expand Down Expand Up @@ -210,7 +205,7 @@ def _base_productmap(
# We iterate in reverse order such that the output dimensions are in the same order
# as the input dimensions.
for pos in reversed(positions):
spec = [None] * len(parameters) # type: list[int | None]
spec: list[int | None] = [None] * len(parameters)
spec[pos] = 0
vmap_specs.append(spec)

Expand Down
Loading