From c06c51956d4493d3f2a830ac3c2b9e09e0fb6f71 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 14:45:24 +0100 Subject: [PATCH 01/21] Replace filters by constraints internally --- src/lcm/input_processing/util.py | 4 ++-- tests/input_processing/test_process_model.py | 1 + tests/test_simulate.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lcm/input_processing/util.py b/src/lcm/input_processing/util.py index 5392c8a7..216fd296 100644 --- a/src/lcm/input_processing/util.py +++ b/src/lcm/input_processing/util.py @@ -22,8 +22,8 @@ def get_function_info(model: Model) -> pd.DataFrame: """ info = pd.DataFrame(index=list(model.functions)) - info["is_filter"] = info.index.str.endswith("_filter") - info["is_constraint"] = info.index.str.endswith("_constraint") + info["is_filter"] = False + info["is_constraint"] = info.index.str.endswith(("_constraint", "_filter")) info["is_next"] = ( info.index.str.startswith("next_") & ~info["is_constraint"] & ~info["is_filter"] ) diff --git a/tests/input_processing/test_process_model.py b/tests/input_processing/test_process_model.py index 86c78485..41b3fe6c 100644 --- a/tests/input_processing/test_process_model.py +++ b/tests/input_processing/test_process_model.py @@ -104,6 +104,7 @@ def test_get_grids(model): assert_array_equal(got["c"], jnp.array([0, 1])) +@pytest.mark.xfail(reason="Filters are replaced by constraints internally currently.") def test_process_model_iskhakov_et_al_2017(): model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) model = process_model(model_config) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 859aa99c..b3f7a666 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -409,6 +409,7 @@ def test_filter_ccv_policy(): assert jnp.all(got == jnp.array([0, 0])) +@pytest.mark.xfail(reason="Filters are replaced by constraints internally currently.") def test_create_data_state_choice_space(): model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) model = process_model(model_config) From 92eb813a5a47d509eb09bc5c19d0d2281586aad5 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 14:58:33 +0100 Subject: [PATCH 02/21] Quick-fix dispatchers explanation notebook --- explanations/dispatchers.ipynb | 131 +++++++++++++++++---------------- 1 file changed, 69 insertions(+), 62 deletions(-) diff --git a/explanations/dispatchers.ipynb b/explanations/dispatchers.ipynb index cab40c16..6a3e2cc3 100644 --- a/explanations/dispatchers.ipynb +++ b/explanations/dispatchers.ipynb @@ -31,7 +31,7 @@ "source": [ "# `vmap_1d`\n", "\n", - "Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function." + "Let's start by vectorizing the function `f` over axis `a` using JAX's `vmap` function." ] }, { @@ -49,6 +49,13 @@ "execution_count": null, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, { "data": { "text/plain": [ @@ -251,7 +258,7 @@ "\n", "If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n", "\n", - "Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n", + "Consider a simplified version of our deterministic test model. Curly brackets {...} denote discrete variables; square brackets [...] represent continuous variables.\n", "\n", "- **Choice variables:**\n", "\n", @@ -265,8 +272,8 @@ "\n", " - _wealth_ $\\in [1, 2, 3, 4]$\n", "\n", - "- **Filter:**\n", - " - Absorbing retirement filter: If _lagged_retirement_ is 1, then the choice\n", + "- **Constraints:**\n", + " - Absorbing retirement constraint: If _lagged_retirement_ is 1, then the choice\n", " _retirement_ can never be 0." ] }, @@ -291,7 +298,7 @@ " return jnp.log(consumption) - 0.5 * working + retirement_habit\n", "\n", "\n", - "def absorbing_retirement_filter(retirement, lagged_retirement):\n", + "def absorbing_retirement_constraint(retirement, lagged_retirement):\n", " return jnp.logical_or(retirement == 1, lagged_retirement == 0)\n", "\n", "\n", @@ -300,7 +307,7 @@ " \"utility\": utility,\n", " \"next_lagged_retirement\": lambda retirement: retirement,\n", " \"next_wealth\": lambda wealth, consumption: wealth - consumption,\n", - " \"absorbing_retirement_filter\": absorbing_retirement_filter,\n", + " \"absorbing_retirement_constraint\": absorbing_retirement_constraint,\n", " },\n", " n_periods=1,\n", " choices={\n", @@ -350,7 +357,9 @@ { "data": { "text/plain": [ - "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}" + "{'lagged_retirement': Array([0, 1], dtype=int32),\n", + " 'retirement': Array([0, 1], dtype=int32),\n", + " 'wealth': Array([1., 2., 3., 4.], dtype=float32)}" ] }, "execution_count": null, @@ -370,8 +379,7 @@ { "data": { "text/plain": [ - "{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n", - " 'retirement': Array([0, 1, 1], dtype=int32)}" + "{}" ] }, "execution_count": null, @@ -409,35 +417,17 @@ " \n", " \n", " \n", - " lagged_retirement\n", - " retirement\n", " \n", " \n", " \n", - " \n", - " 0\n", - " 0\n", - " 0\n", - " \n", - " \n", - " 1\n", - " 0\n", - " 1\n", - " \n", - " \n", - " 2\n", - " 1\n", - " 1\n", - " \n", " \n", "\n", "" ], "text/plain": [ - " lagged_retirement retirement\n", - "0 0 0\n", - "1 0 1\n", - "2 1 1" + "Empty DataFrame\n", + "Columns: []\n", + "Index: []" ] }, "execution_count": null, @@ -508,18 +498,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "segments" ] @@ -568,7 +547,9 @@ { "data": { "text/plain": [ - "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}" + "{'lagged_retirement': Array([0, 1], dtype=int32),\n", + " 'retirement': Array([0, 1], dtype=int32),\n", + " 'wealth': Array([1., 2., 3., 4.], dtype=float32)}" ] }, "execution_count": null, @@ -588,8 +569,7 @@ { "data": { "text/plain": [ - "{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n", - " 'retirement': Array([0, 1, 1], dtype=int32)}" + "{}" ] }, "execution_count": null, @@ -609,9 +589,11 @@ { "data": { "text/plain": [ - "Array([[-0.5, -0.5, -0.5, -0.5],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 1. , 2. , 3. , 4. ]], dtype=float32)" + "Array([[[-0.5, -0.5, -0.5, -0.5],\n", + " [ 0. , 0. , 0. , 0. ]],\n", + "\n", + " [[ 0.5, 1.5, 2.5, 3.5],\n", + " [ 1. , 2. , 3. , 4. ]]], dtype=float32)" ] }, "execution_count": null, @@ -636,7 +618,7 @@ { "data": { "text/plain": [ - "(3, 4)" + "(2, 2, 4)" ] }, "execution_count": null, @@ -655,6 +637,26 @@ "Let's try to get this result via looping over the grids and calling `utility` directly" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Space(sparse_vars={}, dense_vars={'lagged_retirement': Array([0, 1], dtype=int32), 'retirement': Array([0, 1], dtype=int32), 'wealth': Array([1., 2., 3., 4.], dtype=float32)})" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sc_space" + ] + }, { "cell_type": "code", "execution_count": null, @@ -664,8 +666,8 @@ "data": { "text/plain": [ "Array([[-0.5, -0.5, -0.5, -0.5],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 1. , 2. , 3. , 4. ]], dtype=float32)" + " [ 1. , 2. , 3. , 4. ],\n", + " [ 0. , 0. , 0. , 0. ]], dtype=float32)" ] }, "execution_count": null, @@ -679,8 +681,8 @@ "# loop over valid combinations of sparse variables (first axis)\n", "for i, (lagged_retirement, retirement) in enumerate(\n", " zip(\n", - " sc_space.sparse_vars[\"lagged_retirement\"],\n", - " sc_space.sparse_vars[\"retirement\"],\n", + " sc_space.dense_vars[\"lagged_retirement\"],\n", + " sc_space.dense_vars[\"retirement\"],\n", " strict=False,\n", " ),\n", "):\n", @@ -749,13 +751,18 @@ { "data": { "text/plain": [ - "Array([[[-0.5 , -0.5 , -0.5 , -0.5 ],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 1. , 2. , 3. , 4. ]],\n", + "Array([[[[-0.5 , -0.5 , -0.5 , -0.5 ],\n", + " [ 0. , 0. , 0. , 0. ]],\n", + "\n", + " [[ 0.5 , 1.5 , 2.5 , 3.5 ],\n", + " [ 1. , 2. , 3. , 4. ]]],\n", + "\n", "\n", - " [[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n", - " [ 5.9914646, 5.9914646, 5.9914646, 5.9914646],\n", - " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)" + " [[[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n", + " [ 5.9914646, 5.9914646, 5.9914646, 5.9914646]],\n", + "\n", + " [[ 6.4914646, 7.4914646, 8.491465 , 9.491465 ],\n", + " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]]], dtype=float32)" ] }, "execution_count": null, @@ -780,7 +787,7 @@ { "data": { "text/plain": [ - "(2, 3, 4)" + "(2, 2, 2, 4)" ] }, "execution_count": null, @@ -795,7 +802,7 @@ ], "metadata": { "kernelspec": { - "display_name": "lcm", + "display_name": "test-cpu", "language": "python", "name": "python3" }, @@ -809,7 +816,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.0" } }, "nbformat": 4, From 14c62c757db8ac7d0a0c505c8b27656b16598255 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 15:17:31 +0100 Subject: [PATCH 03/21] First iteration: Remove sparse variables from discrete problem and typing --- src/lcm/discrete_problem.py | 23 +++-------------------- src/lcm/entry_point.py | 1 - src/lcm/typing.py | 18 +----------------- tests/test_discrete_problem.py | 6 ++---- 4 files changed, 6 insertions(+), 42 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index a26baab2..0f308372 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -24,9 +24,8 @@ import jax.numpy as jnp import pandas as pd from jax import Array -from jax.ops import segment_max -from lcm.typing import ParamsDict, SegmentInfo, ShockType +from lcm.typing import ParamsDict, ShockType def get_solve_discrete_problem( @@ -34,20 +33,16 @@ def get_solve_discrete_problem( random_utility_shock_type: ShockType, variable_info: pd.DataFrame, is_last_period: bool, - choice_segments: SegmentInfo | None, ) -> Callable[[Array], Array]: """Get function that computes the expected max. of conditional continuation values. - The maximum is taken over the discrete and sparse choice variables in each state. + The maximum is taken over the discrete choice variables in each state. Args: random_utility_shock_type: Type of choice shock. Currently only Shock.NONE is supported. Work for "extreme_value" is in progress. variable_info: DataFrame with information about the variables. is_last_period: Whether the function is created for the last period. - choice_segments: Contains segment information of sparse choices. If None, there - are no sparse choices. - params: Dictionary with model parameters. Returns: callable: Function that calculates the expected maximum of the conditional @@ -67,11 +62,7 @@ def get_solve_discrete_problem( else: raise ValueError(f"Invalid shock_type: {random_utility_shock_type}.") - return partial( - func, - choice_axes=choice_axes, - choice_segments=choice_segments, - ) + return partial(func, choice_axes=choice_axes) # ====================================================================================== @@ -82,7 +73,6 @@ def get_solve_discrete_problem( def _solve_discrete_problem_no_shocks( cc_values: Array, choice_axes: tuple[int, ...] | None, - choice_segments: SegmentInfo | None, params: ParamsDict, # noqa: ARG001 ) -> Array: """Reduce conditional continuation values over discrete choices. @@ -93,7 +83,6 @@ def _solve_discrete_problem_no_shocks( choice_axes (tuple[int, ...]): A tuple of indices representing the axes in the value function that correspond to discrete choices. Returns None if there are no discrete choice axes. - choice_segments: See `get_solve_discrete_problem`. params: See `get_solve_discrete_problem`. Returns: @@ -105,12 +94,6 @@ def _solve_discrete_problem_no_shocks( out = cc_values if choice_axes is not None: out = out.max(axis=choice_axes) - if choice_segments is not None: - out = segment_max( - data=out, - indices_are_sorted=True, - **choice_segments, - ) return out diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 1e560fb6..c41b2ebe 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -150,7 +150,6 @@ def get_lcm_function( random_utility_shock_type=_mod.random_utility_shocks, variable_info=_mod.variable_info, is_last_period=is_last_period, - choice_segments=choice_segments[period], ) emax_calculators.append(calculator) diff --git a/src/lcm/typing.py b/src/lcm/typing.py index 610bcddd..cdae3318 100644 --- a/src/lcm/typing.py +++ b/src/lcm/typing.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, TypedDict +from typing import Any from jax import Array @@ -15,19 +15,3 @@ class ShockType(Enum): EXTREME_VALUE = "extreme_value" NONE = None - - -class SegmentInfo(TypedDict): - """Information on segments which is passed to `jax.ops.segment_max`. - - - "segment_ids" are a 1d integer jax.Array that partitions the first dimension of - `data` into segments over which we need to aggregate. - - - "num_segments" is the number of segments. - - The segment_ids are assumed to be sorted. - - """ - - segment_ids: Array - num_segments: int diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index 3f87f33e..503a0925 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -41,6 +41,7 @@ def segment_info(): # ====================================================================================== +@pytest.mark.xfail(reason="Removec choice segments") @pytest.mark.parametrize(("collapse", "n_extra_axes"), test_cases) def test_aggregation_without_shocks(cc_values, segment_info, collapse, n_extra_axes): cc_values, var_info = _get_reshaped_cc_values_and_variable_info( @@ -157,7 +158,6 @@ def test_get_solve_discrete_problem_illustrative(): random_utility_shock_type=ShockType.NONE, variable_info=variable_info, is_last_period=False, - choice_segments=None, ) cc_values = jnp.array( @@ -172,6 +172,7 @@ def test_get_solve_discrete_problem_illustrative(): aaae(got, jnp.array([1, 3, 5])) +@pytest.mark.xfail(reason="Removec choice segments") @pytest.mark.illustrative def test_solve_discrete_problem_no_shocks_illustrative(): cc_values = jnp.array( @@ -187,7 +188,6 @@ def test_solve_discrete_problem_no_shocks_illustrative(): got = _solve_discrete_problem_no_shocks( cc_values, choice_axes=0, - choice_segments=None, params=None, ) aaae(got, jnp.array([4, 5])) @@ -197,7 +197,6 @@ def test_solve_discrete_problem_no_shocks_illustrative(): got = _solve_discrete_problem_no_shocks( cc_values, choice_axes=None, - choice_segments={"segment_ids": jnp.array([0, 0, 1]), "num_segments": 2}, params=None, ) aaae(got, jnp.array([[2, 3], [4, 5]])) @@ -207,7 +206,6 @@ def test_solve_discrete_problem_no_shocks_illustrative(): got = _solve_discrete_problem_no_shocks( cc_values, choice_axes=1, - choice_segments={"segment_ids": jnp.array([0, 0, 1]), "num_segments": 2}, params=None, ) aaae(got, jnp.array([3, 5])) From 1b8f8e5b531463b92d905d933342741e05d7d39c Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 15:25:05 +0100 Subject: [PATCH 04/21] First iteration: Remove sparse variables from solve brute --- src/lcm/solve_brute.py | 5 ++--- tests/test_solve_brute.py | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lcm/solve_brute.py b/src/lcm/solve_brute.py index a7f1ab82..b5da69ad 100644 --- a/src/lcm/solve_brute.py +++ b/src/lcm/solve_brute.py @@ -88,7 +88,7 @@ def solve_continuous_problem( """Solve the agent's continuous choices problem problem. Args: - state_choice_space (Space): Class with entries dense_vars and sparse_vars. + state_choice_space (Space): Class with model variables. compute_ccv (callable): Function that returns the conditional continuation values for a given combination of states and discrete choices. The function depends on: @@ -113,7 +113,7 @@ def solve_continuous_problem( _gridmapped = spacemap( func=compute_ccv, dense_vars=list(state_choice_space.dense_vars), - sparse_vars=list(state_choice_space.sparse_vars), + sparse_vars=[], put_dense_first=False, ) gridmapped = jax.jit(_gridmapped) @@ -121,7 +121,6 @@ def solve_continuous_problem( return gridmapped( **state_choice_space.dense_vars, **continuous_choice_grids, - **state_choice_space.sparse_vars, **state_indexers, vf_arr=vf_arr, params=params, diff --git a/tests/test_solve_brute.py b/tests/test_solve_brute.py index 594635e2..788d10e5 100644 --- a/tests/test_solve_brute.py +++ b/tests/test_solve_brute.py @@ -1,5 +1,6 @@ import jax.numpy as jnp import numpy as np +import pytest from numpy.testing import assert_array_almost_equal as aaae from lcm.entry_point import create_compute_conditional_continuation_value @@ -116,6 +117,7 @@ def calculate_emax(values, params): # noqa: ARG001 assert isinstance(solution, list) +@pytest.mark.xfail(reason="Removec sparse vars segments") def test_solve_continuous_problem_no_vf_arr(): state_choice_space = Space( dense_vars={"a": jnp.array([0, 1.0]), "b": jnp.array([2, 3.0])}, From 2570ace9a81eb53384aca4e3563b5acac8b25d8d Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 15:32:10 +0100 Subject: [PATCH 05/21] First iteration: Remove sparse variables from function representation --- src/lcm/function_representation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lcm/function_representation.py b/src/lcm/function_representation.py index 2b21b3f3..7b1fe9fd 100644 --- a/src/lcm/function_representation.py +++ b/src/lcm/function_representation.py @@ -20,8 +20,7 @@ def get_function_representation( """Create a function representation of pre-calculated values on a grid. An example of a pre-calculated function is a value or policy function. These are - evaluated on the space of all sparse and dense discrete state variables as well as - all continuous state variables. + evaluated on the space of all discrete and continuous state variables. This function dynamically generates a function that looks up and interpolates values of the pre-calculated function. The arguments of the resulting function can be split @@ -37,7 +36,6 @@ def get_function_representation( The resulting function roughly does the following steps: - Translate values of discrete variables into positions - - Look up the position of sparse variables in an indexer array. - Index into the array with the pre-calculated function values to extract only the part on which interpolation is needed. - Translate values of continuous variables into coordinates needed for interpolation From c96bae755ce1d8ade607ec99562db6e80b4dc2df Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 15:37:09 +0100 Subject: [PATCH 06/21] First iteration: Remove segment argmax --- src/lcm/argmax.py | 72 -------------------------------------------- src/lcm/simulate.py | 7 ++--- tests/test_argmax.py | 46 +--------------------------- 3 files changed, 3 insertions(+), 122 deletions(-) diff --git a/src/lcm/argmax.py b/src/lcm/argmax.py index c713f36f..60f20aea 100644 --- a/src/lcm/argmax.py +++ b/src/lcm/argmax.py @@ -1,6 +1,5 @@ import jax.numpy as jnp from jax import Array -from jax.ops import segment_max # ====================================================================================== # argmax @@ -93,74 +92,3 @@ def _flatten_last_n_axes(a: Array, n: int) -> Array: """ return a.reshape(*a.shape[:-n], -1) - - -# ====================================================================================== -# segment argmax -# ====================================================================================== - - -def segment_argmax( - data: Array, - segment_ids: Array, - num_segments: int, -) -> tuple[Array, Array]: - """Computes the maximum within segments of an array over the first axis of data. - - See `jax.ops.segment_max` for reference. If multiple maxima exist, the last index - will be selected. - - Args: - data (Array): Multidimensional array. - segment_ids (Array): An array with integer dtype that indicates the segments - of data (along its leading axis) to be reduced. Values can be repeated and - need not be sorted. Values outside of the range [0, num_segments) are - dropped and do not contribute to the result. - num_segments (int): An int with nonnegative value indicating the number of - segments. The default is set to be the minimum number of segments that would - support all indices in segment_ids, calculated as max(segment_ids) + 1. - Since num_segments determines the size of the output, a static value must be - provided to use segment_max in a JIT-compiled function. - - Returns: - - Array: The argmax values. Has shape (num_segments, *data.shape[1:]). - - - Array: The maximum values. Has shape (num_segments, *data.shape[1:]). - - """ - # Compute segment maximum and bring to the same shape as data - # ================================================================================== - segment_maximum = segment_max( - data=data, - segment_ids=segment_ids, - num_segments=num_segments, - indices_are_sorted=True, - ) - segment_maximum_expanded = segment_maximum[segment_ids] - - # Check where the array attains its maximum - # ================================================================================== - max_value_mask = data == segment_maximum_expanded - - # Create index array of argmax indices for each segment (has same shape as data) - # ================================================================================== - arange = jnp.arange(data.shape[0]) - reshaped = arange.reshape(-1, *([1] * (data.ndim - 1))) - segment_argmax_ids = jnp.broadcast_to(reshaped, data.shape) - - # Set indices to zero that do not correspond to a maximum - # ================================================================================== - max_value_indices = max_value_mask * segment_argmax_ids - - # Select argmax indices for each segment - # ---------------------------------------------------------------------------------- - # Note: If multiple maxima exist, this approach will select the last index. - # ================================================================================== - segment_argmax = segment_max( - data=max_value_indices, - segment_ids=segment_ids, - num_segments=num_segments, - indices_are_sorted=True, - ) - - return segment_argmax, segment_maximum diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 561cfce9..2c7cdeea 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -7,7 +7,7 @@ from dags import concatenate_functions from jax import vmap -from lcm.argmax import argmax, segment_argmax +from lcm.argmax import argmax from lcm.dispatchers import spacemap, vmap_1d from lcm.interfaces import InternalModel, Space @@ -623,10 +623,7 @@ def _calculate_discrete_argmax(values, choice_axes, choice_segments): # Determine argmax and max over sparse choices # ============================================================================== - if choice_segments is not None: - sparse_argmax, _max = segment_argmax(_max, **choice_segments) - else: - sparse_argmax = None + sparse_argmax = None return dense_argmax, sparse_argmax, _max diff --git a/tests/test_argmax.py b/tests/test_argmax.py index d515ea1d..e2499c7b 100644 --- a/tests/test_argmax.py +++ b/tests/test_argmax.py @@ -2,11 +2,10 @@ from jax import jit from numpy.testing import assert_array_equal -from lcm.argmax import _flatten_last_n_axes, _move_axes_to_back, argmax, segment_argmax +from lcm.argmax import _flatten_last_n_axes, _move_axes_to_back, argmax # Test jitted functions # ====================================================================================== -jitted_segment_argmax = jit(segment_argmax, static_argnums=2) jitted_argmax = jit(argmax, static_argnums=[1, 2]) @@ -109,49 +108,6 @@ def test_argmax_with_ties(): assert_array_equal(_argmax, jnp.array([0, 0])) -# ====================================================================================== -# segment argmax -# ====================================================================================== - - -def test_segment_argmax_1d(): - data = jnp.arange(10) - segment_ids = jnp.array([0, 0, 0, 1, 1, 2, 2, 2, 2, 2]) - _argmax, _max = jitted_segment_argmax(data, segment_ids, num_segments=3) - assert_array_equal(_argmax, jnp.array([2, 4, 9])) - assert_array_equal(_max, jnp.array([2, 4, 9])) - - -def test_segment_argmax_2d(): - data = jnp.arange(10).reshape(5, 2) - segment_ids = jnp.array([0, 0, 0, 1, 1]) - _argmax, _max = jitted_segment_argmax(data, segment_ids, num_segments=2) - assert_array_equal(_argmax, jnp.array([[2, 2], [4, 4]])) - assert_array_equal(_max, jnp.array([[4, 5], [8, 9]])) - - -def test_segment_argmax_3d(): - data = jnp.array( - [ - [[0, 5], [3, 0]], - [[1, 2], [0, 0]], - [[0, 0], [0, 0]], - ], - ) - segment_ids = jnp.array([0, 0, 1]) - _argmax, _max = jitted_segment_argmax(data, segment_ids, num_segments=2) - assert_array_equal(_argmax, jnp.array([[[1, 0], [0, 1]], [[2, 2], [2, 2]]])) - assert_array_equal(_max, jnp.array([[[1, 5], [3, 0]], [[0, 0], [0, 0]]])) - - -def test_segment_argmax_ties(): - # If multiple maxima exist, segment_argmax will select the last index. - data = jnp.zeros(3) - segment_ids = jnp.array([0, 1, 1]) - _argmax, _ = jitted_segment_argmax(data, segment_ids, num_segments=2) - assert_array_equal(_argmax, jnp.array([0, 2])) - - # ====================================================================================== # Move axes to back # ====================================================================================== From 3ec43677e12c750bfa5ad610bcee5d5dba467284 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 16:12:27 +0100 Subject: [PATCH 07/21] Remove sparse variables from state choice space --- src/lcm/entry_point.py | 2 -- src/lcm/state_space.py | 52 ++++------------------------------- tests/test_entry_point.py | 8 ------ tests/test_model_functions.py | 2 -- tests/test_simulate.py | 2 -- tests/test_state_space.py | 2 -- 6 files changed, 6 insertions(+), 62 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index c41b2ebe..93de8a3a 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -96,9 +96,7 @@ def get_lcm_function( # ============================================================================== sc_space, space_info, state_indexer, segments = create_state_choice_space( model=_mod, - period=period, is_last_period=is_last_period, - jit_filter=False, ) state_choice_spaces.append(sc_space) diff --git a/src/lcm/state_space.py b/src/lcm/state_space.py index cb1e3e0c..8bd5af5a 100644 --- a/src/lcm/state_space.py +++ b/src/lcm/state_space.py @@ -8,12 +8,10 @@ from dags import concatenate_functions from lcm.dispatchers import productmap, spacemap -from lcm.interfaces import IndexerInfo, InternalModel, Space, SpaceInfo +from lcm.interfaces import InternalModel, Space, SpaceInfo -def create_state_choice_space( - model: InternalModel, period, *, is_last_period, jit_filter -): +def create_state_choice_space(model: InternalModel, *, is_last_period: bool): """Create a state choice space for the model. A state_choice_space is a compressed representation of all feasible states and the @@ -33,9 +31,7 @@ def create_state_choice_space( Args: model (Model): A processed model. - period (int): The period for which the state space is created. is_last_period (bool): Whether the function is created for the last period. - jit_filter (bool): If True, the filter function is compiled with JAX. Returns: Space: Space object containing the sparse and dense variables. This can be used @@ -54,9 +50,6 @@ def create_state_choice_space( if is_last_period: vi = vi.query("~is_auxiliary") - has_sparse_states = (vi.is_sparse & vi.is_state).any() - has_sparse_vars = vi.is_sparse.any() - # ================================================================================== # create state choice space # ================================================================================== @@ -64,47 +57,23 @@ def create_state_choice_space( grids=model.grids, subset=vi.query("is_dense & ~(is_choice & is_continuous)").index.tolist(), ) - if has_sparse_vars: - _filter_mask = create_filter_mask( - model=model, - subset=vi.query("is_sparse").index.tolist(), - fixed_inputs={"_period": period}, - jit_filter=jit_filter, - ) - - _combination_grid = create_combination_grid( - grids=model.grids, - masks=_filter_mask, - subset=vi.query("is_sparse").index.tolist(), - ) - else: - _combination_grid = {} state_choice_space = Space( - sparse_vars=_combination_grid, + sparse_vars={}, dense_vars=_value_grid, ) # ================================================================================== # create indexers and segments # ================================================================================== - if has_sparse_vars: - _state_indexer, _, choice_segments = create_indexers_and_segments( - mask=_filter_mask, - n_sparse_states=len(vi.query("is_sparse & is_state")), - ) - else: - _state_indexer = None - choice_segments = None + choice_segments = None - state_indexers = {"state_indexer": _state_indexer} if has_sparse_states else {} + state_indexers = {} # type: ignore[var-annotated] # ================================================================================== # create state space info # ================================================================================== # axis_names axis_names = vi.query("is_dense & is_state").index.tolist() - if has_sparse_states: - axis_names = ["state_index", *axis_names] # lookup_info _discrete_states = set(vi.query("is_discrete & is_state").index.tolist()) @@ -115,16 +84,7 @@ def create_state_choice_space( interpolation_info = {k: v for k, v in model.gridspecs.items() if k in _cont_states} # indexer infos - if has_sparse_states: - indexer_infos = [ - IndexerInfo( - axis_names=vi.query("is_sparse & is_state").index.tolist(), - name="state_indexer", - out_name="state_index", - ), - ] - else: - indexer_infos = [] + indexer_infos = [] # type: ignore[var-annotated] space_info = SpaceInfo( axis_names=axis_names, diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 5b1201e7..2abd19ea 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -184,9 +184,7 @@ def test_create_compute_conditional_continuation_value(): _, space_info, _, _ = create_state_choice_space( model=model, - period=0, is_last_period=False, - jit_filter=False, ) u_and_f = get_utility_and_feasibility_function( @@ -232,9 +230,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): _, space_info, _, _ = create_state_choice_space( model=model, - period=0, is_last_period=False, - jit_filter=False, ) u_and_f = get_utility_and_feasibility_function( @@ -285,9 +281,7 @@ def test_create_compute_conditional_continuation_policy(): _, space_info, _, _ = create_state_choice_space( model=model, - period=0, is_last_period=False, - jit_filter=False, ) u_and_f = get_utility_and_feasibility_function( @@ -334,9 +328,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): _, space_info, _, _ = create_state_choice_space( model=model, - period=0, is_last_period=False, - jit_filter=False, ) u_and_f = get_utility_and_feasibility_function( diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index 748be938..fd6f854f 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -58,9 +58,7 @@ def test_get_utility_and_feasibility_function(): _, space_info, _, _ = create_state_choice_space( model=model, - period=0, is_last_period=False, - jit_filter=False, ) u_and_f = get_utility_and_feasibility_function( diff --git a/tests/test_simulate.py b/tests/test_simulate.py index b3f7a666..42478c77 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -44,9 +44,7 @@ def simulate_inputs(): _, space_info, _, _ = create_state_choice_space( model=model, - period=0, is_last_period=False, - jit_filter=False, ) compute_ccv_policy_functions = [] diff --git a/tests/test_state_space.py b/tests/test_state_space.py index b73dd3e2..d0b3b68b 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -22,9 +22,7 @@ def test_create_state_choice_space(): ) create_state_choice_space( model=_model, - period=0, is_last_period=False, - jit_filter=False, ) From e39c7c501e0ed551e0ee1c0b7ea7abf7f92b4c83 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 18:34:35 +0100 Subject: [PATCH 08/21] Remove sparse variables from solution state-choice-space --- src/lcm/simulate.py | 98 ++------------------------------------- src/lcm/state_space.py | 21 +++++---- tests/test_simulate.py | 20 -------- tests/test_state_space.py | 24 +++++----- 4 files changed, 28 insertions(+), 135 deletions(-) diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 2c7cdeea..527699b8 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -109,7 +109,6 @@ def simulate( data_scs, data_choice_segments = create_data_scs( states=states, model=model, - period=period, ) # Compute objects dependent on data-state-choice-space @@ -467,7 +466,6 @@ def retrieve_non_sparse_choices(indices, grids, grid_shape): def create_data_scs( states, model: InternalModel, - period, ): """Create data state choice space. @@ -485,10 +483,6 @@ def create_data_scs( # ================================================================================== vi = model.variable_info - has_sparse_choice_vars = len(vi.query("is_sparse & is_choice")) > 0 - - n_states = len(next(iter(states.values()))) - # check that all states have an initial value # ================================================================================== state_names = set(vi.query("is_state").index) @@ -504,68 +498,14 @@ def create_data_scs( # get sparse and dense choices # ================================================================================== - sparse_choices = { - name: grid - for name, grid in model.grids.items() - if name in vi.query("is_sparse & is_choice").index.tolist() - } - dense_choices = { name: grid for name, grid in model.grids.items() if name in vi.query("is_dense & is_choice & ~is_continuous").index.tolist() } - # create sparse choice state product - # ================================================================================== - if has_sparse_choice_vars: - # create sparse choice product - # ============================================================================== - sc_product, n_sc_product_combinations = dict_product(sparse_choices) - - # create full sparse choice state product - # ============================================================================== - _combination_grid = {} - for name, state in states.items(): - _combination_grid[name] = jnp.repeat( - state, - repeats=n_sc_product_combinations, - ) - - for name, choice in sc_product.items(): - _combination_grid[name] = jnp.tile(choice, reps=n_states) - - # create filter mask - # ============================================================================== - filter_names = model.function_info.query("is_filter").index.tolist() - - scalar_filter = concatenate_functions( - functions=model.functions, - targets=filter_names, - aggregator=jnp.logical_and, - ) - - fixed_inputs = {"_period": period} - potential_kwargs = _combination_grid | fixed_inputs - - parameters = list(inspect.signature(scalar_filter).parameters) - kwargs = {k: v for k, v in potential_kwargs.items() if k in parameters} - - # we do not vmap over the period variable - vmapped_parameters = [p for p in parameters if p != "_period"] - - _filter = vmap_1d(scalar_filter, variables=vmapped_parameters) - mask = _filter(**kwargs) - - # filter infeasible combinations - # ============================================================================== - combination_grid = { - name: grid[mask] for name, grid in _combination_grid.items() - } - - else: - combination_grid = states - data_choice_segments = None + combination_grid = states + data_choice_segments = None data_scs = Space( sparse_vars=combination_grid, @@ -574,13 +514,7 @@ def create_data_scs( # create choice segments # ================================================================================== - if has_sparse_choice_vars: - data_choice_segments = create_choice_segments( - mask=mask, - n_sparse_states=n_states, - ) - else: - data_choice_segments = None + data_choice_segments = None return data_scs, data_choice_segments @@ -652,32 +586,6 @@ def dict_product(d): return dict(zip(d.keys(), list(stacked.T), strict=True)), len(stacked) -def create_choice_segments(mask, n_sparse_states): - """Create choice segment info related to sparse states and choices. - - Comment: Can be made more memory efficient by reshaping mask into 2d. - - Args: - mask (jnp.ndarray): Boolean 1d array, where each entry corresponds to a - data-state-choice combination. - n_sparse_states (np.ndarray): Number of sparse states - - Returns: - dict: Dict with segment_info. - - """ - n_choice_combinations = len(mask) // n_sparse_states - state_ids = jnp.repeat( - jnp.arange(n_sparse_states), - repeats=n_choice_combinations, - ) - segments = state_ids[mask] - return { - "segment_ids": jnp.array(segments), - "num_segments": len(jnp.unique(segments)), - } - - def determine_discrete_dense_choice_axes(variable_info): """Determine which axes correspond to discrete and dense choices. diff --git a/src/lcm/state_space.py b/src/lcm/state_space.py index 8bd5af5a..a3feefa2 100644 --- a/src/lcm/state_space.py +++ b/src/lcm/state_space.py @@ -96,7 +96,16 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): return state_choice_space, space_info, state_indexers, choice_segments -def create_filter_mask(model: InternalModel, subset, fixed_inputs=None, *, jit_filter): +def _create_value_grid(grids, subset): + return {name: grid for name, grid in grids.items() if name in subset} + + +# ====================================================================================== +# Unused code from hereon: +# ====================================================================================== + + +def _create_filter_mask(model: InternalModel, subset, fixed_inputs=None, *, jit_filter): """Create mask for combinations of grid values that is True if all filters are True. Args: @@ -143,7 +152,7 @@ def create_filter_mask(model: InternalModel, subset, fixed_inputs=None, *, jit_f return _filter(**kwargs) -def create_forward_mask( +def _create_forward_mask( initial, grids, next_functions, @@ -233,7 +242,7 @@ def create_forward_mask( return mask -def create_combination_grid(grids, masks, subset=None): +def _create_combination_grid(grids, masks, subset=None): """Create a grid of all feasible combinations of variables. Args: @@ -287,7 +296,7 @@ def _combine_masks(masks): return np.array(mask) -def create_indexers_and_segments(mask, n_sparse_states, fill_value=-1): +def _create_indexers_and_segments(mask, n_sparse_states, fill_value=-1): """Create indexers and segment info related to sparse states and choices. Notes: @@ -341,7 +350,3 @@ def create_indexers_and_segments(mask, n_sparse_states, fill_value=-1): jnp.array(state_choice_indexer), {"segment_ids": jnp.array(segments), "num_segments": n_feasible_states}, ) - - -def _create_value_grid(grids, subset): - return {name: grid for name, grid in grids.items() if name in subset} diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 42478c77..faac01ff 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -1,7 +1,6 @@ import jax.numpy as jnp import pandas as pd import pytest -from jax import random from numpy.testing import assert_array_almost_equal, assert_array_equal from pybaum import tree_equal @@ -18,7 +17,6 @@ _compute_targets, _generate_simulation_keys, _process_simulated_data, - create_choice_segments, create_data_scs, determine_discrete_dense_choice_axes, dict_product, @@ -427,24 +425,6 @@ def test_create_data_state_choice_space(): assert got_segment_info["num_segments"] == 2 -def test_choice_segments(): - got = create_choice_segments( - mask=jnp.array([True, False, True, False, True, False]), - n_sparse_states=2, - ) - assert_array_equal(jnp.array([0, 0, 1]), got["segment_ids"]) - assert got["num_segments"] == 2 - - -def test_choice_segments_weakly_increasing(): - key = random.PRNGKey(12345) - n_states, n_choices = random.randint(key, shape=(2,), minval=1, maxval=100) - mask_len = n_states * n_choices - mask = random.choice(key, a=2, shape=(mask_len,), p=jnp.array([0.5, 0.5])) - got = create_choice_segments(mask, n_sparse_states=n_states)["segment_ids"] - assert jnp.all(got[1:] - got[:-1] >= 0) - - def test_dict_product(): d = {"a": jnp.array([0, 1]), "b": jnp.array([2, 3])} got_dict, got_length = dict_product(d) diff --git a/tests/test_state_space.py b/tests/test_state_space.py index d0b3b68b..42f12af8 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -7,10 +7,10 @@ from lcm.input_processing import process_model from lcm.interfaces import InternalModel from lcm.state_space import ( - create_combination_grid, - create_filter_mask, - create_forward_mask, - create_indexers_and_segments, + _create_combination_grid, + _create_filter_mask, + _create_forward_mask, + _create_indexers_and_segments, create_state_choice_space, ) from tests.test_models import get_model_config @@ -80,7 +80,7 @@ def absorbing_retirement_filter(retirement, lagged_retirement): @pytest.mark.parametrize(("period", "expected"), PARAMETRIZATION) def test_create_filter_mask(filter_mask_inputs, period, expected): - calculated = create_filter_mask( + calculated = _create_filter_mask( model=filter_mask_inputs, subset=["lagged_retirement", "retirement"], fixed_inputs={"period": period}, @@ -98,7 +98,7 @@ def test_create_combination_grid(): mask = jnp.array([[True, False], [True, True]]) - calculated = create_combination_grid(grids=grids, masks=mask) + calculated = _create_combination_grid(grids=grids, masks=mask) expected = { "lagged_retirement": jnp.array([0, 1, 1]), @@ -120,7 +120,7 @@ def test_create_combination_grid_2_masks(): jnp.array([[True, True], [False, True]]), ] - calculated = create_combination_grid(grids=grids, masks=masks) + calculated = _create_combination_grid(grids=grids, masks=masks) expected = { "lagged_retirement": jnp.array([0, 1]), @@ -143,7 +143,7 @@ def test_create_combination_grid_multiple_masks(): jnp.array([[True, True], [False, True]]), ] - calculated = create_combination_grid(grids=grids, masks=masks) + calculated = _create_combination_grid(grids=grids, masks=masks) expected = { "lagged_retirement": jnp.array([0, 1]), @@ -177,7 +177,7 @@ def test_create_forward_mask(): def next_experience(experience, working): return experience + working - calculated = create_forward_mask( + calculated = _create_forward_mask( initial=initial, grids=grids, next_functions={"next_experience": next_experience}, @@ -218,7 +218,7 @@ def next_experience(experience, working): def next_health(experience, working): return ((experience + working) > 1).astype(int) - calculated = create_forward_mask( + calculated = _create_forward_mask( initial=initial, grids=grids, next_functions={"next_experience": next_experience, "next_health": next_health}, @@ -271,7 +271,7 @@ def next_experience(experience, healthy_working): def next_health(working): return (working == 2).astype(int) - calculated = create_forward_mask( + calculated = _create_forward_mask( initial=initial, grids=grids, next_functions={"next_experience": next_experience, "next_health": next_health}, @@ -300,7 +300,7 @@ def test_create_indexers_and_segments(): mask[2] = True mask = jnp.array(mask) - state_indexer, choice_indexer, segments = create_indexers_and_segments( + state_indexer, choice_indexer, segments = _create_indexers_and_segments( mask=mask, n_sparse_states=2, ) From 353136d5a1f938fdd6314e68197cf4a0ac715183 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 18:51:56 +0100 Subject: [PATCH 09/21] Remove sparse variables from state choice space --- src/lcm/dispatchers.py | 4 +- src/lcm/simulate.py | 8 +- src/lcm/solve_brute.py | 1 - src/lcm/state_space.py | 260 -------------------------------------- tests/test_dispatchers.py | 10 +- tests/test_state_space.py | 244 ----------------------------------- 6 files changed, 6 insertions(+), 521 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index d6eed9b4..13de957b 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -13,8 +13,6 @@ def spacemap( func: F, dense_vars: list[str], sparse_vars: list[str], - *, - put_dense_first: bool, ) -> F: """Apply vmap such that func is evaluated on a space of dense and sparse variables. @@ -71,6 +69,8 @@ def spacemap( # Apply vmap_1d for sparse and _base_productmap for dense variables # ================================================================================== + put_dense_first = False + if not sparse_vars: vmapped = _base_productmap(func, dense_vars) elif put_dense_first: diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 527699b8..59fb5f9d 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -260,7 +260,6 @@ def solve_continuous_problem( func=compute_ccv, dense_vars=list(data_scs.dense_vars), sparse_vars=list(data_scs.sparse_vars), - put_dense_first=False, ) gridmapped = jax.jit(_gridmapped) @@ -504,11 +503,8 @@ def create_data_scs( if name in vi.query("is_dense & is_choice & ~is_continuous").index.tolist() } - combination_grid = states - data_choice_segments = None - data_scs = Space( - sparse_vars=combination_grid, + sparse_vars=states, dense_vars=dense_choices, ) @@ -545,7 +541,7 @@ def get_discrete_policy_calculator(variable_info): """ choice_axes = determine_discrete_dense_choice_axes(variable_info) - def _calculate_discrete_argmax(values, choice_axes, choice_segments): + def _calculate_discrete_argmax(values, choice_axes, choice_segments): # noqa: ARG001 _max = values # Determine argmax and max over dense choices diff --git a/src/lcm/solve_brute.py b/src/lcm/solve_brute.py index b5da69ad..cf055f29 100644 --- a/src/lcm/solve_brute.py +++ b/src/lcm/solve_brute.py @@ -114,7 +114,6 @@ def solve_continuous_problem( func=compute_ccv, dense_vars=list(state_choice_space.dense_vars), sparse_vars=[], - put_dense_first=False, ) gridmapped = jax.jit(_gridmapped) diff --git a/src/lcm/state_space.py b/src/lcm/state_space.py index a3feefa2..f4f25a4a 100644 --- a/src/lcm/state_space.py +++ b/src/lcm/state_space.py @@ -1,13 +1,5 @@ """Create a state space for a given model.""" -import inspect - -import jax -import jax.numpy as jnp -import numpy as np -from dags import concatenate_functions - -from lcm.dispatchers import productmap, spacemap from lcm.interfaces import InternalModel, Space, SpaceInfo @@ -98,255 +90,3 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): def _create_value_grid(grids, subset): return {name: grid for name, grid in grids.items() if name in subset} - - -# ====================================================================================== -# Unused code from hereon: -# ====================================================================================== - - -def _create_filter_mask(model: InternalModel, subset, fixed_inputs=None, *, jit_filter): - """Create mask for combinations of grid values that is True if all filters are True. - - Args: - model (Model): A processed model. - subset (list): The subset of variables to be considered in the mask. - fixed_inputs (dict): A dict of fixed inputs for the filters or aux_functions. An - example would be a model period. - jit_filter (bool): Whether the aggregated filter function is jitted before - applying it. - - Returns: - jax.numpy.ndarray: Multi-Dimensional boolean array that is True for a feasible - combination of variables. The order of the dimensions in the mask is defined - by the order of `grids`. - - """ - # preparations - if subset is None: - subset = model.variable_info.query("is_sparse").index.tolist() - - fixed_inputs = {} if fixed_inputs is None else fixed_inputs - _axis_names = [name for name in model.grids if name in subset] - _filter_names = model.function_info.query("is_filter").index.tolist() - - # Create scalar dag function to evaluate all filters - _scalar_filter = concatenate_functions( - functions=model.functions, - targets=_filter_names, - aggregator=jnp.logical_and, - ) - - # Apply dispatcher to get mask - _filter = productmap(_scalar_filter, variables=_axis_names) - - _valid_args = set(inspect.signature(_filter).parameters.keys()) - _potential_kwargs = {**model.grids, **fixed_inputs} - - kwargs = {k: v for k, v in _potential_kwargs.items() if k in _valid_args} - - # Calculate mask - if jit_filter: - _filter = jax.jit(_filter) - - return _filter(**kwargs) - - -def _create_forward_mask( - initial, - grids, - next_functions, - fixed_inputs=None, - aux_functions=None, - *, - jit_next, -): - """Create a mask for combinations of grid values. - - .. warning:: - This function is extremely experimental and probably buggy - - Args: - initial (dict): Dict of arrays with valid combinations of variables. - grids (dict): Dictionary containing a one-dimensional grid for each - variable that is used as a basis to construct the higher dimensional - grid. - next_functions (dict): Dict of functions for the state space transitions. - All keys need to start with "next". - fixed_inputs (dict): A dict of fixed inputs for the next_functions or - aux_functions. An example would be a model period. - aux_functions (dict): Auxiliary functions that calculate derived variables - needed in the filters. - jit_next (bool): Whether the aggregated next_function is jitted before - applying it. - - """ - # preparations - _state_vars = [ - name for name in grids if f"next_{name}" in next_functions - ] # sort in order of grids - _aux_functions = {} if aux_functions is None else aux_functions - _shape = tuple(len(grids[name]) for name in _state_vars) - _next_functions = { - f"next_{name}": next_functions[f"next_{name}"] for name in _state_vars - } - _fixed_inputs = {} if fixed_inputs is None else fixed_inputs - - # find valid arguments - valid_args = set(initial) | set(_aux_functions) | set(_fixed_inputs) - - # find next functions with only valid arguments - _valid_next_functions = {} - for name, func in _next_functions.items(): - present_args = set(inspect.signature(func).parameters) - if present_args.issubset(valid_args): - _valid_next_functions[name] = func - - # create scalar next function - _next = concatenate_functions( - functions={**_valid_next_functions, **_aux_functions}, - targets=list(_valid_next_functions), - return_type="dict", - ) - - # apply dispatcher - needed_args = set(inspect.signature(_next).parameters) - _needed_initial = {k: val for k, val in initial.items() if k in needed_args} - _gridmapped = spacemap( - _next, - dense_vars=[], - sparse_vars=list(_needed_initial), - put_dense_first=True, - ) - - # calculate next values - if jit_next: - _gridmapped = jax.jit(_gridmapped) - _next_values = _gridmapped(**_needed_initial, **_fixed_inputs) - - # create all-false mask - mask = np.full(_shape, fill_value=False) - - # fill with full slices to get indexers - indices = [] - for i, var in enumerate(_state_vars): - name = f"next_{var}" - if name in _next_values: - indices.append(_next_values[name]) - else: - indices.append(slice(0, _shape[i])) - - # set mask to True with indexers - mask[tuple(indices)] = True - - return mask - - -def _create_combination_grid(grids, masks, subset=None): - """Create a grid of all feasible combinations of variables. - - Args: - grids (dict): Dictionary containing a one-dimensional grid for each - dimension of the combination grid. - masks (list): List of masks that define the feasible combinations. - subset (list): The subset of the variables that enter the combination grid. - By default all variables in grids are considered. - - Returns: - dict: Dictionary containing a one-dimensional array for each variable in the - combination grid. Together these arrays store all feasible combinations - of variables. - - """ - _subset = list(grids) if subset is None else subset - _axis_names = [name for name in grids if name in _subset] - _grids = {name: jnp.array(grids[name]) for name in _axis_names} - - # get combined mask - _mask_np = np.array(_combine_masks(masks)) - - # Calculate meshgrid - _all_combis = jnp.meshgrid(*_grids.values(), indexing="ij") - - # Flatten meshgrid entries - return { - name: arr[_mask_np] for name, arr in zip(_axis_names, _all_combis, strict=True) - } - - -def _combine_masks(masks): - """Combine multiple masks into one. - - Args: - masks (list): List of masks. - - Returns: - np.ndarray: Combined mask. - - """ - if isinstance(masks, np.ndarray | jnp.ndarray): - _masks = [masks] - else: - _masks = sorted(masks, key=lambda x: len(x.shape), reverse=True) - - mask = _masks[0] - for m in _masks[1:]: - _shape = tuple(list(m.shape) + [1] * (mask.ndim - m.ndim)) - mask = jnp.logical_and(mask, m.reshape(_shape)) - return np.array(mask) - - -def _create_indexers_and_segments(mask, n_sparse_states, fill_value=-1): - """Create indexers and segment info related to sparse states and choices. - - Notes: - ------ - - This probably does not work if there is not at least one sparse state variable - and at least one sparse choice variable. - - Args: - mask (np.ndarray): Boolean array with one dimension per state - or choice variable that is True for feasible state-choice - combinations. The state variables occupy the first dimensions. - I.e. the shape is (n_s1, ..., n_sm, n_c1, ..., n_cm). - n_sparse_states (np.ndarray): Number of sparse state variables. - fill_value (np.ndarray): Value of the index array for infeasible - states or choices. - - - Returns: - jax.numpy.ndarray: The state indexer with (n_s1, ..., n_sm). The entries are - ``fill_value`` for infeasible states and count the feasible states - otherwise. - jax.numpy.ndarray: the state-choice indexer with shape - (n_feasible_states, n_c1, ..., n_cn). The entries are ``fill_value`` for - infeasible state-choice combinations and count the feasible state-choice - combinations otherwise. - dict: Dict with segment_info. - - """ - mask = np.array(mask) - - choice_axes = tuple(range(n_sparse_states, mask.ndim)) - is_feasible_state = mask.any(axis=choice_axes) - n_feasible_states = np.count_nonzero(is_feasible_state) - - state_indexer = np.full(is_feasible_state.shape, fill_value) - state_indexer[is_feasible_state] = np.arange(n_feasible_states) - - # reduce mask before doing calculations and using higher dtypes - reduced_mask = mask[is_feasible_state] - - counter = reduced_mask.cumsum().reshape(reduced_mask.shape) - 1 - state_choice_indexer = np.full(reduced_mask.shape, fill_value) - state_choice_indexer[reduced_mask] = counter[reduced_mask] - - new_choice_axes = tuple(range(1, mask.ndim - n_sparse_states + 1)) - n_choices = np.count_nonzero(reduced_mask, new_choice_axes) - segments = np.repeat(np.arange(n_feasible_states), n_choices) - - return ( - jnp.array(state_indexer), - jnp.array(state_choice_indexer), - {"segment_ids": jnp.array(segments), "num_segments": n_feasible_states}, - ) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 99a91951..4287dca0 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -225,11 +225,9 @@ def expected_spacemap(): return allow_args(g)(*helper).reshape(3, 2, 4 * 5) -@pytest.mark.parametrize("put_dense_first", [True, False]) def test_spacemap_all_arguments_mapped( setup_spacemap, expected_spacemap, - put_dense_first, ): dense_vars, sparse_vars = setup_spacemap @@ -237,14 +235,10 @@ def test_spacemap_all_arguments_mapped( g, list(dense_vars), list(sparse_vars), - put_dense_first=put_dense_first, ) calculated = decorated(**dense_vars, **sparse_vars) - if put_dense_first: - aaae(calculated, expected_spacemap) - else: - aaae(calculated, jnp.transpose(expected_spacemap, axes=(2, 0, 1))) + aaae(calculated, jnp.transpose(expected_spacemap, axes=(2, 0, 1))) @pytest.mark.parametrize( @@ -264,7 +258,7 @@ def test_spacemap_all_arguments_mapped( ) def test_spacemap_arguments_overlap(error_msg, dense_vars, sparse_vars): with pytest.raises(ValueError, match=error_msg): - spacemap(g, dense_vars, sparse_vars, put_dense_first=True) + spacemap(g, dense_vars, sparse_vars) # ====================================================================================== diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 42f12af8..4a0c3767 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -1,16 +1,10 @@ import jax.numpy as jnp -import numpy as np import pandas as pd import pytest -from numpy.testing import assert_array_almost_equal as aaae from lcm.input_processing import process_model from lcm.interfaces import InternalModel from lcm.state_space import ( - _create_combination_grid, - _create_filter_mask, - _create_forward_mask, - _create_indexers_and_segments, create_state_choice_space, ) from tests.test_models import get_model_config @@ -76,241 +70,3 @@ def absorbing_retirement_filter(retirement, lagged_retirement): (50, jnp.array([[False, False], [False, True]])), (10, jnp.array([[True, True], [False, True]])), ] - - -@pytest.mark.parametrize(("period", "expected"), PARAMETRIZATION) -def test_create_filter_mask(filter_mask_inputs, period, expected): - calculated = _create_filter_mask( - model=filter_mask_inputs, - subset=["lagged_retirement", "retirement"], - fixed_inputs={"period": period}, - jit_filter=False, - ) - - aaae(calculated, expected) - - -def test_create_combination_grid(): - grids = { - "lagged_retirement": jnp.array([0, 1]), - "retirement": jnp.array([0, 1]), - } - - mask = jnp.array([[True, False], [True, True]]) - - calculated = _create_combination_grid(grids=grids, masks=mask) - - expected = { - "lagged_retirement": jnp.array([0, 1, 1]), - "retirement": jnp.array([0, 0, 1]), - } - - for key, expected_value in expected.items(): - aaae(calculated[key], expected_value) - - -def test_create_combination_grid_2_masks(): - grids = { - "lagged_retirement": jnp.array([0, 1]), - "retirement": jnp.array([0, 1]), - } - - masks = [ - jnp.array([[True, False], [True, True]]), - jnp.array([[True, True], [False, True]]), - ] - - calculated = _create_combination_grid(grids=grids, masks=masks) - - expected = { - "lagged_retirement": jnp.array([0, 1]), - "retirement": jnp.array([0, 1]), - } - - for key, expected_value in expected.items(): - aaae(calculated[key], expected_value) - - -def test_create_combination_grid_multiple_masks(): - grids = { - "lagged_retirement": jnp.array([0, 1]), - "retirement": jnp.array([0, 1]), - } - - masks = [ - jnp.array([[True, False], [True, True]]), - jnp.array([[True, False], [True, True]]), - jnp.array([[True, True], [False, True]]), - ] - - calculated = _create_combination_grid(grids=grids, masks=masks) - - expected = { - "lagged_retirement": jnp.array([0, 1]), - "retirement": jnp.array([0, 1]), - } - - for key, expected_value in expected.items(): - aaae(calculated[key], expected_value) - - -def test_create_forward_mask(): - """We use the following simplified test case (that does not make economic sense). - - - People can stay at home (work=0), work part time (work=1) or full time (work=2) - - Experience is measured in work units - - Initial experience is only [0, 1] but in total one can accumulate up to 6 points - - People can only work full time if they have no previous work experience - - People have to work at least part time if they have no previous experience - - """ - grids = { - "experience": jnp.arange(6), - "working": jnp.array([0, 1, 2]), - } - - initial = { - "experience": jnp.array([0, 0, 1, 1]), - "working": jnp.array([1, 2, 0, 1]), - } - - def next_experience(experience, working): - return experience + working - - calculated = _create_forward_mask( - initial=initial, - grids=grids, - next_functions={"next_experience": next_experience}, - jit_next=True, - ) - - expected = jnp.array([False, True, True, False, False, False]) - - aaae(calculated, expected) - - -def test_create_forward_mask_multiple_next_funcs(): - """We use another simple example. - - - People can stay at home (work=0), work part time (work=1) or full time (work=2) - - Experience is measured in work units - - Initial experience is only [0, 1] but in total one can accumulate up to 6 points - - People can only work full time if they have no previous work experience - - People have to work at least part time if they have no previous experience - - People get bad health after they have more than 1 experience - - """ - grids = { - "experience": jnp.arange(6), - "working": jnp.array([0, 1, 2]), - "health": jnp.array([0, 1]), - } - - initial = { - "experience": jnp.array([0, 0, 1, 1]), - "working": jnp.array([1, 2, 0, 1]), - "health": jnp.array([0, 0, 0, 0]), - } - - def next_experience(experience, working): - return experience + working - - def next_health(experience, working): - return ((experience + working) > 1).astype(int) - - calculated = _create_forward_mask( - initial=initial, - grids=grids, - next_functions={"next_experience": next_experience, "next_health": next_health}, - jit_next=True, - ) - - expected = jnp.array( - [ - [False, False], - [True, False], - [False, True], - [False, False], - [False, False], - [False, False], - ], - ) - - aaae(calculated, expected) - - -def test_forward_mask_w_aux_function(): - """We use another simple example. - - - People can stay at home (work=0), work part time (work=1) or full time (work=2) - - Experience is measured in work units - - Initial experience is only [0, 1] but in total one can accumulate up to 6 points - - People have to work at least part time if they have no previous experience - - People get bad health after they work full time. - - In bad health additional work experience does not add anything to experience. - - """ - grids = { - "experience": jnp.arange(6), - "working": jnp.array([0, 1, 2]), - "health": jnp.array([0, 1]), - } - - initial = { - "experience": jnp.array([0, 0, 1, 1, 2]), - "working": jnp.array([1, 2, 0, 1, 2]), - "health": jnp.array([0, 0, 1, 1, 1]), - } - - def healthy_working(health, working): - return jnp.where(health == 0, working, 0) - - def next_experience(experience, healthy_working): - return experience + healthy_working - - def next_health(working): - return (working == 2).astype(int) - - calculated = _create_forward_mask( - initial=initial, - grids=grids, - next_functions={"next_experience": next_experience, "next_health": next_health}, - aux_functions={"healthy_working": healthy_working}, - jit_next=False, - ) - - expected = jnp.array( - [ - [False, False], - [True, False], - [False, True], - [False, False], - [False, False], - [False, False], - ], - ) - - aaae(calculated, expected) - - -def test_create_indexers_and_segments(): - mask = np.full((3, 3, 2), fill_value=False) - mask[1, 0, 0] = True - mask[1, -1, -1] = True - mask[2] = True - mask = jnp.array(mask) - - state_indexer, choice_indexer, segments = _create_indexers_and_segments( - mask=mask, - n_sparse_states=2, - ) - - expected_state_indexer = jnp.array([[-1, -1, -1], [0, -1, 1], [2, 3, 4]]) - - expected_choice_indexer = jnp.array([[0, -1], [-1, 1], [2, 3], [4, 5], [6, 7]]) - - expected_segments = jnp.array([0, 1, 2, 2, 3, 3, 4, 4]) - - aaae(state_indexer, expected_state_indexer) - aaae(choice_indexer, expected_choice_indexer) - aaae(segments["segment_ids"], expected_segments) From dc951e2dd40ea2398b7b787cf13bbbd4789f4114 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 18:57:24 +0100 Subject: [PATCH 10/21] First iteration: Remove filter and sparse vars from process-model --- src/lcm/input_processing/process_model.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/lcm/input_processing/process_model.py b/src/lcm/input_processing/process_model.py index 179c6268..5a828b0e 100644 --- a/src/lcm/input_processing/process_model.py +++ b/src/lcm/input_processing/process_model.py @@ -65,13 +65,11 @@ def _get_internal_functions( dict: Dictionary containing all functions of the model. The keys are the names of the functions. The values are the processed functions. The main difference between processed and unprocessed functions is that - processed functions take `params` as argument unless they are filter - functions. + processed functions take `params` as argument. """ variable_info = get_variable_info(model) grids = get_grids(model) - function_info = get_function_info(model) raw_functions = deepcopy(model.functions) @@ -93,9 +91,8 @@ def _get_internal_functions( # ================================================================================== # We wrap the user functions such that they can be called with the 'params' argument # instead of the individual parameters. This is done for all functions except for - # filter functions, because they cannot depend on model parameters; and dynamically - # generated weighting functions for stochastic next functions, since they are - # constructed to accept the 'params' argument by default. + # the dynamically generated weighting functions for stochastic next functions, since + # they are constructed to accept the 'params' argument by default. functions = {} for name, func in raw_functions.items(): @@ -105,19 +102,11 @@ def _get_internal_functions( processed_func = func else: - is_filter_function = function_info.loc[name, "is_filter"] # params[name] contains the dictionary of parameters for the function, which # is empty if the function does not depend on any model parameters. depends_on_params = bool(params[name]) - if is_filter_function: - if params.get(name, False): - raise ValueError( - f"filters cannot depend on model parameters, but {name} does." - ) - processed_func = func - - elif depends_on_params: + if depends_on_params: processed_func = _replace_func_parameters_by_params( func=func, params=params, From 609bc04a0206d24fe6e56ae6671f36f7bfe7c64a Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 19:37:30 +0100 Subject: [PATCH 11/21] Remove is_sparse and is_dense entries from variable info --- src/lcm/discrete_problem.py | 8 ++--- src/lcm/input_processing/util.py | 28 +++++------------ src/lcm/interfaces.py | 3 +- src/lcm/simulate.py | 13 ++------ src/lcm/state_space.py | 4 +-- tests/input_processing/test_process_model.py | 10 ------ tests/test_discrete_problem.py | 33 ++++++-------------- tests/test_simulate.py | 3 +- 8 files changed, 27 insertions(+), 75 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index 0f308372..f89dd35f 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -213,15 +213,11 @@ def _determine_dense_discrete_choice_axes( discrete choice axes. """ - has_sparse = variable_info["is_sparse"].any() - # List of dense variables excluding continuous choice variables. - dense_vars = variable_info.query( - "is_dense & ~(is_choice & is_continuous)", + axes = variable_info.query( + "~(is_choice & is_continuous)", ).index.tolist() - axes = ["__sparse__", *dense_vars] if has_sparse else dense_vars - choice_vars = set(variable_info.query("is_choice").index.tolist()) choice_indices = tuple(i for i, ax in enumerate(axes) if ax in choice_vars) diff --git a/src/lcm/input_processing/util.py b/src/lcm/input_processing/util.py index 216fd296..e5cec0f8 100644 --- a/src/lcm/input_processing/util.py +++ b/src/lcm/input_processing/util.py @@ -18,15 +18,13 @@ def get_function_info(model: Model) -> pd.DataFrame: pd.DataFrame: A table with information about all functions in the model. The index contains the name of a model function. The columns are booleans that are True if the function has the corresponding property. The columns are: - is_next, is_stochastic_next, is_filter, is_constraint. + is_next, is_stochastic_next, is_constraint. """ info = pd.DataFrame(index=list(model.functions)) - info["is_filter"] = False + # Convert both filter and constraint to constraints, until we forbid filters. info["is_constraint"] = info.index.str.endswith(("_constraint", "_filter")) - info["is_next"] = ( - info.index.str.startswith("next_") & ~info["is_constraint"] & ~info["is_filter"] - ) + info["is_next"] = info.index.str.startswith("next_") & ~info["is_constraint"] info["is_stochastic_next"] = [ hasattr(func, "_stochastic_info") for func in model.functions.values() ] @@ -43,7 +41,7 @@ def get_variable_info(model: Model) -> pd.DataFrame: pd.DataFrame: A table with information about all variables in the model. The index contains the name of a model variable. The columns are booleans that are True if the variable has the corresponding property. The columns are: - is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense. + is_state, is_choice, is_continuous, is_discrete. """ function_info = get_function_info(model) @@ -72,20 +70,10 @@ def get_variable_info(model: Model) -> pd.DataFrame: ) info["is_auxiliary"] = [var in auxiliary_variables for var in variables] - filter_names = function_info.query("is_filter").index.tolist() - filtered_variables: set[str] = set() - for name in filter_names: - filtered_variables.update(get_ancestors(model.functions, name)) - - info["is_sparse"] = [var in filtered_variables for var in variables] - info["is_dense"] = ~info["is_sparse"] - - order = info.query("is_sparse & is_state").index.tolist() - order += info.query("is_sparse & is_choice").index.tolist() - order += info.query("is_dense & is_discrete & is_state").index.tolist() - order += info.query("is_dense & is_discrete & is_choice").index.tolist() - order += info.query("is_dense & is_continuous & is_state").index.tolist() - order += info.query("is_dense & is_continuous & is_choice").index.tolist() + order = info.query("is_discrete & is_state").index.tolist() + order += info.query("is_discrete & is_choice").index.tolist() + order += info.query("is_continuous & is_state").index.tolist() + order += info.query("is_continuous & is_choice").index.tolist() if set(order) != set(info.index): raise ValueError("Order and index do not match.") diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 9e71d810..64d09a84 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -80,8 +80,7 @@ class InternalModel: variable_info: A table with information about all variables in the model. The index contains the name of a model variable. The columns are booleans that are True if the variable has the corresponding property. The columns are: - is_state, is_choice, is_continuous, is_discrete, is_sparse, columns are: - is_state, is_choice, is_continuous, is_discrete, is_sparse, is_dense. + is_state, is_choice, is_continuous, is_discrete. functions: Dictionary that maps names of functions to functions. The functions differ from the user functions in that they all except the filter functions take ``params`` as keyword argument. If the original function depended on diff --git a/src/lcm/simulate.py b/src/lcm/simulate.py index 59fb5f9d..a6fb07bb 100644 --- a/src/lcm/simulate.py +++ b/src/lcm/simulate.py @@ -88,8 +88,6 @@ def simulate( variable_info=model.variable_info, ) - sparse_choice_variables = model.variable_info.query("is_choice & is_sparse").index - # The following variables are updated during the forward simulation states = initial_states key = jax.random.PRNGKey(seed=seed) @@ -168,14 +166,9 @@ def simulate( grid_shape=cont_choice_grid_shape, ) - sparse_choices = { - key: data_scs.sparse_vars[key][sparse_argmax] - for key in sparse_choice_variables - } - # Store results # ============================================================================== - choices = {**dense_choices, **sparse_choices, **cont_choices} + choices = {**dense_choices, **cont_choices} _simulation_results.append( { @@ -500,7 +493,7 @@ def create_data_scs( dense_choices = { name: grid for name, grid in model.grids.items() - if name in vi.query("is_dense & is_choice & ~is_continuous").index.tolist() + if name in vi.query("is_choice & ~is_continuous").index.tolist() } data_scs = Space( @@ -594,7 +587,7 @@ def determine_discrete_dense_choice_axes(variable_info): """ discrete_dense_choice_vars = variable_info.query( - "~is_continuous & is_dense & is_choice", + "~is_continuous & is_choice", ).index.tolist() choice_vars = set(variable_info.query("is_choice").index.tolist()) diff --git a/src/lcm/state_space.py b/src/lcm/state_space.py index f4f25a4a..86c3df78 100644 --- a/src/lcm/state_space.py +++ b/src/lcm/state_space.py @@ -47,7 +47,7 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): # ================================================================================== _value_grid = _create_value_grid( grids=model.grids, - subset=vi.query("is_dense & ~(is_choice & is_continuous)").index.tolist(), + subset=vi.query("~(is_choice & is_continuous)").index.tolist(), ) state_choice_space = Space( @@ -65,7 +65,7 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): # create state space info # ================================================================================== # axis_names - axis_names = vi.query("is_dense & is_state").index.tolist() + axis_names = vi.query("is_state").index.tolist() # lookup_info _discrete_states = set(vi.query("is_discrete & is_state").index.tolist()) diff --git a/tests/input_processing/test_process_model.py b/tests/input_processing/test_process_model.py index 41b3fe6c..29080b34 100644 --- a/tests/input_processing/test_process_model.py +++ b/tests/input_processing/test_process_model.py @@ -59,7 +59,6 @@ def test_get_function_info(model): got = get_function_info(model) exp = pd.DataFrame( { - "is_filter": [False], "is_constraint": [False], "is_next": [True], "is_stochastic_next": [False], @@ -79,8 +78,6 @@ def test_get_variable_info(model): "is_discrete": [True, True], "is_stochastic": [False, False], "is_auxiliary": [False, True], - "is_sparse": [False, False], - "is_dense": [True, True], }, index=["a", "c"], ) @@ -110,11 +107,6 @@ def test_process_model_iskhakov_et_al_2017(): model = process_model(model_config) # Variable Info - assert ( - model.variable_info["is_sparse"].to_numpy() - == np.array([True, True, False, False]) - ).all() - assert ( model.variable_info["is_state"].to_numpy() == np.array([True, False, True, False]) @@ -178,8 +170,6 @@ def test_process_model(): model = process_model(model_config) # Variable Info - assert ~(model.variable_info["is_sparse"].to_numpy()).any() - assert ( model.variable_info["is_state"].to_numpy() == np.array([False, True, False]) ).all() diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index 503a0925..05d0a398 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -85,6 +85,7 @@ def test_aggregation_without_shocks(cc_values, segment_info, collapse, n_extra_a test_cases.append((scale, exp, collapse, n_axes)) +@pytest.mark.xfail(reason="Updated sparse variables") @pytest.mark.parametrize(("scale", "expected", "collapse", "n_extra_axes"), test_cases) def test_aggregation_with_extreme_value_shocks( cc_values, @@ -120,11 +121,8 @@ def _get_reshaped_cc_values_and_variable_info(cc_values, collapse, n_extra_axes) n_agg_axes = 1 if collapse else 2 names = [f"v{i}" for i in range(n_variables)] is_choice = [False, True] + [False] * n_extra_axes + [True] * n_agg_axes - is_sparse = [True, True] + [False] * (n_variables - 2) var_info = pd.DataFrame(index=names) var_info["is_choice"] = is_choice - var_info["is_sparse"] = is_sparse - var_info["is_dense"] = ~var_info["is_sparse"] var_info["is_continuous"] = False if collapse: @@ -147,10 +145,8 @@ def _get_reshaped_cc_values_and_variable_info(cc_values, collapse, n_extra_axes) def test_get_solve_discrete_problem_illustrative(): variable_info = pd.DataFrame( { - "is_sparse": [True, False, False], - "is_dense": [False, True, True], - "is_choice": [True, True, False], - "is_continuous": [False, False, False], + "is_choice": [False, True], + "is_continuous": [False, False], }, ) # leads to choice_axes = [1] @@ -307,33 +303,24 @@ def test_segment_logsumexp_illustrative(): @pytest.mark.illustrative -def test_determine_discrete_choice_axes_illustrative(): - # No discrete choice variable - # ================================================================================== - +def test_determine_discrete_choice_axes_illustrative_one_var(): variable_info = pd.DataFrame( { - "is_sparse": [True, False], - "is_dense": [False, True], - "is_choice": [True, False], - "is_discrete": [True, True], + "is_choice": [False, True], "is_continuous": [False, False], }, ) - assert _determine_dense_discrete_choice_axes(variable_info) is None + assert _determine_dense_discrete_choice_axes(variable_info) == (1,) - # One discrete choice variable - # ================================================================================== +@pytest.mark.illustrative +def test_determine_discrete_choice_axes_illustrative_three_var(): variable_info = pd.DataFrame( { - "is_sparse": [True, False, False, False], - "is_dense": [False, True, True, True], - "is_choice": [True, True, False, True], - "is_discrete": [True, True, True, True], + "is_choice": [False, True, True, True], "is_continuous": [False, False, False, False], }, ) - assert _determine_dense_discrete_choice_axes(variable_info) == (1, 3) + assert _determine_dense_discrete_choice_axes(variable_info) == (1, 2, 3) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index faac01ff..3aede9a9 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -438,10 +438,9 @@ def test_determine_discrete_dense_choice_axes(): variable_info = pd.DataFrame( { "is_state": [True, True, False, True, False, False], - "is_dense": [False, True, True, False, True, True], "is_choice": [False, False, True, True, True, True], "is_continuous": [False, True, False, False, False, True], }, ) got = determine_discrete_dense_choice_axes(variable_info) - assert got == (1, 2) + assert got == (1, 2, 3) From 553959917c2035d111e939f17e7a92bc02b0ab9e Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 19:48:51 +0100 Subject: [PATCH 12/21] Remove nearly all mentions of sparse variables from state-choice-space creation --- src/lcm/discrete_problem.py | 20 ++++++-------------- src/lcm/entry_point.py | 5 ++++- src/lcm/state_space.py | 33 +++++++-------------------------- tests/test_entry_point.py | 8 ++++---- tests/test_model_functions.py | 2 +- tests/test_simulate.py | 2 +- 6 files changed, 23 insertions(+), 47 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index f89dd35f..6a20009e 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -196,27 +196,19 @@ def _segment_logsumexp(a, segment_info): def _determine_dense_discrete_choice_axes( variable_info: pd.DataFrame, ) -> tuple[int, ...] | None: - """Get axes of a state choice space that correspond to dense discrete choices. - - Note: The dense choice axes determine over which axes we reduce the conditional - continuation values using a non-segmented operation. The axes ordering of the - conditional continuation value array is given by [sparse_variable, dense_variables]. - The dense continuous choice dimension is already reduced as we are working with - the conditional continuation values. + """Get axes of a state-choice-space that correspond to discrete choices. Args: - variable_info (pd.DataFrame): DataFrame with information about the variables. + variable_info: DataFrame with information about the variables. Returns: - tuple[int, ...] | None: A tuple of indices representing the axes in the value - function that correspond to discrete choices. Returns None if there are no - discrete choice axes. + tuple[int, ...] | None: A tuple of indices representing the axes positions in + the value function that correspond to discrete choices. Returns None if + there are no discrete choice axes. """ # List of dense variables excluding continuous choice variables. - axes = variable_info.query( - "~(is_choice & is_continuous)", - ).index.tolist() + axes = variable_info.query("~(is_choice & is_continuous)").index.tolist() choice_vars = set(variable_info.query("is_choice").index.tolist()) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index 93de8a3a..b7ffd299 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -94,11 +94,14 @@ def get_lcm_function( # call state space creation function, append trivial items to their lists # ============================================================================== - sc_space, space_info, state_indexer, segments = create_state_choice_space( + sc_space, space_info = create_state_choice_space( model=_mod, is_last_period=is_last_period, ) + state_indexer = {} + segments = None + state_choice_spaces.append(sc_space) choice_segments.append(segments) diff --git a/src/lcm/state_space.py b/src/lcm/state_space.py index 86c3df78..73c1f838 100644 --- a/src/lcm/state_space.py +++ b/src/lcm/state_space.py @@ -3,23 +3,13 @@ from lcm.interfaces import InternalModel, Space, SpaceInfo -def create_state_choice_space(model: InternalModel, *, is_last_period: bool): - """Create a state choice space for the model. +def create_state_choice_space( + model: InternalModel, *, is_last_period: bool +) -> tuple[Space, SpaceInfo]: + """Create a state-choice-space for the model. - A state_choice_space is a compressed representation of all feasible states and the - feasible discrete choices within that state. We currently use the following - compressions: - - We distinguish between dense and sparse variables (dense_vars and sparse_vars). - Dense state or choice variables are those whose set of feasible values does not - depend on any other state or choice variables. Sparse state or choice variables are - all other state variables. For dense state variables it is thus enough to store the - grid of feasible values (value_grid), whereas for sparse variables all feasible - combinations (combination_grid) have to be stored. - - Note: - ----- - - We only use the filter mask, not the forward mask (yet). + A state-choice-space is a compressed representation of all feasible states and the + feasible discrete choices within that state. Args: model (Model): A processed model. @@ -30,9 +20,6 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): to execute a function on an entire space. SpaceInfo: A SpaceInfo object that contains all information needed to work with the output of a function evaluated on the space. - dict: Dictionary containing state indexer arrays. - jnp.ndarray: Jax array containing the choice segments needed for the emax - calculations. """ # ================================================================================== @@ -54,12 +41,6 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): sparse_vars={}, dense_vars=_value_grid, ) - # ================================================================================== - # create indexers and segments - # ================================================================================== - choice_segments = None - - state_indexers = {} # type: ignore[var-annotated] # ================================================================================== # create state space info @@ -85,7 +66,7 @@ def create_state_choice_space(model: InternalModel, *, is_last_period: bool): indexer_infos=indexer_infos, ) - return state_choice_space, space_info, state_indexers, choice_segments + return state_choice_space, space_info def _create_value_grid(grids, subset): diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 2abd19ea..2ebffcbf 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -182,7 +182,7 @@ def test_create_compute_conditional_continuation_value(): }, } - _, space_info, _, _ = create_state_choice_space( + _, space_info = create_state_choice_space( model=model, is_last_period=False, ) @@ -228,7 +228,7 @@ def test_create_compute_conditional_continuation_value_with_discrete_model(): }, } - _, space_info, _, _ = create_state_choice_space( + _, space_info = create_state_choice_space( model=model, is_last_period=False, ) @@ -279,7 +279,7 @@ def test_create_compute_conditional_continuation_policy(): }, } - _, space_info, _, _ = create_state_choice_space( + _, space_info = create_state_choice_space( model=model, is_last_period=False, ) @@ -326,7 +326,7 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): }, } - _, space_info, _, _ = create_state_choice_space( + _, space_info = create_state_choice_space( model=model, is_last_period=False, ) diff --git a/tests/test_model_functions.py b/tests/test_model_functions.py index fd6f854f..d5d7f420 100644 --- a/tests/test_model_functions.py +++ b/tests/test_model_functions.py @@ -56,7 +56,7 @@ def test_get_utility_and_feasibility_function(): }, } - _, space_info, _, _ = create_state_choice_space( + _, space_info = create_state_choice_space( model=model, is_last_period=False, ) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 3aede9a9..45afab9e 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -40,7 +40,7 @@ def simulate_inputs(): model_config = get_model_config("iskhakov_et_al_2017_stripped_down", n_periods=1) model = process_model(model_config) - _, space_info, _, _ = create_state_choice_space( + _, space_info = create_state_choice_space( model=model, is_last_period=False, ) From 1ea6367fdaee204e4d454d2872342ef3b7035bf5 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 19:57:59 +0100 Subject: [PATCH 13/21] Make sparse vars optional for spacemap --- src/lcm/dispatchers.py | 31 ++++++++++++++----------------- src/lcm/solve_brute.py | 1 - tests/test_dispatchers.py | 2 +- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/lcm/dispatchers.py b/src/lcm/dispatchers.py index 13de957b..c91d99d0 100644 --- a/src/lcm/dispatchers.py +++ b/src/lcm/dispatchers.py @@ -12,7 +12,7 @@ def spacemap( func: F, dense_vars: list[str], - sparse_vars: list[str], + sparse_vars: list[str] | None = None, ) -> F: """Apply vmap such that func is evaluated on a space of dense and sparse variables. @@ -46,36 +46,33 @@ def spacemap( """ # Check inputs and prepare function # ================================================================================== - overlap = set(dense_vars).intersection(sparse_vars) - if overlap: - raise ValueError( - f"Dense and sparse variables must be disjoint. Overlap: {overlap}", - ) - duplicates = {v for v in dense_vars if dense_vars.count(v) > 1} if duplicates: raise ValueError( f"Same argument provided more than once in dense variables: {duplicates}", ) - duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1} - if duplicates: - raise ValueError( - f"Same argument provided more than once in sparse variables: {duplicates}", - ) + if sparse_vars: + overlap = set(dense_vars).intersection(sparse_vars) + if overlap: + raise ValueError( + f"Dense and sparse variables must be disjoint. Overlap: {overlap}", + ) + + duplicates = {v for v in sparse_vars if sparse_vars.count(v) > 1} + if duplicates: + raise ValueError( + "Same argument provided more than once in sparse variables: " + f"{duplicates}", + ) # jax.vmap cannot deal with keyword-only arguments func = allow_args(func) # Apply vmap_1d for sparse and _base_productmap for dense variables # ================================================================================== - put_dense_first = False - if not sparse_vars: vmapped = _base_productmap(func, dense_vars) - elif put_dense_first: - vmapped = vmap_1d(func, variables=sparse_vars, callable_with="only_args") - vmapped = _base_productmap(vmapped, dense_vars) else: vmapped = _base_productmap(func, dense_vars) vmapped = vmap_1d(vmapped, variables=sparse_vars, callable_with="only_args") diff --git a/src/lcm/solve_brute.py b/src/lcm/solve_brute.py index cf055f29..52e1de6b 100644 --- a/src/lcm/solve_brute.py +++ b/src/lcm/solve_brute.py @@ -113,7 +113,6 @@ def solve_continuous_problem( _gridmapped = spacemap( func=compute_ccv, dense_vars=list(state_choice_space.dense_vars), - sparse_vars=[], ) gridmapped = jax.jit(_gridmapped) diff --git a/tests/test_dispatchers.py b/tests/test_dispatchers.py index 4287dca0..291f9601 100644 --- a/tests/test_dispatchers.py +++ b/tests/test_dispatchers.py @@ -258,7 +258,7 @@ def test_spacemap_all_arguments_mapped( ) def test_spacemap_arguments_overlap(error_msg, dense_vars, sparse_vars): with pytest.raises(ValueError, match=error_msg): - spacemap(g, dense_vars, sparse_vars) + spacemap(g, dense_vars=dense_vars, sparse_vars=sparse_vars) # ====================================================================================== From 68dbc99e791702b83b0400eaf695fb0df8523b7d Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Mon, 3 Feb 2025 20:10:01 +0100 Subject: [PATCH 14/21] Remove all filter references from code-base --- src/lcm/entry_point.py | 11 ++-- src/lcm/interfaces.py | 11 ++-- tests/input_processing/test_process_model.py | 9 ++-- tests/test_entry_point.py | 8 +-- tests/test_models/deterministic.py | 15 +++--- tests/test_models/discrete_deterministic.py | 14 +---- tests/test_models/stochastic.py | 5 +- tests/test_state_space.py | 57 -------------------- 8 files changed, 28 insertions(+), 102 deletions(-) diff --git a/src/lcm/entry_point.py b/src/lcm/entry_point.py index b7ffd299..30175c01 100644 --- a/src/lcm/entry_point.py +++ b/src/lcm/entry_point.py @@ -83,7 +83,7 @@ def get_lcm_function( space_infos = [] compute_ccv_functions = [] compute_ccv_policy_functions = [] - choice_segments = [] + choice_segments = [] # type: ignore[var-annotated] emax_calculators = [] # ================================================================================== @@ -99,23 +99,20 @@ def get_lcm_function( is_last_period=is_last_period, ) - state_indexer = {} - segments = None - state_choice_spaces.append(sc_space) - choice_segments.append(segments) + choice_segments.append(None) if is_last_period: state_indexers.append({}) else: - state_indexers.append(state_indexer) + state_indexers.append({}) space_infos.append(space_info) # ================================================================================== # Shift space info (in period t we require the space info of period t+1) # ================================================================================== - space_infos = space_infos[1:] + [{}] + space_infos = space_infos[1:] + [{}] # type: ignore[list-item] # ================================================================================== # Create model functions diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index 64d09a84..6a51951d 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -82,15 +82,14 @@ class InternalModel: are True if the variable has the corresponding property. The columns are: is_state, is_choice, is_continuous, is_discrete. functions: Dictionary that maps names of functions to functions. The functions - differ from the user functions in that they all except the filter functions - take ``params`` as keyword argument. If the original function depended on - model parameters, those are automatically extracted from ``params`` and - passed to the original function. Otherwise, the ``params`` argument is - simply ignored. + differ from the user functions in that they take ``params`` as keyword + argument. If the original function depended on model parameters, those are + automatically extracted from ``params`` and passed to the original function. + Otherwise, the ``params`` argument is simply ignored. function_info: A table with information about all functions in the model. The index contains the name of a function. The columns are booleans that are True if the function has the corresponding property. The columns are: - is_filter, is_constraint, is_next. + is_constraint, is_next. params: Dict of model parameters. n_periods: Number of periods. random_utility_shocks: Type of random utility shocks. diff --git a/tests/input_processing/test_process_model.py b/tests/input_processing/test_process_model.py index 29080b34..10d912a7 100644 --- a/tests/input_processing/test_process_model.py +++ b/tests/input_processing/test_process_model.py @@ -101,7 +101,6 @@ def test_get_grids(model): assert_array_equal(got["c"], jnp.array([0, 1])) -@pytest.mark.xfail(reason="Filters are replaced by constraints internally currently.") def test_process_model_iskhakov_et_al_2017(): model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) model = process_model(model_config) @@ -159,7 +158,7 @@ def test_process_model_iskhakov_et_al_2017(): assert ( model.function_info["is_constraint"].to_numpy() - == np.array([False, False, False, True, False, False, False]) + == np.array([False, False, False, True, True, False, False]) ).all() assert ~model.function_info.loc["utility"].to_numpy().any() @@ -264,13 +263,13 @@ def raw_func(health, wealth): ) -def test_variable_info_with_continuous_filter_has_unique_index(): +def test_variable_info_with_continuous_constraint_has_unique_index(): model = get_model_config("iskhakov_et_al_2017", n_periods=3) - def wealth_filter(wealth): + def wealth_constraint(wealth): return wealth > 200 - model.functions["wealth_filter"] = wealth_filter + model.functions["wealth_constraint"] = wealth_constraint got = get_variable_info(model) assert got.index.is_unique diff --git a/tests/test_entry_point.py b/tests/test_entry_point.py index 2ebffcbf..e8d3a567 100644 --- a/tests/test_entry_point.py +++ b/tests/test_entry_point.py @@ -360,20 +360,20 @@ def test_create_compute_conditional_continuation_policy_with_discrete_model(): # ====================================================================================== -# Test filter with _period argument +# Test constraints with _period argument # ====================================================================================== -def test_get_lcm_function_with_period_argument_in_filter(): +def test_get_lcm_function_with_period_argument_in_constraint(): model = get_model_config("iskhakov_et_al_2017", n_periods=3) - def absorbing_retirement_filter(retirement, lagged_retirement, _period): + def absorbing_retirement_constraint(retirement, lagged_retirement, _period): return jnp.logical_or( retirement == RetirementStatus.retired, lagged_retirement == RetirementStatus.working, ) - model.functions["absorbing_retirement_filter"] = absorbing_retirement_filter + model.functions["absorbing_retirement_constraint"] = absorbing_retirement_constraint solve_model, params_template = get_lcm_function(model=model) params = tree_map(lambda _: 0.2, params_template) diff --git a/tests/test_models/deterministic.py b/tests/test_models/deterministic.py index bb176efa..92e56451 100644 --- a/tests/test_models/deterministic.py +++ b/tests/test_models/deterministic.py @@ -34,7 +34,7 @@ def utility(consumption, working, disutility_of_work): return jnp.log(consumption) - disutility_of_work * working -def utility_with_filter( +def utility_with_constraint( consumption, working, disutility_of_work, @@ -80,10 +80,7 @@ def consumption_constraint(consumption, wealth): return consumption <= wealth -# -------------------------------------------------------------------------------------- -# Filters -# -------------------------------------------------------------------------------------- -def absorbing_retirement_filter(retirement, lagged_retirement): +def absorbing_retirement_constraint(retirement, lagged_retirement): return jnp.logical_or( retirement == RetirementStatus.retired, lagged_retirement == RetirementStatus.working, @@ -102,11 +99,11 @@ def absorbing_retirement_filter(retirement, lagged_retirement): ), n_periods=3, functions={ - "utility": utility_with_filter, + "utility": utility_with_constraint, "next_wealth": next_wealth, "next_lagged_retirement": lambda retirement: retirement, "consumption_constraint": consumption_constraint, - "absorbing_retirement_filter": absorbing_retirement_filter, + "absorbing_retirement_constraint": absorbing_retirement_constraint, "labor_income": labor_income, "working": working, }, @@ -131,8 +128,8 @@ def absorbing_retirement_filter(retirement, lagged_retirement): ISKHAKOV_ET_AL_2017_STRIPPED_DOWN = Model( description=( - "Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement " - "state, and adds wage function that depends on age." + "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint " + "and the lagged_retirement state, and adds wage function that depends on age." ), n_periods=3, functions={ diff --git a/tests/test_models/discrete_deterministic.py b/tests/test_models/discrete_deterministic.py index bade5068..423faaa2 100644 --- a/tests/test_models/discrete_deterministic.py +++ b/tests/test_models/discrete_deterministic.py @@ -71,23 +71,13 @@ def consumption_constraint(consumption, wealth): return consumption <= wealth -# -------------------------------------------------------------------------------------- -# Filters -# -------------------------------------------------------------------------------------- -def absorbing_retirement_filter(retirement, lagged_retirement): - return jnp.logical_or( - retirement == RetirementStatus.retired, - lagged_retirement == RetirementStatus.working, - ) - - # ====================================================================================== # Model specifications # ====================================================================================== ISKHAKOV_ET_AL_2017_DISCRETE = Model( description=( - "Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement " - "state, and makes the consumption decision and the wealth state discrete." + "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint " + "and the lagged_retirement state, and makes the consumption decision discrete." ), n_periods=3, functions={ diff --git a/tests/test_models/stochastic.py b/tests/test_models/stochastic.py index 3f11c872..f7e05ace 100644 --- a/tests/test_models/stochastic.py +++ b/tests/test_models/stochastic.py @@ -99,8 +99,9 @@ def consumption_constraint(consumption, wealth): ISKHAKOV_ET_AL_2017_STOCHASTIC = Model( description=( - "Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement " - "state, and adds discrete stochastic state variables health and partner." + "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint " + "and the lagged_retirement state, and adds discrete stochastic state variables " + "health and partner." ), n_periods=3, functions={ diff --git a/tests/test_state_space.py b/tests/test_state_space.py index 4a0c3767..cc0de209 100644 --- a/tests/test_state_space.py +++ b/tests/test_state_space.py @@ -1,9 +1,4 @@ -import jax.numpy as jnp -import pandas as pd -import pytest - from lcm.input_processing import process_model -from lcm.interfaces import InternalModel from lcm.state_space import ( create_state_choice_space, ) @@ -18,55 +13,3 @@ def test_create_state_choice_space(): model=_model, is_last_period=False, ) - - -@pytest.fixture -def filter_mask_inputs(): - def age(period): - return period + 18 - - def mandatory_retirement_filter(retirement, age): - return jnp.logical_or(retirement == 1, age < 65) - - def mandatory_lagged_retirement_filter(lagged_retirement, age): - return jnp.logical_or(lagged_retirement == 1, age < 66) - - def absorbing_retirement_filter(retirement, lagged_retirement): - return jnp.logical_or(retirement == 1, lagged_retirement == 0) - - grids = { - "lagged_retirement": jnp.array([0, 1]), - "retirement": jnp.array([0, 1]), - } - - functions = { - "mandatory_retirement_filter": mandatory_retirement_filter, - "mandatory_lagged_retirement_filter": mandatory_lagged_retirement_filter, - "absorbing_retirement_filter": absorbing_retirement_filter, - "age": age, - } - - function_info = pd.DataFrame( - index=functions.keys(), - columns=["is_filter"], - data=[[True], [True], [True], [False]], - ) - - # create a model instance where some attributes are set to None because they - # are not needed for create_filter_mask - return InternalModel( - grids=grids, - gridspecs=None, - variable_info=None, - functions=functions, - function_info=function_info, - params=None, - random_utility_shocks=None, - n_periods=100, - ) - - -PARAMETRIZATION = [ - (50, jnp.array([[False, False], [False, True]])), - (10, jnp.array([[True, True], [False, True]])), -] From 015a6a083e94da66ee91e740710c37280c39bf62 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Tue, 4 Feb 2025 09:25:06 +0100 Subject: [PATCH 15/21] Fix test_solve_brute.py --- tests/test_solve_brute.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_solve_brute.py b/tests/test_solve_brute.py index 788d10e5..44395cca 100644 --- a/tests/test_solve_brute.py +++ b/tests/test_solve_brute.py @@ -1,6 +1,5 @@ import jax.numpy as jnp import numpy as np -import pytest from numpy.testing import assert_array_almost_equal as aaae from lcm.entry_point import create_compute_conditional_continuation_value @@ -117,11 +116,14 @@ def calculate_emax(values, params): # noqa: ARG001 assert isinstance(solution, list) -@pytest.mark.xfail(reason="Removec sparse vars segments") def test_solve_continuous_problem_no_vf_arr(): state_choice_space = Space( - dense_vars={"a": jnp.array([0, 1.0]), "b": jnp.array([2, 3.0])}, - sparse_vars={"c": jnp.array([4, 5, 6])}, + dense_vars={ + "a": jnp.array([0, 1.0]), + "b": jnp.array([2, 3.0]), + "c": jnp.array([4, 5, 6]), + }, + sparse_vars={}, ) def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 @@ -137,7 +139,6 @@ def _utility_and_feasibility(a, c, b, d, vf_arr, params): # noqa: ARG001 ) expected = np.array([[[6.0, 7, 8], [7, 8, 9]], [[7, 8, 9], [8, 9, 10]]]) - expected = np.transpose(expected, axes=(2, 0, 1)) got = solve_continuous_problem( state_choice_space, From 10945f4f164e188ccb80b569c6d40604ee47aff0 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Tue, 4 Feb 2025 09:28:31 +0100 Subject: [PATCH 16/21] Fix test_create_data_state_choice_sapce --- tests/test_simulate.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 45afab9e..f5a6a130 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -405,7 +405,6 @@ def test_filter_ccv_policy(): assert jnp.all(got == jnp.array([0, 0])) -@pytest.mark.xfail(reason="Filters are replaced by constraints internally currently.") def test_create_data_state_choice_space(): model_config = get_model_config("iskhakov_et_al_2017", n_periods=3) model = process_model(model_config) @@ -415,14 +414,11 @@ def test_create_data_state_choice_space(): "lagged_retirement": jnp.array([0, 1]), }, model=model, - period=0, ) - assert got_space.dense_vars == {} - assert_array_equal(got_space.sparse_vars["wealth"], jnp.array([10.0, 10.0, 20.0])) - assert_array_equal(got_space.sparse_vars["lagged_retirement"], jnp.array([0, 0, 1])) - assert_array_equal(got_space.sparse_vars["retirement"], jnp.array([0, 1, 1])) - assert_array_equal(got_segment_info["segment_ids"], jnp.array([0, 0, 1])) - assert got_segment_info["num_segments"] == 2 + assert_array_equal(got_space.dense_vars["retirement"], jnp.array([0, 1])) + assert_array_equal(got_space.sparse_vars["wealth"], jnp.array([10.0, 20.0])) + assert_array_equal(got_space.sparse_vars["lagged_retirement"], jnp.array([0, 1])) + assert got_segment_info is None def test_dict_product(): From a2048a77b5d9565e760f6cdec765813e239e2d16 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Tue, 4 Feb 2025 09:45:53 +0100 Subject: [PATCH 17/21] Fix test_discrete_problem.py --- src/lcm/discrete_problem.py | 58 +-------- tests/test_discrete_problem.py | 212 ++------------------------------- 2 files changed, 9 insertions(+), 261 deletions(-) diff --git a/src/lcm/discrete_problem.py b/src/lcm/discrete_problem.py index 6a20009e..9430a952 100644 --- a/src/lcm/discrete_problem.py +++ b/src/lcm/discrete_problem.py @@ -21,7 +21,6 @@ from functools import partial import jax -import jax.numpy as jnp import pandas as pd from jax import Array @@ -105,7 +104,7 @@ def _solve_discrete_problem_no_shocks( # ====================================================================================== -def _calculate_emax_extreme_value_shocks(values, choice_axes, choice_segments, params): +def _calculate_emax_extreme_value_shocks(values, choice_axes, params): """Aggregate conditional continuation values over discrete choices. Args: @@ -129,65 +128,10 @@ def _calculate_emax_extreme_value_shocks(values, choice_axes, choice_segments, p out = values if choice_axes is not None: out = scale * jax.scipy.special.logsumexp(out / scale, axis=choice_axes) - if choice_segments is not None: - out = _segment_extreme_value_emax_over_first_axis(out, scale, choice_segments) return out -def _segment_extreme_value_emax_over_first_axis(a, scale, segment_info): - """Calculate emax under iid extreme value assumption over segments of first axis. - - TODO: Explain in more detail how this function is related to EMAX under IID EV. - - Args: - a (jax.numpy.ndarray): Multidimensional jax array. - scale (float): Scale parameter of the extreme value distribution. - segment_info (dict): Dictionary with the entries "segment_ids" - and "num_segments". segment_ids are a 1d integer array that partitions the - last dimension of a. "num_segments" is the number of segments. The - segment_ids are assumed to be sorted. - - Returns: - jax.numpy.ndarray - - """ - return scale * _segment_logsumexp(a / scale, segment_info) - - -def _segment_logsumexp(a, segment_info): - """Calculate a logsumexp over segments of the first axis of a. - - We use the familiar logsumexp trick for numerical stability. See: - https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ for details. - - Args: - a (jax.numpy.ndarray): Multidimensional jax array. - segment_info (dict): Dictionary with the entries "segment_ids" - and "num_segments". segment_ids are a 1d integer array that partitions the - first dimension of a. "num_segments" is the number of segments. The - segment_ids are assumed to be sorted. - - Returns: - jax.numpy.ndarray - - """ - segmax = jax.ops.segment_max( - data=a, - indices_are_sorted=True, - **segment_info, - ) - - exp = jnp.exp(a - segmax[segment_info["segment_ids"]]) - - summed = jax.ops.segment_sum( - data=exp, - indices_are_sorted=True, - **segment_info, - ) - return segmax + jnp.log(summed) - - # ====================================================================================== # Auxiliary functions # ====================================================================================== diff --git a/tests/test_discrete_problem.py b/tests/test_discrete_problem.py index 05d0a398..59556ad7 100644 --- a/tests/test_discrete_problem.py +++ b/tests/test_discrete_problem.py @@ -1,141 +1,16 @@ -from functools import partial -from itertools import product - import jax.numpy as jnp import pandas as pd import pytest -from jax.ops import segment_max from numpy.testing import assert_array_almost_equal as aaae from lcm.discrete_problem import ( _calculate_emax_extreme_value_shocks, _determine_dense_discrete_choice_axes, - _segment_extreme_value_emax_over_first_axis, - _segment_logsumexp, _solve_discrete_problem_no_shocks, get_solve_discrete_problem, ) from lcm.typing import ShockType - -@pytest.fixture -def cc_values(): - """Conditional continuation values.""" - v_t = jnp.arange(20).reshape(2, 2, 5) / 2 - # reuse old test case from when segment axis was last - return jnp.transpose(v_t, axes=(2, 0, 1)) - - -@pytest.fixture -def segment_info(): - return { - "segment_ids": jnp.array([0, 0, 1, 1, 1]), - "num_segments": 2, - } - - -test_cases = list(product([True, False], range(3))) - -# ====================================================================================== -# Aggregation without shocks -# ====================================================================================== - - -@pytest.mark.xfail(reason="Removec choice segments") -@pytest.mark.parametrize(("collapse", "n_extra_axes"), test_cases) -def test_aggregation_without_shocks(cc_values, segment_info, collapse, n_extra_axes): - cc_values, var_info = _get_reshaped_cc_values_and_variable_info( - cc_values, - collapse, - n_extra_axes, - ) - - solve_discrete_problem = get_solve_discrete_problem( - random_utility_shock_type=ShockType.NONE, - variable_info=var_info, - is_last_period=False, - choice_segments=segment_info, - ) - - calculated = solve_discrete_problem(cc_values, params=None) - - expected = jnp.array([8, 9.5]) - - expected_shape = tuple([2] + [1] * n_extra_axes) - assert calculated.shape == expected_shape - aaae(calculated.flatten(), expected) - - -# ====================================================================================== -# Aggregation with extreme value shocks -# ====================================================================================== - -scaling_factors = [0.3, 0.6, 1, 2.5, 10] -expected_results = [ - [8.051974, 9.560844], - [8.225906, 9.800117], - [8.559682, 10.265875], - [10.595821, 12.880228], - [25.184761, 30.494621], -] -test_cases = [] -for scale, exp in zip(scaling_factors, expected_results, strict=True): - for collapse in [True, False]: - for n_axes in range(3): - test_cases.append((scale, exp, collapse, n_axes)) - - -@pytest.mark.xfail(reason="Updated sparse variables") -@pytest.mark.parametrize(("scale", "expected", "collapse", "n_extra_axes"), test_cases) -def test_aggregation_with_extreme_value_shocks( - cc_values, - segment_info, - scale, - expected, - collapse, - n_extra_axes, -): - cc_values, var_info = _get_reshaped_cc_values_and_variable_info( - cc_values, - collapse, - n_extra_axes, - ) - - choice_axes = _determine_dense_discrete_choice_axes(var_info) - solve_discrete_problem = partial( - _calculate_emax_extreme_value_shocks, - choice_axes=choice_axes, - choice_segments=segment_info, - params={"additive_utility_shock": {"scale": scale}}, - ) - - calculated = solve_discrete_problem(cc_values) - - expected_shape = tuple([2] + [1] * n_extra_axes) - assert calculated.shape == expected_shape - aaae(calculated.flatten(), jnp.array(expected), decimal=5) - - -def _get_reshaped_cc_values_and_variable_info(cc_values, collapse, n_extra_axes): - n_variables = cc_values.ndim + 1 + n_extra_axes - collapse - n_agg_axes = 1 if collapse else 2 - names = [f"v{i}" for i in range(n_variables)] - is_choice = [False, True] + [False] * n_extra_axes + [True] * n_agg_axes - var_info = pd.DataFrame(index=names) - var_info["is_choice"] = is_choice - var_info["is_continuous"] = False - - if collapse: - cc_values = cc_values.reshape(5, 4) - - new_shape = tuple( - [cc_values.shape[0]] + [1] * n_extra_axes + list(cc_values.shape[1:]), - ) - cc_values = cc_values.reshape(new_shape) - - return cc_values, var_info - - # ====================================================================================== # Illustrative # ====================================================================================== @@ -168,7 +43,6 @@ def test_get_solve_discrete_problem_illustrative(): aaae(got, jnp.array([1, 3, 5])) -@pytest.mark.xfail(reason="Removec choice segments") @pytest.mark.illustrative def test_solve_discrete_problem_no_shocks_illustrative(): cc_values = jnp.array( @@ -179,7 +53,7 @@ def test_solve_discrete_problem_no_shocks_illustrative(): ], ) - # Only choice axes + # Single choice axes # ================================================================================== got = _solve_discrete_problem_no_shocks( cc_values, @@ -188,23 +62,14 @@ def test_solve_discrete_problem_no_shocks_illustrative(): ) aaae(got, jnp.array([4, 5])) - # Only choice segment - # ================================================================================== - got = _solve_discrete_problem_no_shocks( - cc_values, - choice_axes=None, - params=None, - ) - aaae(got, jnp.array([[2, 3], [4, 5]])) - - # Choice axes and choice segment + # Tuple of choice axes # ================================================================================== got = _solve_discrete_problem_no_shocks( cc_values, - choice_axes=1, + choice_axes=(0, 1), params=None, ) - aaae(got, jnp.array([3, 5])) + aaae(got, 5) @pytest.mark.illustrative @@ -217,84 +82,23 @@ def test_calculate_emax_extreme_value_shocks_illustrative(): ], ) - # Only choice axes + # Single choice axes # ================================================================================== got = _calculate_emax_extreme_value_shocks( cc_values, choice_axes=0, - choice_segments=None, params={"additive_utility_shock": {"scale": 0.1}}, ) aaae(got, jnp.array([4, 5]), decimal=5) - # Only choice segment - # ================================================================================== - got = _calculate_emax_extreme_value_shocks( - cc_values, - choice_axes=None, - choice_segments={"segment_ids": jnp.array([0, 0, 1]), "num_segments": 2}, - params={"additive_utility_shock": {"scale": 0.1}}, - ) - aaae(got, jnp.array([[2, 3], [4, 5]]), decimal=5) - - # Choice axes and choice segment + # Tuple of choice axes # ================================================================================== got = _calculate_emax_extreme_value_shocks( cc_values, - choice_axes=1, - choice_segments={"segment_ids": jnp.array([0, 0, 1]), "num_segments": 2}, + choice_axes=(0, 1), params={"additive_utility_shock": {"scale": 0.1}}, ) - aaae(got, jnp.array([3, 5]), decimal=5) - - -# ====================================================================================== -# Segment max over first axis -# ====================================================================================== - - -@pytest.mark.illustrative -def test_segment_max_over_first_axis_illustrative(): - a = jnp.arange(4) - segment_info = { - "segment_ids": jnp.array([0, 0, 1, 1]), - "num_segments": 2, - } - got = segment_max(a, indices_are_sorted=True, **segment_info) - expected = jnp.array([1, 3]) - aaae(got, expected) - - -@pytest.mark.illustrative -def test_segment_extreme_value_emax_over_first_axis_illustrative(): - a = jnp.array([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]]) - - segment_info = { - "segment_ids": jnp.array([0, 0, 0, 1, 1, 1]), - "num_segments": 2, - } - - got = _segment_extreme_value_emax_over_first_axis( - a, - scale=0.1, - segment_info=segment_info, - ) - expected = jnp.array([[4, 5], [10, 11]]) - aaae(got, expected) - - -@pytest.mark.illustrative -def test_segment_logsumexp_illustrative(): - a = jnp.array([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]]) - - segment_info = { - "segment_ids": jnp.array([0, 0, 0, 1, 1, 1]), - "num_segments": 2, - } - - got = _segment_logsumexp(a, segment_info) - expected = jnp.array([[4, 5], [10, 11]]) - aaae(got, expected, decimal=0) + aaae(got, 5, decimal=5) # ====================================================================================== From cbf56a19f8416d5759541b77355de4922f7dbb4b Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Tue, 4 Feb 2025 09:52:12 +0100 Subject: [PATCH 18/21] Update function_representation explanation notebook --- explanations/function_representation.ipynb | 73 +++++++--------------- 1 file changed, 21 insertions(+), 52 deletions(-) diff --git a/explanations/function_representation.ipynb b/explanations/function_representation.ipynb index 0935ae64..8a39cefa 100644 --- a/explanations/function_representation.ipynb +++ b/explanations/function_representation.ipynb @@ -64,7 +64,7 @@ "### Example\n", "\n", "As an example, consider a stripped-down version of the deterministic model from\n", - "Iskhakov et al. (2017), which removes the absorbing retirement filter and the lagged\n", + "Iskhakov et al. (2017), which removes the absorbing retirement constraint and the lagged\n", "retirement state compared to the original model (this version can be found in the\n", "`tests/test_models/deterministic.py` module). Here we also use a coarser grid to\n", "showcase the behavior of the function representation." @@ -119,8 +119,9 @@ "\n", "model = Model(\n", " description=(\n", - " \"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement \"\n", - " \"state, and adds a wage function that depends on age.\"\n", + " \"Starts from Iskhakov et al. (2017), removes the absorbing retirement \"\n", + " \"constraint and the lagged_retirement state, and adds a wage function that \"\n", + " \"depends on age.\"\n", " ),\n", " n_periods=2,\n", " functions={\n", @@ -368,9 +369,9 @@ { "data": { "text/html": [ - "