diff --git a/src/lcm/conditional_continuation.py b/src/lcm/conditional_continuation.py new file mode 100644 index 00000000..1f769e67 --- /dev/null +++ b/src/lcm/conditional_continuation.py @@ -0,0 +1,80 @@ +import functools +from collections.abc import Callable + +import jax.numpy as jnp +from jax import Array + +from lcm.argmax import argmax +from lcm.dispatchers import productmap +from lcm.typing import ParamsDict + + +def get_compute_conditional_continuation_value( + utility_and_feasibility: Callable[..., tuple[Array, Array]], + continuous_choice_variables: tuple[str, ...], +) -> Callable[..., Array]: + """Get a function that computes the conditional continuation value. + + This function solves the continuous choice problem conditional on a state- + (discrete-)choice combination; and is used in the model solution process. + + Args: + utility_and_feasibility: A function that takes a state-choice combination and + return the utility of that combination (scalar) and whether the state-choice + combination is feasible (bool). + continuous_choice_variables: Tuple of choice variable names that are continuous. + + Returns: + A function that takes a state-choice combination and returns the conditional + continuation value over the continuous choices. + + """ + if continuous_choice_variables: + utility_and_feasibility = productmap( + func=utility_and_feasibility, + variables=continuous_choice_variables, + ) + + @functools.wraps(utility_and_feasibility) + def compute_ccv(params: ParamsDict, **kwargs: Array) -> Array: + u, f = utility_and_feasibility(params=params, **kwargs) + return u.max(where=f, initial=-jnp.inf) + + return compute_ccv + + +def get_compute_conditional_continuation_policy( + utility_and_feasibility: Callable[..., tuple[Array, Array]], + continuous_choice_variables: tuple[str, ...], +) -> Callable[..., tuple[Array, Array]]: + """Get a function that computes the conditional continuation policy. + + This function solves the continuous choice problem conditional on a state- + (discrete-)choice combination; and is used in the model simulation process. + + Args: + utility_and_feasibility: A function that takes a state-choice combination and + return the utility of that combination (scalar) and whether the state-choice + combination is feasible (bool). + continuous_choice_variables: Tuple of choice variable names that are + continuous. + + Returns: + A function that takes a state-choice combination and returns the conditional + continuation value over the continuous choices, and the index that maximizes the + conditional continuation value. + + """ + if continuous_choice_variables: + utility_and_feasibility = productmap( + func=utility_and_feasibility, + variables=continuous_choice_variables, + ) + + @functools.wraps(utility_and_feasibility) + def compute_ccp(params: ParamsDict, **kwargs: Array) -> tuple[Array, Array]: + u, f = utility_and_feasibility(params=params, **kwargs) + _argmax, _max = argmax(u, where=f, initial=-jnp.inf) + return _argmax, _max + + return compute_ccp diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index dd405c81..fb245022 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -23,15 +23,21 @@ import pandas as pd from jax import Array -from lcm.typing import DiscreteProblemSolverFunction, ParamsDict, ShockType +from lcm.argmax import argmax +from lcm.typing import ( + DiscreteProblemPolicySolverFunction, + DiscreteProblemValueSolverFunction, + ParamsDict, + ShockType, +) -def get_solve_discrete_problem( +def get_solve_discrete_problem_value( *, random_utility_shock_type: ShockType, variable_info: pd.DataFrame, is_last_period: bool, -) -> DiscreteProblemSolverFunction: +) -> DiscreteProblemValueSolverFunction: """Get function that computes the expected max. of conditional continuation values. The maximum is taken over the discrete choice variables in each state. @@ -51,7 +57,7 @@ def get_solve_discrete_problem( if is_last_period: variable_info = variable_info.query("~is_auxiliary") - choice_axes = _determine_discrete_choice_axes(variable_info) + choice_axes = _determine_discrete_choice_axes_solution(variable_info) if random_utility_shock_type == ShockType.NONE: func = _solve_discrete_problem_no_shocks @@ -63,6 +69,37 @@ def get_solve_discrete_problem( return partial(func, choice_axes=choice_axes) +def get_solve_discrete_problem_policy( + *, + variable_info: pd.DataFrame, +) -> DiscreteProblemPolicySolverFunction: + """Return a function that calculates the argmax and max of continuation values. + + The argmax is taken over the discrete choice variables in each state. + + Args: + variable_info (pd.DataFrame): DataFrame with information about the model + variables. + + Returns: + callable: Function that calculates the argmax of the conditional continuation + values. The function depends on: + - values (jax.Array): Multidimensional jax array with conditional + continuation values. + + """ + choice_axes = _determine_discrete_choice_axes_simulation(variable_info) + + def _calculate_discrete_argmax( + values: Array, + choice_axes: tuple[int, ...], + params: ParamsDict, # noqa: ARG001 + ) -> tuple[Array, Array]: + return argmax(values, axis=choice_axes) + + return partial(_calculate_discrete_argmax, choice_axes=choice_axes) + + # ====================================================================================== # Discrete problem with no shocks # ====================================================================================== @@ -129,10 +166,10 @@ def _calculate_emax_extreme_value_shocks( # ====================================================================================== -def _determine_discrete_choice_axes( +def _determine_discrete_choice_axes_solution( variable_info: pd.DataFrame, ) -> tuple[int, ...]: - """Get axes of a state-choice-space that correspond to discrete choices. + """Get axes of state-choice-space that correspond to discrete choices in solution. Args: variable_info: DataFrame with information about the variables. @@ -148,3 +185,24 @@ def _determine_discrete_choice_axes( return tuple( i for i, ax in enumerate(variable_info.index) if ax in discrete_choice_vars ) + + +def _determine_discrete_choice_axes_simulation( + variable_info: pd.DataFrame, +) -> tuple[int, ...]: + """Get axes of state-choice-space that correspond to discrete choices in simulation. + + Args: + variable_info: DataFrame with information about the variables. + + Returns: + A tuple of indices representing the axes' positions in the value function that + correspond to discrete choices. + + """ + discrete_choice_vars = set( + variable_info.query("is_choice & is_discrete").index.tolist() + ) + + # The first dimension corresponds to the simulated states, so add 1. + return tuple(1 + i for i in range(len(discrete_choice_vars))) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 35ba50b4..85b64c61 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -12,64 +12,59 @@ ) -def spacemap( +def simulation_spacemap( func: FunctionWithArrayReturn, - product_vars: tuple[str, ...], - combination_vars: tuple[str, ...], + choices_var_names: tuple[str, ...], + states_var_names: tuple[str, ...], ) -> FunctionWithArrayReturn: - """Apply vmap such that func can be evaluated on product and combination variables. + """Apply vmap such that func can be evaluated on choices and simulation states. - Product variables are used to create a Cartesian product of possible values. I.e., - for each product variable, we create a new leading dimension in the output object, - with the size of the dimension being the number of possible values in the grid. The - i-th entries of the combination variables, correspond to one valid combination. For - the combination variables, a single dimension is thus added to the output object, - with the size of the dimension being the number of possible combinations. This means - that all combination variables must have the same size (e.g., in the simulation the - states act as combination variables, and their size equals the number of - simulations). + This function maps the function `func` over the simulation state-choice-space. That + is, it maps `func` over the Cartesian product of the choice variables, and over the + fixed simulation states. For each choice variable, a leading dimension is added to + the output object, with the length of the axis being the number of possible values + in the grid. Importantly, it does not create a Cartesian product over the state + variables, since these are fixed during the simulation. For the state variables, + a single dimension is added to the output object, with the length of the axis + being the number of simulated states. - spacemap preserves the function signature and allows the function to be called with - keyword arguments. + simulation_spacemap preserves the function signature and allows the function to be + called with keyword arguments. Args: func: The function to be dispatched. - product_vars: Names of the product variables, i.e. those that are stored as + choices_var_names: Names of the choice variables, i.e. those that are stored as arrays of possible values in the grid, over which we create a Cartesian product. - combination_vars: Names of the combination variables, i.e. those that are - stored as arrays of possible combinations. + states_var_names: Names of the state variables, i.e. those that are stored as + arrays of possible states. Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.Array or pytree of arrays. If `func` returns a - scalar, the dispatched function returns a jax.Array with k + 1 dimensions, where - k is the length of `product_vars` and the additional dimension corresponds to - the `combination_vars`. The order of the dimensions is determined by the order - of `product_vars`. If the output of `func` is a jax pytree, the usual jax + dimension) that returns an Array or pytree of Arrays. If `func` returns a + scalar, the dispatched function returns an Array with k + 1 dimensions, where k + is the length of `choices_var_names` and the additional dimension corresponds to + the `states_var_names`. The order of the dimensions is determined by the order + of `choices_var_names`. If the output of `func` is a jax pytree, the usual jax behavior applies, i.e. the leading dimensions of all arrays in the pytree are as described above but there might be additional dimensions. """ - if duplicates := find_duplicates(product_vars, combination_vars): + if duplicates := find_duplicates(choices_var_names, states_var_names): msg = ( - "Same argument provided more than once in product variables or combination " - f"variables, or is present in both: {duplicates}" + "Same argument provided more than once in choices or states variables, " + f"or is present in both: {duplicates}" ) raise ValueError(msg) - func_callable_with_args = allow_args(func) - - vmapped = _base_productmap(func_callable_with_args, product_vars) + mappable_func = allow_args(func) - if combination_vars: - vmapped = vmap_1d( - vmapped, variables=combination_vars, callable_with="only_args" - ) + vmapped = _base_productmap(mappable_func, choices_var_names) + vmapped = vmap_1d(vmapped, variables=states_var_names, callable_with="only_args") # This raises a mypy error but is perfectly fine to do. See # https://github.com/python/mypy/issues/12472 - vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined] + vmapped.__signature__ = inspect.signature(mappable_func) # type: ignore[attr-defined] return cast(FunctionWithArrayReturn, allow_only_kwargs(vmapped)) @@ -210,7 +205,7 @@ def _base_productmap( # We iterate in reverse order such that the output dimensions are in the same order # as the input dimensions. for pos in reversed(positions): - spec = [None] * len(parameters) # type: list[int | None] + spec: list[int | None] = [None] * len(parameters) spec[pos] = 0 vmap_specs.append(spec) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 0d89ddca..32048b7a 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,27 +1,31 @@ -import functools from collections.abc import Callable from functools import partial from typing import Literal import jax -import jax.numpy as jnp import pandas as pd from jax import Array -from lcm.argmax import argmax -from lcm.discrete_problem import get_solve_discrete_problem -from lcm.dispatchers import productmap +from lcm.conditional_continuation import ( + get_compute_conditional_continuation_policy, + get_compute_conditional_continuation_value, +) +from lcm.discrete_problem import get_solve_discrete_problem_value from lcm.input_processing import process_model +from lcm.interfaces import StateChoiceSpace, StateSpaceInfo from lcm.logging import get_logger -from lcm.model_functions import ( - get_utility_and_feasibility_function, -) from lcm.next_state import get_next_state_function -from lcm.simulation.simulate import simulate +from lcm.simulation.simulate import simulate, solve_and_simulate from lcm.solution.solve_brute import solve -from lcm.solution.state_choice_space import create_state_choice_space -from lcm.typing import ParamsDict, Target +from lcm.state_choice_space import ( + create_state_choice_space, + create_state_space_info, +) +from lcm.typing import DiscreteProblemValueSolverFunction, ParamsDict, Target from lcm.user_model import Model +from lcm.utility_and_feasibility import ( + get_utility_and_feasibility_function, +) def get_lcm_function( @@ -30,7 +34,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"], debug_mode: bool = True, jit: bool = True, -) -> tuple[Callable[..., list[Array] | pd.DataFrame], ParamsDict]: +) -> tuple[Callable[..., dict[int, Array] | pd.DataFrame], ParamsDict]: """Entry point for users to get high level functions generated by lcm. Return the function to solve and/or simulate a model along with a template for the @@ -59,86 +63,66 @@ def get_lcm_function( if targets not in {"solve", "simulate", "solve_and_simulate"}: raise NotImplementedError - _mod = process_model(model) - last_period = _mod.n_periods - 1 + internal_model = process_model(model) + last_period = internal_model.n_periods - 1 logger = get_logger(debug_mode=debug_mode) # ================================================================================== - # create list of continuous choice grids + # Create model functions and state-choice-spaces # ================================================================================== - # for now they are the same in all periods but this will change. - _subset = _mod.variable_info.query("is_continuous & is_choice").index.tolist() - _choice_grids = {k: _mod.grids[k] for k in _subset} - continuous_choice_grids = [_choice_grids] * _mod.n_periods + state_choice_spaces: dict[int, StateChoiceSpace] = {} + state_space_infos: dict[int, StateSpaceInfo] = {} + compute_ccv_functions: dict[int, Callable[[Array, Array], Array]] = {} + compute_ccp_functions: dict[int, Callable[..., tuple[Array, Array]]] = {} + solve_discrete_problem_functions: dict[int, DiscreteProblemValueSolverFunction] = {} - # ================================================================================== - # Initialize other argument lists - # ================================================================================== - state_choice_spaces = [] - state_space_infos = [] - compute_ccv_functions = [] - compute_ccv_policy_functions = [] - choice_segments = [] # type: ignore[var-annotated] - emax_calculators = [] - - # ================================================================================== - # Create stace choice space for each period - # ================================================================================== - for period in range(_mod.n_periods): + for period in reversed(range(internal_model.n_periods)): is_last_period = period == last_period - # call state space creation function, append trivial items to their lists - # ============================================================================== - state_choice_space, state_space_info = create_state_choice_space( - model=_mod, + state_choice_space = create_state_choice_space( + model=internal_model, is_last_period=is_last_period, ) - state_choice_spaces.append(state_choice_space) - choice_segments.append(None) - state_space_infos.append(state_space_info) - - # ================================================================================== - # Shift space info (in period t we require the space info of period t+1) - # ================================================================================== - state_space_infos = state_space_infos[1:] + [{}] # type: ignore[list-item] + state_space_info = create_state_space_info( + model=internal_model, + is_last_period=is_last_period, + ) - # ================================================================================== - # Create model functions - # ================================================================================== - for period in range(_mod.n_periods): - is_last_period = period == last_period + if is_last_period: + next_state_space_info = LastPeriodsNextStateSpaceInfo + else: + next_state_space_info = state_space_infos[period + 1] - # create the compute conditional continuation value functions and append to list - # ============================================================================== u_and_f = get_utility_and_feasibility_function( - model=_mod, - state_space_info=state_space_infos[period], + model=internal_model, + next_state_space_info=next_state_space_info, period=period, is_last_period=is_last_period, ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=list(_choice_grids), + continuous_choice_variables=tuple(state_choice_space.continuous_choices), ) - compute_ccv_functions.append(compute_ccv) - compute_ccv_argmax = create_compute_conditional_continuation_policy( + compute_ccp = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=list(_choice_grids), + continuous_choice_variables=tuple(state_choice_space.continuous_choices), ) - compute_ccv_policy_functions.append(compute_ccv_argmax) - # create list of emax_calculators - # ============================================================================== - calculator = get_solve_discrete_problem( - random_utility_shock_type=_mod.random_utility_shocks, - variable_info=_mod.variable_info, + solve_discrete_problem = get_solve_discrete_problem_value( + random_utility_shock_type=internal_model.random_utility_shocks, + variable_info=internal_model.variable_info, is_last_period=is_last_period, ) - emax_calculators.append(calculator) + + state_choice_spaces[period] = state_choice_space + state_space_infos[period] = state_space_info + compute_ccv_functions[period] = compute_ccv + compute_ccp_functions[period] = compute_ccp + solve_discrete_problem_functions[period] = solve_discrete_problem # ================================================================================== # select requested solver and partial arguments into it @@ -146,109 +130,49 @@ def get_lcm_function( _solve_model = partial( solve, state_choice_spaces=state_choice_spaces, - continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=compute_ccv_functions, - emax_calculators=emax_calculators, + emax_calculators=solve_discrete_problem_functions, logger=logger, ) + _next_state_simulate = get_next_state_function( + model=internal_model, target=Target.SIMULATE + ) + solve_model = jax.jit(_solve_model) if jit else _solve_model + next_state_simulate = jax.jit(_next_state_simulate) if jit else _next_state_simulate - _next_state_simulate = get_next_state_function(model=_mod, target=Target.SIMULATE) simulate_model = partial( simulate, - continuous_choice_grids=continuous_choice_grids, - compute_ccv_policy_functions=compute_ccv_policy_functions, - model=_mod, - next_state=jax.jit(_next_state_simulate), + compute_ccv_policy_functions=compute_ccp_functions, + model=internal_model, + next_state=next_state_simulate, # type: ignore[arg-type] + logger=logger, + ) + + solve_and_simulate_model = partial( + solve_and_simulate, + compute_ccv_policy_functions=compute_ccp_functions, + model=internal_model, + next_state=next_state_simulate, # type: ignore[arg-type] logger=logger, + solve_model=solve_model, ) - target_func: Callable[..., list[Array] | pd.DataFrame] + target_func: Callable[..., dict[int, Array] | pd.DataFrame] if targets == "solve": target_func = solve_model elif targets == "simulate": target_func = simulate_model elif targets == "solve_and_simulate": - target_func = partial(simulate_model, solve_model=solve_model) - - return target_func, _mod.params - - -def create_compute_conditional_continuation_value( - utility_and_feasibility: Callable[..., tuple[Array, Array]], - continuous_choice_variables: list[str], -) -> Callable[..., Array]: - """Create a function that computes the conditional continuation value. - - Note: - ----- - This function solves the continuous choice problem conditional on a state- - (discrete-)choice combination. - - Args: - utility_and_feasibility (callable): A function that takes a state-choice - combination and return the utility of that combination (float) and whether - the state-choice combination is feasible (bool). - continuous_choice_variables (list): List of choice variable names that are - continuous. - - Returns: - A function that takes a state-choice combination and returns the conditional - continuation value over the continuous choices. - - """ - if continuous_choice_variables: - utility_and_feasibility = productmap( - func=utility_and_feasibility, - variables=tuple(continuous_choice_variables), - ) + target_func = solve_and_simulate_model - @functools.wraps(utility_and_feasibility) - def compute_ccv(*args: Array | ParamsDict, **kwargs: Array | ParamsDict) -> Array: - u, f = utility_and_feasibility(*args, **kwargs) - return u.max(where=f, initial=-jnp.inf) + return target_func, internal_model.params - return compute_ccv - -def create_compute_conditional_continuation_policy( - utility_and_feasibility: Callable[..., tuple[Array, Array]], - continuous_choice_variables: list[str], -) -> Callable[..., tuple[Array, Array]]: - """Create a function that computes the conditional continuation policy. - - Note: - ----- - This function solves the continuous choice problem conditional on a state- - (discrete-)choice combination. - - Args: - utility_and_feasibility (callable): A function that takes a state-choice - combination and return the utility of that combination (float) and whether - the state-choice combination is feasible (bool). - continuous_choice_variables (list): List of choice variable names that are - continuous. - - Returns: - A function that takes a state-choice combination and returns the conditional - continuation value over the continuous choices, and the index that maximizes the - conditional continuation value. - - """ - if continuous_choice_variables: - utility_and_feasibility = productmap( - func=utility_and_feasibility, - variables=tuple(continuous_choice_variables), - ) - - @functools.wraps(utility_and_feasibility) - def compute_ccv_policy( - *args: Array | ParamsDict, **kwargs: Array | ParamsDict - ) -> tuple[Array, Array]: - u, f = utility_and_feasibility(*args, **kwargs) - _argmax, _max = argmax(u, where=f, initial=-jnp.inf) - return _argmax, _max - - return compute_ccv_policy +LastPeriodsNextStateSpaceInfo = StateSpaceInfo( + states_names=(), + discrete_states={}, + continuous_states={}, +) diff --git a/src/lcm/input_processing/util.py b/src/lcm/input_processing/util.py index 72734d1b..f1fdfc7b 100644 --- a/src/lcm/input_processing/util.py +++ b/src/lcm/input_processing/util.py @@ -25,7 +25,8 @@ def get_function_info(model: Model) -> pd.DataFrame: info["is_constraint"] = info.index.str.endswith(("_constraint", "_filter")) info["is_next"] = info.index.str.startswith("next_") & ~info["is_constraint"] info["is_stochastic_next"] = [ - hasattr(func, "_stochastic_info") for func in model.functions.values() + hasattr(func, "_stochastic_info") and info.loc[func_name]["is_next"] + for func_name, func in model.functions.items() ] return info diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index f3ede382..51f2c60d 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -1,3 +1,4 @@ +import dataclasses as dc from collections.abc import Mapping from dataclasses import dataclass @@ -6,6 +7,7 @@ from lcm.grids import ContinuousGrid, DiscreteGrid, Grid from lcm.typing import InternalUserFunction, ParamsDict, ShockType +from lcm.utils import first_non_none @dataclass(frozen=True) @@ -24,18 +26,56 @@ class StateChoiceSpace: The state-choice space becomes the product of state-combinations with the full Cartesian product of the choice variables. + Note: + ----- + We store discrete and continuous choices separately since these are handled during + different stages of the solution and simulation processes. + Attributes: states: Dictionary containing the values of the state variables. - choices: Dictionary containing the values of the choice variables. + discrete_choices: Dictionary containing the values of the discrete choice + variables. + continuous_choices: Dictionary containing the values of the continuous choice + variables. ordered_var_names: Tuple with names of state and choice variables in the order they appear in the variable info table. """ states: dict[str, Array] - choices: dict[str, Array] + discrete_choices: dict[str, Array] + continuous_choices: dict[str, Array] ordered_var_names: tuple[str, ...] + def replace( + self, + states: dict[str, Array] | None = None, + discrete_choices: dict[str, Array] | None = None, + continuous_choices: dict[str, Array] | None = None, + ) -> "StateChoiceSpace": + """Replace the states or choices in the state-choice space. + + Args: + states: Dictionary with new states. If None, the existing states are used. + discrete_choices: Dictionary with new discrete choices. If None, the + existing discrete choices are used. + continuous_choices: Dictionary with new continuous choices. If None, the + existing continuous choices are used. + + Returns: + New state-choice space with the replaced states or choices. + + """ + states = first_non_none(states, self.states) + discrete_choices = first_non_none(discrete_choices, self.discrete_choices) + continuous_choices = first_non_none(continuous_choices, self.continuous_choices) + return dc.replace( + self, + states=states, + discrete_choices=discrete_choices, + continuous_choices=continuous_choices, + ) + @dataclass(frozen=True) class StateSpaceInfo: @@ -95,3 +135,12 @@ class InternalModel: n_periods: int # Not properly processed yet random_utility_shocks: ShockType + + +@dataclass(frozen=True) +class InternalSimulationPeriodResults: + """The results of a simulation for one period.""" + + value: Array + choices: dict[str, Array] + states: dict[str, Array] diff --git a/src/lcm/model_functions.py b/src/lcm/model_functions.py deleted file mode 100644 index 6e14b029..00000000 --- a/src/lcm/model_functions.py +++ /dev/null @@ -1,211 +0,0 @@ -import inspect -from collections.abc import Callable - -import jax.numpy as jnp -from dags import concatenate_functions -from dags.signature import with_signature -from jax import Array - -from lcm.dispatchers import productmap -from lcm.function_representation import get_value_function_representation -from lcm.functools import ( - all_as_args, - all_as_kwargs, - get_union_of_arguments, -) -from lcm.interfaces import InternalModel, StateSpaceInfo -from lcm.next_state import get_next_state_function -from lcm.typing import InternalUserFunction, ParamsDict, Scalar, Target - - -def get_utility_and_feasibility_function( - model: InternalModel, - state_space_info: StateSpaceInfo, - period: int, - *, - is_last_period: bool, -) -> Callable[..., tuple[Array, Array]]: - # ================================================================================== - # Gather information on the model variables - # ================================================================================== - state_variables = model.variable_info.query("is_state").index.tolist() - choice_variables = model.variable_info.query("is_choice").index.tolist() - stochastic_variables = model.variable_info.query("is_stochastic").index.tolist() - - # ================================================================================== - # Generate dynamic functions - # ================================================================================== - current_u_and_f = get_current_u_and_f(model) - - if is_last_period: - relevant_functions: list[ - Callable[..., Scalar] - | Callable[..., tuple[Scalar, Scalar]] - | Callable[..., dict[str, Scalar]] - ] = [current_u_and_f] - - else: - next_state = get_next_state_function(model, target=Target.SOLVE) - next_weights = get_next_weights_function(model) - - scalar_value_function = get_value_function_representation(state_space_info) - - multiply_weights = get_multiply_weights(stochastic_variables) - - relevant_functions = [ - current_u_and_f, - next_state, - next_weights, - scalar_value_function, - ] - - value_function_arguments = list( - inspect.signature(scalar_value_function).parameters, - ) - - # ================================================================================== - # Create the utility and feasability function - # ================================================================================== - - arg_names_set = {"vf_arr"} | get_union_of_arguments(relevant_functions) - { - "_period" - } - arg_names = [arg for arg in arg_names_set if "next_" not in arg] - - if is_last_period: - - @with_signature(args=arg_names) - def u_and_f( - *args: Scalar, params: ParamsDict, **kwargs: Scalar - ) -> tuple[Scalar, Scalar]: - kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names) - - states = {k: v for k, v in kwargs.items() if k in state_variables} - choices = {k: v for k, v in kwargs.items() if k in choice_variables} - - return current_u_and_f( - **states, - **choices, - _period=period, - params=params, - ) - - else: - - @with_signature(args=arg_names) - def u_and_f( - *args: Scalar, params: ParamsDict, **kwargs: Scalar - ) -> tuple[Scalar, Scalar]: - kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names) - - states = {k: v for k, v in kwargs.items() if k in state_variables} - choices = {k: v for k, v in kwargs.items() if k in choice_variables} - - u, f = current_u_and_f( - **states, - **choices, - _period=period, - params=params, - ) - - _next_state = next_state( - **states, - **choices, - _period=period, - params=params, - ) - weights = next_weights( - **states, - **choices, - _period=period, - params=params, - ) - - value_function = productmap( - scalar_value_function, - variables=tuple(f"next_{var}" for var in stochastic_variables), - ) - - ccvs_at_nodes = value_function( - **_next_state, - **{k: v for k, v in kwargs.items() if k in value_function_arguments}, - ) - - node_weights = multiply_weights(**weights) - - ccv = (ccvs_at_nodes * node_weights).sum() - - big_u = u + params["beta"] * ccv - return big_u, f - - return u_and_f - - -def get_multiply_weights(stochastic_variables: list[str]) -> Callable[..., Array]: - """Get multiply_weights function. - - Args: - stochastic_variables (list): List of stochastic variables. - - Returns: - A function that multiplies the weights of the stochastic variables. - - """ - arg_names = [f"weight_next_{var}" for var in stochastic_variables] - - @with_signature(args=arg_names) - def _outer(*args: Array, **kwargs: Array) -> Array: - args = all_as_args(args, kwargs, arg_names=arg_names) - return jnp.prod(jnp.array(args)) - - return productmap(_outer, variables=tuple(arg_names)) - - -def get_combined_constraint(model: InternalModel) -> InternalUserFunction: - """Create a function that combines all constraint functions into a single one. - - Args: - model: The internal model object. - - Returns: - The combined constraint function. - - """ - targets = model.function_info.query("is_constraint").index.tolist() - - if targets: - combined_constraint = concatenate_functions( - functions=model.functions, - targets=targets, - aggregator=jnp.logical_and, - ) - else: - - def combined_constraint() -> None: - return None - - return combined_constraint - - -def get_current_u_and_f(model: InternalModel) -> Callable[..., tuple[Scalar, Scalar]]: - functions = {"feasibility": get_combined_constraint(model), **model.functions} - - return concatenate_functions( - functions=functions, - targets=["utility", "feasibility"], - enforce_signature=False, - ) - - -def get_next_weights_function(model: InternalModel) -> Callable[..., dict[str, Scalar]]: - targets = [ - f"weight_{name}" - for name in model.function_info.query("is_stochastic_next").index.tolist() - ] - - return concatenate_functions( - functions=model.functions, - targets=targets, - return_type="dict", - enforce_signature=False, - ) diff --git a/src/lcm/next_state.py b/src/lcm/next_state.py index 88dbed7f..dc6dd5e4 100644 --- a/src/lcm/next_state.py +++ b/src/lcm/next_state.py @@ -1,11 +1,4 @@ -"""Generate functions that compute the next states of the model. - -For the solution, we simply concatenate the functions that compute the next states. For -the simulation, we generate functions that simulate the next states of stochastic -variables. We then concatenate these functions with the functions that compute the -deteministic next states. - -""" +"""Generate function that compute the next states for solution and simulation.""" from collections.abc import Callable @@ -13,53 +6,63 @@ from dags.signature import with_signature from jax import Array -from lcm.functools import all_as_args from lcm.interfaces import InternalModel -from lcm.random_choice import random_choice -from lcm.typing import Scalar, Target +from lcm.random import random_choice +from lcm.typing import Scalar, StochasticNextFunction, Target def get_next_state_function( - model: InternalModel, target: Target + model: InternalModel, + target: Target, ) -> Callable[..., dict[str, Scalar]]: - """Get function that computes the next states of the model. + """Get function that computes the next states during the solution. Args: - model: Internal model. - target: Target of the function. + model: Internal model instance. + target: Whether to generate the function for the solve or simulate target. Returns: - Function that computes the next states of the model. + Function that computes the next states. Depends on states and choices of the + current period, and the model parameters ("params"). If target is "simulate", + the function also depends on the dictionary of random keys ("keys"), which + corresponds to the names of stochastic next functions. """ - if target == Target.SOLVE: - return _get_next_state_function_for_solution(model) - - if target == Target.SIMULATE: - return _get_next_state_function_for_simulation(model) - - raise ValueError(f"Invalid target: {target}") + targets = model.function_info.query("is_next").index.tolist() + if target == Target.SOLVE: + functions_dict = model.functions + elif target == Target.SIMULATE: + # For the simulation target, we need to extend the functions dictionary with + # stochastic next states functions and their weights. + functions_dict = _extend_functions_dict_for_simulation(model) + else: + raise ValueError(f"Invalid target: {target}") -# ====================================================================================== -# Solution -# ====================================================================================== + return concatenate_functions( + functions=functions_dict, + targets=targets, + return_type="dict", + enforce_signature=False, + ) -def _get_next_state_function_for_solution( +def get_next_stochastic_weights_function( model: InternalModel, -) -> Callable[..., dict[str, Scalar]]: - """Get function that computes the next states for the solution. +) -> Callable[..., dict[str, Array]]: + """Get function that computes the weights for the next stochastic states. Args: - model: Model instance. + model: Internal model instance. Returns: - Function that computes the next states. Depends on states and choices of the - current period, and the model parameters. + Function that computes the weights for the next stochastic states. """ - targets = model.function_info.query("is_next").index.tolist() + targets = [ + f"weight_{name}" + for name in model.function_info.query("is_stochastic_next").index.tolist() + ] return concatenate_functions( functions=model.functions, @@ -69,100 +72,70 @@ def _get_next_state_function_for_solution( ) -# ====================================================================================== -# Simulation -# ====================================================================================== - - -def _get_next_state_function_for_simulation( +def _extend_functions_dict_for_simulation( model: InternalModel, -) -> Callable[..., dict[str, Scalar]]: - """Get function that computes the next states for the simulation. +) -> dict[str, Callable[..., Scalar]]: + """Extend the functions dictionary for the simulation target. Args: - model: Model instance. + model: Internal model instance. Returns: - Function that computes the next states. Depends on states and choices of the - current period, and the model parameters. Additionaly, it depends on: - - key (dict): Dictionary with PRNG keys. Keys are the names of stochastic next - functions, e.g. 'next_health'. + Extended functions dictionary. """ - # ================================================================================== - # Get targets - # ================================================================================== - targets = model.function_info.query("is_next").index.tolist() - - stochastic_targets = model.function_info.query( - "is_next & is_stochastic_next", - ).index + stochastic_targets = model.function_info.query("is_stochastic_next").index - # ================================================================================== # Handle stochastic next states functions # ---------------------------------------------------------------------------------- # We generate stochastic next states functions that simulate the next state given - # a PRNG key and the weights of the stochastic variable. The corresponding weights - # are computed using the stochastic weight functions, which we add the to functions - # dict. `dags.concatenate_functions` then generates a function that computes the - # weights and simulates the next state. - # ================================================================================== + # a random key (think of a seed) and the weights corresponding to the labels of the + # stochastic variable. The weights are computed using the stochastic weight + # functions, which we add the to functions dict. `dags.concatenate_functions` then + # generates a function that computes the weights and simulates the next state in + # one go. + # ---------------------------------------------------------------------------------- stochastic_next = { - name: _get_stochastic_next_func(name, grids=model.grids) + name: _create_stochastic_next_func( + name, labels=model.grids[name.removeprefix("next_")] + ) for name in stochastic_targets } - stochastic_weights_names = [ - f"weight_{name}" - for name in model.function_info.query("is_stochastic_next").index.tolist() - ] - stochastic_weights = { - name: model.functions[name] for name in stochastic_weights_names + f"weight_{name}": model.functions[f"weight_{name}"] + for name in stochastic_targets } - # ================================================================================== # Overwrite model.functions with generated stochastic next states functions - # ================================================================================== - functions_dict = model.functions | stochastic_next | stochastic_weights - - return concatenate_functions( - functions=functions_dict, - targets=targets, - return_type="dict", - enforce_signature=False, - ) + # ---------------------------------------------------------------------------------- + return model.functions | stochastic_next | stochastic_weights -def _get_stochastic_next_func( - name: str, grids: dict[str, Array] -) -> Callable[[Array, Array], Array]: +def _create_stochastic_next_func(name: str, labels: Array) -> StochasticNextFunction: """Get function that simulates the next state of a stochastic variable. Args: name: Name of the stochastic variable. - grids: Dict with grids. + labels: 1d array of labels. Returns: - A function that simulates the next state of the stochastic variable. Depends on - variables: - - key (dict): Dictionary with PRNG keys. Keys are the names of stochastic next - functions, e.g. 'next_health'. - - weight_{name} (jax.numpy.array): 2d array of weights. The first dimension - corresponds to the number of simulation units. The second dimension - corresponds to the number of grid points (labels). + A function that simulates the next state of the stochastic variable. The + function must be called with keyword arguments: + - weight_{name}: 2d array of weights. The first dimension corresponds to the + number of simulation units. The second dimension corresponds to the number of + grid points (labels). + - keys: Dictionary with random key arrays. Dictionary keys correspond to the + names of stochastic next functions, e.g. 'next_health'. """ - arg_names = ["keys", f"weight_{name}"] - labels = grids[name.removeprefix("next_")] - @with_signature(args=arg_names) - def _next_stochastic_state(*args: Array, **kwargs: Array) -> Array: - keys, weights = all_as_args(args, kwargs, arg_names=arg_names) + @with_signature(args=[f"weight_{name}", "keys"]) + def next_stochastic_state(keys: dict[str, Array], **kwargs: Array) -> Array: return random_choice( - key=keys[name], - probs=weights, labels=labels, + probs=kwargs[f"weight_{name}"], + key=keys[name], ) - return _next_stochastic_state + return next_stochastic_state diff --git a/src/lcm/random.py b/src/lcm/random.py new file mode 100644 index 00000000..29feb751 --- /dev/null +++ b/src/lcm/random.py @@ -0,0 +1,73 @@ +import os +from functools import partial + +import jax +from jax import Array + + +def random_choice( + labels: jax.Array, + probs: jax.Array, + key: jax.Array, +) -> jax.Array: + """Draw multiple random choices. + + Args: + labels: 1d array of labels. + probs: 2d array of probabilities. Second dimension must be + the same length as the first dimension of labels. + key: Random key. + + Returns: + Selected labels. 1d array of length len(probs). + + """ + keys = jax.random.split(key, probs.shape[0]) + return _vmapped_choice(keys, probs, labels) + + +@partial(jax.vmap, in_axes=(0, 0, None)) +def _vmapped_choice(key: jax.Array, probs: jax.Array, labels: jax.Array) -> jax.Array: + return jax.random.choice(key, a=labels, p=probs) + + +def generate_simulation_keys( + key: Array, ids: list[str] +) -> tuple[Array, dict[str, Array]]: + """Generate pseudo-random number generator keys (PRNG keys) for simulation. + + PRNG keys in JAX are immutable objects used to control random number generation. + A key can be used to generate a stream of random numbers, e.g., given a key, one can + call jax.random.normal(key) to generate a stream of normal random numbers. In order + to ensure that each simulation is based on a different stream of random numbers, we + split the key into one key per stochastic variable, and one key that will be passed + to the next iteration in order to generate new keys. + + See the JAX documentation for more details: + https://docs.jax.dev/en/latest/random-numbers.html#random-numbers-in-jax + + Args: + key: Random key. + ids: List of names for which a key is to be generated. + + Returns: + - Updated random key. + - Dict with random keys for each id in ids. + + """ + keys = jax.random.split(key, num=len(ids) + 1) + + key = keys[0] + simulation_keys = dict(zip(ids, keys[1:], strict=True)) + + return key, simulation_keys + + +def draw_random_seed() -> int: + """Generate a random seed using the operating system's secure entropy pool. + + Returns: + Random seed. + + """ + return int.from_bytes(os.urandom(4), "little") diff --git a/src/lcm/random_choice.py b/src/lcm/random_choice.py deleted file mode 100644 index dceb1663..00000000 --- a/src/lcm/random_choice.py +++ /dev/null @@ -1,31 +0,0 @@ -from functools import partial - -import jax - - -def random_choice( - key: jax.Array, - probs: jax.Array, - labels: jax.Array, -) -> jax.Array: - """Draw multiple random choices. - - Args: - key: Random key. - probs: 2d array of probabilities. Second dimension must be - the same length as the first dimension of labels. - labels: 1d array of labels. - - Returns: - Selected labels. 1d array of length len(probs). - - """ - keys = jax.random.split(key, probs.shape[0]) - return _vmapped_random_choice(keys, probs, labels) - - -@partial(jax.vmap, in_axes=(0, 0, None)) -def _vmapped_random_choice( - key: jax.Array, probs: jax.Array, labels: jax.Array -) -> jax.Array: - return jax.random.choice(key, a=labels, p=probs) diff --git a/src/lcm/sandbox/state_space_jax_versus_numba.ipynb b/src/lcm/sandbox/state_space_jax_versus_numba.ipynb deleted file mode 100644 index a8f034a6..00000000 --- a/src/lcm/sandbox/state_space_jax_versus_numba.ipynb +++ /dev/null @@ -1,835 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": {}, - "outputs": [], - "source": [ - "import itertools\n", - "import math\n", - "\n", - "import jax.numpy as jnp\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from jax import jit\n", - "from jax.config import config\n", - "from numba import njit\n", - "\n", - "config.update(\"jax_enable_x64\", True)\n", - "\n", - "from dags import concatenate_functions\n", - "from lcm.dispatchers import gridmap, productmap\n", - "from numpy.testing import assert_array_almost_equal as aaae" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_runtime(n_points, runtime_numba, runtime_jax, func=\"product_map\"):\n", - " plt.figure(figsize=[10, 6])\n", - " plt.plot(n_points, [i * 1000 for i in runtime_numba], label=\"Numba\")\n", - " plt.plot(n_points, [i * 1000 for i in runtime_jax], label=\"JAX\")\n", - "\n", - " plt.xlabel(\"Number of Grid Points\")\n", - " plt.ylabel(\"Runtime (in milliseconds)\")\n", - " plt.title(f\"Runtime of hardcoded Numba function versus {func} with JAX\")\n", - "\n", - " plt.legend()\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "2", - "metadata": {}, - "source": [ - "## product_map: 2 grids benchmark" - ] - }, - { - "cell_type": "markdown", - "id": "3", - "metadata": {}, - "source": [ - "### Trivial setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "def _utility(_consumption, _leisure):\n", - " return _consumption + _leisure\n", - "\n", - "\n", - "def _leisure(working):\n", - " return 24 - working\n", - "\n", - "\n", - "def _consumption(working, wage):\n", - " return wage * working" - ] - }, - { - "cell_type": "markdown", - "id": "5", - "metadata": {}, - "source": [ - "### Numba: Dimensions hardcoded" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "@njit(fastmath=True)\n", - "def product_map_numba_2d(wage, working):\n", - " n_wage = wage.shape[0]\n", - " n_working = working.shape[0]\n", - "\n", - " cross_product = np.empty((n_wage, n_working))\n", - "\n", - " for i in range(n_wage):\n", - " for j in range(n_working):\n", - " cross_product[i, j] = wage[i] * working[j] + (24 - working[j])\n", - "\n", - " return cross_product\n", - "\n", - "\n", - "wage = np.linspace(1, 10, 101)\n", - "working = np.linspace(0, 24, 25)\n", - "rslt_numba = product_map_numba_2d(wage, working)" - ] - }, - { - "cell_type": "markdown", - "id": "7", - "metadata": {}, - "source": [ - "### LCM productmap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "grids = {\n", - " \"wage\": jnp.linspace(1, 10, 101),\n", - " \"working\": jnp.linspace(0, 24, 25),\n", - "}\n", - "\n", - "utility_concat = concatenate_functions(\n", - " functions=[_utility, _leisure, _consumption],\n", - " targets=\"_utility\",\n", - ")\n", - "_decorated_func = productmap(utility_concat, [\"wage\", \"working\"])\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**grids)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "### Assert equality\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime_2d():\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grid in grid_space:\n", - " total_points.append(len_grid**2)\n", - "\n", - " grids = {\n", - " \"wage\": np.linspace(1, 10, len_grid),\n", - " \"working\": np.linspace(0, 24, len_grid),\n", - " }\n", - "\n", - " product_map_numba_2d(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o product_map_numba_2d(**grids)\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime_2d():\n", - " utility_concat = concatenate_functions(\n", - " functions=[_utility, _leisure, _consumption],\n", - " targets=\"_utility\",\n", - " )\n", - " _decorated_func = productmap(utility_concat, [\"wage\", \"working\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grid in grid_space:\n", - " total_points.append(len_grid**2)\n", - "\n", - " grids = {\n", - " \"wage\": jnp.linspace(1, 10, len_grid),\n", - " \"working\": jnp.linspace(0, 24, len_grid),\n", - " }\n", - "\n", - " decorated_func_jit(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**grids).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "grid_space = [10, 25, 50] + np.arange(100, 1100, 100).tolist()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [ - "n_points_2_grids, runtime_numba_2d = get_numba_runtime_2d()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax_2d = get_jax_runtime_2d()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points_2_grids, runtime_numba_2d, runtime_jax_2d)" - ] - }, - { - "cell_type": "markdown", - "id": "14", - "metadata": {}, - "source": [ - "## *product_map*: 4 grids benchmark, more complex computations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "def _utility_4d(_consumption_4d, _leisure_4d):\n", - " return _consumption_4d + _leisure_4d\n", - "\n", - "\n", - "def _leisure_4d(d):\n", - " return jnp.cos(d)\n", - "\n", - "\n", - "def _consumption_4d(a, b, c):\n", - " return jnp.log(a) + jnp.sqrt(jnp.square(b) * c)\n", - "\n", - "\n", - "@njit(fastmath=True)\n", - "def product_map_numba_4d(a, b, c, d):\n", - " n_a = a.shape[0]\n", - " n_b = b.shape[0]\n", - " n_c = c.shape[0]\n", - " n_d = d.shape[0]\n", - "\n", - " cross_product = np.empty((n_a, n_b, n_c, n_d))\n", - "\n", - " for i in range(n_a):\n", - " for j in range(n_b):\n", - " for k in range(n_c):\n", - " for l in range(n_d):\n", - " cross_product[i, j, k, l] = (\n", - " np.log(a[i]) + np.sqrt(b[j] ** 2 * c[k]) + np.cos(d[l])\n", - " )\n", - "\n", - " return cross_product" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "grids = {\n", - " \"a\": jnp.linspace(1, 10, 101),\n", - " \"b\": jnp.linspace(0, 24, 25),\n", - " \"c\": jnp.linspace(1, 5, 5),\n", - " \"d\": jnp.linspace(-7, 2, 10),\n", - "}\n", - "\n", - "a = np.linspace(1, 10, 101)\n", - "b = np.linspace(0, 24, 25)\n", - "c = np.linspace(1, 5, 5)\n", - "d = np.linspace(-7, 2, 10)\n", - "\n", - "utility_concat_4d = concatenate_functions(\n", - " functions=[\n", - " _utility_4d,\n", - " _leisure_4d,\n", - " _consumption_4d,\n", - " ],\n", - " targets=\"_utility_4d\",\n", - ")\n", - "\n", - "_decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**grids)\n", - "\n", - "rslt_numba = product_map_numba_4d(a, b, c, d)\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime_4d(len_grids_arr):\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " a = np.linspace(1, 10, len_a)\n", - " b = np.linspace(0, 24, len_b)\n", - " c = np.linspace(1, 5, len_c)\n", - " d = np.linspace(-7, 2, len_d)\n", - "\n", - " product_map_numba_4d(a, b, c, d)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o product_map_numba_4d(a, b, c, d)\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime_4d(len_grids_arr):\n", - " utility_concat_4d = concatenate_functions(\n", - " functions=[\n", - " _utility_4d,\n", - " _leisure_4d,\n", - " _consumption_4d,\n", - " ],\n", - " targets=\"_utility_4d\",\n", - " )\n", - "\n", - " _decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " grids = {\n", - " \"a\": jnp.linspace(1, 10, len_a),\n", - " \"b\": jnp.linspace(0, 24, len_b),\n", - " \"c\": jnp.linspace(1, 5, len_c),\n", - " \"d\": jnp.linspace(-7, 2, len_d),\n", - " }\n", - "\n", - " decorated_func_jit(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**grids).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "len_grids_arr = np.vstack(\n", - " (\n", - " [5, 10, 20, 60, 100],\n", - " np.linspace(5, 25, 5),\n", - " np.linspace(1, 5, 5),\n", - " np.linspace(2, 10, 5),\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "n_points_4d, runtime_numba_4d = get_numba_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax_4d = get_jax_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points_4d, runtime_numba_4d, runtime_jax_4d)" - ] - }, - { - "cell_type": "markdown", - "id": "21", - "metadata": {}, - "source": [ - "## *gridmap*: 4 grids benchmark" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": {}, - "outputs": [], - "source": [ - "@njit(fastmath=True)\n", - "def state_space_map_numba(a, b, c, d):\n", - " n_a = a.shape[0]\n", - " n_b = b.shape[0]\n", - " n_c = c.shape[0]\n", - "\n", - " state_space_product = np.empty((n_a, n_b, n_c))\n", - "\n", - " for i in range(n_a):\n", - " for j in range(n_b):\n", - " for k in range(n_c):\n", - " state_space_product[i, j, k] = (\n", - " np.log(a[i]) + np.sqrt(b[j] ** 2 * c[k]) + np.cos(d[k])\n", - " )\n", - "\n", - " return state_space_product" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [ - "# JAX\n", - "dense_variables = {\n", - " \"a\": jnp.linspace(1, 10, 101),\n", - " \"b\": jnp.linspace(0, 24, 25),\n", - "}\n", - "combined_grids = {\"c\": jnp.linspace(1, 5, 5), \"d\": jnp.linspace(-7, 2, 10)}\n", - "helper = jnp.array(list(itertools.product(*combined_grids.values()))).T\n", - "\n", - "contingent_variables = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - "}\n", - "\n", - "# Numpy\n", - "dense_variables_numpy = {\n", - " \"a\": np.linspace(1, 10, 101),\n", - " \"b\": np.linspace(0, 24, 25),\n", - "}\n", - "combined_grids_numpy = {\"c\": np.linspace(1, 5, 5), \"d\": np.linspace(-7, 2, 10)}\n", - "helper = np.array(list(itertools.product(*combined_grids_numpy.values()))).T\n", - "\n", - "contingent_variables_numpy = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - "}\n", - "all_grids = {**dense_variables_numpy, **contingent_variables_numpy}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "utility_concat_4d = concatenate_functions(\n", - " functions=[\n", - " _utility_4d,\n", - " _leisure_4d,\n", - " _consumption_4d,\n", - " ],\n", - " targets=\"_utility_4d\",\n", - ")\n", - "\n", - "_decorated_func = gridmap(\n", - " utility_concat_4d,\n", - " list(dense_variables),\n", - " list(contingent_variables),\n", - ")\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**dense_variables, **contingent_variables)\n", - "\n", - "rslt_numba = state_space_map_numba(*all_grids.values())\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime(len_grids_arr):\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " dense_variables = {\n", - " \"a\": np.linspace(1, 10, len_a),\n", - " \"b\": np.linspace(0, 24, len_b),\n", - " }\n", - " complex_grids = {\"c\": np.linspace(1, 5, len_c), \"d\": np.linspace(-7, 2, len_d)}\n", - " helper = np.array(list(itertools.product(*complex_grids.values()))).T\n", - "\n", - " complex_variables = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - " }\n", - " all_grids = {**dense_variables, **complex_variables}\n", - "\n", - " state_space_map_numba(*all_grids.values())\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o state_space_map_numba(*all_grids.values())\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime(len_grids_arr):\n", - " utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", - " targets=\"_utility_4d\",\n", - " )\n", - "\n", - " _decorated_func = gridmap(utility_concat_4d, [\"a\", \"b\"], [\"c\", \"d\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " dense_variables = {\n", - " \"a\": jnp.linspace(1, 10, len_a),\n", - " \"b\": jnp.linspace(0, 24, len_b),\n", - " }\n", - " complex_grids = {\n", - " \"c\": jnp.linspace(1, 5, len_c),\n", - " \"d\": jnp.linspace(-7, 2, len_d),\n", - " }\n", - " helper = jnp.array(list(itertools.product(*complex_grids.values()))).T\n", - "\n", - " complex_variables = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - " }\n", - "\n", - " decorated_func_jit(**dense_variables, **complex_variables)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**dense_variables, **complex_variables).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [ - "n_points, runtime_numba = get_numba_runtime(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax = get_jax_runtime(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points, runtime_numba, runtime_jax, func=\"state_space_map\")" - ] - }, - { - "cell_type": "markdown", - "id": "29", - "metadata": {}, - "source": [ - "## 4d Benchmark with multiple targets and product_map" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": {}, - "outputs": [], - "source": [ - "def _utility_4d(_consumption_4d, _leisure_4d):\n", - " return _consumption_4d + _leisure_4d\n", - "\n", - "\n", - "def _leisure_4d(d):\n", - " return jnp.cos(d)\n", - "\n", - "\n", - "def _consumption_4d(a, b, c):\n", - " return jnp.log(a) + jnp.sqrt(jnp.square(b) * c)\n", - "\n", - "\n", - "@njit(fastmath=True)\n", - "def product_map_numba_4d(a, b, c, d):\n", - " n_a = a.shape[0]\n", - " n_b = b.shape[0]\n", - " n_c = c.shape[0]\n", - " n_d = d.shape[0]\n", - "\n", - " out_utility = np.empty((n_a, n_b, n_c, n_d))\n", - " out_leisure = np.empty((n_a, n_b, n_c, n_d))\n", - "\n", - " for i in range(n_a):\n", - " for j in range(n_b):\n", - " for k in range(n_c):\n", - " for l in range(n_d):\n", - " _leis = np.cos(d[l])\n", - "\n", - " out_leisure[i, j, k, l] = _leis\n", - " out_utility[i, j, k, l] = (\n", - " np.log(a[i]) + np.sqrt(b[j] ** 2 * c[k]) + _leis\n", - " )\n", - "\n", - " return out_leisure, out_utility" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31", - "metadata": {}, - "outputs": [], - "source": [ - "grids = {\n", - " \"a\": jnp.linspace(1, 10, 101),\n", - " \"b\": jnp.linspace(0, 24, 25),\n", - " \"c\": jnp.linspace(1, 5, 5),\n", - " \"d\": jnp.linspace(-7, 2, 10),\n", - "}\n", - "\n", - "a = np.linspace(1, 10, 101)\n", - "b = np.linspace(0, 24, 25)\n", - "c = np.linspace(1, 5, 5)\n", - "d = np.linspace(-7, 2, 10)\n", - "\n", - "utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", - " targets=[\"_leisure_4d\", \"_utility_4d\"],\n", - ")\n", - "\n", - "_decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**grids)\n", - "\n", - "rslt_numba = product_map_numba_4d(a, b, c, d)\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime_4d(len_grids_arr):\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " a = np.linspace(1, 10, len_a)\n", - " b = np.linspace(0, 24, len_b)\n", - " c = np.linspace(1, 5, len_c)\n", - " d = np.linspace(-7, 2, len_d)\n", - "\n", - " product_map_numba_4d(a, b, c, d)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o product_map_numba_4d(a, b, c, d)\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime_4d(len_grids_arr):\n", - " utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", - " targets=\"_utility_4d\",\n", - " )\n", - "\n", - " _decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " grids = {\n", - " \"a\": jnp.linspace(1, 10, len_a),\n", - " \"b\": jnp.linspace(0, 24, len_b),\n", - " \"c\": jnp.linspace(1, 5, len_c),\n", - " \"d\": jnp.linspace(-7, 2, len_d),\n", - " }\n", - "\n", - " decorated_func_jit(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**grids).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "len_grids_arr = np.vstack(\n", - " (\n", - " [5, 10, 20, 60, 100],\n", - " np.linspace(5, 25, 5),\n", - " np.linspace(1, 5, 5),\n", - " np.linspace(2, 10, 5),\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": {}, - "outputs": [], - "source": [ - "n_points_4d, runtime_numba_4d = get_numba_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax_4d = get_jax_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points_4d, runtime_numba_4d, runtime_jax_4d)" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "4ad1f9a7057235603a8e0891a7ecf17fdcab16a40a59f91c8779290453820c05" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/lcm/simulation/processing.py b/src/lcm/simulation/processing.py new file mode 100644 index 00000000..a14c2248 --- /dev/null +++ b/src/lcm/simulation/processing.py @@ -0,0 +1,120 @@ +import inspect + +import jax.numpy as jnp +import pandas as pd +from dags import concatenate_functions +from jax import Array + +from lcm.dispatchers import vmap_1d +from lcm.interfaces import InternalModel, InternalSimulationPeriodResults +from lcm.typing import InternalUserFunction, ParamsDict + + +def process_simulated_data( + results: dict[int, InternalSimulationPeriodResults], + model: InternalModel, + params: ParamsDict, + additional_targets: list[str] | None = None, +) -> dict[str, Array]: + """Process and flatten the simulation results. + + This function produces a dict of arrays for each var with dimension (n_periods * + n_initial_states,). The arrays are flattened, so that the resulting dictionary has a + one-dimensional array for each variable. The length of this array is the number of + periods times the number of initial states. The order of array elements is given by + an outer level of periods and an inner level of initial states ids. + + Args: + results: Dict with simulation results. Each dict contains the value, + choices, and states for one period. Choices and states are stored in a + nested dictionary. + model: Model. + params: Parameters. + additional_targets: List of additional targets to compute. + + Returns: + Dict with processed simulation results. The keys are the variable names and the + values are the flattened arrays, with dimension (n_periods * n_initial_states,). + Additionally, the _period variable is added. + + """ + n_periods = len(results) + n_initial_states = len(results[0].value) + + list_of_dicts = [ + {"value": d.value, **d.choices, **d.states} for d in results.values() + ] + dict_of_lists = { + key: [d[key] for d in list_of_dicts] for key in list(list_of_dicts[0]) + } + out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()} + out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states) + + if additional_targets is not None: + calculated_targets = _compute_targets( + out, + targets=additional_targets, + model_functions=model.functions, + params=params, + ) + out = {**out, **calculated_targets} + + return out + + +def as_data_frame(processed: dict[str, Array], n_periods: int) -> pd.DataFrame: + """Convert processed simulation results to DataFrame. + + Args: + processed: Dict with processed simulation results. + n_periods: Number of periods. + + Returns: + DataFrame with the simulation results. The index is a multi-index with the first + level corresponding to the period and the second level corresponding to the + initial state id. The columns correspond to the value, and the choice and state + variables, and potentially auxiliary variables. + + """ + n_initial_states = len(processed["value"]) // n_periods + index = pd.MultiIndex.from_product( + [range(n_periods), range(n_initial_states)], + names=["period", "initial_state_id"], + ) + return pd.DataFrame(processed, index=index) + + +def _compute_targets( + processed_results: dict[str, Array], + targets: list[str], + model_functions: dict[str, InternalUserFunction], + params: ParamsDict, +) -> dict[str, Array]: + """Compute targets. + + Args: + processed_results: Dict with processed simulation results. Values must be + one-dimensional arrays. + targets: List of targets to compute. + model_functions: Dict with model functions. + params: Dict with model parameters. + + Returns: + Dict with computed targets. + + """ + target_func = concatenate_functions( + functions=model_functions, + targets=targets, + return_type="dict", + ) + + # get list of variables over which we want to vectorize the target function + variables = tuple( + p for p in list(inspect.signature(target_func).parameters) if p != "params" + ) + + target_func = vmap_1d(target_func, variables=variables) + + kwargs = {k: v for k, v in processed_results.items() if k in variables} + return target_func(params=params, **kwargs) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 758e50b3..75ebe461 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -1,93 +1,108 @@ -import inspect import logging from collections.abc import Callable from functools import partial -from typing import Any import jax import jax.numpy as jnp import pandas as pd -from dags import concatenate_functions from jax import Array, vmap -from lcm.argmax import argmax -from lcm.dispatchers import spacemap, vmap_1d -from lcm.interfaces import InternalModel, StateChoiceSpace -from lcm.typing import InternalUserFunction, ParamsDict +from lcm.discrete_problem import get_solve_discrete_problem_policy +from lcm.dispatchers import simulation_spacemap, vmap_1d +from lcm.interfaces import ( + InternalModel, + InternalSimulationPeriodResults, + StateChoiceSpace, +) +from lcm.random import draw_random_seed, generate_simulation_keys +from lcm.simulation.processing import as_data_frame, process_simulated_data +from lcm.state_choice_space import create_state_choice_space +from lcm.typing import ParamsDict + + +def solve_and_simulate( + params: ParamsDict, + initial_states: dict[str, Array], + compute_ccv_policy_functions: dict[int, Callable[..., tuple[Array, Array]]], + model: InternalModel, + next_state: Callable[..., dict[str, Array]], + logger: logging.Logger, + solve_model: Callable[..., dict[int, Array]], + *, + additional_targets: list[str] | None = None, + seed: int | None = None, +) -> pd.DataFrame: + """First solve the model and then simulate the model forward in time. + + Same docstring as `simulate` mutatis mutandis. + + """ + vf_arr_dict = solve_model(params) + return simulate( + params=params, + initial_states=initial_states, + compute_ccv_policy_functions=compute_ccv_policy_functions, + model=model, + next_state=next_state, + logger=logger, + vf_arr_dict=vf_arr_dict, + additional_targets=additional_targets, + seed=seed, + ) def simulate( params: ParamsDict, initial_states: dict[str, Array], - continuous_choice_grids: list[dict[str, Array]], - compute_ccv_policy_functions: list[Callable[..., tuple[Array, Array]]], + compute_ccv_policy_functions: dict[int, Callable[..., tuple[Array, Array]]], model: InternalModel, next_state: Callable[..., dict[str, Array]], logger: logging.Logger, - solve_model: Callable[..., list[Array]] | None = None, - pre_computed_vf_arr_list: list[Array] | None = None, + vf_arr_dict: dict[int, Array], + *, additional_targets: list[str] | None = None, - seed: int = 12345, + seed: int | None = None, ) -> pd.DataFrame: - """Simulate the model forward in time. + """Simulate the model forward in time given pre-computed value function arrays. Args: params: Dict of model parameters. initial_states: List of initial states to start from. Typically from the observed dataset. - continuous_choice_grids: List of dicts of length n_periods. Each dict - contains 1d grids for continuous choice variables. - compute_ccv_policy_functions: List of functions of length n_periods. Each - function computes the conditional continuation value dependent on the - discrete choices. + compute_ccv_policy_functions: Dict of length n_periods. Each function computes + the conditional continuation value dependent on the discrete choices. next_state: Function that returns the next state given the current state and choice variables. For stochastic variables, it returns a random draw from the distribution of the next state. model: Model instance. logger: Logger that logs to stdout. - solve_model: Function that solves the model. Is only required if - vf_arr_list is not provided. - pre_computed_vf_arr_list: List of value function arrays of length - n_periods. This is the output of the model's `solve` function. If not - provided, the model is solved first. + vf_arr_dict: Dict of value function arrays of length n_periods. additional_targets: List of targets to compute. If provided, the targets are computed and added to the simulation results. - seed: Random number seed; will be passed to `jax.random.key`. + seed: Random number seed; will be passed to `jax.random.key`. If not provided, + a random seed will be generated. Returns: DataFrame with the simulation results. """ - if pre_computed_vf_arr_list is None: - if solve_model is None: - raise ValueError( - "You need to provide either vf_arr_list or solve_model.", - ) - # We do not need to convert the params here, because the solve_model function - # will do it. - vf_arr_list = solve_model(params) - else: - vf_arr_list = pre_computed_vf_arr_list + if seed is None: + seed = draw_random_seed() logger.info("Starting simulation") - # Update the vf_arr_list - # ---------------------------------------------------------------------------------- - # We drop the value function array for the first period, because it is not needed - # for the simulation. This is because in the first period the agents only consider - # the current utility and the value function of next period. Similarly, the last - # value function array is not required, as the agents only consider the current - # utility in the last period. - # ================================================================================== - vf_arr_list = vf_arr_list[1:] + [jnp.empty(0)] - # Preparations - # ================================================================================== - n_periods = len(vf_arr_list) + # ---------------------------------------------------------------------------------- + n_periods = len(vf_arr_dict) n_initial_states = len(next(iter(initial_states.values()))) - discrete_policy_calculator = get_discrete_policy_calculator( - variable_info=model.variable_info, + state_choice_space = create_state_choice_space( + model=model, + initial_states=initial_states, + ) + + discrete_policy_calculator = get_solve_discrete_problem_policy( + variable_info=model.variable_info ) # The following variables are updated during the forward simulation @@ -95,124 +110,111 @@ def simulate( key = jax.random.key(seed=seed) # Forward simulation - # ================================================================================== - _simulation_results = [] + # ---------------------------------------------------------------------------------- + simulation_results = {} for period in range(n_periods): - # Create data state choice space - # ------------------------------------------------------------------------------ - # Initial states are treated as combination variables, so that the combination - # variables in the data-state-choice-space correspond to the feasible product - # of combination variables and initial states. The space has to be created in - # each iteration because the states change over time. - # ============================================================================== - data_scs = create_data_scs( - states=states, - model=model, - ) + state_choice_space = state_choice_space.replace(states) - # Compute objects dependent on data-state-choice-space - # ============================================================================== - choices_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) - cont_choices_grid_shape = tuple( - len(grid) for grid in continuous_choice_grids[period].values() + # We compute these grid shapes in the loop because they can change over time. + # TODO (@timmens): This could still be pre-computed in the entry point. # noqa: TD003,E501 + discrete_choices_grid_shape = tuple( + len(grid) for grid in state_choice_space.discrete_choices.values() + ) + continuous_choices_grid_shape = tuple( + len(grid) for grid in state_choice_space.continuous_choices.values() ) # Compute optimal continuous choice conditional on discrete choices - # ============================================================================== - ccv_policy, ccv = solve_continuous_problem( - data_scs=data_scs, - compute_ccv=compute_ccv_policy_functions[period], - continuous_choice_grids=continuous_choice_grids[period], - vf_arr=vf_arr_list[period], - params=params, + # ------------------------------------------------------------------------------ + # We need to pass the value function array of the next period to the continuous + # choice problem solver. If we are at the last period, we pass an empty array. + next_period_vf_arr = vf_arr_dict.get(period + 1, jnp.empty(0)) + + conditional_continuous_choice_argmax, conditional_continuous_choice_max = ( + solve_continuous_problem( + data_scs=state_choice_space, + compute_ccv=compute_ccv_policy_functions[period], + vf_arr=next_period_vf_arr, + params=params, + ) ) # Get optimal discrete choice given the optimal conditional continuous choices - # ============================================================================== - discrete_argmax, value = discrete_policy_calculator(ccv) + # ------------------------------------------------------------------------------ + discrete_argmax, choice_value = discrete_policy_calculator( + conditional_continuous_choice_max, params=params + ) - # Select optimal continuous choice corresponding to optimal discrete choice + # Get optimal continuous choice index given optimal discrete choice # ------------------------------------------------------------------------------ - # The conditional continuous choice argmax is computed for each discrete choice - # in the data-state-choice-space. Here we select the the optimal continuous - # choice corresponding to the optimal discrete choice. - # ============================================================================== - cont_choice_argmax = filter_ccv_policy( - ccv_policy=ccv_policy, + continuous_choice_argmax = get_continuous_choice_argmax_given_discrete( + conditional_continuous_choice_argmax=conditional_continuous_choice_argmax, discrete_argmax=discrete_argmax, - vars_grid_shape=choices_grid_shape, + discrete_choices_grid_shape=discrete_choices_grid_shape, ) - # Convert optimal choice indices to actual choice values - # ============================================================================== - choices = retrieve_choices( + # Convert choice indices to choice values + # ------------------------------------------------------------------------------ + discrete_choices = get_values_from_indices( flat_indices=discrete_argmax, - grids=data_scs.choices, - grids_shapes=choices_grid_shape, + grids=state_choice_space.discrete_choices, + grids_shapes=discrete_choices_grid_shape, ) - cont_choices = retrieve_choices( - flat_indices=cont_choice_argmax, - grids=continuous_choice_grids[period], - grids_shapes=cont_choices_grid_shape, + continuous_choices = get_values_from_indices( + flat_indices=continuous_choice_argmax, + grids=state_choice_space.continuous_choices, + grids_shapes=continuous_choices_grid_shape, ) # Store results - # ============================================================================== - choices = {**choices, **cont_choices} - - _simulation_results.append( - { - "value": value, - "choices": choices, - "states": states, - }, + # ------------------------------------------------------------------------------ + choices = {**discrete_choices, **continuous_choices} + + simulation_results[period] = InternalSimulationPeriodResults( + value=choice_value, + choices=choices, + states=states, ) # Update states - # ============================================================================== - key, sim_keys = _generate_simulation_keys( + # ------------------------------------------------------------------------------ + key, stochastic_variables_keys = generate_simulation_keys( key=key, ids=model.function_info.query("is_stochastic_next").index.tolist(), ) - states = next_state( + states_with_prefix = next_state( **states, **choices, _period=jnp.repeat(period, n_initial_states), params=params, - keys=sim_keys, + keys=stochastic_variables_keys, ) - # 'next_' prefix is added by the next_state function, but needs to be removed - # because in the next period, next states are current states. - states = {k.removeprefix("next_"): v for k, v in states.items()} + # because in the next period, next states will be current states. + states = {k.removeprefix("next_"): v for k, v in states_with_prefix.items()} logger.info("Period: %s", period) - processed = _process_simulated_data(_simulation_results) - - if additional_targets is not None: - calculated_targets = _compute_targets( - processed, - targets=additional_targets, - model_functions=model.functions, - params=params, - ) - processed = {**processed, **calculated_targets} + processed = process_simulated_data( + simulation_results, + model=model, + params=params, + additional_targets=additional_targets, + ) - return _as_data_frame(processed, n_periods=n_periods) + return as_data_frame(processed, n_periods=n_periods) def solve_continuous_problem( data_scs: StateChoiceSpace, compute_ccv: Callable[..., tuple[Array, Array]], - continuous_choice_grids: dict[str, Array], vf_arr: Array, params: ParamsDict, ) -> tuple[Array, Array]: - """Solve the agent's continuous choices problem problem. + """Solve the agents' continuous choice problem. Args: data_scs: Class with entries choices and states. @@ -223,8 +225,6 @@ def solve_continuous_problem( - discrete and continuous choice variables - vf_arr - params - continuous_choice_grids: List of dicts with 1d grids for continuous - choice variables. vf_arr: Value function array. params: Dict of model parameters. @@ -236,195 +236,50 @@ def solve_continuous_problem( `gridmap` function. """ - _gridmapped = spacemap( + _gridmapped = simulation_spacemap( func=compute_ccv, - product_vars=tuple(data_scs.choices), - combination_vars=tuple(data_scs.states), + choices_var_names=tuple(data_scs.discrete_choices), + states_var_names=tuple(data_scs.states), ) gridmapped = jax.jit(_gridmapped) return gridmapped( - **data_scs.choices, **data_scs.states, - **continuous_choice_grids, + **data_scs.discrete_choices, + **data_scs.continuous_choices, vf_arr=vf_arr, params=params, ) -# ====================================================================================== -# Output processing -# ====================================================================================== - - -def _as_data_frame(processed: dict[str, Array], n_periods: int) -> pd.DataFrame: - """Convert processed simulation results to DataFrame. - - Args: - processed: Dict with processed simulation results. - n_periods: Number of periods. - - Returns: - DataFrame with the simulation results. The index is a multi-index with the first - level corresponding to the period and the second level corresponding to the - initial state id. The columns correspond to the value, and the choice and state - variables, and potentially auxiliary variables. - - """ - n_initial_states = len(processed["value"]) // n_periods - index = pd.MultiIndex.from_product( - [range(n_periods), range(n_initial_states)], - names=["period", "initial_state_id"], - ) - return pd.DataFrame(processed, index=index) - - -def _compute_targets( - processed_results: dict[str, Array], - targets: list[str], - model_functions: dict[str, InternalUserFunction], - params: ParamsDict, -) -> dict[str, Array]: - """Compute targets. - - Args: - processed_results: Dict with processed simulation results. Values must be - one-dimensional arrays. - targets: List of targets to compute. - model_functions: Dict with model functions. - params: Dict with model parameters. - - Returns: - Dict with computed targets. - - """ - target_func = concatenate_functions( - functions=model_functions, - targets=targets, - return_type="dict", - ) - - # get list of variables over which we want to vectorize the target function - variables = tuple( - p for p in list(inspect.signature(target_func).parameters) if p != "params" - ) - - target_func = vmap_1d(target_func, variables=variables) - - kwargs = {k: v for k, v in processed_results.items() if k in variables} - return target_func(params=params, **kwargs) - - -def _process_simulated_data(results: list[dict[str, Any]]) -> dict[str, Array]: - """Process and flatten the simulation results. - - This function produces a dict of arrays for each var with dimension (n_periods * - n_initial_states,). The arrays are flattened, so that the resulting dictionary has a - one-dimensional array for each variable. The length of this array is the number of - periods times the number of initial states. The order of array elements is given by - an outer level of periods and an inner level of initial states ids. - - - Args: - results (list): List of dicts with simulation results. Each dict contains the - value, choices, and states for one period. Choices and states are stored in - a nested dictionary. - - Returns: - Dict with processed simulation results. The keys are the variable names and the - values are the flattened arrays, with dimension (n_periods * n_initial_states,). - Additionally, the _period variable is added. - - """ - n_periods = len(results) - n_initial_states = len(results[0]["value"]) - - list_of_dicts = [ - {"value": d["value"], **d["choices"], **d["states"]} for d in results - ] - dict_of_lists = { - key: [d[key] for d in list_of_dicts] for key in list(list_of_dicts[0]) - } - out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()} - out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states) - - return out - - -# ====================================================================================== -# Simulation keys -# ====================================================================================== - - -def _generate_simulation_keys( - key: Array, ids: list[str] -) -> tuple[Array, dict[str, Array]]: - """Generate pseudo-random number generator keys (PRNG keys) for simulation. - - PRNG keys in JAX are immutable objects used to control random number generation. - A key can be used to generate a stream of random numbers, e.g., given a key, one can - call jax.random.normal(key) to generate a stream of normal random numbers. In order - to ensure that each simulation is based on a different stream of random numbers, we - split the key into one key per simulation unit, and one key that will be passed to - the next iteration in order to generate new keys. - - See the JAX documentation for more details: - https://docs.jax.dev/en/latest/random-numbers.html#random-numbers-in-jax - - Args: - key: PRNG key. - ids: List of names for which a key is to be generated. - - Returns: - - Updated PRNG key. - - Dict with PRNG keys for each id in ids. - - """ - keys = jax.random.split(key, num=len(ids) + 1) - - key = keys[0] - simulation_keys = dict(zip(ids, keys[1:], strict=True)) - - return key, simulation_keys - - -# ====================================================================================== -# Filter policy -# ====================================================================================== - - -@partial(vmap_1d, variables=("ccv_policy", "discrete_argmax")) -def filter_ccv_policy( - ccv_policy: Array, +@partial(vmap_1d, variables=("conditional_continuous_choice_argmax", "discrete_argmax")) +def get_continuous_choice_argmax_given_discrete( + conditional_continuous_choice_argmax: Array, discrete_argmax: Array, - vars_grid_shape: tuple[int, ...], + discrete_choices_grid_shape: tuple[int, ...], ) -> Array: """Select optimal continuous choice index given optimal discrete choice. Args: - ccv_policy: Index array of optimal continous choices + conditional_continuous_choice_argmax: Index array of optimal continous choices conditional on discrete choices. discrete_argmax: Index array of optimal discrete choices. - vars_grid_shape: Shape of the variables grid. + discrete_choices_grid_shape: Shape of the discrete choices grid. Returns: Index array of optimal continuous choices. """ - if discrete_argmax is None: - out = ccv_policy - else: - indices = jnp.unravel_index(discrete_argmax, shape=vars_grid_shape) - out = ccv_policy[indices] - return out + indices = jnp.unravel_index(discrete_argmax, shape=discrete_choices_grid_shape) + return conditional_continuous_choice_argmax[indices] -def retrieve_choices( +def get_values_from_indices( flat_indices: Array, grids: dict[str, Array], grids_shapes: tuple[int, ...], ) -> dict[str, Array]: - """Retrieve choices given flat indices. + """Retrieve values from indices. Args: flat_indices: General indices. Represents the index of the flattened grid. @@ -432,7 +287,7 @@ def retrieve_choices( grids_shapes: Shape of the grids. Is used to unravel the index. Returns: - Dictionary of choices. + Dictionary of values. """ nd_indices = vmapped_unravel_index(flat_indices, grids_shapes) @@ -445,110 +300,3 @@ def retrieve_choices( # vmap jnp.unravel_index over the first axis of the `indices` argument, while holding # the `shape` argument constant (in_axes = (0, None)). vmapped_unravel_index = vmap(jnp.unravel_index, in_axes=(0, None)) - - -# ====================================================================================== -# Data State Choice Space -# ====================================================================================== - - -def create_data_scs( - states: dict[str, Array], - model: InternalModel, -) -> StateChoiceSpace: - """Create data state choice space. - - Args: - states: Dict with initial states. - model: Model instance. - - Returns: - Data state choice space. - - """ - # preparations - # ================================================================================== - vi = model.variable_info - - # check that all states have an initial value - # ================================================================================== - state_names = set(vi.query("is_state").index) - - if state_names != set(states.keys()): - missing = state_names - set(states.keys()) - too_many = set(states.keys()) - state_names - raise ValueError( - "You need to provide an initial value for each state variable in the model." - f"\n\nMissing initial states: {missing}\n", - f"Provided variables that are not states: {too_many}", - ) - - # get choices - # ================================================================================== - choices = { - name: grid - for name, grid in model.grids.items() - if name in vi.query("is_choice & is_discrete").index.tolist() - } - - return StateChoiceSpace( - states=states, - choices=choices, - ordered_var_names=tuple(vi.query("is_state | is_discrete").index.tolist()), - ) - - -# ====================================================================================== -# Discrete policy -# ====================================================================================== - - -def get_discrete_policy_calculator( - variable_info: pd.DataFrame, -) -> Callable[..., tuple[Array, Array]]: - """Return a function that calculates the argmax and max of continuation values. - - The argmax is taken over the discrete choice variables in each state. - - Args: - variable_info (pd.DataFrame): DataFrame with information about the model - variables. - - Returns: - callable: Function that calculates the argmax of the conditional continuation - values. The function depends on: - - values (jax.Array): Multidimensional jax array with conditional - continuation values. - - """ - choice_axes = determine_discrete_choice_axes(variable_info) - - def _calculate_discrete_argmax( - values: Array, choice_axes: tuple[int, ...] - ) -> tuple[Array, Array]: - return argmax(values, axis=choice_axes) - - return partial(_calculate_discrete_argmax, choice_axes=choice_axes) - - -def determine_discrete_choice_axes(variable_info: pd.DataFrame) -> tuple[int, ...]: - """Determine which axes correspond to discrete choices. - - Args: - variable_info (pd.DataFrame): DataFrame with information about the variables. - - Returns: - tuple: Tuple of ints, specifying which axes in a value function correspond to - discrete choices. - - """ - discrete_choice_vars = variable_info.query( - "is_choice & is_discrete", - ).index.tolist() - - choice_vars = set(variable_info.query("is_choice").index.tolist()) - - # The first dimension corresponds to the simulated states, so add 1. - return tuple( - 1 + i for i, ax in enumerate(discrete_choice_vars) if ax in choice_vars - ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 167beb9b..9b27c8af 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -4,19 +4,18 @@ import jax from jax import Array -from lcm.dispatchers import spacemap +from lcm.dispatchers import productmap from lcm.interfaces import StateChoiceSpace -from lcm.typing import DiscreteProblemSolverFunction, ParamsDict +from lcm.typing import DiscreteProblemValueSolverFunction, ParamsDict def solve( params: ParamsDict, - state_choice_spaces: list[StateChoiceSpace], - continuous_choice_grids: list[dict[str, Array]], - compute_ccv_functions: list[Callable[[Array, Array], Array]], - emax_calculators: list[DiscreteProblemSolverFunction], + state_choice_spaces: dict[int, StateChoiceSpace], + compute_ccv_functions: dict[int, Callable[[Array, Array], Array]], + emax_calculators: dict[int, DiscreteProblemValueSolverFunction], logger: logging.Logger, -) -> list[Array]: +) -> dict[int, Array]: """Solve a model by brute force. Notes: @@ -30,10 +29,8 @@ def solve( Args: params: Dict of model parameters. - state_choice_spaces: List with one state_choice_space per period. - continuous_choice_grids: List of dicts with 1d grids for continuous - choice variables. - compute_ccv_functions: List of functions needed to solve the agent's + state_choice_spaces: Dict with one state_choice_space per period. + compute_ccv_functions: Dict with one function needed to solve the agent's problem. Each function depends on: - discrete and continuous state variables - discrete and continuous choice variables @@ -45,12 +42,11 @@ def solve( logger: Logger that logs to stdout. Returns: - List with one value function array per period. + Dict with one value function array per period. """ - # extract information n_periods = len(state_choice_spaces) - reversed_solution = [] + solution = {} vf_arr = None logger.info("Starting solution") @@ -61,7 +57,6 @@ def solve( conditional_continuation_values = solve_continuous_problem( state_choice_space=state_choice_spaces[period], compute_ccv=compute_ccv_functions[period], - continuous_choice_grids=continuous_choice_grids[period], vf_arr=vf_arr, params=params, ) @@ -69,17 +64,16 @@ def solve( # solve discrete problem by calculating expected maximum over discrete choices calculate_emax = emax_calculators[period] vf_arr = calculate_emax(conditional_continuation_values, params=params) - reversed_solution.append(vf_arr) + solution[period] = vf_arr logger.info("Period: %s", period) - return list(reversed(reversed_solution)) + return solution def solve_continuous_problem( state_choice_space: StateChoiceSpace, compute_ccv: Callable[..., Array], - continuous_choice_grids: dict[str, Array], vf_arr: Array | None, params: ParamsDict, ) -> Array: @@ -94,8 +88,6 @@ def solve_continuous_problem( - discrete and continuous choice variables - vf_arr - params - continuous_choice_grids: List of dicts with 1d grids for continuous - choice variables. vf_arr: Value function array. params: Dict of model parameters. @@ -105,17 +97,16 @@ def solve_continuous_problem( by the `gridmap` function. """ - _gridmapped = spacemap( + _gridmapped = productmap( func=compute_ccv, - product_vars=state_choice_space.ordered_var_names, - combination_vars=(), + variables=state_choice_space.ordered_var_names, ) gridmapped = jax.jit(_gridmapped) return gridmapped( **state_choice_space.states, - **state_choice_space.choices, - **continuous_choice_grids, + **state_choice_space.discrete_choices, + **state_choice_space.continuous_choices, vf_arr=vf_arr, params=params, ) diff --git a/src/lcm/solution/state_choice_space.py b/src/lcm/solution/state_choice_space.py deleted file mode 100644 index 5bdb9caf..00000000 --- a/src/lcm/solution/state_choice_space.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Create a state space for a given model.""" - -from lcm.interfaces import InternalModel, StateChoiceSpace, StateSpaceInfo - - -def create_state_choice_space( - model: InternalModel, - *, - is_last_period: bool, -) -> tuple[StateChoiceSpace, StateSpaceInfo]: - """Create a state-choice-space for the model solution. - - A state-choice-space is a compressed representation of all feasible states and the - feasible discrete choices within that state. - - Args: - model: A processed model. - is_last_period: Whether the function is created for the last period. - - Returns: - - An object containing the variable values of all variables in the - state-choice-space, the grid specifications for the state variables, and the - names of the state variables. Continuous choice variables are not included. - - The state-space information. - - """ - vi = model.variable_info - if is_last_period: - vi = vi.query("~is_auxiliary") - - discrete_states_names = vi.query("is_discrete & is_state").index.tolist() - continuous_states_names = vi.query("is_continuous & is_state").index.tolist() - - discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} - continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} - - state_grids = {sn: model.grids[sn] for sn in vi.query("is_state").index.tolist()} - choice_grids = { - sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index.tolist() - } - ordered_var_names = tuple(vi.query("is_state | is_discrete").index.tolist()) - - state_space_info = StateSpaceInfo( - states_names=tuple(discrete_states_names + continuous_states_names), - discrete_states=discrete_states, # type: ignore[arg-type] - continuous_states=continuous_states, # type: ignore[arg-type] - ) - - state_choice_space = StateChoiceSpace( - states=state_grids, - choices=choice_grids, - ordered_var_names=ordered_var_names, - ) - - return state_choice_space, state_space_info diff --git a/src/lcm/state_choice_space.py b/src/lcm/state_choice_space.py new file mode 100644 index 00000000..db3b6208 --- /dev/null +++ b/src/lcm/state_choice_space.py @@ -0,0 +1,117 @@ +"""Create a state space for a given model.""" + +import pandas as pd +from jax import Array + +from lcm.grids import ContinuousGrid, DiscreteGrid +from lcm.interfaces import InternalModel, StateChoiceSpace, StateSpaceInfo + + +def create_state_choice_space( + model: InternalModel, + *, + initial_states: dict[str, Array] | None = None, + is_last_period: bool = False, +) -> StateChoiceSpace: + """Create a state-choice-space. + + Creates the state-choice-space for the solution and simulation of a model. In the + simulation, initial states must be provided. + + Args: + model: A processed model. + initial_states: A dictionary with the initial values of the state variables. + If None, the initial values are the minimum values of the state variables. + is_last_period: Whether the state-choice-space is created for the last period, + in which case auxiliary variables are not included. + + Returns: + A state-choice-space. Contains the grids of the discrete and continuous choices, + the grids of the state variables, or the initial values of the state variables, + and the names of the state and choice variables in the order they appear in the + variable info table. + + """ + vi = model.variable_info + if is_last_period: + vi = vi.query("~is_auxiliary") + + if initial_states is None: + states = {sn: model.grids[sn] for sn in vi.query("is_state").index} + else: + _validate_initial_states(initial_states, variable_info=vi) + states = initial_states + + discrete_choices = { + name: model.grids[name] for name in vi.query("is_choice & is_discrete").index + } + continuous_choices = { + name: model.grids[name] for name in vi.query("is_choice & is_continuous").index + } + ordered_var_names = tuple(vi.query("is_state | is_discrete").index) + + return StateChoiceSpace( + states=states, + discrete_choices=discrete_choices, + continuous_choices=continuous_choices, + ordered_var_names=ordered_var_names, + ) + + +def create_state_space_info( + model: InternalModel, + *, + is_last_period: bool, +) -> StateSpaceInfo: + """Create a state-space information for the model solution. + + A state-space information is a compressed representation of all feasible states. + + Args: + model: A processed model. + is_last_period: Whether the function is created for the last period. + + Returns: + The state-space information. + + """ + vi = model.variable_info + if is_last_period: + vi = vi.query("~is_auxiliary") + + state_names = vi.query("is_state").index.tolist() + + discrete_states = { + name: grid_spec + for name, grid_spec in model.gridspecs.items() + if name in state_names and isinstance(grid_spec, DiscreteGrid) + } + + continuous_states = { + name: grid_spec + for name, grid_spec in model.gridspecs.items() + if name in state_names and isinstance(grid_spec, ContinuousGrid) + } + + return StateSpaceInfo( + states_names=tuple(state_names), + discrete_states=discrete_states, + continuous_states=continuous_states, + ) + + +def _validate_initial_states( + initial_states: dict[str, Array], variable_info: pd.DataFrame +) -> None: + """Checks if each model-state has an initial value.""" + states_names_in_model = set(variable_info.query("is_state").index) + provided_states_names = set(initial_states) + + if states_names_in_model != provided_states_names: + missing = states_names_in_model - provided_states_names + too_many = provided_states_names - states_names_in_model + raise ValueError( + "You need to provide an initial array for each state variable in the model." + f"\n\nMissing initial states: {missing}\n", + f"Provided variables that are not states: {too_many}", + ) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index abf3355a..a9894712 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -33,8 +33,8 @@ def __call__( # noqa: D102 ) -> Scalar: ... -class DiscreteProblemSolverFunction(Protocol): - """The function that solves the discrete problem. +class DiscreteProblemValueSolverFunction(Protocol): + """The function that solves for the value of the discrete problem. Only used for type checking. @@ -43,6 +43,26 @@ class DiscreteProblemSolverFunction(Protocol): def __call__(self, values: Array, params: ParamsDict) -> Array: ... # noqa: D102 +class DiscreteProblemPolicySolverFunction(Protocol): + """The function that solves for the policy of the discrete problem. + + Only used for type checking. + + """ + + def __call__(self, values: Array, params: ParamsDict) -> tuple[Array, Array]: ... # noqa: D102 + + +class StochasticNextFunction(Protocol): + """The function that simulates the next state of a stochastic variable. + + Only used for type checking. + + """ + + def __call__(self, keys: dict[str, Array], **kwargs: Array) -> Array: ... # noqa: D102 + + class ShockType(Enum): """Type of shocks.""" diff --git a/src/lcm/utility_and_feasibility.py b/src/lcm/utility_and_feasibility.py new file mode 100644 index 00000000..f86f1113 --- /dev/null +++ b/src/lcm/utility_and_feasibility.py @@ -0,0 +1,280 @@ +from collections.abc import Callable +from typing import Any + +import jax.numpy as jnp +from dags import concatenate_functions +from dags.signature import with_signature +from jax import Array + +from lcm.dispatchers import productmap +from lcm.function_representation import get_value_function_representation +from lcm.functools import get_union_of_arguments +from lcm.interfaces import InternalModel, StateSpaceInfo +from lcm.next_state import get_next_state_function, get_next_stochastic_weights_function +from lcm.typing import InternalUserFunction, ParamsDict, Scalar, Target + + +def get_utility_and_feasibility_function( + model: InternalModel, + next_state_space_info: StateSpaceInfo, + period: int, + *, + is_last_period: bool, +) -> Callable[..., tuple[Array, Array]]: + """Create the utility and feasibility function for a given period. + + Args: + model: The internal model object. + next_state_space_info: The state space information of the next period. + period: The period to create the utility and feasibility function for. + is_last_period: Whether the period is the last period. + + Returns: + A function that computes the expected forward-looking utility and feasibility + for the given period. + + """ + if is_last_period: + return get_utility_and_feasibility_function_last_period(model, period=period) + return get_utility_and_feasibility_function_before_last_period( + model, next_state_space_info=next_state_space_info, period=period + ) + + +def get_utility_and_feasibility_function_before_last_period( + model: InternalModel, + next_state_space_info: StateSpaceInfo, + period: int, +) -> Callable[..., tuple[Array, Array]]: + """Create the utility and feasibility function for a period before the last period. + + Args: + model: The internal model object. + next_state_space_info: The state space information of the next period. + period: The period to create the utility and feasibility function for. + + Returns: + A function that computes the utility and feasibility for the given period. + + """ + stochastic_variables = model.variable_info.query("is_stochastic").index.tolist() + + # ---------------------------------------------------------------------------------- + # Generate dynamic functions + # ---------------------------------------------------------------------------------- + # TODO (@timmens): This can be done outside this function, since it # noqa: TD003 + # does not depend on the period. + # ---------------------------------------------------------------------------------- + + # Functions required to calculate the expected continuation values + calculate_state_transition = get_next_state_function(model, target=Target.SOLVE) + calculate_next_weights = get_next_stochastic_weights_function(model) + calculate_node_weights = _get_node_weights_function(stochastic_variables) + _scalar_value_function = get_value_function_representation(next_state_space_info) + value_function = productmap( + _scalar_value_function, + variables=tuple(f"next_{var}" for var in stochastic_variables), + ) + + # Function required to calculate todays utility and feasibility + calculate_todays_u_and_f = _get_current_u_and_f(model) + + # ---------------------------------------------------------------------------------- + # Create the utility and feasability function + # ---------------------------------------------------------------------------------- + + arg_names = _get_required_arg_names_of_u_and_f( + [ + calculate_todays_u_and_f, + calculate_state_transition, + calculate_next_weights, + ] + ) + + @with_signature(args=arg_names) + def utility_and_feasibility( + params: ParamsDict, vf_arr: Array, **states_and_choices: Scalar + ) -> tuple[Scalar, Scalar]: + """Calculate the expected forward-looking utility and feasibility. + + Args: + params: The parameters. + vf_arr: The value function array. + **states_and_choices: Todays states and choices. + + Returns: + A tuple containing the utility and feasibility for the given period. + + """ + # ------------------------------------------------------------------------------ + # Calculate the expected continuation values + # ------------------------------------------------------------------------------ + next_states = calculate_state_transition( + **states_and_choices, + _period=period, + params=params, + ) + + weights = calculate_next_weights( + **states_and_choices, + _period=period, + params=params, + ) + + node_weights = calculate_node_weights(**weights) + + # As we productmap'd the value function over the stochastic variables, the + # resulting continuation values get a new dimension for each stochastic + # variable. + continuation_values_at_nodes = value_function(**next_states, vf_arr=vf_arr) + + # We then weight these continuation values with the joint node weights and sum + # them up to get the expected continuation values. + expected_continuation_values = ( + continuation_values_at_nodes * node_weights + ).sum() + + # ------------------------------------------------------------------------------ + # Calculate the expected forward-looking utility. + # ------------------------------------------------------------------------------ + # This is not the value function yet, as it still depends on the choices. + # ------------------------------------------------------------------------------ + period_utility, period_feasibility = calculate_todays_u_and_f( + **states_and_choices, + _period=period, + params=params, + ) + + expected_forward_utility = ( + period_utility + params["beta"] * expected_continuation_values + ) + + return expected_forward_utility, period_feasibility + + return utility_and_feasibility + + +def get_utility_and_feasibility_function_last_period( + model: InternalModel, + period: int, +) -> Callable[..., tuple[Array, Array]]: + """Create the utility and feasibility function for the last period. + + Args: + model: The internal model object. + period: The period to create the utility and feasibility function for. This is + still relevant for the last period, as some functions might depend on the + actual period value. + + Returns: + A function that computes the utility and feasibility for the given period. + + """ + calculate_todays_u_and_f = _get_current_u_and_f(model) + + arg_names = _get_required_arg_names_of_u_and_f([calculate_todays_u_and_f]) + + @with_signature(args=arg_names) + def utility_and_feasibility( + params: ParamsDict, **kwargs: Scalar + ) -> tuple[Scalar, Scalar]: + return calculate_todays_u_and_f( + **kwargs, + _period=period, + params=params, + ) + + return utility_and_feasibility + + +# ====================================================================================== +# Helper functions +# ====================================================================================== + + +def _get_required_arg_names_of_u_and_f( + model_functions: list[Callable[..., Any]], +) -> list[str]: + """Get the argument names of the utility and feasibility function. + + Args: + model_functions: The list of functions that are used to calculate the utility + and feasibility. + + Returns: + The argument names of the utility and feasibility function. + + """ + dynamic_arg_names = get_union_of_arguments(model_functions) - {"_period"} + static_arg_names = {"params", "vf_arr"} + + return list(static_arg_names | dynamic_arg_names) + + +def _get_node_weights_function(stochastic_variables: list[str]) -> Callable[..., Array]: + """Get joint weights function. + + This function takes the weights of the individual stochastic variables and + multiplies them together to get the joint weights on the product space of the + stochastic variables. + + Args: + stochastic_variables: List of stochastic variables. + + Returns: + A function that multiplies the weights of the stochastic variables. + + """ + arg_names = [f"weight_next_{var}" for var in stochastic_variables] + + @with_signature(args=arg_names) + def _outer(**kwargs: Array) -> Array: + weights = jnp.array(list(kwargs.values())) + return jnp.prod(weights) + + return productmap(_outer, variables=tuple(arg_names)) + + +def _get_current_u_and_f(model: InternalModel) -> Callable[..., tuple[Scalar, Scalar]]: + """Get the current utility and feasibility function. + + Args: + model: The internal model object. + + Returns: + The current utility and feasibility function. + + """ + functions = {"feasibility": _get_feasibility(model), **model.functions} + return concatenate_functions( + functions=functions, + targets=["utility", "feasibility"], + enforce_signature=False, + ) + + +def _get_feasibility(model: InternalModel) -> InternalUserFunction: + """Create a function that combines all constraint functions into a single one. + + Args: + model: The internal model object. + + Returns: + The combined constraint function (feasibility). + + """ + targets = model.function_info.query("is_constraint").index.tolist() + + if targets: + combined_constraint = concatenate_functions( + functions=model.functions, + targets=targets, + aggregator=jnp.logical_and, + ) + else: + + def combined_constraint(**kwargs: Scalar) -> bool: # noqa: ARG001 + """Dummy feasibility function that always returns True.""" + return True + + return combined_constraint diff --git a/src/lcm/utils.py b/src/lcm/utils.py index 86d2fd47..6b31fdc9 100644 --- a/src/lcm/utils.py +++ b/src/lcm/utils.py @@ -10,3 +10,22 @@ def find_duplicates(*containers: Iterable[T]) -> set[T]: combined = chain.from_iterable(containers) counts = Counter(combined) return {v for v, count in counts.items() if count > 1} + + +def first_non_none(*args: T | None) -> T: + """Return the first non-None argument. + + Args: + *args: Arguments to check. + + Returns: + The first non-None argument. + + Raises: + ValueError: If all arguments are None. + + """ + for arg in args: + if arg is not None: + return arg + raise ValueError("All arguments are None") diff --git a/tests/simulation/test_processing.py b/tests/simulation/test_processing.py new file mode 100644 index 00000000..0bd14a4b --- /dev/null +++ b/tests/simulation/test_processing.py @@ -0,0 +1,95 @@ +import jax.numpy as jnp +import pandas as pd +from pybaum import tree_equal + +from lcm.interfaces import InternalSimulationPeriodResults +from lcm.simulation.processing import ( + _compute_targets, + as_data_frame, + process_simulated_data, +) + + +def test_compute_targets(): + processed_results = { + "a": jnp.arange(3), + "b": 1 + jnp.arange(3), + "c": 2 + jnp.arange(3), + } + + def f_a(a, params): + return a + params["disutility_of_work"] + + def f_b(b, params): # noqa: ARG001 + return b + + def f_c(params): # noqa: ARG001 + return None + + model_functions = {"fa": f_a, "fb": f_b, "fc": f_c} + + got = _compute_targets( + processed_results=processed_results, + targets=["fa", "fb"], + model_functions=model_functions, # type: ignore[arg-type] + params={"disutility_of_work": -1.0}, + ) + expected = { + "fa": jnp.arange(3) - 1.0, + "fb": 1 + jnp.arange(3), + } + assert tree_equal(expected, got) + + +def test_as_data_frame(): + processed = { + "value": -6 + jnp.arange(6), + "a": jnp.arange(6), + "b": 6 + jnp.arange(6), + } + got = as_data_frame(processed, n_periods=2) + expected = pd.DataFrame( + { + "period": [0, 0, 0, 1, 1, 1], + "initial_state_id": [0, 1, 2, 0, 1, 2], + **processed, + }, + ).set_index(["period", "initial_state_id"]) + pd.testing.assert_frame_equal(got, expected) + + +def test_process_simulated_data(): + simulated = { + 0: InternalSimulationPeriodResults( + value=jnp.array([0.1, 0.2]), + states={"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, + choices={"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, + ), + 1: InternalSimulationPeriodResults( + value=jnp.array([0.3, 0.4]), + states={ + "b": jnp.array([-3, -4]), + "a": jnp.array([3, 4]), + }, + choices={ + "d": jnp.array([-7, -8]), + "c": jnp.array([7, 8]), + }, + ), + } + expected = { + "value": jnp.array([0.1, 0.2, 0.3, 0.4]), + "c": jnp.array([5, 6, 7, 8]), + "d": jnp.array([-5, -6, -7, -8]), + "a": jnp.array([1, 2, 3, 4]), + "b": jnp.array([-1, -2, -3, -4]), + } + + got = process_simulated_data( + simulated, + # Rest is none, since we are not computing any additional targets + model=None, # type: ignore[arg-type] + params=None, # type: ignore[arg-type] + additional_targets=None, + ) + assert tree_equal(expected, got) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 5fd918c7..aae30ca6 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -1,34 +1,33 @@ +from typing import TYPE_CHECKING + import jax.numpy as jnp -import pandas as pd import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal -from pybaum import tree_equal -from lcm.entry_point import ( - create_compute_conditional_continuation_policy, - get_lcm_function, +from lcm.conditional_continuation import ( + get_compute_conditional_continuation_policy, ) +from lcm.entry_point import get_lcm_function from lcm.input_processing import process_model from lcm.logging import get_logger -from lcm.model_functions import get_utility_and_feasibility_function -from lcm.next_state import _get_next_state_function_for_simulation +from lcm.next_state import get_next_state_function from lcm.simulation.simulate import ( - _as_data_frame, - _compute_targets, - _generate_simulation_keys, - _process_simulated_data, - create_data_scs, - determine_discrete_choice_axes, - filter_ccv_policy, - retrieve_choices, + get_continuous_choice_argmax_given_discrete, + get_values_from_indices, simulate, ) -from lcm.solution.state_choice_space import create_state_choice_space +from lcm.state_choice_space import create_state_space_info +from lcm.typing import Target +from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import ( get_model_config, get_params, ) +if TYPE_CHECKING: + import pandas as pd + + # ====================================================================================== # Test simulate using raw inputs # ====================================================================================== @@ -37,36 +36,34 @@ @pytest.fixture def simulate_inputs(): model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1) + choices = model_config.choices + choices["consumption"] = choices["consumption"].replace(stop=100) # type: ignore[attr-defined] + model_config = model_config.replace(choices=choices) model = process_model(model_config) - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) compute_ccv_policy_functions = [] for period in range(model.n_periods): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=period, is_last_period=True, ) - compute_ccv = create_compute_conditional_continuation_policy( + compute_ccv = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) compute_ccv_policy_functions.append(compute_ccv) - n_grid_points = model_config.choices["consumption"].n_points # type: ignore[attr-defined] - return { - "continuous_choice_grids": [ - {"consumption": jnp.linspace(1, 100, num=n_grid_points)}, - ], "compute_ccv_policy_functions": compute_ccv_policy_functions, "model": model, - "next_state": _get_next_state_function_for_simulation(model), + "next_state": get_next_state_function(model, target=Target.SIMULATE), } @@ -81,7 +78,7 @@ def test_simulate_using_raw_inputs(simulate_inputs): got = simulate( params=params, - pre_computed_vf_arr_list=[jnp.empty(0)], + vf_arr_dict={0: jnp.empty(0)}, initial_states={"wealth": jnp.array([1.0, 50.400803])}, logger=get_logger(debug_mode=False), **simulate_inputs, @@ -113,8 +110,8 @@ def _model_solution(n_periods): solve_model, _ = get_lcm_function(model_config, targets="solve") params = get_params() - vf_arr_list = solve_model(params=params) - return vf_arr_list, params, model_config + vf_arr_dict = solve_model(params=params) + return vf_arr_dict, params, model_config return _model_solution @@ -123,7 +120,7 @@ def test_simulate_using_get_lcm_function( iskhakov_et_al_2017_stripped_down_model_solution, ): n_periods = 3 - vf_arr_list, params, model = iskhakov_et_al_2017_stripped_down_model_solution( + vf_arr_dict, params, model = iskhakov_et_al_2017_stripped_down_model_solution( n_periods=n_periods, ) @@ -131,7 +128,7 @@ def test_simulate_using_get_lcm_function( res: pd.DataFrame = simulate_model( # type: ignore[assignment] params, - pre_computed_vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, initial_states={ "wealth": jnp.array([20.0, 150, 250, 320]), }, @@ -207,13 +204,13 @@ def test_effect_of_beta_on_last_period(): res_low: pd.DataFrame = simulate_model( # type: ignore[assignment] params_low, - pre_computed_vf_arr_list=solution_low, + vf_arr_dict=solution_low, initial_states={"wealth": initial_wealth}, ) res_high: pd.DataFrame = simulate_model( # type: ignore[assignment] params_high, - pre_computed_vf_arr_list=solution_high, + vf_arr_dict=solution_high, initial_states={"wealth": initial_wealth}, ) @@ -251,13 +248,13 @@ def test_effect_of_disutility_of_work(): res_low: pd.DataFrame = simulate_model( # type: ignore[assignment] params_low, - pre_computed_vf_arr_list=solution_low, + vf_arr_dict=solution_low, initial_states={"wealth": initial_wealth}, ) res_high: pd.DataFrame = simulate_model( # type: ignore[assignment] params_high, - pre_computed_vf_arr_list=solution_high, + vf_arr_dict=solution_high, initial_states={"wealth": initial_wealth}, ) @@ -281,96 +278,8 @@ def test_effect_of_disutility_of_work(): # ====================================================================================== -def test_generate_simulation_keys(): - key = jnp.arange(2, dtype="uint32") # PRNG dtype - stochastic_next_functions = ["a", "b"] - got = _generate_simulation_keys(key, stochastic_next_functions) - # assert that all generated keys are different from each other - matrix = jnp.array([key, got[0], got[1]["a"], got[1]["b"]]) - assert jnp.linalg.matrix_rank(matrix) == 2 - - -def test_as_data_frame(): - processed = { - "value": -6 + jnp.arange(6), - "a": jnp.arange(6), - "b": 6 + jnp.arange(6), - } - got = _as_data_frame(processed, n_periods=2) - expected = pd.DataFrame( - { - "period": [0, 0, 0, 1, 1, 1], - "initial_state_id": [0, 1, 2, 0, 1, 2], - **processed, - }, - ).set_index(["period", "initial_state_id"]) - pd.testing.assert_frame_equal(got, expected) - - -def test_compute_targets(): - processed_results = { - "a": jnp.arange(3), - "b": 1 + jnp.arange(3), - "c": 2 + jnp.arange(3), - } - - def f_a(a, params): - return a + params["disutility_of_work"] - - def f_b(b, params): # noqa: ARG001 - return b - - def f_c(params): # noqa: ARG001 - return None - - model_functions = {"fa": f_a, "fb": f_b, "fc": f_c} - - got = _compute_targets( - processed_results=processed_results, - targets=["fa", "fb"], - model_functions=model_functions, # type: ignore[arg-type] - params={"disutility_of_work": -1.0}, - ) - expected = { - "fa": jnp.arange(3) - 1.0, - "fb": 1 + jnp.arange(3), - } - assert tree_equal(expected, got) - - -def test_process_simulated_data(): - simulated = [ - { - "value": jnp.array([0.1, 0.2]), - "states": {"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, - "choices": {"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, - }, - { - "value": jnp.array([0.3, 0.4]), - "states": { - "b": jnp.array([-3, -4]), - "a": jnp.array([3, 4]), - }, - "choices": { - "d": jnp.array([-7, -8]), - "c": jnp.array([7, 8]), - }, - }, - ] - expected = { - "value": jnp.array([0.1, 0.2, 0.3, 0.4]), - "c": jnp.array([5, 6, 7, 8]), - "d": jnp.array([-5, -6, -7, -8]), - "a": jnp.array([1, 2, 3, 4]), - "b": jnp.array([-1, -2, -3, -4]), - } - - got = _process_simulated_data(simulated) - assert tree_equal(expected, got) - - def test_retrieve_choices(): - got = retrieve_choices( + got = get_values_from_indices( flat_indices=jnp.array([0, 3, 7]), grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, grids_shapes=(5, 6), @@ -388,37 +297,9 @@ def test_filter_ccv_policy(): ) argmax = jnp.array([0, 1]) vars_grid_shape = (2,) - got = filter_ccv_policy( - ccv_policy=ccc_policy, + got = get_continuous_choice_argmax_given_discrete( + conditional_continuous_choice_argmax=ccc_policy, discrete_argmax=argmax, - vars_grid_shape=vars_grid_shape, + discrete_choices_grid_shape=vars_grid_shape, ) assert jnp.all(got == jnp.array([0, 0])) - - -def test_create_data_state_choice_space(): - model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) - model = process_model(model_config) - got_space = create_data_scs( - states={ - "wealth": jnp.array([10.0, 20.0]), - "lagged_retirement": jnp.array([0, 1]), - }, - model=model, - ) - assert_array_equal(got_space.choices["retirement"], jnp.array([0, 1])) - assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) - assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) - - -def test_determine_discrete_choice_axes(): - variable_info = pd.DataFrame( - { - "is_state": [True, True, False, True, False, False], - "is_choice": [False, False, True, True, True, True], - "is_discrete": [True, True, True, True, True, False], - "is_continuous": [False, True, False, False, False, True], - }, - ) - got = determine_discrete_choice_axes(variable_info) - assert got == (1, 2, 3) diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 5b0450b6..e2d1f4e4 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -2,7 +2,7 @@ import numpy as np from numpy.testing import assert_array_almost_equal as aaae -from lcm.entry_point import create_compute_conditional_continuation_value +from lcm.conditional_continuation import get_compute_conditional_continuation_value from lcm.interfaces import StateChoiceSpace from lcm.logging import get_logger from lcm.ndimage import map_coordinates @@ -26,28 +26,22 @@ def test_solve_brute(): # create the list of state_choice_spaces # ================================================================================== _scs = StateChoiceSpace( - choices={ + discrete_choices={ # pick [0, 1] such that no label translation is needed # lazy is like a type, it influences utility but is not affected by choices "lazy": jnp.array([0, 1]), "working": jnp.array([0, 1]), }, + continuous_choices={ + "consumption": jnp.array([0, 1, 2, 3]), + }, states={ # pick [0, 1, 2] such that no coordinate mapping is needed "wealth": jnp.array([0.0, 1.0, 2.0]), }, ordered_var_names=("lazy", "working", "wealth"), ) - state_choice_spaces = [_scs] * 2 - - # ================================================================================== - # create continuous choice grids - # ================================================================================== - - # you 1 if working and have at most 2 in existing wealth, so - _ccg = {"consumption": jnp.array([0, 1, 2, 3])} - - continuous_choice_grids = [_ccg] * 2 + state_choice_spaces = {0: _scs, 1: _scs} # ================================================================================== # create the utility_and_feasibility functions @@ -79,12 +73,12 @@ def _get_continuation_value(lazy, wealth, vf_arr): coordinates=jnp.array([wealth]), ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=_utility_and_feasibility, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) - utility_and_feasibility_functions = [compute_ccv] * 2 + compute_ccv_functions = {0: compute_ccv, 1: compute_ccv} # ================================================================================== # create emax aggregators and choice segments @@ -94,7 +88,7 @@ def calculate_emax(values, params): # noqa: ARG001 """Take max over axis that corresponds to working.""" return values.max(axis=1) - emax_calculators = [calculate_emax] * 2 + emax_calculators = {0: calculate_emax, 1: calculate_emax} # ================================================================================== # call solve function @@ -103,22 +97,24 @@ def calculate_emax(values, params): # noqa: ARG001 solution = solve( params=params, state_choice_spaces=state_choice_spaces, - continuous_choice_grids=continuous_choice_grids, - compute_ccv_functions=utility_and_feasibility_functions, + compute_ccv_functions=compute_ccv_functions, emax_calculators=emax_calculators, logger=get_logger(debug_mode=False), ) - assert isinstance(solution, list) + assert isinstance(solution, dict) def test_solve_continuous_problem_no_vf_arr(): state_choice_space = StateChoiceSpace( - choices={ + discrete_choices={ "a": jnp.array([0, 1.0]), "b": jnp.array([2, 3.0]), "c": jnp.array([4, 5, 6]), }, + continuous_choices={ + "d": jnp.arange(12.0), + }, states={}, ordered_var_names=("a", "b", "c"), ) @@ -128,11 +124,9 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 feasib = d <= a + b + c return util, feasib - continuous_choice_grids = {"d": jnp.arange(12.0)} - - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=_utility_and_feasibility, - continuous_choice_variables=["d"], + continuous_choice_variables=("d",), ) expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]]) @@ -140,7 +134,6 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 got = solve_continuous_problem( state_choice_space, compute_ccv, - continuous_choice_grids, vf_arr=None, params={}, ) diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py deleted file mode 100644 index f4fe1341..00000000 --- a/tests/solution/test_state_space.py +++ /dev/null @@ -1,32 +0,0 @@ -import jax.numpy as jnp - -from lcm.input_processing import process_model -from lcm.interfaces import StateChoiceSpace, StateSpaceInfo -from lcm.solution.state_choice_space import ( - create_state_choice_space, -) -from tests.test_models import get_model_config - - -def test_create_state_choice_space(): - model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) - internal_model = process_model(model) - - state_choice_space, state_space_info = create_state_choice_space( - model=internal_model, - is_last_period=False, - ) - - assert isinstance(state_choice_space, StateChoiceSpace) - assert isinstance(state_space_info, StateSpaceInfo) - - assert jnp.array_equal( - state_choice_space.choices["retirement"], model.choices["retirement"].to_jax() - ) - assert jnp.array_equal( - state_choice_space.states["wealth"], model.states["wealth"].to_jax() - ) - - assert state_space_info.states_names == ("wealth",) - assert state_space_info.discrete_states == {} - assert state_space_info.continuous_states == model.states diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index c8a0fa1c..7c20c469 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -68,7 +68,9 @@ def test_analytical_solution(model_name, model_and_params): # ================================================================================== solve_model, _ = get_lcm_function(model=model_and_params["model"], targets="solve") - vf_arr_list: list[Array] = solve_model(params=model_and_params["params"]) # type: ignore[assignment] + vf_arr_dict: dict[int, Array] = solve_model(params=model_and_params["params"]) # type: ignore[assignment] + vf_arr_list = list(dict(sorted(vf_arr_dict.items(), key=lambda x: x[0])).values()) + _numerical = np.stack(vf_arr_list) numerical = { "worker": _numerical[:, 0, :], diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index 30e11a44..141d3c65 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -5,9 +5,10 @@ from lcm.discrete_problem import ( _calculate_emax_extreme_value_shocks, - _determine_discrete_choice_axes, + _determine_discrete_choice_axes_simulation, + _determine_discrete_choice_axes_solution, _solve_discrete_problem_no_shocks, - get_solve_discrete_problem, + get_solve_discrete_problem_value, ) from lcm.typing import ShockType @@ -27,7 +28,7 @@ def test_get_solve_discrete_problem_illustrative(): }, ) # leads to choice_axes = [1] - solve_discrete_problem = get_solve_discrete_problem( + solve_discrete_problem = get_solve_discrete_problem_value( random_utility_shock_type=ShockType.NONE, variable_info=variable_info, is_last_period=False, @@ -130,7 +131,7 @@ def test_determine_discrete_choice_axes_illustrative_one_var(): }, ) - assert _determine_discrete_choice_axes(variable_info) == (1,) + assert _determine_discrete_choice_axes_solution(variable_info) == (1,) @pytest.mark.illustrative @@ -144,4 +145,17 @@ def test_determine_discrete_choice_axes_illustrative_three_var(): }, ) - assert _determine_discrete_choice_axes(variable_info) == (1, 2, 3) + assert _determine_discrete_choice_axes_solution(variable_info) == (1, 2, 3) + + +def test_determine_discrete_choice_axes(): + variable_info = pd.DataFrame( + { + "is_state": [True, True, False, True, False, False], + "is_choice": [False, False, True, True, True, True], + "is_discrete": [True, True, True, True, True, False], + "is_continuous": [False, True, False, False, False, True], + }, + ) + got = _determine_discrete_choice_axes_simulation(variable_info) + assert got == (1, 2, 3) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index d9c467f5..1534e740 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -6,7 +6,7 @@ from lcm.dispatchers import ( productmap, - spacemap, + simulation_spacemap, vmap_1d, ) from lcm.functools import allow_args @@ -231,7 +231,7 @@ def test_spacemap_all_arguments_mapped( ): product_vars, combination_vars = setup_spacemap - decorated = spacemap( + decorated = simulation_spacemap( g, tuple(product_vars), tuple(combination_vars), @@ -245,12 +245,12 @@ def test_spacemap_all_arguments_mapped( ("error_msg", "product_vars", "combination_vars"), [ ( - "Same argument provided more than once in product variables or combination", + "Same argument provided more than once in choices or states variables", ["a", "b"], ["a", "c", "d"], ), ( - "Same argument provided more than once in product variables or combination", + "Same argument provided more than once in choices or states variables", ["a", "a", "b"], ["c", "d"], ), @@ -258,7 +258,9 @@ def test_spacemap_all_arguments_mapped( ) def test_spacemap_arguments_overlap(error_msg, product_vars, combination_vars): with pytest.raises(ValueError, match=error_msg): - spacemap(g, product_vars=product_vars, combination_vars=combination_vars) + simulation_spacemap( + g, choices_var_names=product_vars, states_var_names=combination_vars + ) # ====================================================================================== diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index f6c4d07c..3a4e1d03 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -2,14 +2,14 @@ import pytest from pybaum import tree_equal, tree_map -from lcm.entry_point import ( - create_compute_conditional_continuation_policy, - create_compute_conditional_continuation_value, - get_lcm_function, +from lcm.conditional_continuation import ( + get_compute_conditional_continuation_policy, + get_compute_conditional_continuation_value, ) +from lcm.entry_point import get_lcm_function from lcm.input_processing import process_model -from lcm.model_functions import get_utility_and_feasibility_function -from lcm.solution.state_choice_space import create_state_choice_space +from lcm.state_choice_space import create_state_space_info +from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import get_model_config from tests.test_models.deterministic import RetirementStatus from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility @@ -103,14 +103,14 @@ def test_get_lcm_function_with_simulation_is_coherent(model): # solve solve_model, params_template = get_lcm_function(model=model, targets="solve") params = tree_map(lambda _: 0.2, params_template) - vf_arr_list = solve_model(params) + vf_arr_dict = solve_model(params) # simulate using solution simulate_model, _ = get_lcm_function(model=model, targets="simulate") solve_then_simulate = simulate_model( params, - pre_computed_vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, initial_states={ "wealth": jnp.array([0.0, 10.0, 50.0]), }, @@ -142,14 +142,14 @@ def test_get_lcm_function_with_simulation_target_iskhakov_et_al_2017(model): # solve model solve_model, params_template = get_lcm_function(model=model, targets="solve") params = tree_map(lambda _: 0.2, params_template) - vf_arr_list = solve_model(params) + vf_arr_dict = solve_model(params) # simulate using solution simulate_model, _ = get_lcm_function(model=model, targets="simulate") simulate_model( params, - pre_computed_vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, initial_states={ "wealth": jnp.array([10.0, 10.0, 20.0]), "lagged_retirement": jnp.array( @@ -182,21 +182,21 @@ def test_create_compute_conditional_continuation_value(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) val = compute_ccv( @@ -227,21 +227,21 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=[], + continuous_choice_variables=(), ) val = compute_ccv( @@ -277,21 +277,21 @@ def test_create_compute_conditional_continuation_policy(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv_policy = create_compute_conditional_continuation_policy( + compute_ccv_policy = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) policy, val = compute_ccv_policy( @@ -323,21 +323,21 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv_policy = create_compute_conditional_continuation_policy( + compute_ccv_policy = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=[], + continuous_choice_variables=(), ) policy, val = compute_ccv_policy( diff --git a/tests/test_models/get_model.py b/tests/test_models/get_model.py index 2cf09371..66649f1c 100644 --- a/tests/test_models/get_model.py +++ b/tests/test_models/get_model.py @@ -89,9 +89,20 @@ def get_params( ], # Transition from period 1 to period 2 [ - # Description is the same as above - [[0, 1.0], [1.0, 0]], - [[0, 1.0], [0.0, 1.0]], + # Current working decision 0 + [ + # Current partner state 0 + [0, 1.0], + # Current partner state 1 + [1.0, 0], + ], + # Current working decision 1 + [ + # Current partner state 0 + [0, 1.0], + # Current partner state 1 + [0.0, 1.0], + ], ], ], ) diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 40230efa..e37b6d12 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -5,14 +5,10 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel -from lcm.next_state import _get_stochastic_next_func, get_next_state_function +from lcm.next_state import _create_stochastic_next_func, get_next_state_function from lcm.typing import ParamsDict, Scalar, ShockType, Target from tests.test_models import get_model_config -# ====================================================================================== -# Solve target -# ====================================================================================== - def test_get_next_state_function_with_solve_target(): model = process_model( @@ -35,11 +31,6 @@ def test_get_next_state_function_with_solve_target(): assert got == {"next_wealth": 1.05 * (20 - 10)} -# ====================================================================================== -# Simulate target -# ====================================================================================== - - def test_get_next_state_function_with_simulate_target(): def f_a(state: Array, params: ParamsDict) -> Scalar: # noqa: ARG001 return state[0] @@ -86,12 +77,12 @@ def f_weight_b(state: Scalar, params: ParamsDict) -> Array: # noqa: ARG001 assert tree_equal(expected, got) -def test_get_stochastic_next_func(): - grids = {"a": jnp.arange(2)} - got_func = _get_stochastic_next_func(name="a", grids=grids) +def test_create_stochastic_next_func(): + labels = jnp.arange(2) + got_func = _create_stochastic_next_func(name="a", labels=labels) keys = {"a": jnp.arange(2, dtype="uint32")} # PRNG dtype weights = jnp.array([[0.0, 1], [1, 0]]) - got = got_func(keys=keys, weight_a=weights) # type: ignore[call-arg] + got = got_func(keys=keys, weight_a=weights) assert jnp.array_equal(got, jnp.array([1, 0])) diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 00000000..f6cc7141 --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp + +from lcm.random import generate_simulation_keys, random_choice + + +def test_random_choice(): + key = jax.random.key(0) + probs = jnp.array([[0.0, 0, 1], [1, 0, 0], [0, 1, 0]]) + labels = jnp.array([1, 2, 3]) + got = random_choice(labels=labels, probs=probs, key=key) + assert jnp.array_equal(got, jnp.array([3, 1, 2])) + + +def test_generate_simulation_keys(): + key = jnp.arange(2, dtype="uint32") # PRNG dtype + stochastic_next_functions = ["a", "b"] + got = generate_simulation_keys(key, stochastic_next_functions) + # assert that all generated keys are different from each other + matrix = jnp.array([key, got[0], got[1]["a"], got[1]["b"]]) + assert jnp.linalg.matrix_rank(matrix) == 2 diff --git a/tests/test_random_choice.py b/tests/test_random_choice.py deleted file mode 100644 index 3c1ab25d..00000000 --- a/tests/test_random_choice.py +++ /dev/null @@ -1,12 +0,0 @@ -import jax -import jax.numpy as jnp - -from lcm.random_choice import random_choice - - -def test_random_choice(): - key = jax.random.key(0) - probs = jnp.array([[0.0, 0, 1], [1, 0, 0], [0, 1, 0]]) - labels = jnp.array([1, 2, 3]) - got = random_choice(key, probs=probs, labels=labels) - assert jnp.array_equal(got, jnp.array([3, 1, 2])) diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index ce77c7aa..295fff9b 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -1,5 +1,6 @@ import jax.numpy as jnp import pandas as pd +from jax import Array from numpy.testing import assert_array_almost_equal as aaae from pandas.testing import assert_frame_equal @@ -31,7 +32,7 @@ def test_regression_test(): disutility_of_work=1.0, interest_rate=0.05, ) - got_solve = solve(params) + got_solve: dict[int, Array] = solve(params) # type: ignore[assignment] solve_and_simulate, _ = get_lcm_function( model=model_config, @@ -47,5 +48,5 @@ def test_regression_test(): # Compare # ================================================================================== - aaae(expected_solve, got_solve, decimal=5) + aaae(expected_solve, list(got_solve.values()), decimal=5) assert_frame_equal(expected_simulate, got_simulate) # type: ignore[arg-type] diff --git a/tests/test_state_choice_space.py b/tests/test_state_choice_space.py new file mode 100644 index 00000000..11de9801 --- /dev/null +++ b/tests/test_state_choice_space.py @@ -0,0 +1,75 @@ +import jax.numpy as jnp +from numpy.testing import assert_array_equal + +from lcm.input_processing import process_model +from lcm.interfaces import StateChoiceSpace, StateSpaceInfo +from lcm.state_choice_space import ( + create_state_choice_space, + create_state_space_info, +) +from tests.test_models import get_model_config + + +def test_create_state_choice_space_solution(): + model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) + internal_model = process_model(model) + + state_choice_space = create_state_choice_space( + model=internal_model, + is_last_period=False, + ) + + assert isinstance(state_choice_space, StateChoiceSpace) + assert jnp.array_equal( + state_choice_space.discrete_choices["retirement"], + model.choices["retirement"].to_jax(), + ) + assert jnp.array_equal( + state_choice_space.states["wealth"], model.states["wealth"].to_jax() + ) + + +def test_create_state_choice_space_simulation(): + model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) + model = process_model(model_config) + got_space = create_state_choice_space( + model=model, + initial_states={ + "wealth": jnp.array([10.0, 20.0]), + "lagged_retirement": jnp.array([0, 1]), + }, + ) + assert_array_equal(got_space.discrete_choices["retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) + assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) + + +def test_create_state_space_info(): + model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) + internal_model = process_model(model) + + state_space_info = create_state_space_info( + model=internal_model, + is_last_period=False, + ) + + assert isinstance(state_space_info, StateSpaceInfo) + assert state_space_info.states_names == ("wealth",) + assert state_space_info.discrete_states == {} + assert state_space_info.continuous_states == model.states + + +def test_create_state_choice_space_replace(): + model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) + model = process_model(model_config) + space = create_state_choice_space( + model=model, + initial_states={ + "wealth": jnp.array([10.0, 20.0]), + "lagged_retirement": jnp.array([0, 1]), + }, + ) + new_space = space.replace( + states={"wealth": jnp.array([10.0, 30.0])}, + ) + assert_array_equal(new_space.states["wealth"], jnp.array([10.0, 30.0])) diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 2ae3b24d..581188f9 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -1,6 +1,7 @@ import jax.numpy as jnp import pandas as pd import pytest +from jax import Array import lcm from lcm.entry_point import ( @@ -19,7 +20,7 @@ def test_get_lcm_function_with_simulate_target(): targets="solve_and_simulate", ) - res = simulate_model( + res: pd.DataFrame = simulate_model( # type: ignore[assignment] params=get_params(), initial_states={ "health": jnp.array([1, 1, 0, 0]), @@ -28,21 +29,17 @@ def test_get_lcm_function_with_simulate_target(): }, ) - expected_partner = [ - 0, - 0, - 1, - 0, # period 0 - 1, - 1, - 1, - 1, # period 1 - 1, - 1, - 1, - 0, # period 2 - ] - assert jnp.array_equal(res["partner"].values, expected_partner) # type: ignore[call-overload, arg-type] + # This is derived from the partner transition in get_params. + expected_next_partner = ( + (res.working.astype(bool) | ~res.partner.astype(bool)).astype(int).loc[:1] + ) + + pd.testing.assert_series_equal( + res["partner"].loc[1:], + expected_next_partner, + check_index=False, + check_names=False, + ) # ====================================================================================== @@ -98,7 +95,7 @@ def next_health_deterministic(health): return model_deterministic, model_stochastic, params -def test_compare_deterministic_and_stochastic_results(model_and_params): +def test_compare_deterministic_and_stochastic_results_value_function(model_and_params): """Test that the deterministic and stochastic models produce the same results.""" model_deterministic, model_stochastic, params = model_and_params @@ -114,10 +111,14 @@ def test_compare_deterministic_and_stochastic_results(model_and_params): targets="solve", ) - solution_deterministic = solve_model_deterministic(params) - solution_stochastic = solve_model_stochastic(params) + solution_deterministic: dict[int, Array] = solve_model_deterministic(params) # type: ignore[assignment] + solution_stochastic: dict[int, Array] = solve_model_stochastic(params) # type: ignore[assignment] - assert jnp.array_equal(solution_deterministic, solution_stochastic, equal_nan=True) # type: ignore[arg-type] + assert jnp.array_equal( + jnp.array(list(solution_deterministic.values())), + jnp.array(list(solution_stochastic.values())), + equal_nan=True, + ) # ================================================================================== # Compare simulation results @@ -139,12 +140,12 @@ def test_compare_deterministic_and_stochastic_results(model_and_params): simulation_deterministic = simulate_model_deterministic( params, - pre_computed_vf_arr_list=solution_deterministic, + vf_arr_dict=solution_deterministic, initial_states=initial_states, ) simulation_stochastic = simulate_model_stochastic( params, - pre_computed_vf_arr_list=solution_stochastic, + vf_arr_dict=solution_stochastic, initial_states=initial_states, ) pd.testing.assert_frame_equal(simulation_deterministic, simulation_stochastic) # type: ignore[arg-type] diff --git a/tests/test_model_functions.py b/tests/test_utility_and_feasibility.py similarity index 92% rename from tests/test_model_functions.py rename to tests/test_utility_and_feasibility.py index 25e1327b..cf71bdec 100644 --- a/tests/test_model_functions.py +++ b/tests/test_utility_and_feasibility.py @@ -6,13 +6,13 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel -from lcm.model_functions import ( - get_combined_constraint, - get_multiply_weights, +from lcm.state_choice_space import create_state_space_info +from lcm.typing import ShockType +from lcm.utility_and_feasibility import ( + _get_feasibility, + _get_node_weights_function, get_utility_and_feasibility_function, ) -from lcm.solution.state_choice_space import create_state_choice_space -from lcm.typing import ShockType from tests.test_models import get_model_config from tests.test_models.deterministic import utility @@ -32,14 +32,14 @@ def test_get_utility_and_feasibility_function(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) @@ -119,7 +119,7 @@ def absorbing_retirement_constraint(retirement, lagged_retirement, params): # n @pytest.mark.illustrative def test_get_combined_constraint_illustrative(internal_model_illustrative): - combined_constraint = get_combined_constraint(internal_model_illustrative) + combined_constraint = _get_feasibility(internal_model_illustrative) age, retirement, lagged_retirement = jnp.array( [ @@ -148,7 +148,7 @@ def test_get_combined_constraint_illustrative(internal_model_illustrative): def test_get_multiply_weights(): - multiply_weights = get_multiply_weights( + multiply_weights = _get_node_weights_function( stochastic_variables=["a", "b"], ) @@ -184,6 +184,6 @@ def h(params): # noqa: ARG001 random_utility_shocks=ShockType.NONE, n_periods=0, ) - combined_constraint = get_combined_constraint(model) + combined_constraint = _get_feasibility(model) feasibility: Array = combined_constraint(params={}) # type: ignore[assignment] assert feasibility.item() is False