diff --git a/explanations/dispatchers.ipynb b/explanations/dispatchers.ipynb
index cab40c16..2210d0f2 100644
--- a/explanations/dispatchers.ipynb
+++ b/explanations/dispatchers.ipynb
@@ -10,6 +10,13 @@
"`spacemap` are used by `lcm`."
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**WARNING:** This notebook is outdated."
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -31,7 +38,7 @@
"source": [
"# `vmap_1d`\n",
"\n",
- "Let's start by vectorizing the function `f` over axis `a` using Jax' `vmap` function."
+ "Let's start by vectorizing the function `f` over axis `a` using JAX's `vmap` function."
]
},
{
@@ -49,6 +56,13 @@
"execution_count": null,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
+ ]
+ },
{
"data": {
"text/plain": [
@@ -251,7 +265,7 @@
"\n",
"If the valid values of a variable in a state-choice space depend on another variable, that variable is termed a _sparse_ variable; otherwise, it is a _dense_ variable. To dispatch a function across an entire state-choice space, we must vectorize over both dense and sparse variables. Since, by definition, all values of dense variables are valid, we can simply perform a `productmap` over the Cartesian grid of their values. The valid combinations of sparse variables are stored as a collection of 1D arrays (see below for an example). For these, we can perform a call to `vmap_1d`.\n",
"\n",
- "Consider a simplified version of our deterministic test model. Curly brackets {} denote discrete variables; square brackets [] represent continuous variables.\n",
+ "Consider a simplified version of our deterministic test model. Curly brackets {...} denote discrete variables; square brackets [...] represent continuous variables.\n",
"\n",
"- **Choice variables:**\n",
"\n",
@@ -265,8 +279,8 @@
"\n",
" - _wealth_ $\\in [1, 2, 3, 4]$\n",
"\n",
- "- **Filter:**\n",
- " - Absorbing retirement filter: If _lagged_retirement_ is 1, then the choice\n",
+ "- **Constraints:**\n",
+ " - Absorbing retirement constraint: If _lagged_retirement_ is 1, then the choice\n",
" _retirement_ can never be 0."
]
},
@@ -291,7 +305,7 @@
" return jnp.log(consumption) - 0.5 * working + retirement_habit\n",
"\n",
"\n",
- "def absorbing_retirement_filter(retirement, lagged_retirement):\n",
+ "def absorbing_retirement_constraint(retirement, lagged_retirement):\n",
" return jnp.logical_or(retirement == 1, lagged_retirement == 0)\n",
"\n",
"\n",
@@ -300,7 +314,7 @@
" \"utility\": utility,\n",
" \"next_lagged_retirement\": lambda retirement: retirement,\n",
" \"next_wealth\": lambda wealth, consumption: wealth - consumption,\n",
- " \"absorbing_retirement_filter\": absorbing_retirement_filter,\n",
+ " \"absorbing_retirement_constraint\": absorbing_retirement_constraint,\n",
" },\n",
" n_periods=1,\n",
" choices={\n",
@@ -325,11 +339,9 @@
"\n",
"processed_model = process_model(model)\n",
"\n",
- "sc_space, space_info, state_indexer, segments = create_state_choice_space(\n",
+ "sc_space, space_info = create_state_choice_space(\n",
" processed_model,\n",
- " period=2,\n",
" is_last_period=False,\n",
- " jit_filter=False,\n",
")"
]
},
@@ -350,7 +362,9 @@
{
"data": {
"text/plain": [
- "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
+ "{'lagged_retirement': Array([0, 1], dtype=int32),\n",
+ " 'retirement': Array([0, 1], dtype=int32),\n",
+ " 'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
]
},
"execution_count": null,
@@ -370,8 +384,7 @@
{
"data": {
"text/plain": [
- "{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n",
- " 'retirement': Array([0, 1, 1], dtype=int32)}"
+ "{}"
]
},
"execution_count": null,
@@ -409,35 +422,17 @@
" \n",
" \n",
" | \n",
- " lagged_retirement | \n",
- " retirement | \n",
"
\n",
" \n",
"
\n",
- " \n",
- " 0 | \n",
- " 0 | \n",
- " 0 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0 | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 1 | \n",
- " 1 | \n",
- "
\n",
" \n",
"\n",
""
],
"text/plain": [
- " lagged_retirement retirement\n",
- "0 0 0\n",
- "1 0 1\n",
- "2 1 1"
+ "Empty DataFrame\n",
+ "Columns: []\n",
+ "Index: []"
]
},
"execution_count": null,
@@ -496,34 +491,6 @@
"---"
]
},
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "It is also worth noting the connection between the sparse variable representation\n",
- "and the `segments`. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'segment_ids': Array([0, 0, 1], dtype=int32), 'num_segments': 2}"
- ]
- },
- "execution_count": null,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "segments"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -556,7 +523,6 @@
" func=utility,\n",
" dense_vars=list(sc_space.dense_vars),\n",
" sparse_vars=list(sc_space.sparse_vars),\n",
- " put_dense_first=False,\n",
")"
]
},
@@ -568,7 +534,9 @@
{
"data": {
"text/plain": [
- "{'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
+ "{'lagged_retirement': Array([0, 1], dtype=int32),\n",
+ " 'retirement': Array([0, 1], dtype=int32),\n",
+ " 'wealth': Array([1., 2., 3., 4.], dtype=float32)}"
]
},
"execution_count": null,
@@ -588,8 +556,7 @@
{
"data": {
"text/plain": [
- "{'lagged_retirement': Array([0, 0, 1], dtype=int32),\n",
- " 'retirement': Array([0, 1, 1], dtype=int32)}"
+ "{}"
]
},
"execution_count": null,
@@ -609,9 +576,11 @@
{
"data": {
"text/plain": [
- "Array([[-0.5, -0.5, -0.5, -0.5],\n",
- " [ 0. , 0. , 0. , 0. ],\n",
- " [ 1. , 2. , 3. , 4. ]], dtype=float32)"
+ "Array([[[-0.5, -0.5, -0.5, -0.5],\n",
+ " [ 0. , 0. , 0. , 0. ]],\n",
+ "\n",
+ " [[ 0.5, 1.5, 2.5, 3.5],\n",
+ " [ 1. , 2. , 3. , 4. ]]], dtype=float32)"
]
},
"execution_count": null,
@@ -636,7 +605,7 @@
{
"data": {
"text/plain": [
- "(3, 4)"
+ "(2, 2, 4)"
]
},
"execution_count": null,
@@ -655,6 +624,26 @@
"Let's try to get this result via looping over the grids and calling `utility` directly"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Space(sparse_vars={}, dense_vars={'lagged_retirement': Array([0, 1], dtype=int32), 'retirement': Array([0, 1], dtype=int32), 'wealth': Array([1., 2., 3., 4.], dtype=float32)})"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sc_space"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -664,8 +653,8 @@
"data": {
"text/plain": [
"Array([[-0.5, -0.5, -0.5, -0.5],\n",
- " [ 0. , 0. , 0. , 0. ],\n",
- " [ 1. , 2. , 3. , 4. ]], dtype=float32)"
+ " [ 1. , 2. , 3. , 4. ],\n",
+ " [ 0. , 0. , 0. , 0. ]], dtype=float32)"
]
},
"execution_count": null,
@@ -679,8 +668,8 @@
"# loop over valid combinations of sparse variables (first axis)\n",
"for i, (lagged_retirement, retirement) in enumerate(\n",
" zip(\n",
- " sc_space.sparse_vars[\"lagged_retirement\"],\n",
- " sc_space.sparse_vars[\"retirement\"],\n",
+ " sc_space.dense_vars[\"lagged_retirement\"],\n",
+ " sc_space.dense_vars[\"retirement\"],\n",
" strict=False,\n",
" ),\n",
"):\n",
@@ -749,13 +738,18 @@
{
"data": {
"text/plain": [
- "Array([[[-0.5 , -0.5 , -0.5 , -0.5 ],\n",
- " [ 0. , 0. , 0. , 0. ],\n",
- " [ 1. , 2. , 3. , 4. ]],\n",
+ "Array([[[[-0.5 , -0.5 , -0.5 , -0.5 ],\n",
+ " [ 0. , 0. , 0. , 0. ]],\n",
+ "\n",
+ " [[ 0.5 , 1.5 , 2.5 , 3.5 ],\n",
+ " [ 1. , 2. , 3. , 4. ]]],\n",
"\n",
- " [[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n",
- " [ 5.9914646, 5.9914646, 5.9914646, 5.9914646],\n",
- " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]], dtype=float32)"
+ "\n",
+ " [[[ 5.4914646, 5.4914646, 5.4914646, 5.4914646],\n",
+ " [ 5.9914646, 5.9914646, 5.9914646, 5.9914646]],\n",
+ "\n",
+ " [[ 6.4914646, 7.4914646, 8.491465 , 9.491465 ],\n",
+ " [ 6.9914646, 7.9914646, 8.991465 , 9.991465 ]]]], dtype=float32)"
]
},
"execution_count": null,
@@ -780,7 +774,7 @@
{
"data": {
"text/plain": [
- "(2, 3, 4)"
+ "(2, 2, 2, 4)"
]
},
"execution_count": null,
@@ -795,7 +789,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "lcm",
+ "display_name": "test-cpu",
"language": "python",
"name": "python3"
},
@@ -809,7 +803,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.3"
+ "version": "3.12.0"
}
},
"nbformat": 4,
diff --git a/explanations/function_representation.ipynb b/explanations/function_representation.ipynb
index 0935ae64..8a39cefa 100644
--- a/explanations/function_representation.ipynb
+++ b/explanations/function_representation.ipynb
@@ -64,7 +64,7 @@
"### Example\n",
"\n",
"As an example, consider a stripped-down version of the deterministic model from\n",
- "Iskhakov et al. (2017), which removes the absorbing retirement filter and the lagged\n",
+ "Iskhakov et al. (2017), which removes the absorbing retirement constraint and the lagged\n",
"retirement state compared to the original model (this version can be found in the\n",
"`tests/test_models/deterministic.py` module). Here we also use a coarser grid to\n",
"showcase the behavior of the function representation."
@@ -119,8 +119,9 @@
"\n",
"model = Model(\n",
" description=(\n",
- " \"Starts from Iskhakov et al. (2017), removes filters and the lagged_retirement \"\n",
- " \"state, and adds a wage function that depends on age.\"\n",
+ " \"Starts from Iskhakov et al. (2017), removes the absorbing retirement \"\n",
+ " \"constraint and the lagged_retirement state, and adds a wage function that \"\n",
+ " \"depends on age.\"\n",
" ),\n",
" n_periods=2,\n",
" functions={\n",
@@ -368,9 +369,9 @@
{
"data": {
"text/html": [
- "