From 6d788ad2487a0d43bfb3b73cb4ab14c561cbeb2e Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Wed, 5 Feb 2025 10:25:49 +0100 Subject: [PATCH] Remove filters (#101) --- explanations/dispatchers.ipynb | 158 +++++---- explanations/function_representation.ipynb | 73 ++-- src/lcm/argmax.py | 72 ---- src/lcm/discrete_problem.py | 105 +----- src/lcm/dispatchers.py | 31 +- src/lcm/entry_point.py | 13 +- src/lcm/function_representation.py | 4 +- src/lcm/input_processing/process_model.py | 19 +- src/lcm/input_processing/util.py | 30 +- src/lcm/interfaces.py | 16 +- src/lcm/simulate.py | 122 +------ src/lcm/solve_brute.py | 5 +- src/lcm/state_space.py | 336 +------------------ src/lcm/typing.py | 18 +- tests/input_processing/test_process_model.py | 18 +- tests/test_argmax.py | 46 +-- tests/test_discrete_problem.py | 247 ++------------ tests/test_dispatchers.py | 10 +- tests/test_entry_point.py | 24 +- tests/test_model_functions.py | 140 ++++++-- tests/test_models/deterministic.py | 15 +- tests/test_models/discrete_deterministic.py | 14 +- tests/test_models/stochastic.py | 5 +- tests/test_simulate.py | 39 +-- tests/test_solve_brute.py | 9 +- tests/test_state_space.py | 303 ----------------- 26 files changed, 338 insertions(+), 1534 deletions(-) diff --git a/explanations/dispatchers.ipynb b/explanations/dispatchers.ipynb index cab40c16..2210d0f2 100644 --- a/explanations/dispatchers.ipynb +++ b/explanations/dispatchers.ipynb @@ -10,6 +10,13 @@ "`spacemap` are used by `lcm`." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**WARNING:** This notebook is outdated." + ] + }, { "cell_type": "code", "execution_count": null, @@ -31,7 +38,7 @@ "source": [ "# `vmap_1d`\n", "\n", - "Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function." + "Let's start by vectorizing the function `f` over axis `a` using JAX's `vmap` function." ] }, { @@ -49,6 +56,13 @@ "execution_count": null, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + }, { "data": { "text/plain": [ @@ -251,7 +265,7 @@ "\n", "If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n", "\n", - "Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n", + "Consider a simplified version of our deterministic test model. Curly brackets {...} denote discrete variables; square brackets [...] represent continuous variables.\n", "\n", "- **Choice variables:**\n", "\n", @@ -265,8 +279,8 @@ "\n", " - _wealth_ $\\in [1, 2, 3, 4]$\n", "\n", - "- **Filter:**\n", - " - Absorbing retirement filter: If _lagged_retirement_ is 1, then the choice\n", + "- **Constraints:**\n", + " - Absorbing retirement constraint: If _lagged_retirement_ is 1, then the choice\n", " _retirement_ can never be 0." ] }, @@ -291,7 +305,7 @@ " return jnp.log(consumption) - 0.5 * working + retirement_habit\n", "\n", "\n", - "def absorbing_retirement_filter(retirement, lagged_retirement):\n", + "def absorbing_retirement_constraint(retirement, lagged_retirement):\n", " return jnp.logical_or(retirement == 1, lagged_retirement == 0)\n", "\n", "\n", @@ -300,7 +314,7 @@ " \"utility\": utility,\n", " \"next_lagged_retirement\": lambda retirement: retirement,\n", " \"next_wealth\": lambda wealth, consumption: wealth - consumption,\n", - " \"absorbing_retirement_filter\": absorbing_retirement_filter,\n", + " \"absorbing_retirement_constraint\": absorbing_retirement_constraint,\n", " },\n", " n_periods=1,\n", " choices={\n", @@ -325,11 +339,9 @@ "\n", "processed_model = process_model(model)\n", "\n", - "sc_space, space_info, state_indexer, segments = create_state_choice_space(\n", + "sc_space, space_info = create_state_choice_space(\n", " processed_model,\n", - " period=2,\n", " is_last_period=False,\n", - " jit_filter=False,\n", ")" ] }, @@ -350,7 +362,9 @@ { "data": { "text/plain": [ - "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}" + "{'lagged_retirement': Array([0, 1], dtype=int32),\n", + " 'retirement': Array([0, 1], dtype=int32),\n", + " 'wealth': Array([1., 2., 3., 4.], dtype=float32)}" ] }, "execution_count": null, @@ -370,8 +384,7 @@ { "data": { "text/plain": [ - "{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n", - " 'retirement': Array([0, 1, 1], dtype=int32)}" + "{}" ] }, "execution_count": null, @@ -409,35 +422,17 @@ " \n", " \n", " \n", - " lagged_retirement\n", - " retirement\n", " \n", " \n", " \n", - " \n", - " 0\n", - " 0\n", - " 0\n", - " \n", - " \n", - " 1\n", - " 0\n", - " 1\n", - " \n", - " \n", - " 2\n", - " 1\n", - " 1\n", - " \n", " \n", "\n", "" ], "text/plain": [ - " lagged_retirement retirement\n", - "0 0 0\n", - "1 0 1\n", - "2 1 1" + "Empty DataFrame\n", + "Columns: []\n", + "Index: []" ] }, "execution_count": null, @@ -496,34 +491,6 @@ "---" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "It is also worth noting the connection between the sparse variable representation\n", - "and the `segments`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "segments" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -556,7 +523,6 @@ " func=utility,\n", " dense_vars=list(sc_space.dense_vars),\n", " sparse_vars=list(sc_space.sparse_vars),\n", - " put_dense_first=False,\n", ")" ] }, @@ -568,7 +534,9 @@ { "data": { "text/plain": [ - "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}" + "{'lagged_retirement': Array([0, 1], dtype=int32),\n", + " 'retirement': Array([0, 1], dtype=int32),\n", + " 'wealth': Array([1., 2., 3., 4.], dtype=float32)}" ] }, "execution_count": null, @@ -588,8 +556,7 @@ { "data": { "text/plain": [ - "{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n", - " 'retirement': Array([0, 1, 1], dtype=int32)}" + "{}" ] }, "execution_count": null, @@ -609,9 +576,11 @@ { "data": { "text/plain": [ - "Array([[-0.5, -0.5, -0.5, -0.5],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 1. , 2. , 3. , 4. ]], dtype=float32)" + "Array([[[-0.5, -0.5, -0.5, -0.5],\n", + " [ 0. , 0. , 0. , 0. ]],\n", + "\n", + " [[ 0.5, 1.5, 2.5, 3.5],\n", + " [ 1. , 2. , 3. , 4. ]]], dtype=float32)" ] }, "execution_count": null, @@ -636,7 +605,7 @@ { "data": { "text/plain": [ - "(3, 4)" + "(2, 2, 4)" ] }, "execution_count": null, @@ -655,6 +624,26 @@ "Let's try to get this result via looping over the grids and calling `utility` directly" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Space(sparse_vars={}, dense_vars={'lagged_retirement': Array([0, 1], dtype=int32), 'retirement': Array([0, 1], dtype=int32), 'wealth': Array([1., 2., 3., 4.], dtype=float32)})" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sc_space" + ] + }, { "cell_type": "code", "execution_count": null, @@ -664,8 +653,8 @@ "data": { "text/plain": [ "Array([[-0.5, -0.5, -0.5, -0.5],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 1. , 2. , 3. , 4. ]], dtype=float32)" + " [ 1. , 2. , 3. , 4. ],\n", + " [ 0. , 0. , 0. , 0. ]], dtype=float32)" ] }, "execution_count": null, @@ -679,8 +668,8 @@ "# loop over valid combinations of sparse variables (first axis)\n", "for i, (lagged_retirement, retirement) in enumerate(\n", " zip(\n", - " sc_space.sparse_vars[\"lagged_retirement\"],\n", - " sc_space.sparse_vars[\"retirement\"],\n", + " sc_space.dense_vars[\"lagged_retirement\"],\n", + " sc_space.dense_vars[\"retirement\"],\n", " strict=False,\n", " ),\n", "):\n", @@ -749,13 +738,18 @@ { "data": { "text/plain": [ - "Array([[[-0.5 , -0.5 , -0.5 , -0.5 ],\n", - " [ 0. , 0. , 0. , 0. ],\n", - " [ 1. , 2. , 3. , 4. ]],\n", + "Array([[[[-0.5 , -0.5 , -0.5 , -0.5 ],\n", + " [ 0. , 0. , 0. , 0. ]],\n", + "\n", + " [[ 0.5 , 1.5 , 2.5 , 3.5 ],\n", + " [ 1. , 2. , 3. , 4. ]]],\n", "\n", - " [[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n", - " [ 5.9914646, 5.9914646, 5.9914646, 5.9914646],\n", - " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)" + "\n", + " [[[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n", + " [ 5.9914646, 5.9914646, 5.9914646, 5.9914646]],\n", + "\n", + " [[ 6.4914646, 7.4914646, 8.491465 , 9.491465 ],\n", + " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]]], dtype=float32)" ] }, "execution_count": null, @@ -780,7 +774,7 @@ { "data": { "text/plain": [ - "(2, 3, 4)" + "(2, 2, 2, 4)" ] }, "execution_count": null, @@ -795,7 +789,7 @@ ], "metadata": { "kernelspec": { - "display_name": "lcm", + "display_name": "test-cpu", "language": "python", "name": "python3" }, @@ -809,7 +803,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/explanations/function_representation.ipynb b/explanations/function_representation.ipynb index 0935ae64..8a39cefa 100644 --- a/explanations/function_representation.ipynb +++ b/explanations/function_representation.ipynb @@ -64,7 +64,7 @@ "### Example\n", "\n", "As an example, consider a stripped-down version of the deterministic model from\n", - "Iskhakov et al. (2017), which removes the absorbing retirement filter and the lagged\n", + "Iskhakov et al. (2017), which removes the absorbing retirement constraint and the lagged\n", "retirement state compared to the original model (this version can be found in the\n", "`tests/test_models/deterministic.py` module). Here we also use a coarser grid to\n", "showcase the behavior of the function representation." @@ -119,8 +119,9 @@ "\n", "model = Model(\n", " description=(\n", - " \"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement \"\n", - " \"state, and adds a wage function that depends on age.\"\n", + " \"Starts from Iskhakov et al. (2017), removes the absorbing retirement \"\n", + " \"constraint and the lagged_retirement state, and adds a wage function that \"\n", + " \"depends on age.\"\n", " ),\n", " n_periods=2,\n", " functions={\n", @@ -368,9 +369,9 @@ { "data": { "text/html": [ - "