Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove filters #101

Merged
merged 23 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c06c519
Replace filters by constraints internally
timmens Feb 3, 2025
92eb813
Quick-fix dispatchers explanation notebook
timmens Feb 3, 2025
14c62c7
First iteration: Remove sparse variables from discrete problem and ty…
timmens Feb 3, 2025
1b8f8e5
First iteration: Remove sparse variables from solve brute
timmens Feb 3, 2025
2570ace
First iteration: Remove sparse variables from function representation
timmens Feb 3, 2025
c96bae7
First iteration: Remove segment argmax
timmens Feb 3, 2025
3ec4367
Remove sparse variables from state choice space
timmens Feb 3, 2025
e39c7c5
Remove sparse variables from solution state-choice-space
timmens Feb 3, 2025
353136d
Remove sparse variables from state choice space
timmens Feb 3, 2025
dc951e2
First iteration: Remove filter and sparse vars from process-model
timmens Feb 3, 2025
609bc04
Remove is_sparse and is_dense entries from variable info
timmens Feb 3, 2025
5539599
Remove nearly all mentions of sparse variables from state-choice-spac…
timmens Feb 3, 2025
1ea6367
Make sparse vars optional for spacemap
timmens Feb 3, 2025
68dbc99
Remove all filter references from code-base
timmens Feb 3, 2025
015a6a0
Fix test_solve_brute.py
timmens Feb 4, 2025
10945f4
Fix test_create_data_state_choice_sapce
timmens Feb 4, 2025
a2048a7
Fix test_discrete_problem.py
timmens Feb 4, 2025
cbf56a1
Update function_representation explanation notebook
timmens Feb 4, 2025
a8fd9fa
Fix dispatchers explanation notebook; but do not rewrite
timmens Feb 4, 2025
010c339
Merge branch 'main' into remove-filter
timmens Feb 4, 2025
d8d3996
Incorporate comments from review
timmens Feb 4, 2025
517ac13
Merge branch 'remove-filter' of https://github.com/OpenSourceEconomic…
timmens Feb 4, 2025
0ab9a0b
Re-use test case from state-choice-space in testing of get_combined_c…
timmens Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 76 additions & 82 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
"`spacemap` are used by `lcm`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**WARNING:** This notebook is outdated."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -31,7 +38,7 @@
"source": [
"# `vmap_1d`\n",
"\n",
"Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function."
"Let's start by vectorizing the function `f` over axis `a` using JAX's `vmap` function."
]
},
{
Expand All @@ -49,6 +56,13 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -251,7 +265,7 @@
"\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"\n",
"Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n",
"Consider a simplified version of our deterministic test model. Curly brackets {...} denote discrete variables; square brackets [...] represent continuous variables.\n",
"\n",
"- **Choice variables:**\n",
"\n",
Expand All @@ -265,8 +279,8 @@
"\n",
" - _wealth_ $\\in [1, 2, 3, 4]$\n",
"\n",
"- **Filter:**\n",
" - Absorbing retirement filter: If _lagged_retirement_ is 1, then the choice\n",
"- **Constraints:**\n",
" - Absorbing retirement constraint: If _lagged_retirement_ is 1, then the choice\n",
" _retirement_ can never be 0."
]
},
Expand All @@ -291,7 +305,7 @@
" return jnp.log(consumption) - 0.5 * working + retirement_habit\n",
"\n",
"\n",
"def absorbing_retirement_filter(retirement, lagged_retirement):\n",
"def absorbing_retirement_constraint(retirement, lagged_retirement):\n",
" return jnp.logical_or(retirement == 1, lagged_retirement == 0)\n",
"\n",
"\n",
Expand All @@ -300,7 +314,7 @@
" \"utility\": utility,\n",
" \"next_lagged_retirement\": lambda retirement: retirement,\n",
" \"next_wealth\": lambda wealth, consumption: wealth - consumption,\n",
" \"absorbing_retirement_filter\": absorbing_retirement_filter,\n",
" \"absorbing_retirement_constraint\": absorbing_retirement_constraint,\n",
" },\n",
" n_periods=1,\n",
" choices={\n",
Expand All @@ -325,11 +339,9 @@
"\n",
"processed_model = process_model(model)\n",
"\n",
"sc_space, space_info, state_indexer, segments = create_state_choice_space(\n",
"sc_space, space_info = create_state_choice_space(\n",
" processed_model,\n",
" period=2,\n",
" is_last_period=False,\n",
" jit_filter=False,\n",
")"
]
},
Expand All @@ -350,7 +362,9 @@
{
"data": {
"text/plain": [
"{'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
"{'lagged_retirement': Array([0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1], dtype=int32),\n",
" 'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
]
},
"execution_count": null,
Expand All @@ -370,8 +384,7 @@
{
"data": {
"text/plain": [
"{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1, 1], dtype=int32)}"
"{}"
]
},
"execution_count": null,
Expand Down Expand Up @@ -409,35 +422,17 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>lagged_retirement</th>\n",
" <th>retirement</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" lagged_retirement retirement\n",
"0 0 0\n",
"1 0 1\n",
"2 1 1"
"Empty DataFrame\n",
"Columns: []\n",
"Index: []"
]
},
"execution_count": null,
Expand Down Expand Up @@ -496,34 +491,6 @@
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is also worth noting the connection between the sparse variable representation\n",
"and the `segments`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"segments"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -556,7 +523,6 @@
" func=utility,\n",
" dense_vars=list(sc_space.dense_vars),\n",
" sparse_vars=list(sc_space.sparse_vars),\n",
" put_dense_first=False,\n",
")"
]
},
Expand All @@ -568,7 +534,9 @@
{
"data": {
"text/plain": [
"{'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
"{'lagged_retirement': Array([0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1], dtype=int32),\n",
" 'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
]
},
"execution_count": null,
Expand All @@ -588,8 +556,7 @@
{
"data": {
"text/plain": [
"{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1, 1], dtype=int32)}"
"{}"
]
},
"execution_count": null,
Expand All @@ -609,9 +576,11 @@
{
"data": {
"text/plain": [
"Array([[-0.5, -0.5, -0.5, -0.5],\n",
" [ 0. , 0. , 0. , 0. ],\n",
" [ 1. , 2. , 3. , 4. ]], dtype=float32)"
"Array([[[-0.5, -0.5, -0.5, -0.5],\n",
" [ 0. , 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.5, 1.5, 2.5, 3.5],\n",
" [ 1. , 2. , 3. , 4. ]]], dtype=float32)"
]
},
"execution_count": null,
Expand All @@ -636,7 +605,7 @@
{
"data": {
"text/plain": [
"(3, 4)"
"(2, 2, 4)"
]
},
"execution_count": null,
Expand All @@ -655,6 +624,26 @@
"Let's try to get this result via looping over the grids and calling `utility` directly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Space(sparse_vars={}, dense_vars={'lagged_retirement': Array([0, 1], dtype=int32), 'retirement': Array([0, 1], dtype=int32), 'wealth': Array([1., 2., 3., 4.], dtype=float32)})"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sc_space"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -664,8 +653,8 @@
"data": {
"text/plain": [
"Array([[-0.5, -0.5, -0.5, -0.5],\n",
" [ 0. , 0. , 0. , 0. ],\n",
" [ 1. , 2. , 3. , 4. ]], dtype=float32)"
" [ 1. , 2. , 3. , 4. ],\n",
" [ 0. , 0. , 0. , 0. ]], dtype=float32)"
]
},
"execution_count": null,
Expand All @@ -679,8 +668,8 @@
"# loop over valid combinations of sparse variables (first axis)\n",
"for i, (lagged_retirement, retirement) in enumerate(\n",
" zip(\n",
" sc_space.sparse_vars[\"lagged_retirement\"],\n",
" sc_space.sparse_vars[\"retirement\"],\n",
" sc_space.dense_vars[\"lagged_retirement\"],\n",
" sc_space.dense_vars[\"retirement\"],\n",
" strict=False,\n",
" ),\n",
"):\n",
Expand Down Expand Up @@ -749,13 +738,18 @@
{
"data": {
"text/plain": [
"Array([[[-0.5 , -0.5 , -0.5 , -0.5 ],\n",
" [ 0. , 0. , 0. , 0. ],\n",
" [ 1. , 2. , 3. , 4. ]],\n",
"Array([[[[-0.5 , -0.5 , -0.5 , -0.5 ],\n",
" [ 0. , 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.5 , 1.5 , 2.5 , 3.5 ],\n",
" [ 1. , 2. , 3. , 4. ]]],\n",
"\n",
" [[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n",
" [ 5.9914646, 5.9914646, 5.9914646, 5.9914646],\n",
" [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)"
"\n",
" [[[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n",
" [ 5.9914646, 5.9914646, 5.9914646, 5.9914646]],\n",
"\n",
" [[ 6.4914646, 7.4914646, 8.491465 , 9.491465 ],\n",
" [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]]], dtype=float32)"
]
},
"execution_count": null,
Expand All @@ -780,7 +774,7 @@
{
"data": {
"text/plain": [
"(2, 3, 4)"
"(2, 2, 2, 4)"
]
},
"execution_count": null,
Expand All @@ -795,7 +789,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "lcm",
"display_name": "test-cpu",
"language": "python",
"name": "python3"
},
Expand All @@ -809,7 +803,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.0"
}
},
"nbformat": 4,
Expand Down
Loading
Loading