Skip to content

Commit

Permalink
Remove filters (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens authored Feb 5, 2025
1 parent 0e8b073 commit 6d788ad
Show file tree
Hide file tree
Showing 26 changed files with 338 additions and 1,534 deletions.
158 changes: 76 additions & 82 deletions explanations/dispatchers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
"`spacemap` are used by `lcm`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**WARNING:** This notebook is outdated."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -31,7 +38,7 @@
"source": [
"# `vmap_1d`\n",
"\n",
"Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function."
"Let's start by vectorizing the function `f` over axis `a` using JAX's `vmap` function."
]
},
{
Expand All @@ -49,6 +56,13 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -251,7 +265,7 @@
"\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"\n",
"Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n",
"Consider a simplified version of our deterministic test model. Curly brackets {...} denote discrete variables; square brackets [...] represent continuous variables.\n",
"\n",
"- **Choice variables:**\n",
"\n",
Expand All @@ -265,8 +279,8 @@
"\n",
" - _wealth_ $\\in [1, 2, 3, 4]$\n",
"\n",
"- **Filter:**\n",
" - Absorbing retirement filter: If _lagged_retirement_ is 1, then the choice\n",
"- **Constraints:**\n",
" - Absorbing retirement constraint: If _lagged_retirement_ is 1, then the choice\n",
" _retirement_ can never be 0."
]
},
Expand All @@ -291,7 +305,7 @@
" return jnp.log(consumption) - 0.5 * working + retirement_habit\n",
"\n",
"\n",
"def absorbing_retirement_filter(retirement, lagged_retirement):\n",
"def absorbing_retirement_constraint(retirement, lagged_retirement):\n",
" return jnp.logical_or(retirement == 1, lagged_retirement == 0)\n",
"\n",
"\n",
Expand All @@ -300,7 +314,7 @@
" \"utility\": utility,\n",
" \"next_lagged_retirement\": lambda retirement: retirement,\n",
" \"next_wealth\": lambda wealth, consumption: wealth - consumption,\n",
" \"absorbing_retirement_filter\": absorbing_retirement_filter,\n",
" \"absorbing_retirement_constraint\": absorbing_retirement_constraint,\n",
" },\n",
" n_periods=1,\n",
" choices={\n",
Expand All @@ -325,11 +339,9 @@
"\n",
"processed_model = process_model(model)\n",
"\n",
"sc_space, space_info, state_indexer, segments = create_state_choice_space(\n",
"sc_space, space_info = create_state_choice_space(\n",
" processed_model,\n",
" period=2,\n",
" is_last_period=False,\n",
" jit_filter=False,\n",
")"
]
},
Expand All @@ -350,7 +362,9 @@
{
"data": {
"text/plain": [
"{'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
"{'lagged_retirement': Array([0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1], dtype=int32),\n",
" 'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
]
},
"execution_count": null,
Expand All @@ -370,8 +384,7 @@
{
"data": {
"text/plain": [
"{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1, 1], dtype=int32)}"
"{}"
]
},
"execution_count": null,
Expand Down Expand Up @@ -409,35 +422,17 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>lagged_retirement</th>\n",
" <th>retirement</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" lagged_retirement retirement\n",
"0 0 0\n",
"1 0 1\n",
"2 1 1"
"Empty DataFrame\n",
"Columns: []\n",
"Index: []"
]
},
"execution_count": null,
Expand Down Expand Up @@ -496,34 +491,6 @@
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is also worth noting the connection between the sparse variable representation\n",
"and the `segments`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"segments"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -556,7 +523,6 @@
" func=utility,\n",
" dense_vars=list(sc_space.dense_vars),\n",
" sparse_vars=list(sc_space.sparse_vars),\n",
" put_dense_first=False,\n",
")"
]
},
Expand All @@ -568,7 +534,9 @@
{
"data": {
"text/plain": [
"{'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
"{'lagged_retirement': Array([0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1], dtype=int32),\n",
" 'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
]
},
"execution_count": null,
Expand All @@ -588,8 +556,7 @@
{
"data": {
"text/plain": [
"{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n",
" 'retirement': Array([0, 1, 1], dtype=int32)}"
"{}"
]
},
"execution_count": null,
Expand All @@ -609,9 +576,11 @@
{
"data": {
"text/plain": [
"Array([[-0.5, -0.5, -0.5, -0.5],\n",
" [ 0. , 0. , 0. , 0. ],\n",
" [ 1. , 2. , 3. , 4. ]], dtype=float32)"
"Array([[[-0.5, -0.5, -0.5, -0.5],\n",
" [ 0. , 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.5, 1.5, 2.5, 3.5],\n",
" [ 1. , 2. , 3. , 4. ]]], dtype=float32)"
]
},
"execution_count": null,
Expand All @@ -636,7 +605,7 @@
{
"data": {
"text/plain": [
"(3, 4)"
"(2, 2, 4)"
]
},
"execution_count": null,
Expand All @@ -655,6 +624,26 @@
"Let's try to get this result via looping over the grids and calling `utility` directly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Space(sparse_vars={}, dense_vars={'lagged_retirement': Array([0, 1], dtype=int32), 'retirement': Array([0, 1], dtype=int32), 'wealth': Array([1., 2., 3., 4.], dtype=float32)})"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sc_space"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -664,8 +653,8 @@
"data": {
"text/plain": [
"Array([[-0.5, -0.5, -0.5, -0.5],\n",
" [ 0. , 0. , 0. , 0. ],\n",
" [ 1. , 2. , 3. , 4. ]], dtype=float32)"
" [ 1. , 2. , 3. , 4. ],\n",
" [ 0. , 0. , 0. , 0. ]], dtype=float32)"
]
},
"execution_count": null,
Expand All @@ -679,8 +668,8 @@
"# loop over valid combinations of sparse variables (first axis)\n",
"for i, (lagged_retirement, retirement) in enumerate(\n",
" zip(\n",
" sc_space.sparse_vars[\"lagged_retirement\"],\n",
" sc_space.sparse_vars[\"retirement\"],\n",
" sc_space.dense_vars[\"lagged_retirement\"],\n",
" sc_space.dense_vars[\"retirement\"],\n",
" strict=False,\n",
" ),\n",
"):\n",
Expand Down Expand Up @@ -749,13 +738,18 @@
{
"data": {
"text/plain": [
"Array([[[-0.5 , -0.5 , -0.5 , -0.5 ],\n",
" [ 0. , 0. , 0. , 0. ],\n",
" [ 1. , 2. , 3. , 4. ]],\n",
"Array([[[[-0.5 , -0.5 , -0.5 , -0.5 ],\n",
" [ 0. , 0. , 0. , 0. ]],\n",
"\n",
" [[ 0.5 , 1.5 , 2.5 , 3.5 ],\n",
" [ 1. , 2. , 3. , 4. ]]],\n",
"\n",
" [[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n",
" [ 5.9914646, 5.9914646, 5.9914646, 5.9914646],\n",
" [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)"
"\n",
" [[[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n",
" [ 5.9914646, 5.9914646, 5.9914646, 5.9914646]],\n",
"\n",
" [[ 6.4914646, 7.4914646, 8.491465 , 9.491465 ],\n",
" [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]]], dtype=float32)"
]
},
"execution_count": null,
Expand All @@ -780,7 +774,7 @@
{
"data": {
"text/plain": [
"(2, 3, 4)"
"(2, 2, 2, 4)"
]
},
"execution_count": null,
Expand All @@ -795,7 +789,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "lcm",
"display_name": "test-cpu",
"language": "python",
"name": "python3"
},
Expand All @@ -809,7 +803,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.0"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 6d788ad

Please sign in to comment.