From 19aeee17e4feeb5848c510ad42e7e50c67c703ff Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 12:46:42 +0100 Subject: [PATCH 01/20] Make sure that stochastic next funcs are next funcs --- src/lcm/input_processing/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lcm/input_processing/util.py b/src/lcm/input_processing/util.py index 72734d1..f1fdfc7 100644 --- a/src/lcm/input_processing/util.py +++ b/src/lcm/input_processing/util.py @@ -25,7 +25,8 @@ def get_function_info(model: Model) -> pd.DataFrame: info["is_constraint"] = info.index.str.endswith(("_constraint", "_filter")) info["is_next"] = info.index.str.startswith("next_") & ~info["is_constraint"] info["is_stochastic_next"] = [ - hasattr(func, "_stochastic_info") for func in model.functions.values() + hasattr(func, "_stochastic_info") and info.loc[func_name]["is_next"] + for func_name, func in model.functions.items() ] return info From 58d5ae74e57c60fb485a97c4962f1b68fdd99d94 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 15:02:30 +0100 Subject: [PATCH 02/20] Refactor next_state.py --- src/lcm/next_state.py | 156 ++++++++++-------------------- src/lcm/random_choice.py | 8 +- src/lcm/typing.py | 10 ++ tests/simulation/test_simulate.py | 5 +- tests/test_next_state.py | 19 +--- tests/test_random_choice.py | 2 +- 6 files changed, 75 insertions(+), 125 deletions(-) diff --git a/src/lcm/next_state.py b/src/lcm/next_state.py index 88dbed7..c9d9e5c 100644 --- a/src/lcm/next_state.py +++ b/src/lcm/next_state.py @@ -1,11 +1,4 @@ -"""Generate functions that compute the next states of the model. - -For the solution, we simply concatenate the functions that compute the next states. For -the simulation, we generate functions that simulate the next states of stochastic -variables. We then concatenate these functions with the functions that compute the -deteministic next states. - -""" +"""Generate function that compute the next states for solution and simulation.""" from collections.abc import Callable @@ -13,156 +6,111 @@ from dags.signature import with_signature from jax import Array -from lcm.functools import all_as_args from lcm.interfaces import InternalModel from lcm.random_choice import random_choice -from lcm.typing import Scalar, Target +from lcm.typing import Scalar, StochasticNextFunction, Target def get_next_state_function( - model: InternalModel, target: Target -) -> Callable[..., dict[str, Scalar]]: - """Get function that computes the next states of the model. - - Args: - model: Internal model. - target: Target of the function. - - Returns: - Function that computes the next states of the model. - - """ - if target == Target.SOLVE: - return _get_next_state_function_for_solution(model) - - if target == Target.SIMULATE: - return _get_next_state_function_for_simulation(model) - - raise ValueError(f"Invalid target: {target}") - - -# ====================================================================================== -# Solution -# ====================================================================================== - - -def _get_next_state_function_for_solution( model: InternalModel, + target: Target, ) -> Callable[..., dict[str, Scalar]]: - """Get function that computes the next states for the solution. + """Get function that computes the next states during the solution. Args: - model: Model instance. + model: Internal model instance. + target: Whether to generate the function for the solve or simulate target. Returns: Function that computes the next states. Depends on states and choices of the - current period, and the model parameters. + current period, and the model parameters ("params"). If target is "simulate", + the function also depends on the dictionary of random keys ("keys"), which + corresponds to the names of stochastic next functions. """ targets = model.function_info.query("is_next").index.tolist() + if target == Target.SOLVE: + functions_dict = model.functions + elif target == Target.SIMULATE: + # For the simulation target, we need to extend the functions dictionary with + # stochastic next states functions and their weights. + functions_dict = _extend_functions_dict_for_simulation(model) + else: + raise ValueError(f"Invalid target: {target}") + return concatenate_functions( - functions=model.functions, + functions=functions_dict, targets=targets, return_type="dict", enforce_signature=False, ) -# ====================================================================================== -# Simulation -# ====================================================================================== - - -def _get_next_state_function_for_simulation( +def _extend_functions_dict_for_simulation( model: InternalModel, -) -> Callable[..., dict[str, Scalar]]: - """Get function that computes the next states for the simulation. +) -> dict[str, Callable[..., Scalar]]: + """Extend the functions dictionary for the simulation target. Args: - model: Model instance. + model: Internal model instance. Returns: - Function that computes the next states. Depends on states and choices of the - current period, and the model parameters. Additionaly, it depends on: - - key (dict): Dictionary with PRNG keys. Keys are the names of stochastic next - functions, e.g. 'next_health'. + Extended functions dictionary. """ - # ================================================================================== - # Get targets - # ================================================================================== - targets = model.function_info.query("is_next").index.tolist() - - stochastic_targets = model.function_info.query( - "is_next & is_stochastic_next", - ).index + stochastic_targets = model.function_info.query("is_stochastic_next").index - # ================================================================================== # Handle stochastic next states functions # ---------------------------------------------------------------------------------- # We generate stochastic next states functions that simulate the next state given - # a PRNG key and the weights of the stochastic variable. The corresponding weights - # are computed using the stochastic weight functions, which we add the to functions - # dict. `dags.concatenate_functions` then generates a function that computes the - # weights and simulates the next state. - # ================================================================================== + # a random key (think of a seed) and the weights corresponding to the labels of the + # stochastic variable. The weights are computed using the stochastic weight + # functions, which we add the to functions dict. `dags.concatenate_functions` then + # generates a function that computes the weights and simulates the next state in + # one go. + # ---------------------------------------------------------------------------------- stochastic_next = { - name: _get_stochastic_next_func(name, grids=model.grids) + name: _create_stochastic_next_func( + name, labels=model.grids[name.removeprefix("next_")] + ) for name in stochastic_targets } - stochastic_weights_names = [ - f"weight_{name}" - for name in model.function_info.query("is_stochastic_next").index.tolist() - ] - stochastic_weights = { - name: model.functions[name] for name in stochastic_weights_names + f"weight_{name}": model.functions[f"weight_{name}"] + for name in stochastic_targets } - # ================================================================================== # Overwrite model.functions with generated stochastic next states functions - # ================================================================================== - functions_dict = model.functions | stochastic_next | stochastic_weights - - return concatenate_functions( - functions=functions_dict, - targets=targets, - return_type="dict", - enforce_signature=False, - ) + # ---------------------------------------------------------------------------------- + return model.functions | stochastic_next | stochastic_weights -def _get_stochastic_next_func( - name: str, grids: dict[str, Array] -) -> Callable[[Array, Array], Array]: +def _create_stochastic_next_func(name: str, labels: Array) -> StochasticNextFunction: """Get function that simulates the next state of a stochastic variable. Args: name: Name of the stochastic variable. - grids: Dict with grids. + labels: 1d array of labels. Returns: - A function that simulates the next state of the stochastic variable. Depends on - variables: - - key (dict): Dictionary with PRNG keys. Keys are the names of stochastic next - functions, e.g. 'next_health'. - - weight_{name} (jax.numpy.array): 2d array of weights. The first dimension - corresponds to the number of simulation units. The second dimension - corresponds to the number of grid points (labels). + A function that simulates the next state of the stochastic variable. The + function must be called with keyword arguments: + - weight_{name}: 2d array of weights. The first dimension corresponds to the + number of simulation units. The second dimension corresponds to the number of + grid points (labels). + - keys: Dictionary with random key arrays. Dictionary keys correspond to the + names of stochastic next functions, e.g. 'next_health'. """ - arg_names = ["keys", f"weight_{name}"] - labels = grids[name.removeprefix("next_")] - @with_signature(args=arg_names) - def _next_stochastic_state(*args: Array, **kwargs: Array) -> Array: - keys, weights = all_as_args(args, kwargs, arg_names=arg_names) + @with_signature(args=[f"weight_{name}", "keys"]) + def next_stochastic_state(keys: dict[str, Array], **kwargs: Array) -> Array: return random_choice( - key=keys[name], - probs=weights, labels=labels, + probs=kwargs[f"weight_{name}"], + key=keys[name], ) - return _next_stochastic_state + return next_stochastic_state diff --git a/src/lcm/random_choice.py b/src/lcm/random_choice.py index dceb166..6e37d58 100644 --- a/src/lcm/random_choice.py +++ b/src/lcm/random_choice.py @@ -4,17 +4,17 @@ def random_choice( - key: jax.Array, - probs: jax.Array, labels: jax.Array, + probs: jax.Array, + key: jax.Array, ) -> jax.Array: """Draw multiple random choices. Args: - key: Random key. + labels: 1d array of labels. probs: 2d array of probabilities. Second dimension must be the same length as the first dimension of labels. - labels: 1d array of labels. + key: Random key. Returns: Selected labels. 1d array of length len(probs). diff --git a/src/lcm/typing.py b/src/lcm/typing.py index abf3355..1df6e16 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -43,6 +43,16 @@ class DiscreteProblemSolverFunction(Protocol): def __call__(self, values: Array, params: ParamsDict) -> Array: ... # noqa: D102 +class StochasticNextFunction(Protocol): + """The function that simulates the next state of a stochastic variable. + + Only used for type checking. + + """ + + def __call__(self, keys: dict[str, Array], **kwargs: Array) -> Array: ... # noqa: D102 + + class ShockType(Enum): """Type of shocks.""" diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 5fd918c..135b267 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -11,7 +11,7 @@ from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.model_functions import get_utility_and_feasibility_function -from lcm.next_state import _get_next_state_function_for_simulation +from lcm.next_state import get_next_state_function from lcm.simulation.simulate import ( _as_data_frame, _compute_targets, @@ -24,6 +24,7 @@ simulate, ) from lcm.solution.state_choice_space import create_state_choice_space +from lcm.typing import Target from tests.test_models import ( get_model_config, get_params, @@ -66,7 +67,7 @@ def simulate_inputs(): ], "compute_ccv_policy_functions": compute_ccv_policy_functions, "model": model, - "next_state": _get_next_state_function_for_simulation(model), + "next_state": get_next_state_function(model, target=Target.SIMULATE), } diff --git a/tests/test_next_state.py b/tests/test_next_state.py index 40230ef..e37b6d1 100644 --- a/tests/test_next_state.py +++ b/tests/test_next_state.py @@ -5,14 +5,10 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel -from lcm.next_state import _get_stochastic_next_func, get_next_state_function +from lcm.next_state import _create_stochastic_next_func, get_next_state_function from lcm.typing import ParamsDict, Scalar, ShockType, Target from tests.test_models import get_model_config -# ====================================================================================== -# Solve target -# ====================================================================================== - def test_get_next_state_function_with_solve_target(): model = process_model( @@ -35,11 +31,6 @@ def test_get_next_state_function_with_solve_target(): assert got == {"next_wealth": 1.05 * (20 - 10)} -# ====================================================================================== -# Simulate target -# ====================================================================================== - - def test_get_next_state_function_with_simulate_target(): def f_a(state: Array, params: ParamsDict) -> Scalar: # noqa: ARG001 return state[0] @@ -86,12 +77,12 @@ def f_weight_b(state: Scalar, params: ParamsDict) -> Array: # noqa: ARG001 assert tree_equal(expected, got) -def test_get_stochastic_next_func(): - grids = {"a": jnp.arange(2)} - got_func = _get_stochastic_next_func(name="a", grids=grids) +def test_create_stochastic_next_func(): + labels = jnp.arange(2) + got_func = _create_stochastic_next_func(name="a", labels=labels) keys = {"a": jnp.arange(2, dtype="uint32")} # PRNG dtype weights = jnp.array([[0.0, 1], [1, 0]]) - got = got_func(keys=keys, weight_a=weights) # type: ignore[call-arg] + got = got_func(keys=keys, weight_a=weights) assert jnp.array_equal(got, jnp.array([1, 0])) diff --git a/tests/test_random_choice.py b/tests/test_random_choice.py index 3c1ab25..2315788 100644 --- a/tests/test_random_choice.py +++ b/tests/test_random_choice.py @@ -8,5 +8,5 @@ def test_random_choice(): key = jax.random.key(0) probs = jnp.array([[0.0, 0, 1], [1, 0, 0], [0, 1, 0]]) labels = jnp.array([1, 2, 3]) - got = random_choice(key, probs=probs, labels=labels) + got = random_choice(labels=labels, probs=probs, key=key) assert jnp.array_equal(got, jnp.array([3, 1, 2])) From f2d6c53c511662ac991673058ea61971193f85b3 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 16:52:31 +0100 Subject: [PATCH 03/20] Refactor utility and feasibility --- src/lcm/entry_point.py | 6 +- src/lcm/model_functions.py | 211 ------------- src/lcm/next_state.py | 25 ++ src/lcm/utility_and_feasibility.py | 280 ++++++++++++++++++ tests/simulation/test_simulate.py | 2 +- tests/test_entry_point.py | 2 +- ...ons.py => test_utility_and_feasibility.py} | 16 +- 7 files changed, 318 insertions(+), 224 deletions(-) delete mode 100644 src/lcm/model_functions.py create mode 100644 src/lcm/utility_and_feasibility.py rename tests/{test_model_functions.py => test_utility_and_feasibility.py} (95%) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 0d89ddc..0b55868 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -13,15 +13,15 @@ from lcm.dispatchers import productmap from lcm.input_processing import process_model from lcm.logging import get_logger -from lcm.model_functions import ( - get_utility_and_feasibility_function, -) from lcm.next_state import get_next_state_function from lcm.simulation.simulate import simulate from lcm.solution.solve_brute import solve from lcm.solution.state_choice_space import create_state_choice_space from lcm.typing import ParamsDict, Target from lcm.user_model import Model +from lcm.utility_and_feasibility import ( + get_utility_and_feasibility_function, +) def get_lcm_function( diff --git a/src/lcm/model_functions.py b/src/lcm/model_functions.py deleted file mode 100644 index 6e14b02..0000000 --- a/src/lcm/model_functions.py +++ /dev/null @@ -1,211 +0,0 @@ -import inspect -from collections.abc import Callable - -import jax.numpy as jnp -from dags import concatenate_functions -from dags.signature import with_signature -from jax import Array - -from lcm.dispatchers import productmap -from lcm.function_representation import get_value_function_representation -from lcm.functools import ( - all_as_args, - all_as_kwargs, - get_union_of_arguments, -) -from lcm.interfaces import InternalModel, StateSpaceInfo -from lcm.next_state import get_next_state_function -from lcm.typing import InternalUserFunction, ParamsDict, Scalar, Target - - -def get_utility_and_feasibility_function( - model: InternalModel, - state_space_info: StateSpaceInfo, - period: int, - *, - is_last_period: bool, -) -> Callable[..., tuple[Array, Array]]: - # ================================================================================== - # Gather information on the model variables - # ================================================================================== - state_variables = model.variable_info.query("is_state").index.tolist() - choice_variables = model.variable_info.query("is_choice").index.tolist() - stochastic_variables = model.variable_info.query("is_stochastic").index.tolist() - - # ================================================================================== - # Generate dynamic functions - # ================================================================================== - current_u_and_f = get_current_u_and_f(model) - - if is_last_period: - relevant_functions: list[ - Callable[..., Scalar] - | Callable[..., tuple[Scalar, Scalar]] - | Callable[..., dict[str, Scalar]] - ] = [current_u_and_f] - - else: - next_state = get_next_state_function(model, target=Target.SOLVE) - next_weights = get_next_weights_function(model) - - scalar_value_function = get_value_function_representation(state_space_info) - - multiply_weights = get_multiply_weights(stochastic_variables) - - relevant_functions = [ - current_u_and_f, - next_state, - next_weights, - scalar_value_function, - ] - - value_function_arguments = list( - inspect.signature(scalar_value_function).parameters, - ) - - # ================================================================================== - # Create the utility and feasability function - # ================================================================================== - - arg_names_set = {"vf_arr"} | get_union_of_arguments(relevant_functions) - { - "_period" - } - arg_names = [arg for arg in arg_names_set if "next_" not in arg] - - if is_last_period: - - @with_signature(args=arg_names) - def u_and_f( - *args: Scalar, params: ParamsDict, **kwargs: Scalar - ) -> tuple[Scalar, Scalar]: - kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names) - - states = {k: v for k, v in kwargs.items() if k in state_variables} - choices = {k: v for k, v in kwargs.items() if k in choice_variables} - - return current_u_and_f( - **states, - **choices, - _period=period, - params=params, - ) - - else: - - @with_signature(args=arg_names) - def u_and_f( - *args: Scalar, params: ParamsDict, **kwargs: Scalar - ) -> tuple[Scalar, Scalar]: - kwargs = all_as_kwargs(args, kwargs, arg_names=arg_names) - - states = {k: v for k, v in kwargs.items() if k in state_variables} - choices = {k: v for k, v in kwargs.items() if k in choice_variables} - - u, f = current_u_and_f( - **states, - **choices, - _period=period, - params=params, - ) - - _next_state = next_state( - **states, - **choices, - _period=period, - params=params, - ) - weights = next_weights( - **states, - **choices, - _period=period, - params=params, - ) - - value_function = productmap( - scalar_value_function, - variables=tuple(f"next_{var}" for var in stochastic_variables), - ) - - ccvs_at_nodes = value_function( - **_next_state, - **{k: v for k, v in kwargs.items() if k in value_function_arguments}, - ) - - node_weights = multiply_weights(**weights) - - ccv = (ccvs_at_nodes * node_weights).sum() - - big_u = u + params["beta"] * ccv - return big_u, f - - return u_and_f - - -def get_multiply_weights(stochastic_variables: list[str]) -> Callable[..., Array]: - """Get multiply_weights function. - - Args: - stochastic_variables (list): List of stochastic variables. - - Returns: - A function that multiplies the weights of the stochastic variables. - - """ - arg_names = [f"weight_next_{var}" for var in stochastic_variables] - - @with_signature(args=arg_names) - def _outer(*args: Array, **kwargs: Array) -> Array: - args = all_as_args(args, kwargs, arg_names=arg_names) - return jnp.prod(jnp.array(args)) - - return productmap(_outer, variables=tuple(arg_names)) - - -def get_combined_constraint(model: InternalModel) -> InternalUserFunction: - """Create a function that combines all constraint functions into a single one. - - Args: - model: The internal model object. - - Returns: - The combined constraint function. - - """ - targets = model.function_info.query("is_constraint").index.tolist() - - if targets: - combined_constraint = concatenate_functions( - functions=model.functions, - targets=targets, - aggregator=jnp.logical_and, - ) - else: - - def combined_constraint() -> None: - return None - - return combined_constraint - - -def get_current_u_and_f(model: InternalModel) -> Callable[..., tuple[Scalar, Scalar]]: - functions = {"feasibility": get_combined_constraint(model), **model.functions} - - return concatenate_functions( - functions=functions, - targets=["utility", "feasibility"], - enforce_signature=False, - ) - - -def get_next_weights_function(model: InternalModel) -> Callable[..., dict[str, Scalar]]: - targets = [ - f"weight_{name}" - for name in model.function_info.query("is_stochastic_next").index.tolist() - ] - - return concatenate_functions( - functions=model.functions, - targets=targets, - return_type="dict", - enforce_signature=False, - ) diff --git a/src/lcm/next_state.py b/src/lcm/next_state.py index c9d9e5c..a82958a 100644 --- a/src/lcm/next_state.py +++ b/src/lcm/next_state.py @@ -47,6 +47,31 @@ def get_next_state_function( ) +def get_next_stochastic_weights_function( + model: InternalModel, +) -> Callable[..., dict[str, Array]]: + """Get function that computes the weights for the next stochastic states. + + Args: + model: Internal model instance. + + Returns: + Function that computes the weights for the next stochastic states. + + """ + targets = [ + f"weight_{name}" + for name in model.function_info.query("is_stochastic_next").index.tolist() + ] + + return concatenate_functions( + functions=model.functions, + targets=targets, + return_type="dict", + enforce_signature=False, + ) + + def _extend_functions_dict_for_simulation( model: InternalModel, ) -> dict[str, Callable[..., Scalar]]: diff --git a/src/lcm/utility_and_feasibility.py b/src/lcm/utility_and_feasibility.py new file mode 100644 index 0000000..b510c86 --- /dev/null +++ b/src/lcm/utility_and_feasibility.py @@ -0,0 +1,280 @@ +from collections.abc import Callable +from typing import Any + +import jax.numpy as jnp +from dags import concatenate_functions +from dags.signature import with_signature +from jax import Array + +from lcm.dispatchers import productmap +from lcm.function_representation import get_value_function_representation +from lcm.functools import get_union_of_arguments +from lcm.interfaces import InternalModel, StateSpaceInfo +from lcm.next_state import get_next_state_function, get_next_stochastic_weights_function +from lcm.typing import InternalUserFunction, ParamsDict, Scalar, Target + + +def get_utility_and_feasibility_function( + model: InternalModel, + state_space_info: StateSpaceInfo, + period: int, + *, + is_last_period: bool, +) -> Callable[..., tuple[Array, Array]]: + """Create the utility and feasibility function for a given period. + + Args: + model: The internal model object. + state_space_info: The state space information. + period: The period to create the utility and feasibility function for. + is_last_period: Whether the period is the last period. + + Returns: + A function that computes the expected forward-looking utility and feasibility + for the given period. + + """ + if is_last_period: + return get_utility_and_feasibility_function_last_period(model, period=period) + return get_utility_and_feasibility_function_before_last_period( + model, state_space_info=state_space_info, period=period + ) + + +def get_utility_and_feasibility_function_before_last_period( + model: InternalModel, + state_space_info: StateSpaceInfo, + period: int, +) -> Callable[..., tuple[Array, Array]]: + """Create the utility and feasibility function for a period before the last period. + + Args: + model: The internal model object. + state_space_info: The state space information. + period: The period to create the utility and feasibility function for. + + Returns: + A function that computes the utility and feasibility for the given period. + + """ + stochastic_variables = model.variable_info.query("is_stochastic").index.tolist() + + # ---------------------------------------------------------------------------------- + # Generate dynamic functions + # ---------------------------------------------------------------------------------- + # TODO (@timmens): This can be done outside this function, since it # noqa: TD003 + # does not depend on the period. + # ---------------------------------------------------------------------------------- + + # Functions required to calculate the expected continuation values + calculate_state_transition = get_next_state_function(model, target=Target.SOLVE) + calculate_next_weights = get_next_stochastic_weights_function(model) + calculate_node_weights = _get_node_weights_function(stochastic_variables) + _scalar_value_function = get_value_function_representation(state_space_info) + value_function = productmap( + _scalar_value_function, + variables=tuple(f"next_{var}" for var in stochastic_variables), + ) + + # Function required to calculate todays utility and feasibility + calculate_todays_u_and_f = _get_current_u_and_f(model) + + # ---------------------------------------------------------------------------------- + # Create the utility and feasability function + # ---------------------------------------------------------------------------------- + + arg_names = _get_required_arg_names_of_u_and_f( + [ + calculate_todays_u_and_f, + calculate_state_transition, + calculate_next_weights, + ] + ) + + @with_signature(args=arg_names) + def utility_and_feasibility( + params: ParamsDict, vf_arr: Array, **states_and_choices: Scalar + ) -> tuple[Scalar, Scalar]: + """Calculate the expected forward-looking utility and feasibility. + + Args: + params: The parameters. + vf_arr: The value function array. + **states_and_choices: Todays states and choices. + + Returns: + A tuple containing the utility and feasibility for the given period. + + """ + # ------------------------------------------------------------------------------ + # Calculate the expected continuation values + # ------------------------------------------------------------------------------ + next_states = calculate_state_transition( + **states_and_choices, + _period=period, + params=params, + ) + + weights = calculate_next_weights( + **states_and_choices, + _period=period, + params=params, + ) + + node_weights = calculate_node_weights(**weights) + + # As we productmap'd the value function over the stochastic variables, the + # resulting continuation values get a new dimension for each stochastic + # variable. + continuation_values_at_nodes = value_function(**next_states, vf_arr=vf_arr) + + # We then weight these continuation values with the joint node weights and sum + # them up to get the expected continuation values. + expected_continuation_values = ( + continuation_values_at_nodes * node_weights + ).sum() + + # ------------------------------------------------------------------------------ + # Calculate the expected forward-looking utility. + # ------------------------------------------------------------------------------ + # This is not the value function yet, as it still depends on the choices. + # ------------------------------------------------------------------------------ + period_utility, period_feasibility = calculate_todays_u_and_f( + **states_and_choices, + _period=period, + params=params, + ) + + expected_forward_utility = ( + period_utility + params["beta"] * expected_continuation_values + ) + + return expected_forward_utility, period_feasibility + + return utility_and_feasibility + + +def get_utility_and_feasibility_function_last_period( + model: InternalModel, + period: int, +) -> Callable[..., tuple[Array, Array]]: + """Create the utility and feasibility function for the last period. + + Args: + model: The internal model object. + period: The period to create the utility and feasibility function for. This is + still relevant for the last period, as some functions might depend on the + actual period value. + + Returns: + A function that computes the utility and feasibility for the given period. + + """ + calculate_todays_u_and_f = _get_current_u_and_f(model) + + arg_names = _get_required_arg_names_of_u_and_f([calculate_todays_u_and_f]) + + @with_signature(args=arg_names) + def utility_and_feasibility( + params: ParamsDict, **kwargs: Scalar + ) -> tuple[Scalar, Scalar]: + return calculate_todays_u_and_f( + **kwargs, + _period=period, + params=params, + ) + + return utility_and_feasibility + + +# ====================================================================================== +# Helper functions +# ====================================================================================== + + +def _get_required_arg_names_of_u_and_f( + model_functions: list[Callable[..., Any]], +) -> list[str]: + """Get the argument names of the utility and feasibility function. + + Args: + model_functions: The list of functions that are used to calculate the utility + and feasibility. + + Returns: + The argument names of the utility and feasibility function. + + """ + dynamic_arg_names = get_union_of_arguments(model_functions) - {"_period"} + static_arg_names = {"params", "vf_arr"} + + return list(static_arg_names | dynamic_arg_names) + + +def _get_node_weights_function(stochastic_variables: list[str]) -> Callable[..., Array]: + """Get joint weights function. + + This function takes the weights of the individual stochastic variables and + multiplies them together to get the joint weights on the product space of the + stochastic variables. + + Args: + stochastic_variables: List of stochastic variables. + + Returns: + A function that multiplies the weights of the stochastic variables. + + """ + arg_names = [f"weight_next_{var}" for var in stochastic_variables] + + @with_signature(args=arg_names) + def _outer(**kwargs: Array) -> Array: + weights = jnp.array(list(kwargs.values())) + return jnp.prod(weights) + + return productmap(_outer, variables=tuple(arg_names)) + + +def _get_current_u_and_f(model: InternalModel) -> Callable[..., tuple[Scalar, Scalar]]: + """Get the current utility and feasibility function. + + Args: + model: The internal model object. + + Returns: + The current utility and feasibility function. + + """ + functions = {"feasibility": _get_feasibility(model), **model.functions} + return concatenate_functions( + functions=functions, + targets=["utility", "feasibility"], + enforce_signature=False, + ) + + +def _get_feasibility(model: InternalModel) -> InternalUserFunction: + """Create a function that combines all constraint functions into a single one. + + Args: + model: The internal model object. + + Returns: + The combined constraint function (feasibility). + + """ + targets = model.function_info.query("is_constraint").index.tolist() + + if targets: + combined_constraint = concatenate_functions( + functions=model.functions, + targets=targets, + aggregator=jnp.logical_and, + ) + else: + + def combined_constraint(**kwargs: Scalar) -> bool: # noqa: ARG001 + """Dummy feasibility function that always returns True.""" + return True + + return combined_constraint diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 135b267..32aa03f 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -10,7 +10,6 @@ ) from lcm.input_processing import process_model from lcm.logging import get_logger -from lcm.model_functions import get_utility_and_feasibility_function from lcm.next_state import get_next_state_function from lcm.simulation.simulate import ( _as_data_frame, @@ -25,6 +24,7 @@ ) from lcm.solution.state_choice_space import create_state_choice_space from lcm.typing import Target +from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import ( get_model_config, get_params, diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index f6c4d07..77e887d 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -8,8 +8,8 @@ get_lcm_function, ) from lcm.input_processing import process_model -from lcm.model_functions import get_utility_and_feasibility_function from lcm.solution.state_choice_space import create_state_choice_space +from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import get_model_config from tests.test_models.deterministic import RetirementStatus from tests.test_models.deterministic import utility as iskhakov_et_al_2017_utility diff --git a/tests/test_model_functions.py b/tests/test_utility_and_feasibility.py similarity index 95% rename from tests/test_model_functions.py rename to tests/test_utility_and_feasibility.py index 25e1327..9a89380 100644 --- a/tests/test_model_functions.py +++ b/tests/test_utility_and_feasibility.py @@ -6,13 +6,13 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel -from lcm.model_functions import ( - get_combined_constraint, - get_multiply_weights, - get_utility_and_feasibility_function, -) from lcm.solution.state_choice_space import create_state_choice_space from lcm.typing import ShockType +from lcm.utility_and_feasibility import ( + _get_feasibility, + _get_node_weights_function, + get_utility_and_feasibility_function, +) from tests.test_models import get_model_config from tests.test_models.deterministic import utility @@ -119,7 +119,7 @@ def absorbing_retirement_constraint(retirement, lagged_retirement, params): # n @pytest.mark.illustrative def test_get_combined_constraint_illustrative(internal_model_illustrative): - combined_constraint = get_combined_constraint(internal_model_illustrative) + combined_constraint = _get_feasibility(internal_model_illustrative) age, retirement, lagged_retirement = jnp.array( [ @@ -148,7 +148,7 @@ def test_get_combined_constraint_illustrative(internal_model_illustrative): def test_get_multiply_weights(): - multiply_weights = get_multiply_weights( + multiply_weights = _get_node_weights_function( stochastic_variables=["a", "b"], ) @@ -184,6 +184,6 @@ def h(params): # noqa: ARG001 random_utility_shocks=ShockType.NONE, n_periods=0, ) - combined_constraint = get_combined_constraint(model) + combined_constraint = _get_feasibility(model) feasibility: Array = combined_constraint(params={}) # type: ignore[assignment] assert feasibility.item() is False From a7b1f36d6ff0c4c4fc51ad214e86b1365c8613d2 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 17:00:07 +0100 Subject: [PATCH 04/20] Delete sandbox folder --- .../state_space_jax_versus_numba.ipynb | 835 ------------------ 1 file changed, 835 deletions(-) delete mode 100644 src/lcm/sandbox/state_space_jax_versus_numba.ipynb diff --git a/src/lcm/sandbox/state_space_jax_versus_numba.ipynb b/src/lcm/sandbox/state_space_jax_versus_numba.ipynb deleted file mode 100644 index a8f034a..0000000 --- a/src/lcm/sandbox/state_space_jax_versus_numba.ipynb +++ /dev/null @@ -1,835 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": {}, - "outputs": [], - "source": [ - "import itertools\n", - "import math\n", - "\n", - "import jax.numpy as jnp\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from jax import jit\n", - "from jax.config import config\n", - "from numba import njit\n", - "\n", - "config.update(\"jax_enable_x64\", True)\n", - "\n", - "from dags import concatenate_functions\n", - "from lcm.dispatchers import gridmap, productmap\n", - "from numpy.testing import assert_array_almost_equal as aaae" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_runtime(n_points, runtime_numba, runtime_jax, func=\"product_map\"):\n", - " plt.figure(figsize=[10, 6])\n", - " plt.plot(n_points, [i * 1000 for i in runtime_numba], label=\"Numba\")\n", - " plt.plot(n_points, [i * 1000 for i in runtime_jax], label=\"JAX\")\n", - "\n", - " plt.xlabel(\"Number of Grid Points\")\n", - " plt.ylabel(\"Runtime (in milliseconds)\")\n", - " plt.title(f\"Runtime of hardcoded Numba function versus {func} with JAX\")\n", - "\n", - " plt.legend()\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "2", - "metadata": {}, - "source": [ - "## product_map: 2 grids benchmark" - ] - }, - { - "cell_type": "markdown", - "id": "3", - "metadata": {}, - "source": [ - "### Trivial setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "def _utility(_consumption, _leisure):\n", - " return _consumption + _leisure\n", - "\n", - "\n", - "def _leisure(working):\n", - " return 24 - working\n", - "\n", - "\n", - "def _consumption(working, wage):\n", - " return wage * working" - ] - }, - { - "cell_type": "markdown", - "id": "5", - "metadata": {}, - "source": [ - "### Numba: Dimensions hardcoded" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "@njit(fastmath=True)\n", - "def product_map_numba_2d(wage, working):\n", - " n_wage = wage.shape[0]\n", - " n_working = working.shape[0]\n", - "\n", - " cross_product = np.empty((n_wage, n_working))\n", - "\n", - " for i in range(n_wage):\n", - " for j in range(n_working):\n", - " cross_product[i, j] = wage[i] * working[j] + (24 - working[j])\n", - "\n", - " return cross_product\n", - "\n", - "\n", - "wage = np.linspace(1, 10, 101)\n", - "working = np.linspace(0, 24, 25)\n", - "rslt_numba = product_map_numba_2d(wage, working)" - ] - }, - { - "cell_type": "markdown", - "id": "7", - "metadata": {}, - "source": [ - "### LCM productmap" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "grids = {\n", - " \"wage\": jnp.linspace(1, 10, 101),\n", - " \"working\": jnp.linspace(0, 24, 25),\n", - "}\n", - "\n", - "utility_concat = concatenate_functions(\n", - " functions=[_utility, _leisure, _consumption],\n", - " targets=\"_utility\",\n", - ")\n", - "_decorated_func = productmap(utility_concat, [\"wage\", \"working\"])\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**grids)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "### Assert equality\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime_2d():\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grid in grid_space:\n", - " total_points.append(len_grid**2)\n", - "\n", - " grids = {\n", - " \"wage\": np.linspace(1, 10, len_grid),\n", - " \"working\": np.linspace(0, 24, len_grid),\n", - " }\n", - "\n", - " product_map_numba_2d(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o product_map_numba_2d(**grids)\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime_2d():\n", - " utility_concat = concatenate_functions(\n", - " functions=[_utility, _leisure, _consumption],\n", - " targets=\"_utility\",\n", - " )\n", - " _decorated_func = productmap(utility_concat, [\"wage\", \"working\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grid in grid_space:\n", - " total_points.append(len_grid**2)\n", - "\n", - " grids = {\n", - " \"wage\": jnp.linspace(1, 10, len_grid),\n", - " \"working\": jnp.linspace(0, 24, len_grid),\n", - " }\n", - "\n", - " decorated_func_jit(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**grids).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "grid_space = [10, 25, 50] + np.arange(100, 1100, 100).tolist()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [ - "n_points_2_grids, runtime_numba_2d = get_numba_runtime_2d()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax_2d = get_jax_runtime_2d()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points_2_grids, runtime_numba_2d, runtime_jax_2d)" - ] - }, - { - "cell_type": "markdown", - "id": "14", - "metadata": {}, - "source": [ - "## *product_map*: 4 grids benchmark, more complex computations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "def _utility_4d(_consumption_4d, _leisure_4d):\n", - " return _consumption_4d + _leisure_4d\n", - "\n", - "\n", - "def _leisure_4d(d):\n", - " return jnp.cos(d)\n", - "\n", - "\n", - "def _consumption_4d(a, b, c):\n", - " return jnp.log(a) + jnp.sqrt(jnp.square(b) * c)\n", - "\n", - "\n", - "@njit(fastmath=True)\n", - "def product_map_numba_4d(a, b, c, d):\n", - " n_a = a.shape[0]\n", - " n_b = b.shape[0]\n", - " n_c = c.shape[0]\n", - " n_d = d.shape[0]\n", - "\n", - " cross_product = np.empty((n_a, n_b, n_c, n_d))\n", - "\n", - " for i in range(n_a):\n", - " for j in range(n_b):\n", - " for k in range(n_c):\n", - " for l in range(n_d):\n", - " cross_product[i, j, k, l] = (\n", - " np.log(a[i]) + np.sqrt(b[j] ** 2 * c[k]) + np.cos(d[l])\n", - " )\n", - "\n", - " return cross_product" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "grids = {\n", - " \"a\": jnp.linspace(1, 10, 101),\n", - " \"b\": jnp.linspace(0, 24, 25),\n", - " \"c\": jnp.linspace(1, 5, 5),\n", - " \"d\": jnp.linspace(-7, 2, 10),\n", - "}\n", - "\n", - "a = np.linspace(1, 10, 101)\n", - "b = np.linspace(0, 24, 25)\n", - "c = np.linspace(1, 5, 5)\n", - "d = np.linspace(-7, 2, 10)\n", - "\n", - "utility_concat_4d = concatenate_functions(\n", - " functions=[\n", - " _utility_4d,\n", - " _leisure_4d,\n", - " _consumption_4d,\n", - " ],\n", - " targets=\"_utility_4d\",\n", - ")\n", - "\n", - "_decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**grids)\n", - "\n", - "rslt_numba = product_map_numba_4d(a, b, c, d)\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime_4d(len_grids_arr):\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " a = np.linspace(1, 10, len_a)\n", - " b = np.linspace(0, 24, len_b)\n", - " c = np.linspace(1, 5, len_c)\n", - " d = np.linspace(-7, 2, len_d)\n", - "\n", - " product_map_numba_4d(a, b, c, d)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o product_map_numba_4d(a, b, c, d)\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime_4d(len_grids_arr):\n", - " utility_concat_4d = concatenate_functions(\n", - " functions=[\n", - " _utility_4d,\n", - " _leisure_4d,\n", - " _consumption_4d,\n", - " ],\n", - " targets=\"_utility_4d\",\n", - " )\n", - "\n", - " _decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " grids = {\n", - " \"a\": jnp.linspace(1, 10, len_a),\n", - " \"b\": jnp.linspace(0, 24, len_b),\n", - " \"c\": jnp.linspace(1, 5, len_c),\n", - " \"d\": jnp.linspace(-7, 2, len_d),\n", - " }\n", - "\n", - " decorated_func_jit(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**grids).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "len_grids_arr = np.vstack(\n", - " (\n", - " [5, 10, 20, 60, 100],\n", - " np.linspace(5, 25, 5),\n", - " np.linspace(1, 5, 5),\n", - " np.linspace(2, 10, 5),\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "n_points_4d, runtime_numba_4d = get_numba_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax_4d = get_jax_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points_4d, runtime_numba_4d, runtime_jax_4d)" - ] - }, - { - "cell_type": "markdown", - "id": "21", - "metadata": {}, - "source": [ - "## *gridmap*: 4 grids benchmark" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": {}, - "outputs": [], - "source": [ - "@njit(fastmath=True)\n", - "def state_space_map_numba(a, b, c, d):\n", - " n_a = a.shape[0]\n", - " n_b = b.shape[0]\n", - " n_c = c.shape[0]\n", - "\n", - " state_space_product = np.empty((n_a, n_b, n_c))\n", - "\n", - " for i in range(n_a):\n", - " for j in range(n_b):\n", - " for k in range(n_c):\n", - " state_space_product[i, j, k] = (\n", - " np.log(a[i]) + np.sqrt(b[j] ** 2 * c[k]) + np.cos(d[k])\n", - " )\n", - "\n", - " return state_space_product" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [ - "# JAX\n", - "dense_variables = {\n", - " \"a\": jnp.linspace(1, 10, 101),\n", - " \"b\": jnp.linspace(0, 24, 25),\n", - "}\n", - "combined_grids = {\"c\": jnp.linspace(1, 5, 5), \"d\": jnp.linspace(-7, 2, 10)}\n", - "helper = jnp.array(list(itertools.product(*combined_grids.values()))).T\n", - "\n", - "contingent_variables = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - "}\n", - "\n", - "# Numpy\n", - "dense_variables_numpy = {\n", - " \"a\": np.linspace(1, 10, 101),\n", - " \"b\": np.linspace(0, 24, 25),\n", - "}\n", - "combined_grids_numpy = {\"c\": np.linspace(1, 5, 5), \"d\": np.linspace(-7, 2, 10)}\n", - "helper = np.array(list(itertools.product(*combined_grids_numpy.values()))).T\n", - "\n", - "contingent_variables_numpy = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - "}\n", - "all_grids = {**dense_variables_numpy, **contingent_variables_numpy}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "utility_concat_4d = concatenate_functions(\n", - " functions=[\n", - " _utility_4d,\n", - " _leisure_4d,\n", - " _consumption_4d,\n", - " ],\n", - " targets=\"_utility_4d\",\n", - ")\n", - "\n", - "_decorated_func = gridmap(\n", - " utility_concat_4d,\n", - " list(dense_variables),\n", - " list(contingent_variables),\n", - ")\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**dense_variables, **contingent_variables)\n", - "\n", - "rslt_numba = state_space_map_numba(*all_grids.values())\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime(len_grids_arr):\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " dense_variables = {\n", - " \"a\": np.linspace(1, 10, len_a),\n", - " \"b\": np.linspace(0, 24, len_b),\n", - " }\n", - " complex_grids = {\"c\": np.linspace(1, 5, len_c), \"d\": np.linspace(-7, 2, len_d)}\n", - " helper = np.array(list(itertools.product(*complex_grids.values()))).T\n", - "\n", - " complex_variables = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - " }\n", - " all_grids = {**dense_variables, **complex_variables}\n", - "\n", - " state_space_map_numba(*all_grids.values())\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o state_space_map_numba(*all_grids.values())\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime(len_grids_arr):\n", - " utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", - " targets=\"_utility_4d\",\n", - " )\n", - "\n", - " _decorated_func = gridmap(utility_concat_4d, [\"a\", \"b\"], [\"c\", \"d\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " dense_variables = {\n", - " \"a\": jnp.linspace(1, 10, len_a),\n", - " \"b\": jnp.linspace(0, 24, len_b),\n", - " }\n", - " complex_grids = {\n", - " \"c\": jnp.linspace(1, 5, len_c),\n", - " \"d\": jnp.linspace(-7, 2, len_d),\n", - " }\n", - " helper = jnp.array(list(itertools.product(*complex_grids.values()))).T\n", - "\n", - " complex_variables = {\n", - " \"c\": helper[0],\n", - " \"d\": helper[1],\n", - " }\n", - "\n", - " decorated_func_jit(**dense_variables, **complex_variables)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**dense_variables, **complex_variables).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [ - "n_points, runtime_numba = get_numba_runtime(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax = get_jax_runtime(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points, runtime_numba, runtime_jax, func=\"state_space_map\")" - ] - }, - { - "cell_type": "markdown", - "id": "29", - "metadata": {}, - "source": [ - "## 4d Benchmark with multiple targets and product_map" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": {}, - "outputs": [], - "source": [ - "def _utility_4d(_consumption_4d, _leisure_4d):\n", - " return _consumption_4d + _leisure_4d\n", - "\n", - "\n", - "def _leisure_4d(d):\n", - " return jnp.cos(d)\n", - "\n", - "\n", - "def _consumption_4d(a, b, c):\n", - " return jnp.log(a) + jnp.sqrt(jnp.square(b) * c)\n", - "\n", - "\n", - "@njit(fastmath=True)\n", - "def product_map_numba_4d(a, b, c, d):\n", - " n_a = a.shape[0]\n", - " n_b = b.shape[0]\n", - " n_c = c.shape[0]\n", - " n_d = d.shape[0]\n", - "\n", - " out_utility = np.empty((n_a, n_b, n_c, n_d))\n", - " out_leisure = np.empty((n_a, n_b, n_c, n_d))\n", - "\n", - " for i in range(n_a):\n", - " for j in range(n_b):\n", - " for k in range(n_c):\n", - " for l in range(n_d):\n", - " _leis = np.cos(d[l])\n", - "\n", - " out_leisure[i, j, k, l] = _leis\n", - " out_utility[i, j, k, l] = (\n", - " np.log(a[i]) + np.sqrt(b[j] ** 2 * c[k]) + _leis\n", - " )\n", - "\n", - " return out_leisure, out_utility" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31", - "metadata": {}, - "outputs": [], - "source": [ - "grids = {\n", - " \"a\": jnp.linspace(1, 10, 101),\n", - " \"b\": jnp.linspace(0, 24, 25),\n", - " \"c\": jnp.linspace(1, 5, 5),\n", - " \"d\": jnp.linspace(-7, 2, 10),\n", - "}\n", - "\n", - "a = np.linspace(1, 10, 101)\n", - "b = np.linspace(0, 24, 25)\n", - "c = np.linspace(1, 5, 5)\n", - "d = np.linspace(-7, 2, 10)\n", - "\n", - "utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", - " targets=[\"_leisure_4d\", \"_utility_4d\"],\n", - ")\n", - "\n", - "_decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - "decorated_func_jit = jit(_decorated_func)\n", - "rslt_jax = decorated_func_jit(**grids)\n", - "\n", - "rslt_numba = product_map_numba_4d(a, b, c, d)\n", - "\n", - "aaae(rslt_jax, rslt_numba)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [ - "def get_numba_runtime_4d(len_grids_arr):\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " a = np.linspace(1, 10, len_a)\n", - " b = np.linspace(0, 24, len_b)\n", - " c = np.linspace(1, 5, len_c)\n", - " d = np.linspace(-7, 2, len_d)\n", - "\n", - " product_map_numba_4d(a, b, c, d)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o product_map_numba_4d(a, b, c, d)\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "def get_jax_runtime_4d(len_grids_arr):\n", - " utility_concat_4d = concatenate_functions(\n", - " functions=[_utility_4d, _leisure_4d, _consumption_4d],\n", - " targets=\"_utility_4d\",\n", - " )\n", - "\n", - " _decorated_func = productmap(utility_concat_4d, [\"a\", \"b\", \"c\", \"d\"])\n", - " decorated_func_jit = jit(_decorated_func)\n", - "\n", - " total_points = []\n", - " runtime = []\n", - "\n", - " for len_grids in len_grids_arr.T:\n", - " len_a, len_b, len_c, len_d = len_grids.astype(int)\n", - " total_points.append(math.prod(len_grids))\n", - "\n", - " grids = {\n", - " \"a\": jnp.linspace(1, 10, len_a),\n", - " \"b\": jnp.linspace(0, 24, len_b),\n", - " \"c\": jnp.linspace(1, 5, len_c),\n", - " \"d\": jnp.linspace(-7, 2, len_d),\n", - " }\n", - "\n", - " decorated_func_jit(**grids)\n", - "\n", - " timeit_res = %timeit -r 7 -n 1_000 -o decorated_func_jit(**grids).block_until_ready()\n", - " runtime.append(timeit_res.average)\n", - "\n", - " return total_points, runtime\n", - "\n", - "\n", - "len_grids_arr = np.vstack(\n", - " (\n", - " [5, 10, 20, 60, 100],\n", - " np.linspace(5, 25, 5),\n", - " np.linspace(1, 5, 5),\n", - " np.linspace(2, 10, 5),\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": {}, - "outputs": [], - "source": [ - "n_points_4d, runtime_numba_4d = get_numba_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34", - "metadata": {}, - "outputs": [], - "source": [ - "_, runtime_jax_4d = get_jax_runtime_4d(len_grids_arr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35", - "metadata": {}, - "outputs": [], - "source": [ - "plot_runtime(n_points_4d, runtime_numba_4d, runtime_jax_4d)" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "4ad1f9a7057235603a8e0891a7ecf17fdcab16a40a59f91c8779290453820c05" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 580746e1f6248a48a7b93c812bfc5f7afa02e137 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 17:07:32 +0100 Subject: [PATCH 05/20] Use productmap instead of spacemap in solution --- src/lcm/solution/solve_brute.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 167beb9..e3622d2 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -4,7 +4,7 @@ import jax from jax import Array -from lcm.dispatchers import spacemap +from lcm.dispatchers import productmap from lcm.interfaces import StateChoiceSpace from lcm.typing import DiscreteProblemSolverFunction, ParamsDict @@ -105,10 +105,9 @@ def solve_continuous_problem( by the `gridmap` function. """ - _gridmapped = spacemap( + _gridmapped = productmap( func=compute_ccv, - product_vars=state_choice_space.ordered_var_names, - combination_vars=(), + variables=state_choice_space.ordered_var_names, ) gridmapped = jax.jit(_gridmapped) From fda2df72bc32a7dd95692c456a67d6261a4d5751 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 17:21:16 +0100 Subject: [PATCH 06/20] Use spacemap only for simulation --- src/lcm/dispatchers.py | 60 +++++++++++++++------------------- src/lcm/simulation/simulate.py | 8 ++--- tests/test_dispatchers.py | 12 ++++--- 3 files changed, 38 insertions(+), 42 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 35ba50b..c217602 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -12,64 +12,58 @@ ) -def spacemap( +def simulation_spacemap( func: FunctionWithArrayReturn, - product_vars: tuple[str, ...], - combination_vars: tuple[str, ...], + choices_var_names: tuple[str, ...], + states_var_names: tuple[str, ...], ) -> FunctionWithArrayReturn: - """Apply vmap such that func can be evaluated on product and combination variables. + """Apply vmap such that func can be evaluated on choices and simulation states. - Product variables are used to create a Cartesian product of possible values. I.e., - for each product variable, we create a new leading dimension in the output object, + Choice variables are used to create a Cartesian product of possible values. I.e., + for each choice variable, we create a new leading dimension in the output object, with the size of the dimension being the number of possible values in the grid. The - i-th entries of the combination variables, correspond to one valid combination. For - the combination variables, a single dimension is thus added to the output object, - with the size of the dimension being the number of possible combinations. This means - that all combination variables must have the same size (e.g., in the simulation the - states act as combination variables, and their size equals the number of - simulations). + i-th entries of the state variables, correspond to one simulation state. For + the state variables, a single dimension is thus added to the output object, + with the size of the dimension being the number of simulations. This means + that all state variables must have the same size. - spacemap preserves the function signature and allows the function to be called with - keyword arguments. + simulation_spacemap preserves the function signature and allows the function to be + called with keyword arguments. Args: func: The function to be dispatched. - product_vars: Names of the product variables, i.e. those that are stored as + choices_var_names: Names of the choice variables, i.e. those that are stored as arrays of possible values in the grid, over which we create a Cartesian product. - combination_vars: Names of the combination variables, i.e. those that are - stored as arrays of possible combinations. + states_var_names: Names of the state variables, i.e. those that are stored as + arrays of possible states. Returns: A callable with the same arguments as func (but with an additional leading - dimension) that returns a jax.Array or pytree of arrays. If `func` returns a - scalar, the dispatched function returns a jax.Array with k + 1 dimensions, where - k is the length of `product_vars` and the additional dimension corresponds to - the `combination_vars`. The order of the dimensions is determined by the order - of `product_vars`. If the output of `func` is a jax pytree, the usual jax + dimension) that returns an Array or pytree of Arrays. If `func` returns a + scalar, the dispatched function returns an Array with k + 1 dimensions, where k + is the length of `choices_var_names` and the additional dimension corresponds to + the `states_var_names`. The order of the dimensions is determined by the order + of `choices_var_names`. If the output of `func` is a jax pytree, the usual jax behavior applies, i.e. the leading dimensions of all arrays in the pytree are as described above but there might be additional dimensions. """ - if duplicates := find_duplicates(product_vars, combination_vars): + if duplicates := find_duplicates(choices_var_names, states_var_names): msg = ( - "Same argument provided more than once in product variables or combination " - f"variables, or is present in both: {duplicates}" + "Same argument provided more than once in choices or states variables, " + f"or is present in both: {duplicates}" ) raise ValueError(msg) - func_callable_with_args = allow_args(func) - - vmapped = _base_productmap(func_callable_with_args, product_vars) + mappable_func = allow_args(func) - if combination_vars: - vmapped = vmap_1d( - vmapped, variables=combination_vars, callable_with="only_args" - ) + vmapped = _base_productmap(mappable_func, choices_var_names) + vmapped = vmap_1d(vmapped, variables=states_var_names, callable_with="only_args") # This raises a mypy error but is perfectly fine to do. See # https://github.com/python/mypy/issues/12472 - vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined] + vmapped.__signature__ = inspect.signature(mappable_func) # type: ignore[attr-defined] return cast(FunctionWithArrayReturn, allow_only_kwargs(vmapped)) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 758e50b..9f7a079 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -11,7 +11,7 @@ from jax import Array, vmap from lcm.argmax import argmax -from lcm.dispatchers import spacemap, vmap_1d +from lcm.dispatchers import simulation_spacemap, vmap_1d from lcm.interfaces import InternalModel, StateChoiceSpace from lcm.typing import InternalUserFunction, ParamsDict @@ -236,10 +236,10 @@ def solve_continuous_problem( `gridmap` function. """ - _gridmapped = spacemap( + _gridmapped = simulation_spacemap( func=compute_ccv, - product_vars=tuple(data_scs.choices), - combination_vars=tuple(data_scs.states), + choices_var_names=tuple(data_scs.choices), + states_var_names=tuple(data_scs.states), ) gridmapped = jax.jit(_gridmapped) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index d9c467f..1534e74 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -6,7 +6,7 @@ from lcm.dispatchers import ( productmap, - spacemap, + simulation_spacemap, vmap_1d, ) from lcm.functools import allow_args @@ -231,7 +231,7 @@ def test_spacemap_all_arguments_mapped( ): product_vars, combination_vars = setup_spacemap - decorated = spacemap( + decorated = simulation_spacemap( g, tuple(product_vars), tuple(combination_vars), @@ -245,12 +245,12 @@ def test_spacemap_all_arguments_mapped( ("error_msg", "product_vars", "combination_vars"), [ ( - "Same argument provided more than once in product variables or combination", + "Same argument provided more than once in choices or states variables", ["a", "b"], ["a", "c", "d"], ), ( - "Same argument provided more than once in product variables or combination", + "Same argument provided more than once in choices or states variables", ["a", "a", "b"], ["c", "d"], ), @@ -258,7 +258,9 @@ def test_spacemap_all_arguments_mapped( ) def test_spacemap_arguments_overlap(error_msg, product_vars, combination_vars): with pytest.raises(ValueError, match=error_msg): - spacemap(g, product_vars=product_vars, combination_vars=combination_vars) + simulation_spacemap( + g, choices_var_names=product_vars, states_var_names=combination_vars + ) # ====================================================================================== From 30b9682fcd41463a7f72b594a210ee0c3c8b9151 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 18:03:48 +0100 Subject: [PATCH 07/20] Integrate comments from previous reviews --- src/lcm/dispatchers.py | 17 +++++++++-------- src/lcm/simulation/simulate.py | 28 ++++++++++++++++------------ src/lcm/utils.py | 11 +++++++++++ 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index c217602..85b64c6 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -19,13 +19,14 @@ def simulation_spacemap( ) -> FunctionWithArrayReturn: """Apply vmap such that func can be evaluated on choices and simulation states. - Choice variables are used to create a Cartesian product of possible values. I.e., - for each choice variable, we create a new leading dimension in the output object, - with the size of the dimension being the number of possible values in the grid. The - i-th entries of the state variables, correspond to one simulation state. For - the state variables, a single dimension is thus added to the output object, - with the size of the dimension being the number of simulations. This means - that all state variables must have the same size. + This function maps the function `func` over the simulation state-choice-space. That + is, it maps `func` over the Cartesian product of the choice variables, and over the + fixed simulation states. For each choice variable, a leading dimension is added to + the output object, with the length of the axis being the number of possible values + in the grid. Importantly, it does not create a Cartesian product over the state + variables, since these are fixed during the simulation. For the state variables, + a single dimension is added to the output object, with the length of the axis + being the number of simulated states. simulation_spacemap preserves the function signature and allows the function to be called with keyword arguments. @@ -204,7 +205,7 @@ def _base_productmap( # We iterate in reverse order such that the output dimensions are in the same order # as the input dimensions. for pos in reversed(positions): - spec = [None] * len(parameters) # type: list[int | None] + spec: list[int | None] = [None] * len(parameters) spec[pos] = 0 vmap_specs.append(spec) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 9f7a079..adaa094 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -14,6 +14,7 @@ from lcm.dispatchers import simulation_spacemap, vmap_1d from lcm.interfaces import InternalModel, StateChoiceSpace from lcm.typing import InternalUserFunction, ParamsDict +from lcm.utils import draw_random_seed def simulate( @@ -27,7 +28,7 @@ def simulate( solve_model: Callable[..., list[Array]] | None = None, pre_computed_vf_arr_list: list[Array] | None = None, additional_targets: list[str] | None = None, - seed: int = 12345, + seed: int | None = None, ) -> pd.DataFrame: """Simulate the model forward in time. @@ -52,7 +53,8 @@ def simulate( provided, the model is solved first. additional_targets: List of targets to compute. If provided, the targets are computed and added to the simulation results. - seed: Random number seed; will be passed to `jax.random.key`. + seed: Random number seed; will be passed to `jax.random.key`. If not provided, + a random seed will be generated. Returns: DataFrame with the simulation results. @@ -69,22 +71,24 @@ def simulate( else: vf_arr_list = pre_computed_vf_arr_list + if seed is None: + seed = draw_random_seed() + logger.info("Starting simulation") - # Update the vf_arr_list - # ---------------------------------------------------------------------------------- + # Preparations + # ================================================================================== + n_periods = len(vf_arr_list) + n_initial_states = len(next(iter(initial_states.values()))) + # We drop the value function array for the first period, because it is not needed # for the simulation. This is because in the first period the agents only consider # the current utility and the value function of next period. Similarly, the last # value function array is not required, as the agents only consider the current # utility in the last period. - # ================================================================================== - vf_arr_list = vf_arr_list[1:] + [jnp.empty(0)] - - # Preparations - # ================================================================================== - n_periods = len(vf_arr_list) - n_initial_states = len(next(iter(initial_states.values()))) + next_vf_arr = dict( + zip(range(n_periods), vf_arr_list[1:] + [jnp.empty(0)], strict=True) + ) discrete_policy_calculator = get_discrete_policy_calculator( variable_info=model.variable_info, @@ -124,7 +128,7 @@ def simulate( data_scs=data_scs, compute_ccv=compute_ccv_policy_functions[period], continuous_choice_grids=continuous_choice_grids[period], - vf_arr=vf_arr_list[period], + vf_arr=next_vf_arr[period], params=params, ) diff --git a/src/lcm/utils.py b/src/lcm/utils.py index 86d2fd4..1f951ee 100644 --- a/src/lcm/utils.py +++ b/src/lcm/utils.py @@ -1,3 +1,4 @@ +import os from collections import Counter from collections.abc import Iterable from itertools import chain @@ -10,3 +11,13 @@ def find_duplicates(*containers: Iterable[T]) -> set[T]: combined = chain.from_iterable(containers) counts = Counter(combined) return {v for v, count in counts.items() if count > 1} + + +def draw_random_seed() -> int: + """Generate a random seed using the operating system's secure entropy pool. + + Returns: + Random seed. + + """ + return int.from_bytes(os.urandom(4), "little") From cc5550352991ae5344608d7b937b1970a3551b25 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 18:33:32 +0100 Subject: [PATCH 08/20] Create conditional continuation module --- src/lcm/conditional_continuation.py | 80 ++++++++++++++++++++++ src/lcm/entry_point.py | 96 +++------------------------ src/lcm/utility_and_feasibility.py | 12 ++-- tests/simulation/test_simulate.py | 12 ++-- tests/solution/test_solve_brute.py | 10 +-- tests/test_entry_point.py | 32 ++++----- tests/test_utility_and_feasibility.py | 2 +- 7 files changed, 123 insertions(+), 121 deletions(-) create mode 100644 src/lcm/conditional_continuation.py diff --git a/src/lcm/conditional_continuation.py b/src/lcm/conditional_continuation.py new file mode 100644 index 0000000..1f769e6 --- /dev/null +++ b/src/lcm/conditional_continuation.py @@ -0,0 +1,80 @@ +import functools +from collections.abc import Callable + +import jax.numpy as jnp +from jax import Array + +from lcm.argmax import argmax +from lcm.dispatchers import productmap +from lcm.typing import ParamsDict + + +def get_compute_conditional_continuation_value( + utility_and_feasibility: Callable[..., tuple[Array, Array]], + continuous_choice_variables: tuple[str, ...], +) -> Callable[..., Array]: + """Get a function that computes the conditional continuation value. + + This function solves the continuous choice problem conditional on a state- + (discrete-)choice combination; and is used in the model solution process. + + Args: + utility_and_feasibility: A function that takes a state-choice combination and + return the utility of that combination (scalar) and whether the state-choice + combination is feasible (bool). + continuous_choice_variables: Tuple of choice variable names that are continuous. + + Returns: + A function that takes a state-choice combination and returns the conditional + continuation value over the continuous choices. + + """ + if continuous_choice_variables: + utility_and_feasibility = productmap( + func=utility_and_feasibility, + variables=continuous_choice_variables, + ) + + @functools.wraps(utility_and_feasibility) + def compute_ccv(params: ParamsDict, **kwargs: Array) -> Array: + u, f = utility_and_feasibility(params=params, **kwargs) + return u.max(where=f, initial=-jnp.inf) + + return compute_ccv + + +def get_compute_conditional_continuation_policy( + utility_and_feasibility: Callable[..., tuple[Array, Array]], + continuous_choice_variables: tuple[str, ...], +) -> Callable[..., tuple[Array, Array]]: + """Get a function that computes the conditional continuation policy. + + This function solves the continuous choice problem conditional on a state- + (discrete-)choice combination; and is used in the model simulation process. + + Args: + utility_and_feasibility: A function that takes a state-choice combination and + return the utility of that combination (scalar) and whether the state-choice + combination is feasible (bool). + continuous_choice_variables: Tuple of choice variable names that are + continuous. + + Returns: + A function that takes a state-choice combination and returns the conditional + continuation value over the continuous choices, and the index that maximizes the + conditional continuation value. + + """ + if continuous_choice_variables: + utility_and_feasibility = productmap( + func=utility_and_feasibility, + variables=continuous_choice_variables, + ) + + @functools.wraps(utility_and_feasibility) + def compute_ccp(params: ParamsDict, **kwargs: Array) -> tuple[Array, Array]: + u, f = utility_and_feasibility(params=params, **kwargs) + _argmax, _max = argmax(u, where=f, initial=-jnp.inf) + return _argmax, _max + + return compute_ccp diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 0b55868..6427707 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -1,16 +1,16 @@ -import functools from collections.abc import Callable from functools import partial from typing import Literal import jax -import jax.numpy as jnp import pandas as pd from jax import Array -from lcm.argmax import argmax +from lcm.conditional_continuation import ( + get_compute_conditional_continuation_policy, + get_compute_conditional_continuation_value, +) from lcm.discrete_problem import get_solve_discrete_problem -from lcm.dispatchers import productmap from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.next_state import get_next_state_function @@ -114,20 +114,20 @@ def get_lcm_function( # ============================================================================== u_and_f = get_utility_and_feasibility_function( model=_mod, - state_space_info=state_space_infos[period], + next_state_space_info=state_space_infos[period], period=period, is_last_period=is_last_period, ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=list(_choice_grids), + continuous_choice_variables=tuple(_choice_grids), ) compute_ccv_functions.append(compute_ccv) - compute_ccv_argmax = create_compute_conditional_continuation_policy( + compute_ccv_argmax = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=list(_choice_grids), + continuous_choice_variables=tuple(_choice_grids), ) compute_ccv_policy_functions.append(compute_ccv_argmax) @@ -174,81 +174,3 @@ def get_lcm_function( target_func = partial(simulate_model, solve_model=solve_model) return target_func, _mod.params - - -def create_compute_conditional_continuation_value( - utility_and_feasibility: Callable[..., tuple[Array, Array]], - continuous_choice_variables: list[str], -) -> Callable[..., Array]: - """Create a function that computes the conditional continuation value. - - Note: - ----- - This function solves the continuous choice problem conditional on a state- - (discrete-)choice combination. - - Args: - utility_and_feasibility (callable): A function that takes a state-choice - combination and return the utility of that combination (float) and whether - the state-choice combination is feasible (bool). - continuous_choice_variables (list): List of choice variable names that are - continuous. - - Returns: - A function that takes a state-choice combination and returns the conditional - continuation value over the continuous choices. - - """ - if continuous_choice_variables: - utility_and_feasibility = productmap( - func=utility_and_feasibility, - variables=tuple(continuous_choice_variables), - ) - - @functools.wraps(utility_and_feasibility) - def compute_ccv(*args: Array | ParamsDict, **kwargs: Array | ParamsDict) -> Array: - u, f = utility_and_feasibility(*args, **kwargs) - return u.max(where=f, initial=-jnp.inf) - - return compute_ccv - - -def create_compute_conditional_continuation_policy( - utility_and_feasibility: Callable[..., tuple[Array, Array]], - continuous_choice_variables: list[str], -) -> Callable[..., tuple[Array, Array]]: - """Create a function that computes the conditional continuation policy. - - Note: - ----- - This function solves the continuous choice problem conditional on a state- - (discrete-)choice combination. - - Args: - utility_and_feasibility (callable): A function that takes a state-choice - combination and return the utility of that combination (float) and whether - the state-choice combination is feasible (bool). - continuous_choice_variables (list): List of choice variable names that are - continuous. - - Returns: - A function that takes a state-choice combination and returns the conditional - continuation value over the continuous choices, and the index that maximizes the - conditional continuation value. - - """ - if continuous_choice_variables: - utility_and_feasibility = productmap( - func=utility_and_feasibility, - variables=tuple(continuous_choice_variables), - ) - - @functools.wraps(utility_and_feasibility) - def compute_ccv_policy( - *args: Array | ParamsDict, **kwargs: Array | ParamsDict - ) -> tuple[Array, Array]: - u, f = utility_and_feasibility(*args, **kwargs) - _argmax, _max = argmax(u, where=f, initial=-jnp.inf) - return _argmax, _max - - return compute_ccv_policy diff --git a/src/lcm/utility_and_feasibility.py b/src/lcm/utility_and_feasibility.py index b510c86..f86f111 100644 --- a/src/lcm/utility_and_feasibility.py +++ b/src/lcm/utility_and_feasibility.py @@ -16,7 +16,7 @@ def get_utility_and_feasibility_function( model: InternalModel, - state_space_info: StateSpaceInfo, + next_state_space_info: StateSpaceInfo, period: int, *, is_last_period: bool, @@ -25,7 +25,7 @@ def get_utility_and_feasibility_function( Args: model: The internal model object. - state_space_info: The state space information. + next_state_space_info: The state space information of the next period. period: The period to create the utility and feasibility function for. is_last_period: Whether the period is the last period. @@ -37,20 +37,20 @@ def get_utility_and_feasibility_function( if is_last_period: return get_utility_and_feasibility_function_last_period(model, period=period) return get_utility_and_feasibility_function_before_last_period( - model, state_space_info=state_space_info, period=period + model, next_state_space_info=next_state_space_info, period=period ) def get_utility_and_feasibility_function_before_last_period( model: InternalModel, - state_space_info: StateSpaceInfo, + next_state_space_info: StateSpaceInfo, period: int, ) -> Callable[..., tuple[Array, Array]]: """Create the utility and feasibility function for a period before the last period. Args: model: The internal model object. - state_space_info: The state space information. + next_state_space_info: The state space information of the next period. period: The period to create the utility and feasibility function for. Returns: @@ -70,7 +70,7 @@ def get_utility_and_feasibility_function_before_last_period( calculate_state_transition = get_next_state_function(model, target=Target.SOLVE) calculate_next_weights = get_next_stochastic_weights_function(model) calculate_node_weights = _get_node_weights_function(stochastic_variables) - _scalar_value_function = get_value_function_representation(state_space_info) + _scalar_value_function = get_value_function_representation(next_state_space_info) value_function = productmap( _scalar_value_function, variables=tuple(f"next_{var}" for var in stochastic_variables), diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 32aa03f..63cb891 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -4,10 +4,10 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal from pybaum import tree_equal -from lcm.entry_point import ( - create_compute_conditional_continuation_policy, - get_lcm_function, +from lcm.conditional_continuation import ( + get_compute_conditional_continuation_policy, ) +from lcm.entry_point import get_lcm_function from lcm.input_processing import process_model from lcm.logging import get_logger from lcm.next_state import get_next_state_function @@ -49,13 +49,13 @@ def simulate_inputs(): for period in range(model.n_periods): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=period, is_last_period=True, ) - compute_ccv = create_compute_conditional_continuation_policy( + compute_ccv = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) compute_ccv_policy_functions.append(compute_ccv) diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 5b0450b..075fd1e 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -2,7 +2,7 @@ import numpy as np from numpy.testing import assert_array_almost_equal as aaae -from lcm.entry_point import create_compute_conditional_continuation_value +from lcm.conditional_continuation import get_compute_conditional_continuation_value from lcm.interfaces import StateChoiceSpace from lcm.logging import get_logger from lcm.ndimage import map_coordinates @@ -79,9 +79,9 @@ def _get_continuation_value(lazy, wealth, vf_arr): coordinates=jnp.array([wealth]), ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=_utility_and_feasibility, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) utility_and_feasibility_functions = [compute_ccv] * 2 @@ -130,9 +130,9 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 continuous_choice_grids = {"d": jnp.arange(12.0)} - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=_utility_and_feasibility, - continuous_choice_variables=["d"], + continuous_choice_variables=("d",), ) expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]]) diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 77e887d..4e45e93 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -2,11 +2,11 @@ import pytest from pybaum import tree_equal, tree_map -from lcm.entry_point import ( - create_compute_conditional_continuation_policy, - create_compute_conditional_continuation_value, - get_lcm_function, +from lcm.conditional_continuation import ( + get_compute_conditional_continuation_policy, + get_compute_conditional_continuation_value, ) +from lcm.entry_point import get_lcm_function from lcm.input_processing import process_model from lcm.solution.state_choice_space import create_state_choice_space from lcm.utility_and_feasibility import get_utility_and_feasibility_function @@ -189,14 +189,14 @@ def test_create_compute_conditional_continuation_value(): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) val = compute_ccv( @@ -234,14 +234,14 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv = create_compute_conditional_continuation_value( + compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=[], + continuous_choice_variables=(), ) val = compute_ccv( @@ -284,14 +284,14 @@ def test_create_compute_conditional_continuation_policy(): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv_policy = create_compute_conditional_continuation_policy( + compute_ccv_policy = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=["consumption"], + continuous_choice_variables=("consumption",), ) policy, val = compute_ccv_policy( @@ -330,14 +330,14 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) - compute_ccv_policy = create_compute_conditional_continuation_policy( + compute_ccv_policy = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=[], + continuous_choice_variables=(), ) policy, val = compute_ccv_policy( diff --git a/tests/test_utility_and_feasibility.py b/tests/test_utility_and_feasibility.py index 9a89380..3b08a11 100644 --- a/tests/test_utility_and_feasibility.py +++ b/tests/test_utility_and_feasibility.py @@ -39,7 +39,7 @@ def test_get_utility_and_feasibility_function(): u_and_f = get_utility_and_feasibility_function( model=model, - state_space_info=state_space_info, + next_state_space_info=state_space_info, period=model.n_periods - 1, is_last_period=True, ) From c88f368a1c65cc3e4f24e84b1169561197f82149 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 19:03:37 +0100 Subject: [PATCH 09/20] Refactor entry point --- src/lcm/entry_point.py | 106 ++++++++++++++++---------------- src/lcm/simulation/simulate.py | 13 ++-- src/lcm/solution/solve_brute.py | 16 ++--- 3 files changed, 66 insertions(+), 69 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 6427707..dffcbbb 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -12,12 +12,13 @@ ) from lcm.discrete_problem import get_solve_discrete_problem from lcm.input_processing import process_model +from lcm.interfaces import StateChoiceSpace, StateSpaceInfo from lcm.logging import get_logger from lcm.next_state import get_next_state_function from lcm.simulation.simulate import simulate from lcm.solution.solve_brute import solve from lcm.solution.state_choice_space import create_state_choice_space -from lcm.typing import ParamsDict, Target +from lcm.typing import DiscreteProblemSolverFunction, ParamsDict, Target from lcm.user_model import Model from lcm.utility_and_feasibility import ( get_utility_and_feasibility_function, @@ -59,62 +60,49 @@ def get_lcm_function( if targets not in {"solve", "simulate", "solve_and_simulate"}: raise NotImplementedError - _mod = process_model(model) - last_period = _mod.n_periods - 1 + internal_model = process_model(model) + last_period = internal_model.n_periods - 1 logger = get_logger(debug_mode=debug_mode) # ================================================================================== - # create list of continuous choice grids + # Create continuous choice grids + # ---------------------------------------------------------------------------------- + # For now they are the same in all periods but this can change. # ================================================================================== - # for now they are the same in all periods but this will change. - _subset = _mod.variable_info.query("is_continuous & is_choice").index.tolist() - _choice_grids = {k: _mod.grids[k] for k in _subset} - continuous_choice_grids = [_choice_grids] * _mod.n_periods + continuous_choices_names = internal_model.variable_info.query( + "is_continuous & is_choice" + ).index.tolist() + _choice_grids = {n: internal_model.grids[n] for n in continuous_choices_names} + continuous_choice_grids = { + period: _choice_grids for period in range(internal_model.n_periods) + } # ================================================================================== - # Initialize other argument lists + # Create model functions and state-choice-spaces # ================================================================================== - state_choice_spaces = [] - state_space_infos = [] - compute_ccv_functions = [] - compute_ccv_policy_functions = [] - choice_segments = [] # type: ignore[var-annotated] - emax_calculators = [] + state_choice_spaces: dict[int, StateChoiceSpace] = {} + state_space_infos: dict[int, StateSpaceInfo] = {} + compute_ccv_functions: dict[int, Callable[[Array, Array], Array]] = {} + compute_ccp_functions: dict[int, Callable[..., tuple[Array, Array]]] = {} + solve_discrete_problem_functions: dict[int, DiscreteProblemSolverFunction] = {} - # ================================================================================== - # Create stace choice space for each period - # ================================================================================== - for period in range(_mod.n_periods): + for period in reversed(range(internal_model.n_periods)): is_last_period = period == last_period - # call state space creation function, append trivial items to their lists - # ============================================================================== state_choice_space, state_space_info = create_state_choice_space( - model=_mod, + model=internal_model, is_last_period=is_last_period, ) - state_choice_spaces.append(state_choice_space) - choice_segments.append(None) - state_space_infos.append(state_space_info) - - # ================================================================================== - # Shift space info (in period t we require the space info of period t+1) - # ================================================================================== - state_space_infos = state_space_infos[1:] + [{}] # type: ignore[list-item] - - # ================================================================================== - # Create model functions - # ================================================================================== - for period in range(_mod.n_periods): - is_last_period = period == last_period + if is_last_period: + next_state_space_info = LastPeriodsNextStateSpaceInfo + else: + next_state_space_info = state_space_infos[period + 1] - # create the compute conditional continuation value functions and append to list - # ============================================================================== u_and_f = get_utility_and_feasibility_function( - model=_mod, - next_state_space_info=state_space_infos[period], + model=internal_model, + next_state_space_info=next_state_space_info, period=period, is_last_period=is_last_period, ) @@ -123,22 +111,23 @@ def get_lcm_function( utility_and_feasibility=u_and_f, continuous_choice_variables=tuple(_choice_grids), ) - compute_ccv_functions.append(compute_ccv) - compute_ccv_argmax = get_compute_conditional_continuation_policy( + compute_ccp = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, continuous_choice_variables=tuple(_choice_grids), ) - compute_ccv_policy_functions.append(compute_ccv_argmax) - # create list of emax_calculators - # ============================================================================== - calculator = get_solve_discrete_problem( - random_utility_shock_type=_mod.random_utility_shocks, - variable_info=_mod.variable_info, + solve_discrete_problem = get_solve_discrete_problem( + random_utility_shock_type=internal_model.random_utility_shocks, + variable_info=internal_model.variable_info, is_last_period=is_last_period, ) - emax_calculators.append(calculator) + + state_choice_spaces[period] = state_choice_space + state_space_infos[period] = state_space_info + compute_ccv_functions[period] = compute_ccv + compute_ccp_functions[period] = compute_ccp + solve_discrete_problem_functions[period] = solve_discrete_problem # ================================================================================== # select requested solver and partial arguments into it @@ -148,18 +137,20 @@ def get_lcm_function( state_choice_spaces=state_choice_spaces, continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=compute_ccv_functions, - emax_calculators=emax_calculators, + emax_calculators=solve_discrete_problem_functions, logger=logger, ) solve_model = jax.jit(_solve_model) if jit else _solve_model - _next_state_simulate = get_next_state_function(model=_mod, target=Target.SIMULATE) + _next_state_simulate = get_next_state_function( + model=internal_model, target=Target.SIMULATE + ) simulate_model = partial( simulate, continuous_choice_grids=continuous_choice_grids, - compute_ccv_policy_functions=compute_ccv_policy_functions, - model=_mod, + compute_ccv_policy_functions=compute_ccp_functions, + model=internal_model, next_state=jax.jit(_next_state_simulate), logger=logger, ) @@ -173,4 +164,11 @@ def get_lcm_function( elif targets == "solve_and_simulate": target_func = partial(simulate_model, solve_model=solve_model) - return target_func, _mod.params + return target_func, internal_model.params + + +LastPeriodsNextStateSpaceInfo = StateSpaceInfo( + states_names=(), + discrete_states={}, + continuous_states={}, +) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index adaa094..34f61bb 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -20,8 +20,8 @@ def simulate( params: ParamsDict, initial_states: dict[str, Array], - continuous_choice_grids: list[dict[str, Array]], - compute_ccv_policy_functions: list[Callable[..., tuple[Array, Array]]], + continuous_choice_grids: dict[int, dict[str, Array]], + compute_ccv_policy_functions: dict[int, Callable[..., tuple[Array, Array]]], model: InternalModel, next_state: Callable[..., dict[str, Array]], logger: logging.Logger, @@ -36,11 +36,10 @@ def simulate( params: Dict of model parameters. initial_states: List of initial states to start from. Typically from the observed dataset. - continuous_choice_grids: List of dicts of length n_periods. Each dict - contains 1d grids for continuous choice variables. - compute_ccv_policy_functions: List of functions of length n_periods. Each - function computes the conditional continuation value dependent on the - discrete choices. + continuous_choice_grids: Dict of length n_periods. Each dict contains 1d grids + for continuous choice variables. + compute_ccv_policy_functions: Dict of length n_periods. Each function computes + the conditional continuation value dependent on the discrete choices. next_state: Function that returns the next state given the current state and choice variables. For stochastic variables, it returns a random draw from the distribution of the next state. diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index e3622d2..44e325e 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -11,10 +11,10 @@ def solve( params: ParamsDict, - state_choice_spaces: list[StateChoiceSpace], - continuous_choice_grids: list[dict[str, Array]], - compute_ccv_functions: list[Callable[[Array, Array], Array]], - emax_calculators: list[DiscreteProblemSolverFunction], + state_choice_spaces: dict[int, StateChoiceSpace], + continuous_choice_grids: dict[int, dict[str, Array]], + compute_ccv_functions: dict[int, Callable[[Array, Array], Array]], + emax_calculators: dict[int, DiscreteProblemSolverFunction], logger: logging.Logger, ) -> list[Array]: """Solve a model by brute force. @@ -30,10 +30,10 @@ def solve( Args: params: Dict of model parameters. - state_choice_spaces: List with one state_choice_space per period. - continuous_choice_grids: List of dicts with 1d grids for continuous - choice variables. - compute_ccv_functions: List of functions needed to solve the agent's + state_choice_spaces: Dict with one state_choice_space per period. + continuous_choice_grids: Dict with one dict of 1d grids for continuous + choice variables per period. + compute_ccv_functions: Dict with one function needed to solve the agent's problem. Each function depends on: - discrete and continuous state variables - discrete and continuous choice variables From dab3594ccb60fdd05d00e4a63c69135c43899664 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 26 Feb 2025 19:29:59 +0100 Subject: [PATCH 10/20] Fix mypy errors --- tests/solution/test_solve_brute.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 075fd1e..11ce8fb 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -38,7 +38,7 @@ def test_solve_brute(): }, ordered_var_names=("lazy", "working", "wealth"), ) - state_choice_spaces = [_scs] * 2 + state_choice_spaces = {0: _scs, 1: _scs} # ================================================================================== # create continuous choice grids @@ -47,7 +47,7 @@ def test_solve_brute(): # you 1 if working and have at most 2 in existing wealth, so _ccg = {"consumption": jnp.array([0, 1, 2, 3])} - continuous_choice_grids = [_ccg] * 2 + continuous_choice_grids = {0: _ccg, 1: _ccg} # ================================================================================== # create the utility_and_feasibility functions @@ -84,7 +84,7 @@ def _get_continuation_value(lazy, wealth, vf_arr): continuous_choice_variables=("consumption",), ) - utility_and_feasibility_functions = [compute_ccv] * 2 + compute_ccv_functions = {0: compute_ccv, 1: compute_ccv} # ================================================================================== # create emax aggregators and choice segments @@ -94,7 +94,7 @@ def calculate_emax(values, params): # noqa: ARG001 """Take max over axis that corresponds to working.""" return values.max(axis=1) - emax_calculators = [calculate_emax] * 2 + emax_calculators = {0: calculate_emax, 1: calculate_emax} # ================================================================================== # call solve function @@ -104,7 +104,7 @@ def calculate_emax(values, params): # noqa: ARG001 params=params, state_choice_spaces=state_choice_spaces, continuous_choice_grids=continuous_choice_grids, - compute_ccv_functions=utility_and_feasibility_functions, + compute_ccv_functions=compute_ccv_functions, emax_calculators=emax_calculators, logger=get_logger(debug_mode=False), ) From 85a592b950252dd4ce14cda23d701667e721ebca Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 10:44:12 +0100 Subject: [PATCH 11/20] Fix failing stochastic test --- tests/test_models/get_model.py | 17 ++++++++++++++--- tests/test_stochastic.py | 28 ++++++++++++---------------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/tests/test_models/get_model.py b/tests/test_models/get_model.py index 2cf0937..66649f1 100644 --- a/tests/test_models/get_model.py +++ b/tests/test_models/get_model.py @@ -89,9 +89,20 @@ def get_params( ], # Transition from period 1 to period 2 [ - # Description is the same as above - [[0, 1.0], [1.0, 0]], - [[0, 1.0], [0.0, 1.0]], + # Current working decision 0 + [ + # Current partner state 0 + [0, 1.0], + # Current partner state 1 + [1.0, 0], + ], + # Current working decision 1 + [ + # Current partner state 0 + [0, 1.0], + # Current partner state 1 + [0.0, 1.0], + ], ], ], ) diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 2ae3b24..e6e6b4b 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -19,7 +19,7 @@ def test_get_lcm_function_with_simulate_target(): targets="solve_and_simulate", ) - res = simulate_model( + res: pd.DataFrame = simulate_model( # type: ignore[assignment] params=get_params(), initial_states={ "health": jnp.array([1, 1, 0, 0]), @@ -28,21 +28,17 @@ def test_get_lcm_function_with_simulate_target(): }, ) - expected_partner = [ - 0, - 0, - 1, - 0, # period 0 - 1, - 1, - 1, - 1, # period 1 - 1, - 1, - 1, - 0, # period 2 - ] - assert jnp.array_equal(res["partner"].values, expected_partner) # type: ignore[call-overload, arg-type] + # This is derived from the partner transition in get_params. + expected_next_partner = ( + (res.working.astype(bool) | ~res.partner.astype(bool)).astype(int).loc[:1] + ) + + pd.testing.assert_series_equal( + res["partner"].loc[1:], + expected_next_partner, + check_index=False, + check_names=False, + ) # ====================================================================================== From c8d811cfcd610b912d561b6ffcf78a652950a042 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 11:54:51 +0100 Subject: [PATCH 12/20] Start refactoring of simulation --- src/lcm/interfaces.py | 21 ++ src/lcm/next_state.py | 2 +- src/lcm/random.py | 62 ++++++ src/lcm/random_choice.py | 31 --- src/lcm/simulation/processing.py | 121 +++++++++++ src/lcm/simulation/simulate.py | 227 ++------------------ src/lcm/simulation/state_choice_space.py | 51 +++++ tests/simulation/test_processing.py | 94 ++++++++ tests/simulation/test_simulate.py | 109 ---------- tests/simulation/test_state_choice_space.py | 37 ++++ tests/test_random.py | 21 ++ tests/test_random_choice.py | 12 -- 12 files changed, 427 insertions(+), 361 deletions(-) create mode 100644 src/lcm/random.py delete mode 100644 src/lcm/random_choice.py create mode 100644 src/lcm/simulation/processing.py create mode 100644 src/lcm/simulation/state_choice_space.py create mode 100644 tests/simulation/test_processing.py create mode 100644 tests/simulation/test_state_choice_space.py create mode 100644 tests/test_random.py delete mode 100644 tests/test_random_choice.py diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index f3ede38..a6890d5 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -1,3 +1,4 @@ +import dataclasses as dc from collections.abc import Mapping from dataclasses import dataclass @@ -36,6 +37,26 @@ class StateChoiceSpace: choices: dict[str, Array] ordered_var_names: tuple[str, ...] + def replace( + self, + states: dict[str, Array] | None = None, + choices: dict[str, Array] | None = None, + ) -> "StateChoiceSpace": + """Replace the states or choices in the state-choice space. + + Args: + states: Dictionary with new states. If None, the existing states are used. + choices: Dictionary with new choices. If None, the existing choices are + used. + + Returns: + New state-choice space with the replaced states or choices. + + """ + states = states if states is not None else self.states + choices = choices if choices is not None else self.choices + return dc.replace(self, states=states, choices=choices) + @dataclass(frozen=True) class StateSpaceInfo: diff --git a/src/lcm/next_state.py b/src/lcm/next_state.py index a82958a..dc6dd5e 100644 --- a/src/lcm/next_state.py +++ b/src/lcm/next_state.py @@ -7,7 +7,7 @@ from jax import Array from lcm.interfaces import InternalModel -from lcm.random_choice import random_choice +from lcm.random import random_choice from lcm.typing import Scalar, StochasticNextFunction, Target diff --git a/src/lcm/random.py b/src/lcm/random.py new file mode 100644 index 0000000..4c096b1 --- /dev/null +++ b/src/lcm/random.py @@ -0,0 +1,62 @@ +from functools import partial + +import jax +from jax import Array + + +def random_choice( + labels: jax.Array, + probs: jax.Array, + key: jax.Array, +) -> jax.Array: + """Draw multiple random choices. + + Args: + labels: 1d array of labels. + probs: 2d array of probabilities. Second dimension must be + the same length as the first dimension of labels. + key: Random key. + + Returns: + Selected labels. 1d array of length len(probs). + + """ + keys = jax.random.split(key, probs.shape[0]) + return _vmapped_choice(keys, probs, labels) + + +@partial(jax.vmap, in_axes=(0, 0, None)) +def _vmapped_choice(key: jax.Array, probs: jax.Array, labels: jax.Array) -> jax.Array: + return jax.random.choice(key, a=labels, p=probs) + + +def generate_simulation_keys( + key: Array, ids: list[str] +) -> tuple[Array, dict[str, Array]]: + """Generate pseudo-random number generator keys (PRNG keys) for simulation. + + PRNG keys in JAX are immutable objects used to control random number generation. + A key can be used to generate a stream of random numbers, e.g., given a key, one can + call jax.random.normal(key) to generate a stream of normal random numbers. In order + to ensure that each simulation is based on a different stream of random numbers, we + split the key into one key per stochastic variable, and one key that will be passed + to the next iteration in order to generate new keys. + + See the JAX documentation for more details: + https://docs.jax.dev/en/latest/random-numbers.html#random-numbers-in-jax + + Args: + key: Random key. + ids: List of names for which a key is to be generated. + + Returns: + - Updated random key. + - Dict with random keys for each id in ids. + + """ + keys = jax.random.split(key, num=len(ids) + 1) + + key = keys[0] + simulation_keys = dict(zip(ids, keys[1:], strict=True)) + + return key, simulation_keys diff --git a/src/lcm/random_choice.py b/src/lcm/random_choice.py deleted file mode 100644 index 6e37d58..0000000 --- a/src/lcm/random_choice.py +++ /dev/null @@ -1,31 +0,0 @@ -from functools import partial - -import jax - - -def random_choice( - labels: jax.Array, - probs: jax.Array, - key: jax.Array, -) -> jax.Array: - """Draw multiple random choices. - - Args: - labels: 1d array of labels. - probs: 2d array of probabilities. Second dimension must be - the same length as the first dimension of labels. - key: Random key. - - Returns: - Selected labels. 1d array of length len(probs). - - """ - keys = jax.random.split(key, probs.shape[0]) - return _vmapped_random_choice(keys, probs, labels) - - -@partial(jax.vmap, in_axes=(0, 0, None)) -def _vmapped_random_choice( - key: jax.Array, probs: jax.Array, labels: jax.Array -) -> jax.Array: - return jax.random.choice(key, a=labels, p=probs) diff --git a/src/lcm/simulation/processing.py b/src/lcm/simulation/processing.py new file mode 100644 index 0000000..e0d4b8a --- /dev/null +++ b/src/lcm/simulation/processing.py @@ -0,0 +1,121 @@ +import inspect +from typing import Any + +import jax.numpy as jnp +import pandas as pd +from dags import concatenate_functions +from jax import Array + +from lcm.dispatchers import vmap_1d +from lcm.interfaces import InternalModel +from lcm.typing import InternalUserFunction, ParamsDict + + +def process_simulated_data( + results: list[dict[str, Any]], + model: InternalModel, + params: ParamsDict, + additional_targets: list[str] | None = None, +) -> dict[str, Array]: + """Process and flatten the simulation results. + + This function produces a dict of arrays for each var with dimension (n_periods * + n_initial_states,). The arrays are flattened, so that the resulting dictionary has a + one-dimensional array for each variable. The length of this array is the number of + periods times the number of initial states. The order of array elements is given by + an outer level of periods and an inner level of initial states ids. + + Args: + results: List of dicts with simulation results. Each dict contains the value, + choices, and states for one period. Choices and states are stored in a + nested dictionary. + model: Model. + params: Parameters. + additional_targets: List of additional targets to compute. + + Returns: + Dict with processed simulation results. The keys are the variable names and the + values are the flattened arrays, with dimension (n_periods * n_initial_states,). + Additionally, the _period variable is added. + + """ + n_periods = len(results) + n_initial_states = len(results[0]["value"]) + + list_of_dicts = [ + {"value": d["value"], **d["choices"], **d["states"]} for d in results + ] + dict_of_lists = { + key: [d[key] for d in list_of_dicts] for key in list(list_of_dicts[0]) + } + out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()} + out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states) + + if additional_targets is not None: + calculated_targets = _compute_targets( + out, + targets=additional_targets, + model_functions=model.functions, + params=params, + ) + out = {**out, **calculated_targets} + + return out + + +def as_data_frame(processed: dict[str, Array], n_periods: int) -> pd.DataFrame: + """Convert processed simulation results to DataFrame. + + Args: + processed: Dict with processed simulation results. + n_periods: Number of periods. + + Returns: + DataFrame with the simulation results. The index is a multi-index with the first + level corresponding to the period and the second level corresponding to the + initial state id. The columns correspond to the value, and the choice and state + variables, and potentially auxiliary variables. + + """ + n_initial_states = len(processed["value"]) // n_periods + index = pd.MultiIndex.from_product( + [range(n_periods), range(n_initial_states)], + names=["period", "initial_state_id"], + ) + return pd.DataFrame(processed, index=index) + + +def _compute_targets( + processed_results: dict[str, Array], + targets: list[str], + model_functions: dict[str, InternalUserFunction], + params: ParamsDict, +) -> dict[str, Array]: + """Compute targets. + + Args: + processed_results: Dict with processed simulation results. Values must be + one-dimensional arrays. + targets: List of targets to compute. + model_functions: Dict with model functions. + params: Dict with model parameters. + + Returns: + Dict with computed targets. + + """ + target_func = concatenate_functions( + functions=model_functions, + targets=targets, + return_type="dict", + ) + + # get list of variables over which we want to vectorize the target function + variables = tuple( + p for p in list(inspect.signature(target_func).parameters) if p != "params" + ) + + target_func = vmap_1d(target_func, variables=variables) + + kwargs = {k: v for k, v in processed_results.items() if k in variables} + return target_func(params=params, **kwargs) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 34f61bb..e5c03a7 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -1,19 +1,19 @@ -import inspect import logging from collections.abc import Callable from functools import partial -from typing import Any import jax import jax.numpy as jnp import pandas as pd -from dags import concatenate_functions from jax import Array, vmap from lcm.argmax import argmax from lcm.dispatchers import simulation_spacemap, vmap_1d from lcm.interfaces import InternalModel, StateChoiceSpace -from lcm.typing import InternalUserFunction, ParamsDict +from lcm.random import generate_simulation_keys +from lcm.simulation.processing import as_data_frame, process_simulated_data +from lcm.simulation.state_choice_space import create_state_choice_space +from lcm.typing import ParamsDict from lcm.utils import draw_random_seed @@ -80,6 +80,11 @@ def simulate( n_periods = len(vf_arr_list) n_initial_states = len(next(iter(initial_states.values()))) + data_scs = create_state_choice_space( + model=model, + initial_states=initial_states, + ) + # We drop the value function array for the first period, because it is not needed # for the simulation. This is because in the first period the agents only consider # the current utility and the value function of next period. Similarly, the last @@ -109,10 +114,7 @@ def simulate( # of combination variables and initial states. The space has to be created in # each iteration because the states change over time. # ============================================================================== - data_scs = create_data_scs( - states=states, - model=model, - ) + data_scs = data_scs.replace(states) # Compute objects dependent on data-state-choice-space # ============================================================================== @@ -175,7 +177,7 @@ def simulate( # Update states # ============================================================================== - key, sim_keys = _generate_simulation_keys( + key, stochastic_variables_keys = generate_simulation_keys( key=key, ids=model.function_info.query("is_stochastic_next").index.tolist(), ) @@ -185,7 +187,7 @@ def simulate( **choices, _period=jnp.repeat(period, n_initial_states), params=params, - keys=sim_keys, + keys=stochastic_variables_keys, ) # 'next_' prefix is added by the next_state function, but needs to be removed @@ -194,18 +196,14 @@ def simulate( logger.info("Period: %s", period) - processed = _process_simulated_data(_simulation_results) - - if additional_targets is not None: - calculated_targets = _compute_targets( - processed, - targets=additional_targets, - model_functions=model.functions, - params=params, - ) - processed = {**processed, **calculated_targets} + processed = process_simulated_data( + _simulation_results, + model=model, + params=params, + additional_targets=additional_targets, + ) - return _as_data_frame(processed, n_periods=n_periods) + return as_data_frame(processed, n_periods=n_periods) def solve_continuous_problem( @@ -255,142 +253,6 @@ def solve_continuous_problem( ) -# ====================================================================================== -# Output processing -# ====================================================================================== - - -def _as_data_frame(processed: dict[str, Array], n_periods: int) -> pd.DataFrame: - """Convert processed simulation results to DataFrame. - - Args: - processed: Dict with processed simulation results. - n_periods: Number of periods. - - Returns: - DataFrame with the simulation results. The index is a multi-index with the first - level corresponding to the period and the second level corresponding to the - initial state id. The columns correspond to the value, and the choice and state - variables, and potentially auxiliary variables. - - """ - n_initial_states = len(processed["value"]) // n_periods - index = pd.MultiIndex.from_product( - [range(n_periods), range(n_initial_states)], - names=["period", "initial_state_id"], - ) - return pd.DataFrame(processed, index=index) - - -def _compute_targets( - processed_results: dict[str, Array], - targets: list[str], - model_functions: dict[str, InternalUserFunction], - params: ParamsDict, -) -> dict[str, Array]: - """Compute targets. - - Args: - processed_results: Dict with processed simulation results. Values must be - one-dimensional arrays. - targets: List of targets to compute. - model_functions: Dict with model functions. - params: Dict with model parameters. - - Returns: - Dict with computed targets. - - """ - target_func = concatenate_functions( - functions=model_functions, - targets=targets, - return_type="dict", - ) - - # get list of variables over which we want to vectorize the target function - variables = tuple( - p for p in list(inspect.signature(target_func).parameters) if p != "params" - ) - - target_func = vmap_1d(target_func, variables=variables) - - kwargs = {k: v for k, v in processed_results.items() if k in variables} - return target_func(params=params, **kwargs) - - -def _process_simulated_data(results: list[dict[str, Any]]) -> dict[str, Array]: - """Process and flatten the simulation results. - - This function produces a dict of arrays for each var with dimension (n_periods * - n_initial_states,). The arrays are flattened, so that the resulting dictionary has a - one-dimensional array for each variable. The length of this array is the number of - periods times the number of initial states. The order of array elements is given by - an outer level of periods and an inner level of initial states ids. - - - Args: - results (list): List of dicts with simulation results. Each dict contains the - value, choices, and states for one period. Choices and states are stored in - a nested dictionary. - - Returns: - Dict with processed simulation results. The keys are the variable names and the - values are the flattened arrays, with dimension (n_periods * n_initial_states,). - Additionally, the _period variable is added. - - """ - n_periods = len(results) - n_initial_states = len(results[0]["value"]) - - list_of_dicts = [ - {"value": d["value"], **d["choices"], **d["states"]} for d in results - ] - dict_of_lists = { - key: [d[key] for d in list_of_dicts] for key in list(list_of_dicts[0]) - } - out = {key: jnp.concatenate(values) for key, values in dict_of_lists.items()} - out["_period"] = jnp.repeat(jnp.arange(n_periods), n_initial_states) - - return out - - -# ====================================================================================== -# Simulation keys -# ====================================================================================== - - -def _generate_simulation_keys( - key: Array, ids: list[str] -) -> tuple[Array, dict[str, Array]]: - """Generate pseudo-random number generator keys (PRNG keys) for simulation. - - PRNG keys in JAX are immutable objects used to control random number generation. - A key can be used to generate a stream of random numbers, e.g., given a key, one can - call jax.random.normal(key) to generate a stream of normal random numbers. In order - to ensure that each simulation is based on a different stream of random numbers, we - split the key into one key per simulation unit, and one key that will be passed to - the next iteration in order to generate new keys. - - See the JAX documentation for more details: - https://docs.jax.dev/en/latest/random-numbers.html#random-numbers-in-jax - - Args: - key: PRNG key. - ids: List of names for which a key is to be generated. - - Returns: - - Updated PRNG key. - - Dict with PRNG keys for each id in ids. - - """ - keys = jax.random.split(key, num=len(ids) + 1) - - key = keys[0] - simulation_keys = dict(zip(ids, keys[1:], strict=True)) - - return key, simulation_keys - - # ====================================================================================== # Filter policy # ====================================================================================== @@ -450,57 +312,6 @@ def retrieve_choices( vmapped_unravel_index = vmap(jnp.unravel_index, in_axes=(0, None)) -# ====================================================================================== -# Data State Choice Space -# ====================================================================================== - - -def create_data_scs( - states: dict[str, Array], - model: InternalModel, -) -> StateChoiceSpace: - """Create data state choice space. - - Args: - states: Dict with initial states. - model: Model instance. - - Returns: - Data state choice space. - - """ - # preparations - # ================================================================================== - vi = model.variable_info - - # check that all states have an initial value - # ================================================================================== - state_names = set(vi.query("is_state").index) - - if state_names != set(states.keys()): - missing = state_names - set(states.keys()) - too_many = set(states.keys()) - state_names - raise ValueError( - "You need to provide an initial value for each state variable in the model." - f"\n\nMissing initial states: {missing}\n", - f"Provided variables that are not states: {too_many}", - ) - - # get choices - # ================================================================================== - choices = { - name: grid - for name, grid in model.grids.items() - if name in vi.query("is_choice & is_discrete").index.tolist() - } - - return StateChoiceSpace( - states=states, - choices=choices, - ordered_var_names=tuple(vi.query("is_state | is_discrete").index.tolist()), - ) - - # ====================================================================================== # Discrete policy # ====================================================================================== diff --git a/src/lcm/simulation/state_choice_space.py b/src/lcm/simulation/state_choice_space.py new file mode 100644 index 0000000..eeb9f70 --- /dev/null +++ b/src/lcm/simulation/state_choice_space.py @@ -0,0 +1,51 @@ +from jax import Array + +from lcm.interfaces import InternalModel, StateChoiceSpace + + +def create_state_choice_space( + model: InternalModel, + initial_states: dict[str, Array], +) -> StateChoiceSpace: + """Create the initial state choice space. + + In comparison to the solution, the state choice space in the simulation must be + created during each iteration, because the states change over time. + + Args: + model: Model instance. + initial_states: Dict with initial states. + + Returns: + State choice space. + + Raises: + ValueError: If the initial states do not match the state variables in the model. + + """ + vi = model.variable_info + state_names = set(vi.query("is_state").index) + + if state_names != set(initial_states.keys()): + missing = state_names - set(initial_states.keys()) + too_many = set(initial_states.keys()) - state_names + raise ValueError( + "You need to provide an initial value for each state variable in the model." + f"\n\nMissing initial states: {missing}\n", + f"Provided variables that are not states: {too_many}", + ) + + ordered_var_names = tuple(vi.query("is_state | is_discrete").index) + discrete_choice_names = vi.query("is_choice & is_discrete").index + + discrete_choices = { + name: grid + for name, grid in model.grids.items() + if name in discrete_choice_names + } + + return StateChoiceSpace( + states=initial_states, + choices=discrete_choices, + ordered_var_names=ordered_var_names, + ) diff --git a/tests/simulation/test_processing.py b/tests/simulation/test_processing.py new file mode 100644 index 0000000..3958163 --- /dev/null +++ b/tests/simulation/test_processing.py @@ -0,0 +1,94 @@ +import jax.numpy as jnp +import pandas as pd +from pybaum import tree_equal + +from lcm.simulation.processing import ( + _compute_targets, + as_data_frame, + process_simulated_data, +) + + +def test_compute_targets(): + processed_results = { + "a": jnp.arange(3), + "b": 1 + jnp.arange(3), + "c": 2 + jnp.arange(3), + } + + def f_a(a, params): + return a + params["disutility_of_work"] + + def f_b(b, params): # noqa: ARG001 + return b + + def f_c(params): # noqa: ARG001 + return None + + model_functions = {"fa": f_a, "fb": f_b, "fc": f_c} + + got = _compute_targets( + processed_results=processed_results, + targets=["fa", "fb"], + model_functions=model_functions, # type: ignore[arg-type] + params={"disutility_of_work": -1.0}, + ) + expected = { + "fa": jnp.arange(3) - 1.0, + "fb": 1 + jnp.arange(3), + } + assert tree_equal(expected, got) + + +def test_as_data_frame(): + processed = { + "value": -6 + jnp.arange(6), + "a": jnp.arange(6), + "b": 6 + jnp.arange(6), + } + got = as_data_frame(processed, n_periods=2) + expected = pd.DataFrame( + { + "period": [0, 0, 0, 1, 1, 1], + "initial_state_id": [0, 1, 2, 0, 1, 2], + **processed, + }, + ).set_index(["period", "initial_state_id"]) + pd.testing.assert_frame_equal(got, expected) + + +def test_process_simulated_data(): + simulated = [ + { + "value": jnp.array([0.1, 0.2]), + "states": {"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, + "choices": {"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, + }, + { + "value": jnp.array([0.3, 0.4]), + "states": { + "b": jnp.array([-3, -4]), + "a": jnp.array([3, 4]), + }, + "choices": { + "d": jnp.array([-7, -8]), + "c": jnp.array([7, 8]), + }, + }, + ] + expected = { + "value": jnp.array([0.1, 0.2, 0.3, 0.4]), + "c": jnp.array([5, 6, 7, 8]), + "d": jnp.array([-5, -6, -7, -8]), + "a": jnp.array([1, 2, 3, 4]), + "b": jnp.array([-1, -2, -3, -4]), + } + + got = process_simulated_data( + simulated, + # Rest is none, since we are not computing any additional targets + model=None, # type: ignore[arg-type] + params=None, # type: ignore[arg-type] + additional_targets=None, + ) + assert tree_equal(expected, got) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 63cb891..22ae35e 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -2,7 +2,6 @@ import pandas as pd import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal -from pybaum import tree_equal from lcm.conditional_continuation import ( get_compute_conditional_continuation_policy, @@ -12,11 +11,6 @@ from lcm.logging import get_logger from lcm.next_state import get_next_state_function from lcm.simulation.simulate import ( - _as_data_frame, - _compute_targets, - _generate_simulation_keys, - _process_simulated_data, - create_data_scs, determine_discrete_choice_axes, filter_ccv_policy, retrieve_choices, @@ -282,94 +276,6 @@ def test_effect_of_disutility_of_work(): # ====================================================================================== -def test_generate_simulation_keys(): - key = jnp.arange(2, dtype="uint32") # PRNG dtype - stochastic_next_functions = ["a", "b"] - got = _generate_simulation_keys(key, stochastic_next_functions) - # assert that all generated keys are different from each other - matrix = jnp.array([key, got[0], got[1]["a"], got[1]["b"]]) - assert jnp.linalg.matrix_rank(matrix) == 2 - - -def test_as_data_frame(): - processed = { - "value": -6 + jnp.arange(6), - "a": jnp.arange(6), - "b": 6 + jnp.arange(6), - } - got = _as_data_frame(processed, n_periods=2) - expected = pd.DataFrame( - { - "period": [0, 0, 0, 1, 1, 1], - "initial_state_id": [0, 1, 2, 0, 1, 2], - **processed, - }, - ).set_index(["period", "initial_state_id"]) - pd.testing.assert_frame_equal(got, expected) - - -def test_compute_targets(): - processed_results = { - "a": jnp.arange(3), - "b": 1 + jnp.arange(3), - "c": 2 + jnp.arange(3), - } - - def f_a(a, params): - return a + params["disutility_of_work"] - - def f_b(b, params): # noqa: ARG001 - return b - - def f_c(params): # noqa: ARG001 - return None - - model_functions = {"fa": f_a, "fb": f_b, "fc": f_c} - - got = _compute_targets( - processed_results=processed_results, - targets=["fa", "fb"], - model_functions=model_functions, # type: ignore[arg-type] - params={"disutility_of_work": -1.0}, - ) - expected = { - "fa": jnp.arange(3) - 1.0, - "fb": 1 + jnp.arange(3), - } - assert tree_equal(expected, got) - - -def test_process_simulated_data(): - simulated = [ - { - "value": jnp.array([0.1, 0.2]), - "states": {"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, - "choices": {"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, - }, - { - "value": jnp.array([0.3, 0.4]), - "states": { - "b": jnp.array([-3, -4]), - "a": jnp.array([3, 4]), - }, - "choices": { - "d": jnp.array([-7, -8]), - "c": jnp.array([7, 8]), - }, - }, - ] - expected = { - "value": jnp.array([0.1, 0.2, 0.3, 0.4]), - "c": jnp.array([5, 6, 7, 8]), - "d": jnp.array([-5, -6, -7, -8]), - "a": jnp.array([1, 2, 3, 4]), - "b": jnp.array([-1, -2, -3, -4]), - } - - got = _process_simulated_data(simulated) - assert tree_equal(expected, got) - - def test_retrieve_choices(): got = retrieve_choices( flat_indices=jnp.array([0, 3, 7]), @@ -397,21 +303,6 @@ def test_filter_ccv_policy(): assert jnp.all(got == jnp.array([0, 0])) -def test_create_data_state_choice_space(): - model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) - model = process_model(model_config) - got_space = create_data_scs( - states={ - "wealth": jnp.array([10.0, 20.0]), - "lagged_retirement": jnp.array([0, 1]), - }, - model=model, - ) - assert_array_equal(got_space.choices["retirement"], jnp.array([0, 1])) - assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) - assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) - - def test_determine_discrete_choice_axes(): variable_info = pd.DataFrame( { diff --git a/tests/simulation/test_state_choice_space.py b/tests/simulation/test_state_choice_space.py new file mode 100644 index 0000000..45e8e41 --- /dev/null +++ b/tests/simulation/test_state_choice_space.py @@ -0,0 +1,37 @@ +import jax.numpy as jnp +from numpy.testing import assert_array_equal + +from lcm.input_processing import process_model +from lcm.simulation.state_choice_space import create_state_choice_space +from tests.test_models import get_model_config + + +def test_create_state_choice_space(): + model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) + model = process_model(model_config) + got_space = create_state_choice_space( + model=model, + initial_states={ + "wealth": jnp.array([10.0, 20.0]), + "lagged_retirement": jnp.array([0, 1]), + }, + ) + assert_array_equal(got_space.choices["retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) + assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) + + +def test_create_state_choice_space_replace(): + model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) + model = process_model(model_config) + space = create_state_choice_space( + model=model, + initial_states={ + "wealth": jnp.array([10.0, 20.0]), + "lagged_retirement": jnp.array([0, 1]), + }, + ) + new_space = space.replace( + states={"wealth": jnp.array([10.0, 30.0])}, + ) + assert_array_equal(new_space.states["wealth"], jnp.array([10.0, 30.0])) diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 0000000..f6cc714 --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp + +from lcm.random import generate_simulation_keys, random_choice + + +def test_random_choice(): + key = jax.random.key(0) + probs = jnp.array([[0.0, 0, 1], [1, 0, 0], [0, 1, 0]]) + labels = jnp.array([1, 2, 3]) + got = random_choice(labels=labels, probs=probs, key=key) + assert jnp.array_equal(got, jnp.array([3, 1, 2])) + + +def test_generate_simulation_keys(): + key = jnp.arange(2, dtype="uint32") # PRNG dtype + stochastic_next_functions = ["a", "b"] + got = generate_simulation_keys(key, stochastic_next_functions) + # assert that all generated keys are different from each other + matrix = jnp.array([key, got[0], got[1]["a"], got[1]["b"]]) + assert jnp.linalg.matrix_rank(matrix) == 2 diff --git a/tests/test_random_choice.py b/tests/test_random_choice.py deleted file mode 100644 index 2315788..0000000 --- a/tests/test_random_choice.py +++ /dev/null @@ -1,12 +0,0 @@ -import jax -import jax.numpy as jnp - -from lcm.random_choice import random_choice - - -def test_random_choice(): - key = jax.random.key(0) - probs = jnp.array([[0.0, 0, 1], [1, 0, 0], [0, 1, 0]]) - labels = jnp.array([1, 2, 3]) - got = random_choice(labels=labels, probs=probs, key=key) - assert jnp.array_equal(got, jnp.array([3, 1, 2])) From 94c177d9dd4c00b19c737d840b026582d99d10c5 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 12:19:44 +0100 Subject: [PATCH 13/20] Move discrete policy from simulation to discrete problem module --- src/lcm/discrete_problem.py | 68 +++++++++++++++++++++++++++--- src/lcm/entry_point.py | 8 ++-- src/lcm/simulation/simulate.py | 69 +++---------------------------- src/lcm/solution/solve_brute.py | 4 +- src/lcm/typing.py | 14 ++++++- tests/simulation/test_simulate.py | 21 +++------- tests/test_discrete_problem.py | 24 ++++++++--- 7 files changed, 111 insertions(+), 97 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index dd405c8..b265dcd 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -23,15 +23,21 @@ import pandas as pd from jax import Array -from lcm.typing import DiscreteProblemSolverFunction, ParamsDict, ShockType +from lcm.argmax import argmax +from lcm.typing import ( + DiscreteProblemPolicySolverFunction, + DiscreteProblemValueSolverFunction, + ParamsDict, + ShockType, +) -def get_solve_discrete_problem( +def get_solve_discrete_problem_value( *, random_utility_shock_type: ShockType, variable_info: pd.DataFrame, is_last_period: bool, -) -> DiscreteProblemSolverFunction: +) -> DiscreteProblemValueSolverFunction: """Get function that computes the expected max. of conditional continuation values. The maximum is taken over the discrete choice variables in each state. @@ -51,7 +57,7 @@ def get_solve_discrete_problem( if is_last_period: variable_info = variable_info.query("~is_auxiliary") - choice_axes = _determine_discrete_choice_axes(variable_info) + choice_axes = _determine_discrete_choice_axes_solution(variable_info) if random_utility_shock_type == ShockType.NONE: func = _solve_discrete_problem_no_shocks @@ -63,6 +69,35 @@ def get_solve_discrete_problem( return partial(func, choice_axes=choice_axes) +def get_solve_discrete_problem_policy( + *, + variable_info: pd.DataFrame, +) -> DiscreteProblemPolicySolverFunction: + """Return a function that calculates the argmax and max of continuation values. + + The argmax is taken over the discrete choice variables in each state. + + Args: + variable_info (pd.DataFrame): DataFrame with information about the model + variables. + + Returns: + callable: Function that calculates the argmax of the conditional continuation + values. The function depends on: + - values (jax.Array): Multidimensional jax array with conditional + continuation values. + + """ + choice_axes = _determine_discrete_choice_axes_simulation(variable_info) + + def _calculate_discrete_argmax( + values: Array, choice_axes: tuple[int, ...] + ) -> tuple[Array, Array]: + return argmax(values, axis=choice_axes) + + return partial(_calculate_discrete_argmax, choice_axes=choice_axes) + + # ====================================================================================== # Discrete problem with no shocks # ====================================================================================== @@ -129,10 +164,10 @@ def _calculate_emax_extreme_value_shocks( # ====================================================================================== -def _determine_discrete_choice_axes( +def _determine_discrete_choice_axes_solution( variable_info: pd.DataFrame, ) -> tuple[int, ...]: - """Get axes of a state-choice-space that correspond to discrete choices. + """Get axes of state-choice-space that correspond to discrete choices in solution. Args: variable_info: DataFrame with information about the variables. @@ -148,3 +183,24 @@ def _determine_discrete_choice_axes( return tuple( i for i, ax in enumerate(variable_info.index) if ax in discrete_choice_vars ) + + +def _determine_discrete_choice_axes_simulation( + variable_info: pd.DataFrame, +) -> tuple[int, ...]: + """Get axes of state-choice-space that correspond to discrete choices in simulation. + + Args: + variable_info: DataFrame with information about the variables. + + Returns: + A tuple of indices representing the axes' positions in the value function that + correspond to discrete choices. + + """ + discrete_choice_vars = set( + variable_info.query("is_choice & is_discrete").index.tolist() + ) + + # The first dimension corresponds to the simulated states, so add 1. + return tuple(1 + i for i in range(len(discrete_choice_vars))) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index dffcbbb..997b69c 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -10,7 +10,7 @@ get_compute_conditional_continuation_policy, get_compute_conditional_continuation_value, ) -from lcm.discrete_problem import get_solve_discrete_problem +from lcm.discrete_problem import get_solve_discrete_problem_value from lcm.input_processing import process_model from lcm.interfaces import StateChoiceSpace, StateSpaceInfo from lcm.logging import get_logger @@ -18,7 +18,7 @@ from lcm.simulation.simulate import simulate from lcm.solution.solve_brute import solve from lcm.solution.state_choice_space import create_state_choice_space -from lcm.typing import DiscreteProblemSolverFunction, ParamsDict, Target +from lcm.typing import DiscreteProblemValueSolverFunction, ParamsDict, Target from lcm.user_model import Model from lcm.utility_and_feasibility import ( get_utility_and_feasibility_function, @@ -85,7 +85,7 @@ def get_lcm_function( state_space_infos: dict[int, StateSpaceInfo] = {} compute_ccv_functions: dict[int, Callable[[Array, Array], Array]] = {} compute_ccp_functions: dict[int, Callable[..., tuple[Array, Array]]] = {} - solve_discrete_problem_functions: dict[int, DiscreteProblemSolverFunction] = {} + solve_discrete_problem_functions: dict[int, DiscreteProblemValueSolverFunction] = {} for period in reversed(range(internal_model.n_periods)): is_last_period = period == last_period @@ -117,7 +117,7 @@ def get_lcm_function( continuous_choice_variables=tuple(_choice_grids), ) - solve_discrete_problem = get_solve_discrete_problem( + solve_discrete_problem = get_solve_discrete_problem_value( random_utility_shock_type=internal_model.random_utility_shocks, variable_info=internal_model.variable_info, is_last_period=is_last_period, diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index e5c03a7..f25c407 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -7,7 +7,7 @@ import pandas as pd from jax import Array, vmap -from lcm.argmax import argmax +from lcm.discrete_problem import get_solve_discrete_problem_policy from lcm.dispatchers import simulation_spacemap, vmap_1d from lcm.interfaces import InternalModel, StateChoiceSpace from lcm.random import generate_simulation_keys @@ -94,7 +94,7 @@ def simulate( zip(range(n_periods), vf_arr_list[1:] + [jnp.empty(0)], strict=True) ) - discrete_policy_calculator = get_discrete_policy_calculator( + discrete_policy_calculator = get_solve_discrete_problem_policy( variable_info=model.variable_info, ) @@ -135,7 +135,7 @@ def simulate( # Get optimal discrete choice given the optimal conditional continuous choices # ============================================================================== - discrete_argmax, value = discrete_policy_calculator(ccv) + discrete_argmax, value = discrete_policy_calculator(ccv, params=params) # Select optimal continuous choice corresponding to optimal discrete choice # ------------------------------------------------------------------------------ @@ -182,17 +182,16 @@ def simulate( ids=model.function_info.query("is_stochastic_next").index.tolist(), ) - states = next_state( + states_with_prefix = next_state( **states, **choices, _period=jnp.repeat(period, n_initial_states), params=params, keys=stochastic_variables_keys, ) - # 'next_' prefix is added by the next_state function, but needs to be removed - # because in the next period, next states are current states. - states = {k.removeprefix("next_"): v for k, v in states.items()} + # because in the next period, next states will be current states. + states = {k.removeprefix("next_"): v for k, v in states_with_prefix.items()} logger.info("Period: %s", period) @@ -310,59 +309,3 @@ def retrieve_choices( # vmap jnp.unravel_index over the first axis of the `indices` argument, while holding # the `shape` argument constant (in_axes = (0, None)). vmapped_unravel_index = vmap(jnp.unravel_index, in_axes=(0, None)) - - -# ====================================================================================== -# Discrete policy -# ====================================================================================== - - -def get_discrete_policy_calculator( - variable_info: pd.DataFrame, -) -> Callable[..., tuple[Array, Array]]: - """Return a function that calculates the argmax and max of continuation values. - - The argmax is taken over the discrete choice variables in each state. - - Args: - variable_info (pd.DataFrame): DataFrame with information about the model - variables. - - Returns: - callable: Function that calculates the argmax of the conditional continuation - values. The function depends on: - - values (jax.Array): Multidimensional jax array with conditional - continuation values. - - """ - choice_axes = determine_discrete_choice_axes(variable_info) - - def _calculate_discrete_argmax( - values: Array, choice_axes: tuple[int, ...] - ) -> tuple[Array, Array]: - return argmax(values, axis=choice_axes) - - return partial(_calculate_discrete_argmax, choice_axes=choice_axes) - - -def determine_discrete_choice_axes(variable_info: pd.DataFrame) -> tuple[int, ...]: - """Determine which axes correspond to discrete choices. - - Args: - variable_info (pd.DataFrame): DataFrame with information about the variables. - - Returns: - tuple: Tuple of ints, specifying which axes in a value function correspond to - discrete choices. - - """ - discrete_choice_vars = variable_info.query( - "is_choice & is_discrete", - ).index.tolist() - - choice_vars = set(variable_info.query("is_choice").index.tolist()) - - # The first dimension corresponds to the simulated states, so add 1. - return tuple( - 1 + i for i, ax in enumerate(discrete_choice_vars) if ax in choice_vars - ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 44e325e..6acee7a 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -6,7 +6,7 @@ from lcm.dispatchers import productmap from lcm.interfaces import StateChoiceSpace -from lcm.typing import DiscreteProblemSolverFunction, ParamsDict +from lcm.typing import DiscreteProblemValueSolverFunction, ParamsDict def solve( @@ -14,7 +14,7 @@ def solve( state_choice_spaces: dict[int, StateChoiceSpace], continuous_choice_grids: dict[int, dict[str, Array]], compute_ccv_functions: dict[int, Callable[[Array, Array], Array]], - emax_calculators: dict[int, DiscreteProblemSolverFunction], + emax_calculators: dict[int, DiscreteProblemValueSolverFunction], logger: logging.Logger, ) -> list[Array]: """Solve a model by brute force. diff --git a/src/lcm/typing.py b/src/lcm/typing.py index 1df6e16..a989471 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -33,8 +33,8 @@ def __call__( # noqa: D102 ) -> Scalar: ... -class DiscreteProblemSolverFunction(Protocol): - """The function that solves the discrete problem. +class DiscreteProblemValueSolverFunction(Protocol): + """The function that solves for the value of the discrete problem. Only used for type checking. @@ -43,6 +43,16 @@ class DiscreteProblemSolverFunction(Protocol): def __call__(self, values: Array, params: ParamsDict) -> Array: ... # noqa: D102 +class DiscreteProblemPolicySolverFunction(Protocol): + """The function that solves for the policy of the discrete problem. + + Only used for type checking. + + """ + + def __call__(self, values: Array, params: ParamsDict) -> tuple[Array, Array]: ... # noqa: D102 + + class StochasticNextFunction(Protocol): """The function that simulates the next state of a stochastic variable. diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 22ae35e..aa8a1bb 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -1,5 +1,6 @@ +from typing import TYPE_CHECKING + import jax.numpy as jnp -import pandas as pd import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal @@ -11,7 +12,6 @@ from lcm.logging import get_logger from lcm.next_state import get_next_state_function from lcm.simulation.simulate import ( - determine_discrete_choice_axes, filter_ccv_policy, retrieve_choices, simulate, @@ -24,6 +24,10 @@ get_params, ) +if TYPE_CHECKING: + import pandas as pd + + # ====================================================================================== # Test simulate using raw inputs # ====================================================================================== @@ -301,16 +305,3 @@ def test_filter_ccv_policy(): vars_grid_shape=vars_grid_shape, ) assert jnp.all(got == jnp.array([0, 0])) - - -def test_determine_discrete_choice_axes(): - variable_info = pd.DataFrame( - { - "is_state": [True, True, False, True, False, False], - "is_choice": [False, False, True, True, True, True], - "is_discrete": [True, True, True, True, True, False], - "is_continuous": [False, True, False, False, False, True], - }, - ) - got = determine_discrete_choice_axes(variable_info) - assert got == (1, 2, 3) diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index 30e11a4..141d3c6 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -5,9 +5,10 @@ from lcm.discrete_problem import ( _calculate_emax_extreme_value_shocks, - _determine_discrete_choice_axes, + _determine_discrete_choice_axes_simulation, + _determine_discrete_choice_axes_solution, _solve_discrete_problem_no_shocks, - get_solve_discrete_problem, + get_solve_discrete_problem_value, ) from lcm.typing import ShockType @@ -27,7 +28,7 @@ def test_get_solve_discrete_problem_illustrative(): }, ) # leads to choice_axes = [1] - solve_discrete_problem = get_solve_discrete_problem( + solve_discrete_problem = get_solve_discrete_problem_value( random_utility_shock_type=ShockType.NONE, variable_info=variable_info, is_last_period=False, @@ -130,7 +131,7 @@ def test_determine_discrete_choice_axes_illustrative_one_var(): }, ) - assert _determine_discrete_choice_axes(variable_info) == (1,) + assert _determine_discrete_choice_axes_solution(variable_info) == (1,) @pytest.mark.illustrative @@ -144,4 +145,17 @@ def test_determine_discrete_choice_axes_illustrative_three_var(): }, ) - assert _determine_discrete_choice_axes(variable_info) == (1, 2, 3) + assert _determine_discrete_choice_axes_solution(variable_info) == (1, 2, 3) + + +def test_determine_discrete_choice_axes(): + variable_info = pd.DataFrame( + { + "is_state": [True, True, False, True, False, False], + "is_choice": [False, False, True, True, True, True], + "is_discrete": [True, True, True, True, True, False], + "is_continuous": [False, True, False, False, False, True], + }, + ) + got = _determine_discrete_choice_axes_simulation(variable_info) + assert got == (1, 2, 3) From 7db750937b4ff88ff1a4c687f847ad92985c77a3 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 14:01:23 +0100 Subject: [PATCH 14/20] Implemenent solve_and_simulate function --- src/lcm/discrete_problem.py | 4 ++- src/lcm/entry_point.py | 22 +++++++++--- src/lcm/simulation/simulate.py | 56 ++++++++++++++++++++----------- tests/simulation/test_simulate.py | 12 +++---- tests/test_entry_point.py | 4 +-- tests/test_stochastic.py | 4 +-- 6 files changed, 67 insertions(+), 35 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index b265dcd..fb24502 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -91,7 +91,9 @@ def get_solve_discrete_problem_policy( choice_axes = _determine_discrete_choice_axes_simulation(variable_info) def _calculate_discrete_argmax( - values: Array, choice_axes: tuple[int, ...] + values: Array, + choice_axes: tuple[int, ...], + params: ParamsDict, # noqa: ARG001 ) -> tuple[Array, Array]: return argmax(values, axis=choice_axes) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 997b69c..b8c2340 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -15,7 +15,7 @@ from lcm.interfaces import StateChoiceSpace, StateSpaceInfo from lcm.logging import get_logger from lcm.next_state import get_next_state_function -from lcm.simulation.simulate import simulate +from lcm.simulation.simulate import simulate, solve_and_simulate from lcm.solution.solve_brute import solve from lcm.solution.state_choice_space import create_state_choice_space from lcm.typing import DiscreteProblemValueSolverFunction, ParamsDict, Target @@ -141,18 +141,30 @@ def get_lcm_function( logger=logger, ) - solve_model = jax.jit(_solve_model) if jit else _solve_model - _next_state_simulate = get_next_state_function( model=internal_model, target=Target.SIMULATE ) + + solve_model = jax.jit(_solve_model) if jit else _solve_model + next_state_simulate = jax.jit(_next_state_simulate) if jit else _next_state_simulate + simulate_model = partial( simulate, continuous_choice_grids=continuous_choice_grids, compute_ccv_policy_functions=compute_ccp_functions, model=internal_model, - next_state=jax.jit(_next_state_simulate), + next_state=next_state_simulate, # type: ignore[arg-type] + logger=logger, + ) + + solve_and_simulate_model = partial( + solve_and_simulate, + continuous_choice_grids=continuous_choice_grids, + compute_ccv_policy_functions=compute_ccp_functions, + model=internal_model, + next_state=next_state_simulate, # type: ignore[arg-type] logger=logger, + solve_model=solve_model, ) target_func: Callable[..., list[Array] | pd.DataFrame] @@ -162,7 +174,7 @@ def get_lcm_function( elif targets == "simulate": target_func = simulate_model elif targets == "solve_and_simulate": - target_func = partial(simulate_model, solve_model=solve_model) + target_func = solve_and_simulate_model return target_func, internal_model.params diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index f25c407..126ec57 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -17,6 +17,39 @@ from lcm.utils import draw_random_seed +def solve_and_simulate( + params: ParamsDict, + initial_states: dict[str, Array], + continuous_choice_grids: dict[int, dict[str, Array]], + compute_ccv_policy_functions: dict[int, Callable[..., tuple[Array, Array]]], + model: InternalModel, + next_state: Callable[..., dict[str, Array]], + logger: logging.Logger, + solve_model: Callable[..., list[Array]], + *, + additional_targets: list[str] | None = None, + seed: int | None = None, +) -> pd.DataFrame: + """First solve the model and then simulate the model forward in time. + + Same docstring as `simulate` mutatis mutandis. + + """ + vf_arr_list = solve_model(params) + return simulate( + params=params, + initial_states=initial_states, + continuous_choice_grids=continuous_choice_grids, + compute_ccv_policy_functions=compute_ccv_policy_functions, + model=model, + next_state=next_state, + logger=logger, + vf_arr_list=vf_arr_list, + additional_targets=additional_targets, + seed=seed, + ) + + def simulate( params: ParamsDict, initial_states: dict[str, Array], @@ -25,12 +58,12 @@ def simulate( model: InternalModel, next_state: Callable[..., dict[str, Array]], logger: logging.Logger, - solve_model: Callable[..., list[Array]] | None = None, - pre_computed_vf_arr_list: list[Array] | None = None, + vf_arr_list: list[Array], + *, additional_targets: list[str] | None = None, seed: int | None = None, ) -> pd.DataFrame: - """Simulate the model forward in time. + """Simulate the model forward in time given pre-computed value function arrays. Args: params: Dict of model parameters. @@ -45,11 +78,7 @@ def simulate( draw from the distribution of the next state. model: Model instance. logger: Logger that logs to stdout. - solve_model: Function that solves the model. Is only required if - vf_arr_list is not provided. - pre_computed_vf_arr_list: List of value function arrays of length - n_periods. This is the output of the model's `solve` function. If not - provided, the model is solved first. + vf_arr_list: List of value function arrays of length n_periods. additional_targets: List of targets to compute. If provided, the targets are computed and added to the simulation results. seed: Random number seed; will be passed to `jax.random.key`. If not provided, @@ -59,17 +88,6 @@ def simulate( DataFrame with the simulation results. """ - if pre_computed_vf_arr_list is None: - if solve_model is None: - raise ValueError( - "You need to provide either vf_arr_list or solve_model.", - ) - # We do not need to convert the params here, because the solve_model function - # will do it. - vf_arr_list = solve_model(params) - else: - vf_arr_list = pre_computed_vf_arr_list - if seed is None: seed = draw_random_seed() diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index aa8a1bb..670a740 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -80,7 +80,7 @@ def test_simulate_using_raw_inputs(simulate_inputs): got = simulate( params=params, - pre_computed_vf_arr_list=[jnp.empty(0)], + vf_arr_list=[jnp.empty(0)], initial_states={"wealth": jnp.array([1.0, 50.400803])}, logger=get_logger(debug_mode=False), **simulate_inputs, @@ -130,7 +130,7 @@ def test_simulate_using_get_lcm_function( res: pd.DataFrame = simulate_model( # type: ignore[assignment] params, - pre_computed_vf_arr_list=vf_arr_list, + vf_arr_list=vf_arr_list, initial_states={ "wealth": jnp.array([20.0, 150, 250, 320]), }, @@ -206,13 +206,13 @@ def test_effect_of_beta_on_last_period(): res_low: pd.DataFrame = simulate_model( # type: ignore[assignment] params_low, - pre_computed_vf_arr_list=solution_low, + vf_arr_list=solution_low, initial_states={"wealth": initial_wealth}, ) res_high: pd.DataFrame = simulate_model( # type: ignore[assignment] params_high, - pre_computed_vf_arr_list=solution_high, + vf_arr_list=solution_high, initial_states={"wealth": initial_wealth}, ) @@ -250,13 +250,13 @@ def test_effect_of_disutility_of_work(): res_low: pd.DataFrame = simulate_model( # type: ignore[assignment] params_low, - pre_computed_vf_arr_list=solution_low, + vf_arr_list=solution_low, initial_states={"wealth": initial_wealth}, ) res_high: pd.DataFrame = simulate_model( # type: ignore[assignment] params_high, - pre_computed_vf_arr_list=solution_high, + vf_arr_list=solution_high, initial_states={"wealth": initial_wealth}, ) diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 4e45e93..1a8d413 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -110,7 +110,7 @@ def test_get_lcm_function_with_simulation_is_coherent(model): solve_then_simulate = simulate_model( params, - pre_computed_vf_arr_list=vf_arr_list, + vf_arr_list=vf_arr_list, initial_states={ "wealth": jnp.array([0.0, 10.0, 50.0]), }, @@ -149,7 +149,7 @@ def test_get_lcm_function_with_simulation_target_iskhakov_et_al_2017(model): simulate_model( params, - pre_computed_vf_arr_list=vf_arr_list, + vf_arr_list=vf_arr_list, initial_states={ "wealth": jnp.array([10.0, 10.0, 20.0]), "lagged_retirement": jnp.array( diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index e6e6b4b..6dba1e4 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -135,12 +135,12 @@ def test_compare_deterministic_and_stochastic_results(model_and_params): simulation_deterministic = simulate_model_deterministic( params, - pre_computed_vf_arr_list=solution_deterministic, + vf_arr_list=solution_deterministic, initial_states=initial_states, ) simulation_stochastic = simulate_model_stochastic( params, - pre_computed_vf_arr_list=solution_stochastic, + vf_arr_list=solution_stochastic, initial_states=initial_states, ) pd.testing.assert_frame_equal(simulation_deterministic, simulation_stochastic) # type: ignore[arg-type] From db0a22c451ab46cad6d3c5531184779b2c66db50 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 14:29:10 +0100 Subject: [PATCH 15/20] Use vf_arr_dict with period keys instead of vf_arr_list --- src/lcm/entry_point.py | 4 ++-- src/lcm/simulation/simulate.py | 23 +++++++---------------- src/lcm/solution/solve_brute.py | 11 +++++------ tests/simulation/test_simulate.py | 18 +++++++++--------- tests/solution/test_solve_brute.py | 2 +- tests/test_analytical_solution.py | 4 +++- tests/test_entry_point.py | 8 ++++---- tests/test_regression_test.py | 5 +++-- tests/test_stochastic.py | 17 +++++++++++------ 9 files changed, 45 insertions(+), 47 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index b8c2340..afa4097 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -31,7 +31,7 @@ def get_lcm_function( targets: Literal["solve", "simulate", "solve_and_simulate"], debug_mode: bool = True, jit: bool = True, -) -> tuple[Callable[..., list[Array] | pd.DataFrame], ParamsDict]: +) -> tuple[Callable[..., dict[int, Array] | pd.DataFrame], ParamsDict]: """Entry point for users to get high level functions generated by lcm. Return the function to solve and/or simulate a model along with a template for the @@ -167,7 +167,7 @@ def get_lcm_function( solve_model=solve_model, ) - target_func: Callable[..., list[Array] | pd.DataFrame] + target_func: Callable[..., dict[int, Array] | pd.DataFrame] if targets == "solve": target_func = solve_model diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 126ec57..f94219b 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -25,7 +25,7 @@ def solve_and_simulate( model: InternalModel, next_state: Callable[..., dict[str, Array]], logger: logging.Logger, - solve_model: Callable[..., list[Array]], + solve_model: Callable[..., dict[int, Array]], *, additional_targets: list[str] | None = None, seed: int | None = None, @@ -35,7 +35,7 @@ def solve_and_simulate( Same docstring as `simulate` mutatis mutandis. """ - vf_arr_list = solve_model(params) + vf_arr_dict = solve_model(params) return simulate( params=params, initial_states=initial_states, @@ -44,7 +44,7 @@ def solve_and_simulate( model=model, next_state=next_state, logger=logger, - vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, additional_targets=additional_targets, seed=seed, ) @@ -58,7 +58,7 @@ def simulate( model: InternalModel, next_state: Callable[..., dict[str, Array]], logger: logging.Logger, - vf_arr_list: list[Array], + vf_arr_dict: dict[int, Array], *, additional_targets: list[str] | None = None, seed: int | None = None, @@ -78,7 +78,7 @@ def simulate( draw from the distribution of the next state. model: Model instance. logger: Logger that logs to stdout. - vf_arr_list: List of value function arrays of length n_periods. + vf_arr_dict: Dict of value function arrays of length n_periods. additional_targets: List of targets to compute. If provided, the targets are computed and added to the simulation results. seed: Random number seed; will be passed to `jax.random.key`. If not provided, @@ -95,7 +95,7 @@ def simulate( # Preparations # ================================================================================== - n_periods = len(vf_arr_list) + n_periods = len(vf_arr_dict) n_initial_states = len(next(iter(initial_states.values()))) data_scs = create_state_choice_space( @@ -103,15 +103,6 @@ def simulate( initial_states=initial_states, ) - # We drop the value function array for the first period, because it is not needed - # for the simulation. This is because in the first period the agents only consider - # the current utility and the value function of next period. Similarly, the last - # value function array is not required, as the agents only consider the current - # utility in the last period. - next_vf_arr = dict( - zip(range(n_periods), vf_arr_list[1:] + [jnp.empty(0)], strict=True) - ) - discrete_policy_calculator = get_solve_discrete_problem_policy( variable_info=model.variable_info, ) @@ -147,7 +138,7 @@ def simulate( data_scs=data_scs, compute_ccv=compute_ccv_policy_functions[period], continuous_choice_grids=continuous_choice_grids[period], - vf_arr=next_vf_arr[period], + vf_arr=vf_arr_dict.get(period + 1, jnp.empty(0)), params=params, ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 6acee7a..2cbd176 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -16,7 +16,7 @@ def solve( compute_ccv_functions: dict[int, Callable[[Array, Array], Array]], emax_calculators: dict[int, DiscreteProblemValueSolverFunction], logger: logging.Logger, -) -> list[Array]: +) -> dict[int, Array]: """Solve a model by brute force. Notes: @@ -45,12 +45,11 @@ def solve( logger: Logger that logs to stdout. Returns: - List with one value function array per period. + Dict with one value function array per period. """ - # extract information n_periods = len(state_choice_spaces) - reversed_solution = [] + solution = {} vf_arr = None logger.info("Starting solution") @@ -69,11 +68,11 @@ def solve( # solve discrete problem by calculating expected maximum over discrete choices calculate_emax = emax_calculators[period] vf_arr = calculate_emax(conditional_continuation_values, params=params) - reversed_solution.append(vf_arr) + solution[period] = vf_arr logger.info("Period: %s", period) - return list(reversed(reversed_solution)) + return solution def solve_continuous_problem( diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 670a740..b0f9370 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -80,7 +80,7 @@ def test_simulate_using_raw_inputs(simulate_inputs): got = simulate( params=params, - vf_arr_list=[jnp.empty(0)], + vf_arr_dict={0: jnp.empty(0)}, initial_states={"wealth": jnp.array([1.0, 50.400803])}, logger=get_logger(debug_mode=False), **simulate_inputs, @@ -112,8 +112,8 @@ def _model_solution(n_periods): solve_model, _ = get_lcm_function(model_config, targets="solve") params = get_params() - vf_arr_list = solve_model(params=params) - return vf_arr_list, params, model_config + vf_arr_dict = solve_model(params=params) + return vf_arr_dict, params, model_config return _model_solution @@ -122,7 +122,7 @@ def test_simulate_using_get_lcm_function( iskhakov_et_al_2017_stripped_down_model_solution, ): n_periods = 3 - vf_arr_list, params, model = iskhakov_et_al_2017_stripped_down_model_solution( + vf_arr_dict, params, model = iskhakov_et_al_2017_stripped_down_model_solution( n_periods=n_periods, ) @@ -130,7 +130,7 @@ def test_simulate_using_get_lcm_function( res: pd.DataFrame = simulate_model( # type: ignore[assignment] params, - vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, initial_states={ "wealth": jnp.array([20.0, 150, 250, 320]), }, @@ -206,13 +206,13 @@ def test_effect_of_beta_on_last_period(): res_low: pd.DataFrame = simulate_model( # type: ignore[assignment] params_low, - vf_arr_list=solution_low, + vf_arr_dict=solution_low, initial_states={"wealth": initial_wealth}, ) res_high: pd.DataFrame = simulate_model( # type: ignore[assignment] params_high, - vf_arr_list=solution_high, + vf_arr_dict=solution_high, initial_states={"wealth": initial_wealth}, ) @@ -250,13 +250,13 @@ def test_effect_of_disutility_of_work(): res_low: pd.DataFrame = simulate_model( # type: ignore[assignment] params_low, - vf_arr_list=solution_low, + vf_arr_dict=solution_low, initial_states={"wealth": initial_wealth}, ) res_high: pd.DataFrame = simulate_model( # type: ignore[assignment] params_high, - vf_arr_list=solution_high, + vf_arr_dict=solution_high, initial_states={"wealth": initial_wealth}, ) diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 11ce8fb..80532d9 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -109,7 +109,7 @@ def calculate_emax(values, params): # noqa: ARG001 logger=get_logger(debug_mode=False), ) - assert isinstance(solution, list) + assert isinstance(solution, dict) def test_solve_continuous_problem_no_vf_arr(): diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index c8a0fa1..7c20c46 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -68,7 +68,9 @@ def test_analytical_solution(model_name, model_and_params): # ================================================================================== solve_model, _ = get_lcm_function(model=model_and_params["model"], targets="solve") - vf_arr_list: list[Array] = solve_model(params=model_and_params["params"]) # type: ignore[assignment] + vf_arr_dict: dict[int, Array] = solve_model(params=model_and_params["params"]) # type: ignore[assignment] + vf_arr_list = list(dict(sorted(vf_arr_dict.items(), key=lambda x: x[0])).values()) + _numerical = np.stack(vf_arr_list) numerical = { "worker": _numerical[:, 0, :], diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 1a8d413..0bee350 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -103,14 +103,14 @@ def test_get_lcm_function_with_simulation_is_coherent(model): # solve solve_model, params_template = get_lcm_function(model=model, targets="solve") params = tree_map(lambda _: 0.2, params_template) - vf_arr_list = solve_model(params) + vf_arr_dict = solve_model(params) # simulate using solution simulate_model, _ = get_lcm_function(model=model, targets="simulate") solve_then_simulate = simulate_model( params, - vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, initial_states={ "wealth": jnp.array([0.0, 10.0, 50.0]), }, @@ -142,14 +142,14 @@ def test_get_lcm_function_with_simulation_target_iskhakov_et_al_2017(model): # solve model solve_model, params_template = get_lcm_function(model=model, targets="solve") params = tree_map(lambda _: 0.2, params_template) - vf_arr_list = solve_model(params) + vf_arr_dict = solve_model(params) # simulate using solution simulate_model, _ = get_lcm_function(model=model, targets="simulate") simulate_model( params, - vf_arr_list=vf_arr_list, + vf_arr_dict=vf_arr_dict, initial_states={ "wealth": jnp.array([10.0, 10.0, 20.0]), "lagged_retirement": jnp.array( diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index ce77c7a..295fff9 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -1,5 +1,6 @@ import jax.numpy as jnp import pandas as pd +from jax import Array from numpy.testing import assert_array_almost_equal as aaae from pandas.testing import assert_frame_equal @@ -31,7 +32,7 @@ def test_regression_test(): disutility_of_work=1.0, interest_rate=0.05, ) - got_solve = solve(params) + got_solve: dict[int, Array] = solve(params) # type: ignore[assignment] solve_and_simulate, _ = get_lcm_function( model=model_config, @@ -47,5 +48,5 @@ def test_regression_test(): # Compare # ================================================================================== - aaae(expected_solve, got_solve, decimal=5) + aaae(expected_solve, list(got_solve.values()), decimal=5) assert_frame_equal(expected_simulate, got_simulate) # type: ignore[arg-type] diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 6dba1e4..581188f 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -1,6 +1,7 @@ import jax.numpy as jnp import pandas as pd import pytest +from jax import Array import lcm from lcm.entry_point import ( @@ -94,7 +95,7 @@ def next_health_deterministic(health): return model_deterministic, model_stochastic, params -def test_compare_deterministic_and_stochastic_results(model_and_params): +def test_compare_deterministic_and_stochastic_results_value_function(model_and_params): """Test that the deterministic and stochastic models produce the same results.""" model_deterministic, model_stochastic, params = model_and_params @@ -110,10 +111,14 @@ def test_compare_deterministic_and_stochastic_results(model_and_params): targets="solve", ) - solution_deterministic = solve_model_deterministic(params) - solution_stochastic = solve_model_stochastic(params) + solution_deterministic: dict[int, Array] = solve_model_deterministic(params) # type: ignore[assignment] + solution_stochastic: dict[int, Array] = solve_model_stochastic(params) # type: ignore[assignment] - assert jnp.array_equal(solution_deterministic, solution_stochastic, equal_nan=True) # type: ignore[arg-type] + assert jnp.array_equal( + jnp.array(list(solution_deterministic.values())), + jnp.array(list(solution_stochastic.values())), + equal_nan=True, + ) # ================================================================================== # Compare simulation results @@ -135,12 +140,12 @@ def test_compare_deterministic_and_stochastic_results(model_and_params): simulation_deterministic = simulate_model_deterministic( params, - vf_arr_list=solution_deterministic, + vf_arr_dict=solution_deterministic, initial_states=initial_states, ) simulation_stochastic = simulate_model_stochastic( params, - vf_arr_list=solution_stochastic, + vf_arr_dict=solution_stochastic, initial_states=initial_states, ) pd.testing.assert_frame_equal(simulation_deterministic, simulation_stochastic) # type: ignore[arg-type] From fe9e1cd1087c58d181bba8af4cbc916bcd456dea Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 15:34:01 +0100 Subject: [PATCH 16/20] Refactor simulation --- src/lcm/interfaces.py | 9 ++ src/lcm/simulation/processing.py | 11 +-- src/lcm/simulation/simulate.py | 144 +++++++++++++--------------- tests/simulation/test_processing.py | 25 ++--- tests/simulation/test_simulate.py | 12 +-- 5 files changed, 101 insertions(+), 100 deletions(-) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index a6890d5..2eab3db 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -116,3 +116,12 @@ class InternalModel: n_periods: int # Not properly processed yet random_utility_shocks: ShockType + + +@dataclass(frozen=True) +class InternalSimulationPeriodResults: + """The results of a simulation for one period.""" + + value: Array + choices: dict[str, Array] + states: dict[str, Array] diff --git a/src/lcm/simulation/processing.py b/src/lcm/simulation/processing.py index e0d4b8a..a14c224 100644 --- a/src/lcm/simulation/processing.py +++ b/src/lcm/simulation/processing.py @@ -1,5 +1,4 @@ import inspect -from typing import Any import jax.numpy as jnp import pandas as pd @@ -7,12 +6,12 @@ from jax import Array from lcm.dispatchers import vmap_1d -from lcm.interfaces import InternalModel +from lcm.interfaces import InternalModel, InternalSimulationPeriodResults from lcm.typing import InternalUserFunction, ParamsDict def process_simulated_data( - results: list[dict[str, Any]], + results: dict[int, InternalSimulationPeriodResults], model: InternalModel, params: ParamsDict, additional_targets: list[str] | None = None, @@ -26,7 +25,7 @@ def process_simulated_data( an outer level of periods and an inner level of initial states ids. Args: - results: List of dicts with simulation results. Each dict contains the value, + results: Dict with simulation results. Each dict contains the value, choices, and states for one period. Choices and states are stored in a nested dictionary. model: Model. @@ -40,10 +39,10 @@ def process_simulated_data( """ n_periods = len(results) - n_initial_states = len(results[0]["value"]) + n_initial_states = len(results[0].value) list_of_dicts = [ - {"value": d["value"], **d["choices"], **d["states"]} for d in results + {"value": d.value, **d.choices, **d.states} for d in results.values() ] dict_of_lists = { key: [d[key] for d in list_of_dicts] for key in list(list_of_dicts[0]) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index f94219b..1398855 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -9,7 +9,11 @@ from lcm.discrete_problem import get_solve_discrete_problem_policy from lcm.dispatchers import simulation_spacemap, vmap_1d -from lcm.interfaces import InternalModel, StateChoiceSpace +from lcm.interfaces import ( + InternalModel, + InternalSimulationPeriodResults, + StateChoiceSpace, +) from lcm.random import generate_simulation_keys from lcm.simulation.processing import as_data_frame, process_simulated_data from lcm.simulation.state_choice_space import create_state_choice_space @@ -94,17 +98,17 @@ def simulate( logger.info("Starting simulation") # Preparations - # ================================================================================== + # ---------------------------------------------------------------------------------- n_periods = len(vf_arr_dict) n_initial_states = len(next(iter(initial_states.values()))) - data_scs = create_state_choice_space( + state_choice_space = create_state_choice_space( model=model, initial_states=initial_states, ) discrete_policy_calculator = get_solve_discrete_problem_policy( - variable_info=model.variable_info, + variable_info=model.variable_info ) # The following variables are updated during the forward simulation @@ -112,80 +116,77 @@ def simulate( key = jax.random.key(seed=seed) # Forward simulation - # ================================================================================== - _simulation_results = [] + # ---------------------------------------------------------------------------------- + simulation_results = {} for period in range(n_periods): - # Create data state choice space - # ------------------------------------------------------------------------------ - # Initial states are treated as combination variables, so that the combination - # variables in the data-state-choice-space correspond to the feasible product - # of combination variables and initial states. The space has to be created in - # each iteration because the states change over time. - # ============================================================================== - data_scs = data_scs.replace(states) - - # Compute objects dependent on data-state-choice-space - # ============================================================================== - choices_grid_shape = tuple(len(grid) for grid in data_scs.choices.values()) - cont_choices_grid_shape = tuple( + state_choice_space = state_choice_space.replace(states) + + # We compute these grid shapes in the loop because they can change over time. + # TODO (@timmens): This could still be pre-computed in the entry point. # noqa: TD003,E501 + discrete_choices_grid_shape = tuple( + len(grid) for grid in state_choice_space.choices.values() + ) + continuous_choices_grid_shape = tuple( len(grid) for grid in continuous_choice_grids[period].values() ) # Compute optimal continuous choice conditional on discrete choices - # ============================================================================== - ccv_policy, ccv = solve_continuous_problem( - data_scs=data_scs, - compute_ccv=compute_ccv_policy_functions[period], - continuous_choice_grids=continuous_choice_grids[period], - vf_arr=vf_arr_dict.get(period + 1, jnp.empty(0)), - params=params, + # ------------------------------------------------------------------------------ + # We need to pass the value function array of the next period to the continuous + # choice problem solver. If we are at the last period, we pass an empty array. + next_period_vf_arr = vf_arr_dict.get(period + 1, jnp.empty(0)) + + conditional_continuous_choice_argmax, conditional_continuous_choice_max = ( + solve_continuous_problem( + data_scs=state_choice_space, + compute_ccv=compute_ccv_policy_functions[period], + continuous_choice_grids=continuous_choice_grids[period], + vf_arr=next_period_vf_arr, + params=params, + ) ) # Get optimal discrete choice given the optimal conditional continuous choices - # ============================================================================== - discrete_argmax, value = discrete_policy_calculator(ccv, params=params) + # ------------------------------------------------------------------------------ + discrete_argmax, choice_value = discrete_policy_calculator( + conditional_continuous_choice_max, params=params + ) - # Select optimal continuous choice corresponding to optimal discrete choice + # Get optimal continuous choice index given optimal discrete choice # ------------------------------------------------------------------------------ - # The conditional continuous choice argmax is computed for each discrete choice - # in the data-state-choice-space. Here we select the the optimal continuous - # choice corresponding to the optimal discrete choice. - # ============================================================================== - cont_choice_argmax = filter_ccv_policy( - ccv_policy=ccv_policy, + continuous_choice_argmax = get_continuous_choice_argmax_given_discrete( + conditional_continuous_choice_argmax=conditional_continuous_choice_argmax, discrete_argmax=discrete_argmax, - vars_grid_shape=choices_grid_shape, + discrete_choices_grid_shape=discrete_choices_grid_shape, ) - # Convert optimal choice indices to actual choice values - # ============================================================================== - choices = retrieve_choices( + # Convert choice indices to choice values + # ------------------------------------------------------------------------------ + discrete_choices = get_values_from_indices( flat_indices=discrete_argmax, - grids=data_scs.choices, - grids_shapes=choices_grid_shape, + grids=state_choice_space.choices, + grids_shapes=discrete_choices_grid_shape, ) - cont_choices = retrieve_choices( - flat_indices=cont_choice_argmax, + continuous_choices = get_values_from_indices( + flat_indices=continuous_choice_argmax, grids=continuous_choice_grids[period], - grids_shapes=cont_choices_grid_shape, + grids_shapes=continuous_choices_grid_shape, ) # Store results - # ============================================================================== - choices = {**choices, **cont_choices} - - _simulation_results.append( - { - "value": value, - "choices": choices, - "states": states, - }, + # ------------------------------------------------------------------------------ + choices = {**discrete_choices, **continuous_choices} + + simulation_results[period] = InternalSimulationPeriodResults( + value=choice_value, + choices=choices, + states=states, ) # Update states - # ============================================================================== + # ------------------------------------------------------------------------------ key, stochastic_variables_keys = generate_simulation_keys( key=key, ids=model.function_info.query("is_stochastic_next").index.tolist(), @@ -205,7 +206,7 @@ def simulate( logger.info("Period: %s", period) processed = process_simulated_data( - _simulation_results, + simulation_results, model=model, params=params, additional_targets=additional_targets, @@ -221,7 +222,7 @@ def solve_continuous_problem( vf_arr: Array, params: ParamsDict, ) -> tuple[Array, Array]: - """Solve the agent's continuous choices problem problem. + """Solve the agents' continuous choice problem. Args: data_scs: Class with entries choices and states. @@ -261,43 +262,34 @@ def solve_continuous_problem( ) -# ====================================================================================== -# Filter policy -# ====================================================================================== - - -@partial(vmap_1d, variables=("ccv_policy", "discrete_argmax")) -def filter_ccv_policy( - ccv_policy: Array, +@partial(vmap_1d, variables=("conditional_continuous_choice_argmax", "discrete_argmax")) +def get_continuous_choice_argmax_given_discrete( + conditional_continuous_choice_argmax: Array, discrete_argmax: Array, - vars_grid_shape: tuple[int, ...], + discrete_choices_grid_shape: tuple[int, ...], ) -> Array: """Select optimal continuous choice index given optimal discrete choice. Args: - ccv_policy: Index array of optimal continous choices + conditional_continuous_choice_argmax: Index array of optimal continous choices conditional on discrete choices. discrete_argmax: Index array of optimal discrete choices. - vars_grid_shape: Shape of the variables grid. + discrete_choices_grid_shape: Shape of the discrete choices grid. Returns: Index array of optimal continuous choices. """ - if discrete_argmax is None: - out = ccv_policy - else: - indices = jnp.unravel_index(discrete_argmax, shape=vars_grid_shape) - out = ccv_policy[indices] - return out + indices = jnp.unravel_index(discrete_argmax, shape=discrete_choices_grid_shape) + return conditional_continuous_choice_argmax[indices] -def retrieve_choices( +def get_values_from_indices( flat_indices: Array, grids: dict[str, Array], grids_shapes: tuple[int, ...], ) -> dict[str, Array]: - """Retrieve choices given flat indices. + """Retrieve values from indices. Args: flat_indices: General indices. Represents the index of the flattened grid. @@ -305,7 +297,7 @@ def retrieve_choices( grids_shapes: Shape of the grids. Is used to unravel the index. Returns: - Dictionary of choices. + Dictionary of values. """ nd_indices = vmapped_unravel_index(flat_indices, grids_shapes) diff --git a/tests/simulation/test_processing.py b/tests/simulation/test_processing.py index 3958163..0bd14a4 100644 --- a/tests/simulation/test_processing.py +++ b/tests/simulation/test_processing.py @@ -2,6 +2,7 @@ import pandas as pd from pybaum import tree_equal +from lcm.interfaces import InternalSimulationPeriodResults from lcm.simulation.processing import ( _compute_targets, as_data_frame, @@ -58,24 +59,24 @@ def test_as_data_frame(): def test_process_simulated_data(): - simulated = [ - { - "value": jnp.array([0.1, 0.2]), - "states": {"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, - "choices": {"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, - }, - { - "value": jnp.array([0.3, 0.4]), - "states": { + simulated = { + 0: InternalSimulationPeriodResults( + value=jnp.array([0.1, 0.2]), + states={"a": jnp.array([1, 2]), "b": jnp.array([-1, -2])}, + choices={"c": jnp.array([5, 6]), "d": jnp.array([-5, -6])}, + ), + 1: InternalSimulationPeriodResults( + value=jnp.array([0.3, 0.4]), + states={ "b": jnp.array([-3, -4]), "a": jnp.array([3, 4]), }, - "choices": { + choices={ "d": jnp.array([-7, -8]), "c": jnp.array([7, 8]), }, - }, - ] + ), + } expected = { "value": jnp.array([0.1, 0.2, 0.3, 0.4]), "c": jnp.array([5, 6, 7, 8]), diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index b0f9370..fbcc39e 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -12,8 +12,8 @@ from lcm.logging import get_logger from lcm.next_state import get_next_state_function from lcm.simulation.simulate import ( - filter_ccv_policy, - retrieve_choices, + get_continuous_choice_argmax_given_discrete, + get_values_from_indices, simulate, ) from lcm.solution.state_choice_space import create_state_choice_space @@ -281,7 +281,7 @@ def test_effect_of_disutility_of_work(): def test_retrieve_choices(): - got = retrieve_choices( + got = get_values_from_indices( flat_indices=jnp.array([0, 3, 7]), grids={"a": jnp.linspace(0, 1, 5), "b": jnp.linspace(10, 20, 6)}, grids_shapes=(5, 6), @@ -299,9 +299,9 @@ def test_filter_ccv_policy(): ) argmax = jnp.array([0, 1]) vars_grid_shape = (2,) - got = filter_ccv_policy( - ccv_policy=ccc_policy, + got = get_continuous_choice_argmax_given_discrete( + conditional_continuous_choice_argmax=ccc_policy, discrete_argmax=argmax, - vars_grid_shape=vars_grid_shape, + discrete_choices_grid_shape=vars_grid_shape, ) assert jnp.all(got == jnp.array([0, 0])) From 2f1b08cc74e403cda08f7ea5fdb2fa4b987260d0 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Thu, 27 Feb 2025 18:30:13 +0100 Subject: [PATCH 17/20] Add continuous choices to state-choice-space --- src/lcm/entry_point.py | 20 ++---------- src/lcm/interfaces.py | 35 ++++++++++++++++----- src/lcm/simulation/simulate.py | 23 +++++--------- src/lcm/simulation/state_choice_space.py | 12 +++++-- src/lcm/solution/solve_brute.py | 11 ++----- src/lcm/solution/state_choice_space.py | 14 ++++++--- src/lcm/utils.py | 19 +++++++++++ tests/simulation/test_simulate.py | 8 ++--- tests/simulation/test_state_choice_space.py | 2 +- tests/solution/test_solve_brute.py | 23 +++++--------- tests/solution/test_state_space.py | 3 +- 11 files changed, 89 insertions(+), 81 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index afa4097..b4178c5 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -65,19 +65,6 @@ def get_lcm_function( logger = get_logger(debug_mode=debug_mode) - # ================================================================================== - # Create continuous choice grids - # ---------------------------------------------------------------------------------- - # For now they are the same in all periods but this can change. - # ================================================================================== - continuous_choices_names = internal_model.variable_info.query( - "is_continuous & is_choice" - ).index.tolist() - _choice_grids = {n: internal_model.grids[n] for n in continuous_choices_names} - continuous_choice_grids = { - period: _choice_grids for period in range(internal_model.n_periods) - } - # ================================================================================== # Create model functions and state-choice-spaces # ================================================================================== @@ -109,12 +96,12 @@ def get_lcm_function( compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=u_and_f, - continuous_choice_variables=tuple(_choice_grids), + continuous_choice_variables=tuple(state_choice_space.continuous_choices), ) compute_ccp = get_compute_conditional_continuation_policy( utility_and_feasibility=u_and_f, - continuous_choice_variables=tuple(_choice_grids), + continuous_choice_variables=tuple(state_choice_space.continuous_choices), ) solve_discrete_problem = get_solve_discrete_problem_value( @@ -135,7 +122,6 @@ def get_lcm_function( _solve_model = partial( solve, state_choice_spaces=state_choice_spaces, - continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=compute_ccv_functions, emax_calculators=solve_discrete_problem_functions, logger=logger, @@ -150,7 +136,6 @@ def get_lcm_function( simulate_model = partial( simulate, - continuous_choice_grids=continuous_choice_grids, compute_ccv_policy_functions=compute_ccp_functions, model=internal_model, next_state=next_state_simulate, # type: ignore[arg-type] @@ -159,7 +144,6 @@ def get_lcm_function( solve_and_simulate_model = partial( solve_and_simulate, - continuous_choice_grids=continuous_choice_grids, compute_ccv_policy_functions=compute_ccp_functions, model=internal_model, next_state=next_state_simulate, # type: ignore[arg-type] diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 2eab3db..51f2c60 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -7,6 +7,7 @@ from lcm.grids import ContinuousGrid, DiscreteGrid, Grid from lcm.typing import InternalUserFunction, ParamsDict, ShockType +from lcm.utils import first_non_none @dataclass(frozen=True) @@ -25,37 +26,55 @@ class StateChoiceSpace: The state-choice space becomes the product of state-combinations with the full Cartesian product of the choice variables. + Note: + ----- + We store discrete and continuous choices separately since these are handled during + different stages of the solution and simulation processes. + Attributes: states: Dictionary containing the values of the state variables. - choices: Dictionary containing the values of the choice variables. + discrete_choices: Dictionary containing the values of the discrete choice + variables. + continuous_choices: Dictionary containing the values of the continuous choice + variables. ordered_var_names: Tuple with names of state and choice variables in the order they appear in the variable info table. """ states: dict[str, Array] - choices: dict[str, Array] + discrete_choices: dict[str, Array] + continuous_choices: dict[str, Array] ordered_var_names: tuple[str, ...] def replace( self, states: dict[str, Array] | None = None, - choices: dict[str, Array] | None = None, + discrete_choices: dict[str, Array] | None = None, + continuous_choices: dict[str, Array] | None = None, ) -> "StateChoiceSpace": """Replace the states or choices in the state-choice space. Args: states: Dictionary with new states. If None, the existing states are used. - choices: Dictionary with new choices. If None, the existing choices are - used. + discrete_choices: Dictionary with new discrete choices. If None, the + existing discrete choices are used. + continuous_choices: Dictionary with new continuous choices. If None, the + existing continuous choices are used. Returns: New state-choice space with the replaced states or choices. """ - states = states if states is not None else self.states - choices = choices if choices is not None else self.choices - return dc.replace(self, states=states, choices=choices) + states = first_non_none(states, self.states) + discrete_choices = first_non_none(discrete_choices, self.discrete_choices) + continuous_choices = first_non_none(continuous_choices, self.continuous_choices) + return dc.replace( + self, + states=states, + discrete_choices=discrete_choices, + continuous_choices=continuous_choices, + ) @dataclass(frozen=True) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 1398855..797264e 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -24,7 +24,6 @@ def solve_and_simulate( params: ParamsDict, initial_states: dict[str, Array], - continuous_choice_grids: dict[int, dict[str, Array]], compute_ccv_policy_functions: dict[int, Callable[..., tuple[Array, Array]]], model: InternalModel, next_state: Callable[..., dict[str, Array]], @@ -43,7 +42,6 @@ def solve_and_simulate( return simulate( params=params, initial_states=initial_states, - continuous_choice_grids=continuous_choice_grids, compute_ccv_policy_functions=compute_ccv_policy_functions, model=model, next_state=next_state, @@ -57,7 +55,6 @@ def solve_and_simulate( def simulate( params: ParamsDict, initial_states: dict[str, Array], - continuous_choice_grids: dict[int, dict[str, Array]], compute_ccv_policy_functions: dict[int, Callable[..., tuple[Array, Array]]], model: InternalModel, next_state: Callable[..., dict[str, Array]], @@ -73,8 +70,6 @@ def simulate( params: Dict of model parameters. initial_states: List of initial states to start from. Typically from the observed dataset. - continuous_choice_grids: Dict of length n_periods. Each dict contains 1d grids - for continuous choice variables. compute_ccv_policy_functions: Dict of length n_periods. Each function computes the conditional continuation value dependent on the discrete choices. next_state: Function that returns the next state given the current @@ -125,10 +120,10 @@ def simulate( # We compute these grid shapes in the loop because they can change over time. # TODO (@timmens): This could still be pre-computed in the entry point. # noqa: TD003,E501 discrete_choices_grid_shape = tuple( - len(grid) for grid in state_choice_space.choices.values() + len(grid) for grid in state_choice_space.discrete_choices.values() ) continuous_choices_grid_shape = tuple( - len(grid) for grid in continuous_choice_grids[period].values() + len(grid) for grid in state_choice_space.continuous_choices.values() ) # Compute optimal continuous choice conditional on discrete choices @@ -141,7 +136,6 @@ def simulate( solve_continuous_problem( data_scs=state_choice_space, compute_ccv=compute_ccv_policy_functions[period], - continuous_choice_grids=continuous_choice_grids[period], vf_arr=next_period_vf_arr, params=params, ) @@ -165,13 +159,13 @@ def simulate( # ------------------------------------------------------------------------------ discrete_choices = get_values_from_indices( flat_indices=discrete_argmax, - grids=state_choice_space.choices, + grids=state_choice_space.discrete_choices, grids_shapes=discrete_choices_grid_shape, ) continuous_choices = get_values_from_indices( flat_indices=continuous_choice_argmax, - grids=continuous_choice_grids[period], + grids=state_choice_space.continuous_choices, grids_shapes=continuous_choices_grid_shape, ) @@ -218,7 +212,6 @@ def simulate( def solve_continuous_problem( data_scs: StateChoiceSpace, compute_ccv: Callable[..., tuple[Array, Array]], - continuous_choice_grids: dict[str, Array], vf_arr: Array, params: ParamsDict, ) -> tuple[Array, Array]: @@ -233,8 +226,6 @@ def solve_continuous_problem( - discrete and continuous choice variables - vf_arr - params - continuous_choice_grids: List of dicts with 1d grids for continuous - choice variables. vf_arr: Value function array. params: Dict of model parameters. @@ -248,15 +239,15 @@ def solve_continuous_problem( """ _gridmapped = simulation_spacemap( func=compute_ccv, - choices_var_names=tuple(data_scs.choices), + choices_var_names=tuple(data_scs.discrete_choices), states_var_names=tuple(data_scs.states), ) gridmapped = jax.jit(_gridmapped) return gridmapped( - **data_scs.choices, **data_scs.states, - **continuous_choice_grids, + **data_scs.discrete_choices, + **data_scs.continuous_choices, vf_arr=vf_arr, params=params, ) diff --git a/src/lcm/simulation/state_choice_space.py b/src/lcm/simulation/state_choice_space.py index eeb9f70..e81e0ac 100644 --- a/src/lcm/simulation/state_choice_space.py +++ b/src/lcm/simulation/state_choice_space.py @@ -36,16 +36,22 @@ def create_state_choice_space( ) ordered_var_names = tuple(vi.query("is_state | is_discrete").index) - discrete_choice_names = vi.query("is_choice & is_discrete").index discrete_choices = { name: grid for name, grid in model.grids.items() - if name in discrete_choice_names + if name in vi.query("is_choice & is_discrete").index + } + + continuous_choices = { + name: grid + for name, grid in model.grids.items() + if name in vi.query("is_choice & is_continuous").index } return StateChoiceSpace( states=initial_states, - choices=discrete_choices, + discrete_choices=discrete_choices, + continuous_choices=continuous_choices, ordered_var_names=ordered_var_names, ) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index 2cbd176..9b27c8a 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -12,7 +12,6 @@ def solve( params: ParamsDict, state_choice_spaces: dict[int, StateChoiceSpace], - continuous_choice_grids: dict[int, dict[str, Array]], compute_ccv_functions: dict[int, Callable[[Array, Array], Array]], emax_calculators: dict[int, DiscreteProblemValueSolverFunction], logger: logging.Logger, @@ -31,8 +30,6 @@ def solve( Args: params: Dict of model parameters. state_choice_spaces: Dict with one state_choice_space per period. - continuous_choice_grids: Dict with one dict of 1d grids for continuous - choice variables per period. compute_ccv_functions: Dict with one function needed to solve the agent's problem. Each function depends on: - discrete and continuous state variables @@ -60,7 +57,6 @@ def solve( conditional_continuation_values = solve_continuous_problem( state_choice_space=state_choice_spaces[period], compute_ccv=compute_ccv_functions[period], - continuous_choice_grids=continuous_choice_grids[period], vf_arr=vf_arr, params=params, ) @@ -78,7 +74,6 @@ def solve( def solve_continuous_problem( state_choice_space: StateChoiceSpace, compute_ccv: Callable[..., Array], - continuous_choice_grids: dict[str, Array], vf_arr: Array | None, params: ParamsDict, ) -> Array: @@ -93,8 +88,6 @@ def solve_continuous_problem( - discrete and continuous choice variables - vf_arr - params - continuous_choice_grids: List of dicts with 1d grids for continuous - choice variables. vf_arr: Value function array. params: Dict of model parameters. @@ -112,8 +105,8 @@ def solve_continuous_problem( return gridmapped( **state_choice_space.states, - **state_choice_space.choices, - **continuous_choice_grids, + **state_choice_space.discrete_choices, + **state_choice_space.continuous_choices, vf_arr=vf_arr, params=params, ) diff --git a/src/lcm/solution/state_choice_space.py b/src/lcm/solution/state_choice_space.py index 5bdb9ca..d341c35 100644 --- a/src/lcm/solution/state_choice_space.py +++ b/src/lcm/solution/state_choice_space.py @@ -34,11 +34,14 @@ def create_state_choice_space( discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} - state_grids = {sn: model.grids[sn] for sn in vi.query("is_state").index.tolist()} - choice_grids = { - sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index.tolist() + state_grids = {sn: model.grids[sn] for sn in vi.query("is_state").index} + discrete_choices = { + sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index } - ordered_var_names = tuple(vi.query("is_state | is_discrete").index.tolist()) + continuous_choices = { + sn: model.grids[sn] for sn in vi.query("is_choice & is_continuous").index + } + ordered_var_names = tuple(vi.query("is_state | is_discrete").index) state_space_info = StateSpaceInfo( states_names=tuple(discrete_states_names + continuous_states_names), @@ -48,7 +51,8 @@ def create_state_choice_space( state_choice_space = StateChoiceSpace( states=state_grids, - choices=choice_grids, + discrete_choices=discrete_choices, + continuous_choices=continuous_choices, ordered_var_names=ordered_var_names, ) diff --git a/src/lcm/utils.py b/src/lcm/utils.py index 1f951ee..683b722 100644 --- a/src/lcm/utils.py +++ b/src/lcm/utils.py @@ -21,3 +21,22 @@ def draw_random_seed() -> int: """ return int.from_bytes(os.urandom(4), "little") + + +def first_non_none(*args: T | None) -> T: + """Return the first non-None argument. + + Args: + *args: Arguments to check. + + Returns: + The first non-None argument. + + Raises: + ValueError: If all arguments are None. + + """ + for arg in args: + if arg is not None: + return arg + raise ValueError("All arguments are None") diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index fbcc39e..e08c270 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -36,6 +36,9 @@ @pytest.fixture def simulate_inputs(): model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1) + choices = model_config.choices + choices["consumption"] = choices["consumption"].replace(stop=100) # type: ignore[attr-defined] + model_config = model_config.replace(choices=choices) model = process_model(model_config) state_space_info = create_state_choice_space( @@ -57,12 +60,7 @@ def simulate_inputs(): ) compute_ccv_policy_functions.append(compute_ccv) - n_grid_points = model_config.choices["consumption"].n_points # type: ignore[attr-defined] - return { - "continuous_choice_grids": [ - {"consumption": jnp.linspace(1, 100, num=n_grid_points)}, - ], "compute_ccv_policy_functions": compute_ccv_policy_functions, "model": model, "next_state": get_next_state_function(model, target=Target.SIMULATE), diff --git a/tests/simulation/test_state_choice_space.py b/tests/simulation/test_state_choice_space.py index 45e8e41..c696fc7 100644 --- a/tests/simulation/test_state_choice_space.py +++ b/tests/simulation/test_state_choice_space.py @@ -16,7 +16,7 @@ def test_create_state_choice_space(): "lagged_retirement": jnp.array([0, 1]), }, ) - assert_array_equal(got_space.choices["retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.discrete_choices["retirement"], jnp.array([0, 1])) assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 80532d9..e2d1f4e 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -26,12 +26,15 @@ def test_solve_brute(): # create the list of state_choice_spaces # ================================================================================== _scs = StateChoiceSpace( - choices={ + discrete_choices={ # pick [0, 1] such that no label translation is needed # lazy is like a type, it influences utility but is not affected by choices "lazy": jnp.array([0, 1]), "working": jnp.array([0, 1]), }, + continuous_choices={ + "consumption": jnp.array([0, 1, 2, 3]), + }, states={ # pick [0, 1, 2] such that no coordinate mapping is needed "wealth": jnp.array([0.0, 1.0, 2.0]), @@ -40,15 +43,6 @@ def test_solve_brute(): ) state_choice_spaces = {0: _scs, 1: _scs} - # ================================================================================== - # create continuous choice grids - # ================================================================================== - - # you 1 if working and have at most 2 in existing wealth, so - _ccg = {"consumption": jnp.array([0, 1, 2, 3])} - - continuous_choice_grids = {0: _ccg, 1: _ccg} - # ================================================================================== # create the utility_and_feasibility functions # ================================================================================== @@ -103,7 +97,6 @@ def calculate_emax(values, params): # noqa: ARG001 solution = solve( params=params, state_choice_spaces=state_choice_spaces, - continuous_choice_grids=continuous_choice_grids, compute_ccv_functions=compute_ccv_functions, emax_calculators=emax_calculators, logger=get_logger(debug_mode=False), @@ -114,11 +107,14 @@ def calculate_emax(values, params): # noqa: ARG001 def test_solve_continuous_problem_no_vf_arr(): state_choice_space = StateChoiceSpace( - choices={ + discrete_choices={ "a": jnp.array([0, 1.0]), "b": jnp.array([2, 3.0]), "c": jnp.array([4, 5, 6]), }, + continuous_choices={ + "d": jnp.arange(12.0), + }, states={}, ordered_var_names=("a", "b", "c"), ) @@ -128,8 +124,6 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 feasib = d <= a + b + c return util, feasib - continuous_choice_grids = {"d": jnp.arange(12.0)} - compute_ccv = get_compute_conditional_continuation_value( utility_and_feasibility=_utility_and_feasibility, continuous_choice_variables=("d",), @@ -140,7 +134,6 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 got = solve_continuous_problem( state_choice_space, compute_ccv, - continuous_choice_grids, vf_arr=None, params={}, ) diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index f4fe134..fd66cf7 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -21,7 +21,8 @@ def test_create_state_choice_space(): assert isinstance(state_space_info, StateSpaceInfo) assert jnp.array_equal( - state_choice_space.choices["retirement"], model.choices["retirement"].to_jax() + state_choice_space.discrete_choices["retirement"], + model.choices["retirement"].to_jax(), ) assert jnp.array_equal( state_choice_space.states["wealth"], model.states["wealth"].to_jax() From b898fdc45920627f2fd986da9172727658f898db Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Fri, 28 Feb 2025 10:41:53 +0100 Subject: [PATCH 18/20] Split state-choice-space and state-space-info creation --- src/lcm/entry_point.py | 12 +++++-- src/lcm/solution/state_choice_space.py | 50 ++++++++++++++++++-------- tests/simulation/test_simulate.py | 6 ++-- tests/solution/test_state_space.py | 16 +++++++-- tests/test_entry_point.py | 18 +++++----- tests/test_utility_and_feasibility.py | 6 ++-- 6 files changed, 73 insertions(+), 35 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index b4178c5..7c065c1 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -17,7 +17,10 @@ from lcm.next_state import get_next_state_function from lcm.simulation.simulate import simulate, solve_and_simulate from lcm.solution.solve_brute import solve -from lcm.solution.state_choice_space import create_state_choice_space +from lcm.solution.state_choice_space import ( + create_state_choice_space, + create_state_space_info, +) from lcm.typing import DiscreteProblemValueSolverFunction, ParamsDict, Target from lcm.user_model import Model from lcm.utility_and_feasibility import ( @@ -77,7 +80,12 @@ def get_lcm_function( for period in reversed(range(internal_model.n_periods)): is_last_period = period == last_period - state_choice_space, state_space_info = create_state_choice_space( + state_choice_space = create_state_choice_space( + model=internal_model, + is_last_period=is_last_period, + ) + + state_space_info = create_state_space_info( model=internal_model, is_last_period=is_last_period, ) diff --git a/src/lcm/solution/state_choice_space.py b/src/lcm/solution/state_choice_space.py index d341c35..494c84e 100644 --- a/src/lcm/solution/state_choice_space.py +++ b/src/lcm/solution/state_choice_space.py @@ -7,7 +7,7 @@ def create_state_choice_space( model: InternalModel, *, is_last_period: bool, -) -> tuple[StateChoiceSpace, StateSpaceInfo]: +) -> StateChoiceSpace: """Create a state-choice-space for the model solution. A state-choice-space is a compressed representation of all feasible states and the @@ -28,12 +28,6 @@ def create_state_choice_space( if is_last_period: vi = vi.query("~is_auxiliary") - discrete_states_names = vi.query("is_discrete & is_state").index.tolist() - continuous_states_names = vi.query("is_continuous & is_state").index.tolist() - - discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} - continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} - state_grids = {sn: model.grids[sn] for sn in vi.query("is_state").index} discrete_choices = { sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index @@ -43,17 +37,43 @@ def create_state_choice_space( } ordered_var_names = tuple(vi.query("is_state | is_discrete").index) - state_space_info = StateSpaceInfo( - states_names=tuple(discrete_states_names + continuous_states_names), - discrete_states=discrete_states, # type: ignore[arg-type] - continuous_states=continuous_states, # type: ignore[arg-type] - ) - - state_choice_space = StateChoiceSpace( + return StateChoiceSpace( states=state_grids, discrete_choices=discrete_choices, continuous_choices=continuous_choices, ordered_var_names=ordered_var_names, ) - return state_choice_space, state_space_info + +def create_state_space_info( + model: InternalModel, + *, + is_last_period: bool, +) -> StateSpaceInfo: + """Create a state-space information for the model solution. + + A state-space information is a compressed representation of all feasible states. + + Args: + model: A processed model. + is_last_period: Whether the function is created for the last period. + + Returns: + The state-space information. + + """ + vi = model.variable_info + if is_last_period: + vi = vi.query("~is_auxiliary") + + discrete_states_names = vi.query("is_discrete & is_state").index.tolist() + continuous_states_names = vi.query("is_continuous & is_state").index.tolist() + + discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} + continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} + + return StateSpaceInfo( + states_names=tuple(discrete_states_names + continuous_states_names), + discrete_states=discrete_states, # type: ignore[arg-type] + continuous_states=continuous_states, # type: ignore[arg-type] + ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index e08c270..2453de7 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -16,7 +16,7 @@ get_values_from_indices, simulate, ) -from lcm.solution.state_choice_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_space_info from lcm.typing import Target from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import ( @@ -41,10 +41,10 @@ def simulate_inputs(): model_config = model_config.replace(choices=choices) model = process_model(model_config) - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) compute_ccv_policy_functions = [] for period in range(model.n_periods): diff --git a/tests/solution/test_state_space.py b/tests/solution/test_state_space.py index fd66cf7..73ccb3f 100644 --- a/tests/solution/test_state_space.py +++ b/tests/solution/test_state_space.py @@ -4,6 +4,7 @@ from lcm.interfaces import StateChoiceSpace, StateSpaceInfo from lcm.solution.state_choice_space import ( create_state_choice_space, + create_state_space_info, ) from tests.test_models import get_model_config @@ -12,14 +13,12 @@ def test_create_state_choice_space(): model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) internal_model = process_model(model) - state_choice_space, state_space_info = create_state_choice_space( + state_choice_space = create_state_choice_space( model=internal_model, is_last_period=False, ) assert isinstance(state_choice_space, StateChoiceSpace) - assert isinstance(state_space_info, StateSpaceInfo) - assert jnp.array_equal( state_choice_space.discrete_choices["retirement"], model.choices["retirement"].to_jax(), @@ -28,6 +27,17 @@ def test_create_state_choice_space(): state_choice_space.states["wealth"], model.states["wealth"].to_jax() ) + +def test_create_state_space_info(): + model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) + internal_model = process_model(model) + + state_space_info = create_state_space_info( + model=internal_model, + is_last_period=False, + ) + + assert isinstance(state_space_info, StateSpaceInfo) assert state_space_info.states_names == ("wealth",) assert state_space_info.discrete_states == {} assert state_space_info.continuous_states == model.states diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 0bee350..e829ade 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -8,7 +8,7 @@ ) from lcm.entry_point import get_lcm_function from lcm.input_processing import process_model -from lcm.solution.state_choice_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_space_info from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import get_model_config from tests.test_models.deterministic import RetirementStatus @@ -182,10 +182,10 @@ def test_create_compute_conditional_continuation_value(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, @@ -227,10 +227,10 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, @@ -277,10 +277,10 @@ def test_create_compute_conditional_continuation_policy(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, @@ -323,10 +323,10 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, diff --git a/tests/test_utility_and_feasibility.py b/tests/test_utility_and_feasibility.py index 3b08a11..1802731 100644 --- a/tests/test_utility_and_feasibility.py +++ b/tests/test_utility_and_feasibility.py @@ -6,7 +6,7 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel -from lcm.solution.state_choice_space import create_state_choice_space +from lcm.solution.state_choice_space import create_state_space_info from lcm.typing import ShockType from lcm.utility_and_feasibility import ( _get_feasibility, @@ -32,10 +32,10 @@ def test_get_utility_and_feasibility_function(): }, } - state_space_info = create_state_choice_space( + state_space_info = create_state_space_info( model=model, is_last_period=False, - )[1] + ) u_and_f = get_utility_and_feasibility_function( model=model, From e761c1195e25491bb35a21b20cc52490dc2f05d4 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Fri, 28 Feb 2025 11:40:15 +0100 Subject: [PATCH 19/20] Consolidate simulation and solution state-choice-space functions --- src/lcm/entry_point.py | 2 +- src/lcm/simulation/simulate.py | 2 +- src/lcm/simulation/state_choice_space.py | 57 --------- src/lcm/solution/state_choice_space.py | 79 ------------- src/lcm/state_choice_space.py | 108 ++++++++++++++++++ tests/simulation/test_simulate.py | 2 +- tests/simulation/test_state_choice_space.py | 37 ------ tests/test_entry_point.py | 2 +- ...te_space.py => test_state_choice_space.py} | 36 +++++- tests/test_utility_and_feasibility.py | 2 +- 10 files changed, 147 insertions(+), 180 deletions(-) delete mode 100644 src/lcm/simulation/state_choice_space.py delete mode 100644 src/lcm/solution/state_choice_space.py create mode 100644 src/lcm/state_choice_space.py delete mode 100644 tests/simulation/test_state_choice_space.py rename tests/{solution/test_state_space.py => test_state_choice_space.py} (50%) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 7c065c1..32048b7 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -17,7 +17,7 @@ from lcm.next_state import get_next_state_function from lcm.simulation.simulate import simulate, solve_and_simulate from lcm.solution.solve_brute import solve -from lcm.solution.state_choice_space import ( +from lcm.state_choice_space import ( create_state_choice_space, create_state_space_info, ) diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 797264e..67d1246 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -16,7 +16,7 @@ ) from lcm.random import generate_simulation_keys from lcm.simulation.processing import as_data_frame, process_simulated_data -from lcm.simulation.state_choice_space import create_state_choice_space +from lcm.state_choice_space import create_state_choice_space from lcm.typing import ParamsDict from lcm.utils import draw_random_seed diff --git a/src/lcm/simulation/state_choice_space.py b/src/lcm/simulation/state_choice_space.py deleted file mode 100644 index e81e0ac..0000000 --- a/src/lcm/simulation/state_choice_space.py +++ /dev/null @@ -1,57 +0,0 @@ -from jax import Array - -from lcm.interfaces import InternalModel, StateChoiceSpace - - -def create_state_choice_space( - model: InternalModel, - initial_states: dict[str, Array], -) -> StateChoiceSpace: - """Create the initial state choice space. - - In comparison to the solution, the state choice space in the simulation must be - created during each iteration, because the states change over time. - - Args: - model: Model instance. - initial_states: Dict with initial states. - - Returns: - State choice space. - - Raises: - ValueError: If the initial states do not match the state variables in the model. - - """ - vi = model.variable_info - state_names = set(vi.query("is_state").index) - - if state_names != set(initial_states.keys()): - missing = state_names - set(initial_states.keys()) - too_many = set(initial_states.keys()) - state_names - raise ValueError( - "You need to provide an initial value for each state variable in the model." - f"\n\nMissing initial states: {missing}\n", - f"Provided variables that are not states: {too_many}", - ) - - ordered_var_names = tuple(vi.query("is_state | is_discrete").index) - - discrete_choices = { - name: grid - for name, grid in model.grids.items() - if name in vi.query("is_choice & is_discrete").index - } - - continuous_choices = { - name: grid - for name, grid in model.grids.items() - if name in vi.query("is_choice & is_continuous").index - } - - return StateChoiceSpace( - states=initial_states, - discrete_choices=discrete_choices, - continuous_choices=continuous_choices, - ordered_var_names=ordered_var_names, - ) diff --git a/src/lcm/solution/state_choice_space.py b/src/lcm/solution/state_choice_space.py deleted file mode 100644 index 494c84e..0000000 --- a/src/lcm/solution/state_choice_space.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Create a state space for a given model.""" - -from lcm.interfaces import InternalModel, StateChoiceSpace, StateSpaceInfo - - -def create_state_choice_space( - model: InternalModel, - *, - is_last_period: bool, -) -> StateChoiceSpace: - """Create a state-choice-space for the model solution. - - A state-choice-space is a compressed representation of all feasible states and the - feasible discrete choices within that state. - - Args: - model: A processed model. - is_last_period: Whether the function is created for the last period. - - Returns: - - An object containing the variable values of all variables in the - state-choice-space, the grid specifications for the state variables, and the - names of the state variables. Continuous choice variables are not included. - - The state-space information. - - """ - vi = model.variable_info - if is_last_period: - vi = vi.query("~is_auxiliary") - - state_grids = {sn: model.grids[sn] for sn in vi.query("is_state").index} - discrete_choices = { - sn: model.grids[sn] for sn in vi.query("is_choice & is_discrete").index - } - continuous_choices = { - sn: model.grids[sn] for sn in vi.query("is_choice & is_continuous").index - } - ordered_var_names = tuple(vi.query("is_state | is_discrete").index) - - return StateChoiceSpace( - states=state_grids, - discrete_choices=discrete_choices, - continuous_choices=continuous_choices, - ordered_var_names=ordered_var_names, - ) - - -def create_state_space_info( - model: InternalModel, - *, - is_last_period: bool, -) -> StateSpaceInfo: - """Create a state-space information for the model solution. - - A state-space information is a compressed representation of all feasible states. - - Args: - model: A processed model. - is_last_period: Whether the function is created for the last period. - - Returns: - The state-space information. - - """ - vi = model.variable_info - if is_last_period: - vi = vi.query("~is_auxiliary") - - discrete_states_names = vi.query("is_discrete & is_state").index.tolist() - continuous_states_names = vi.query("is_continuous & is_state").index.tolist() - - discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} - continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} - - return StateSpaceInfo( - states_names=tuple(discrete_states_names + continuous_states_names), - discrete_states=discrete_states, # type: ignore[arg-type] - continuous_states=continuous_states, # type: ignore[arg-type] - ) diff --git a/src/lcm/state_choice_space.py b/src/lcm/state_choice_space.py new file mode 100644 index 0000000..0f826f0 --- /dev/null +++ b/src/lcm/state_choice_space.py @@ -0,0 +1,108 @@ +"""Create a state space for a given model.""" + +import pandas as pd +from jax import Array + +from lcm.interfaces import InternalModel, StateChoiceSpace, StateSpaceInfo + + +def create_state_choice_space( + model: InternalModel, + *, + initial_states: dict[str, Array] | None = None, + is_last_period: bool = False, +) -> StateChoiceSpace: + """Create a state-choice-space. + + Creates the state-choice-space for the solution and simulation of a model. In the + simulation, initial states must be provided. + + Args: + model: A processed model. + initial_states: A dictionary with the initial values of the state variables. + If None, the initial values are the minimum values of the state variables. + is_last_period: Whether the state-choice-space is created for the last period, + in which case auxiliary variables are not included. + + Returns: + A state-choice-space. Contains the grids of the discrete and continuous choices, + the grids of the state variables, or the initial values of the state variables, + and the names of the state and choice variables in the order they appear in the + variable info table. + + """ + vi = model.variable_info + if is_last_period: + vi = vi.query("~is_auxiliary") + + if initial_states is None: + states = {sn: model.grids[sn] for sn in vi.query("is_state").index} + else: + _validate_initial_states(initial_states, variable_info=vi) + states = initial_states + + discrete_choices = { + name: model.grids[name] for name in vi.query("is_choice & is_discrete").index + } + continuous_choices = { + name: model.grids[name] for name in vi.query("is_choice & is_continuous").index + } + ordered_var_names = tuple(vi.query("is_state | is_discrete").index) + + return StateChoiceSpace( + states=states, + discrete_choices=discrete_choices, + continuous_choices=continuous_choices, + ordered_var_names=ordered_var_names, + ) + + +def create_state_space_info( + model: InternalModel, + *, + is_last_period: bool, +) -> StateSpaceInfo: + """Create a state-space information for the model solution. + + A state-space information is a compressed representation of all feasible states. + + Args: + model: A processed model. + is_last_period: Whether the function is created for the last period. + + Returns: + The state-space information. + + """ + vi = model.variable_info + if is_last_period: + vi = vi.query("~is_auxiliary") + + discrete_states_names = vi.query("is_discrete & is_state").index.tolist() + continuous_states_names = vi.query("is_continuous & is_state").index.tolist() + + discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} + continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} + + return StateSpaceInfo( + states_names=tuple(discrete_states_names + continuous_states_names), + discrete_states=discrete_states, # type: ignore[arg-type] + continuous_states=continuous_states, # type: ignore[arg-type] + ) + + +def _validate_initial_states( + initial_states: dict[str, Array], variable_info: pd.DataFrame +) -> None: + """Checks if each model-state has an initial value.""" + states_names_in_model = set(variable_info.query("is_state").index) + provided_states_names = set(initial_states) + + if states_names_in_model != provided_states_names: + missing = states_names_in_model - provided_states_names + too_many = provided_states_names - states_names_in_model + raise ValueError( + "You need to provide an initial array for each state variable in the model." + f"\n\nMissing initial states: {missing}\n", + f"Provided variables that are not states: {too_many}", + ) diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 2453de7..aae30ca 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -16,7 +16,7 @@ get_values_from_indices, simulate, ) -from lcm.solution.state_choice_space import create_state_space_info +from lcm.state_choice_space import create_state_space_info from lcm.typing import Target from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import ( diff --git a/tests/simulation/test_state_choice_space.py b/tests/simulation/test_state_choice_space.py deleted file mode 100644 index c696fc7..0000000 --- a/tests/simulation/test_state_choice_space.py +++ /dev/null @@ -1,37 +0,0 @@ -import jax.numpy as jnp -from numpy.testing import assert_array_equal - -from lcm.input_processing import process_model -from lcm.simulation.state_choice_space import create_state_choice_space -from tests.test_models import get_model_config - - -def test_create_state_choice_space(): - model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) - model = process_model(model_config) - got_space = create_state_choice_space( - model=model, - initial_states={ - "wealth": jnp.array([10.0, 20.0]), - "lagged_retirement": jnp.array([0, 1]), - }, - ) - assert_array_equal(got_space.discrete_choices["retirement"], jnp.array([0, 1])) - assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) - assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) - - -def test_create_state_choice_space_replace(): - model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) - model = process_model(model_config) - space = create_state_choice_space( - model=model, - initial_states={ - "wealth": jnp.array([10.0, 20.0]), - "lagged_retirement": jnp.array([0, 1]), - }, - ) - new_space = space.replace( - states={"wealth": jnp.array([10.0, 30.0])}, - ) - assert_array_equal(new_space.states["wealth"], jnp.array([10.0, 30.0])) diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index e829ade..3a4e1d0 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -8,7 +8,7 @@ ) from lcm.entry_point import get_lcm_function from lcm.input_processing import process_model -from lcm.solution.state_choice_space import create_state_space_info +from lcm.state_choice_space import create_state_space_info from lcm.utility_and_feasibility import get_utility_and_feasibility_function from tests.test_models import get_model_config from tests.test_models.deterministic import RetirementStatus diff --git a/tests/solution/test_state_space.py b/tests/test_state_choice_space.py similarity index 50% rename from tests/solution/test_state_space.py rename to tests/test_state_choice_space.py index 73ccb3f..11de980 100644 --- a/tests/solution/test_state_space.py +++ b/tests/test_state_choice_space.py @@ -1,15 +1,16 @@ import jax.numpy as jnp +from numpy.testing import assert_array_equal from lcm.input_processing import process_model from lcm.interfaces import StateChoiceSpace, StateSpaceInfo -from lcm.solution.state_choice_space import ( +from lcm.state_choice_space import ( create_state_choice_space, create_state_space_info, ) from tests.test_models import get_model_config -def test_create_state_choice_space(): +def test_create_state_choice_space_solution(): model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) internal_model = process_model(model) @@ -28,6 +29,21 @@ def test_create_state_choice_space(): ) +def test_create_state_choice_space_simulation(): + model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) + model = process_model(model_config) + got_space = create_state_choice_space( + model=model, + initial_states={ + "wealth": jnp.array([10.0, 20.0]), + "lagged_retirement": jnp.array([0, 1]), + }, + ) + assert_array_equal(got_space.discrete_choices["retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.states["wealth"], jnp.array([10.0, 20.0])) + assert_array_equal(got_space.states["lagged_retirement"], jnp.array([0, 1])) + + def test_create_state_space_info(): model = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=3) internal_model = process_model(model) @@ -41,3 +57,19 @@ def test_create_state_space_info(): assert state_space_info.states_names == ("wealth",) assert state_space_info.discrete_states == {} assert state_space_info.continuous_states == model.states + + +def test_create_state_choice_space_replace(): + model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) + model = process_model(model_config) + space = create_state_choice_space( + model=model, + initial_states={ + "wealth": jnp.array([10.0, 20.0]), + "lagged_retirement": jnp.array([0, 1]), + }, + ) + new_space = space.replace( + states={"wealth": jnp.array([10.0, 30.0])}, + ) + assert_array_equal(new_space.states["wealth"], jnp.array([10.0, 30.0])) diff --git a/tests/test_utility_and_feasibility.py b/tests/test_utility_and_feasibility.py index 1802731..cf71bde 100644 --- a/tests/test_utility_and_feasibility.py +++ b/tests/test_utility_and_feasibility.py @@ -6,7 +6,7 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel -from lcm.solution.state_choice_space import create_state_space_info +from lcm.state_choice_space import create_state_space_info from lcm.typing import ShockType from lcm.utility_and_feasibility import ( _get_feasibility, From 5d5c047f2195350159f5f6209207a2c88e4465f1 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Fri, 28 Feb 2025 16:29:03 +0100 Subject: [PATCH 20/20] Auto-review --- src/lcm/random.py | 11 +++++++++++ src/lcm/simulation/simulate.py | 3 +-- src/lcm/state_choice_space.py | 23 ++++++++++++++++------- src/lcm/utils.py | 11 ----------- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/lcm/random.py b/src/lcm/random.py index 4c096b1..29feb75 100644 --- a/src/lcm/random.py +++ b/src/lcm/random.py @@ -1,3 +1,4 @@ +import os from functools import partial import jax @@ -60,3 +61,13 @@ def generate_simulation_keys( simulation_keys = dict(zip(ids, keys[1:], strict=True)) return key, simulation_keys + + +def draw_random_seed() -> int: + """Generate a random seed using the operating system's secure entropy pool. + + Returns: + Random seed. + + """ + return int.from_bytes(os.urandom(4), "little") diff --git a/src/lcm/simulation/simulate.py b/src/lcm/simulation/simulate.py index 67d1246..75ebe46 100644 --- a/src/lcm/simulation/simulate.py +++ b/src/lcm/simulation/simulate.py @@ -14,11 +14,10 @@ InternalSimulationPeriodResults, StateChoiceSpace, ) -from lcm.random import generate_simulation_keys +from lcm.random import draw_random_seed, generate_simulation_keys from lcm.simulation.processing import as_data_frame, process_simulated_data from lcm.state_choice_space import create_state_choice_space from lcm.typing import ParamsDict -from lcm.utils import draw_random_seed def solve_and_simulate( diff --git a/src/lcm/state_choice_space.py b/src/lcm/state_choice_space.py index 0f826f0..db3b620 100644 --- a/src/lcm/state_choice_space.py +++ b/src/lcm/state_choice_space.py @@ -3,6 +3,7 @@ import pandas as pd from jax import Array +from lcm.grids import ContinuousGrid, DiscreteGrid from lcm.interfaces import InternalModel, StateChoiceSpace, StateSpaceInfo @@ -78,16 +79,24 @@ def create_state_space_info( if is_last_period: vi = vi.query("~is_auxiliary") - discrete_states_names = vi.query("is_discrete & is_state").index.tolist() - continuous_states_names = vi.query("is_continuous & is_state").index.tolist() + state_names = vi.query("is_state").index.tolist() - discrete_states = {sn: model.gridspecs[sn] for sn in discrete_states_names} - continuous_states = {sn: model.gridspecs[sn] for sn in continuous_states_names} + discrete_states = { + name: grid_spec + for name, grid_spec in model.gridspecs.items() + if name in state_names and isinstance(grid_spec, DiscreteGrid) + } + + continuous_states = { + name: grid_spec + for name, grid_spec in model.gridspecs.items() + if name in state_names and isinstance(grid_spec, ContinuousGrid) + } return StateSpaceInfo( - states_names=tuple(discrete_states_names + continuous_states_names), - discrete_states=discrete_states, # type: ignore[arg-type] - continuous_states=continuous_states, # type: ignore[arg-type] + states_names=tuple(state_names), + discrete_states=discrete_states, + continuous_states=continuous_states, ) diff --git a/src/lcm/utils.py b/src/lcm/utils.py index 683b722..6b31fdc 100644 --- a/src/lcm/utils.py +++ b/src/lcm/utils.py @@ -1,4 +1,3 @@ -import os from collections import Counter from collections.abc import Iterable from itertools import chain @@ -13,16 +12,6 @@ def find_duplicates(*containers: Iterable[T]) -> set[T]: return {v for v, count in counts.items() if count > 1} -def draw_random_seed() -> int: - """Generate a random seed using the operating system's secure entropy pool. - - Returns: - Random seed. - - """ - return int.from_bytes(os.urandom(4), "little") - - def first_non_none(*args: T | None) -> T: """Return the first non-None argument.