diff --git a/lectures/wealth_dynamics.md b/lectures/wealth_dynamics.md index 760ba85..914284c 100644 --- a/lectures/wealth_dynamics.md +++ b/lectures/wealth_dynamics.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.2 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -16,17 +16,41 @@ kernelspec: ```{include} _admonition/gpu.md ``` -This lecture is the extended JAX implementation of [this lecture](https://python.quantecon.org/wealth_dynamics.html). +In this lecture we examine wealth dynamics in large cross-section of agents who +are subject to both -Please refer that lecture for all background and notation. +* idiosyncratic shocks, which affect labor income and returns, and +* an aggregate shock, which also impacts on labor income and returns -We will use the following imports. +In most macroeconomic models savings and consumption are determined by optimization. + +Here savings and consumption behavior is taken as given -- you can plug in your +favorite model to obtain savings behavior and then analyze distribution dynamics +using the techniques described below. + +One of our interests will be how different aspects of wealth dynamics -- such +as labor income and the rate of return on investments -- feed into measures of +inequality, such as the Gini coefficient. + +In addition to JAX and Anaconda, this lecture will need the following libraries: ```{code-cell} ipython3 +:tags: [hide-output] + +!pip install quantecon +``` + +We will use the following imports: + +```{code-cell} ipython3 +import numba +import pandas as pd +import numpy as np import matplotlib.pyplot as plt +import quantecon as qe import jax import jax.numpy as jnp -from collections import namedtuple +from time import time ``` Let's check the GPU we are running @@ -35,558 +59,596 @@ Let's check the GPU we are running !nvidia-smi ``` -## Lorenz Curves and the Gini Coefficient +## Wealth dynamics -Before we investigate wealth dynamics, we briefly review some measures of -inequality. -### Lorenz Curves +Wealth evolves as follows: -One popular graphical measure of inequality is the [Lorenz curve](https://en.wikipedia.org/wiki/Lorenz_curve). +```{math} + w_{t+1} = (1 + r_{t+1}) s(w_t) + y_{t+1} +``` -To illustrate, let us define a function `lorenz_curve_jax` that returns the cumulative share of people and the cumulative share of income earned. +Here -```{code-cell} ipython3 -@jax.jit -def lorenz_curve_jax(y): - n = y.shape[0] - y = jnp.sort(y) - s = jnp.concatenate((jnp.zeros(1), jnp.cumsum(y))) - _cum_p = jnp.arange(1, n + 1) / n - cum_income = s / s[n] - cum_people = jnp.concatenate((jnp.zeros(1), _cum_p)) - return cum_people, cum_income -``` +- $w_t$ is wealth at time $t$ for a given household, +- $r_t$ is the rate of return of financial assets, +- $y_t$ is labor income and +- $s(w_t)$ is savings (current wealth minus current consumption) -Let's suppose that -```{code-cell} ipython3 -n = 10_000 # Size of sample -rand_key = jax.random.PRNGKey(101) # Set random key -w = jnp.exp(jax.random.normal(rand_key, shape=(n,))) # Lognormal draws -``` +There is an aggregate state process -is data representing the wealth of 10,000 households. +$$ + z_{t+1} = a z_t + b + \sigma_z \epsilon_{t+1} +$$ -We can compute and plot the Lorenz curve as follows: +that affects the interest rate and labor income. -```{code-cell} ipython3 -%%time +In particular, the gross interest rates obey -f_vals, l_vals = lorenz_curve_jax(w) -``` +$$ + R_t := 1 + r_t = c_r \exp(z_t) + \exp(\mu_r + \sigma_r \xi_t) +$$ -```{code-cell} ipython3 -%%time +while -# This will be much faster as it will use the jitted function -f_vals, l_vals = lorenz_curve_jax(w) -``` +$$ + y_t = c_y \exp(z_t) + \exp(\mu_y + \sigma_y \zeta_t) +$$ -```{code-cell} ipython3 -fig, ax = plt.subplots() -ax.plot(f_vals, l_vals, label='Lorenz curve, lognormal sample') -ax.plot(f_vals, f_vals, label='Lorenz curve, equality') -ax.legend() -plt.show() -``` +The tuple $\{ (\epsilon_t, \xi_t, \zeta_t) \}$ is IID and standard normal in $\mathbb R^3$. -Here is another example, which shows how the Lorenz curve shifts as the -underlying distribution changes. +(Each household receives their own idiosyncratic shocks.) -We generate 10,000 observations using the Pareto distribution with a range of -parameters, and then compute the Lorenz curve corresponding to each set of -observations. +Regarding the savings function $s$, our default model will be -```{code-cell} ipython3 -a_vals = (1, 2, 5) # Pareto tail index -n = 10_000 # size of each sample -``` +```{math} +:label: sav_ah -```{code-cell} ipython3 -fig, ax = plt.subplots() -for a in a_vals: - rand_key = jax.random.PRNGKey(a*100) - u = jax.random.uniform(rand_key, shape=(n,)) - y = u**(-1/a) # distributed as Pareto with tail index a - f_vals, l_vals = lorenz_curve_jax(y) - ax.plot(f_vals, l_vals, label=f'$a = {a}$') - -ax.plot(f_vals, f_vals, label='equality') -ax.legend() -plt.show() +s(w) = s_0 w \cdot \mathbb 1\{w \geq \hat w\} ``` -You can see that, as the tail parameter of the Pareto distribution increases, inequality decreases. - -This is to be expected, because a higher tail index implies less weight in the tail of the Pareto distribution. +where $s_0$ is a positive constant. -### The Gini Coefficient +Thus, -The definition and interpretation of the Gini coefficient can be found on the corresponding [Wikipedia page](https://en.wikipedia.org/wiki/Gini_coefficient). +* for $w < \hat w$, the household saves nothing, while +* for $w \geq \bar w$, the household saves a fraction $s_0$ of their wealth. -We can test it on the Weibull distribution with parameter $a$, where the Gini coefficient is known to be +## Implementation -$$ -G = 1 - 2^{-1/a} -$$ +### Numba implementation -Let's define a function to compute the Gini coefficient. +Here's a function that collects parameters and useful constants ```{code-cell} ipython3 -@jax.jit -def gini_jax(y): - n = y.shape[0] - g_sum = 0 +def create_wealth_model(w_hat=1.0, # Savings parameter + s_0=0.75, # Savings parameter + c_y=1.0, # Labor income parameter + μ_y=1.0, # Labor income parameter + σ_y=0.2, # Labor income parameter + c_r=0.05, # Rate of return parameter + μ_r=0.1, # Rate of return parameter + σ_r=0.5, # Rate of return parameter + a=0.5, # Aggregate shock parameter + b=0.0, # Aggregate shock parameter + σ_z=0.1): # Aggregate shock parameter + """ + Create a wealth model with given parameters. - def sum_y_gini(i, g_sum): - g_sum += jnp.sum(jnp.abs(y[i] - y)) - return g_sum + Return a tuple model = (household_params, aggregate_params), where + household_params collects household information and aggregate_params + collects information relevant to the aggregate shock process. - g_sum = jax.lax.fori_loop(0, n, sum_y_gini, 0) - return g_sum / (2 * n * jnp.sum(y)) + """ + # Mean and variance of z process + z_mean = b / (1 - a) + z_var = σ_z**2 / (1 - a**2) + exp_z_mean = np.exp(z_mean + z_var / 2) + # Mean of R and y processes + R_mean = c_r * exp_z_mean + np.exp(μ_r + σ_r**2 / 2) + y_mean = c_y * exp_z_mean + np.exp(μ_y + σ_y**2 / 2) + # Test stability condition ensuring wealth does not diverge + # to infinity. + α = R_mean * s_0 + if α >= 1: + raise ValueError("Stability condition failed.") + # Pack values into tuples and return them + household_params = (w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, y_mean) + aggregate_params = (a, b, σ_z, z_mean, z_var) + model = household_params, aggregate_params + return model ``` -Let's see if the Gini coefficient computed from a simulated sample matches -this at each fixed value of $a$. +Here's a function that generates the aggregate state process ```{code-cell} ipython3 -a_vals = range(1, 20) -ginis = [] -ginis_theoretical = [] -n = 100 - -for a in a_vals: - rand_key = jax.random.PRNGKey(a) - y = jax.random.weibull_min(rand_key, 1, a, shape=(n,)) - ginis.append(gini_jax(y)) - ginis_theoretical.append(1 - 2**(-1/a)) +@numba.jit +def generate_aggregate_state_sequence(aggregate_params, length=100): + a, b, σ_z, z_mean, z_var = aggregate_params + z = np.empty(length+1) + z[0] = z_mean # Initialize at z_mean + for t in range(length): + z[t+1] = a * z[t] + b + σ_z * np.random.randn() + return z ``` +Here's a function that updates household wealth by one period, taking the +current value of the aggregate shock + ```{code-cell} ipython3 -fig, ax = plt.subplots() -ax.plot(a_vals, ginis, label='estimated gini coefficient') -ax.plot(a_vals, ginis_theoretical, label='theoretical gini coefficient') -ax.legend() -ax.set_xlabel("Weibull parameter $a$") -ax.set_ylabel("Gini coefficient") -plt.show() +@numba.jit +def update_wealth(household_params, w, z): + """ + Generate w_{t+1} given w_t and z_{t+1}. + """ + # Unpack + w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, y_mean = household_params + # Update wealth + y = c_y * np.exp(z) + np.exp(μ_y + σ_y * np.random.randn()) + wp = y + if w >= w_hat: + R = c_r * np.exp(z) + np.exp(μ_r + σ_r * np.random.randn()) + wp += R * s_0 * w + return wp ``` -The simulation shows that the fit is good. - -## A Model of Wealth Dynamics - -Having discussed inequality measures, let us now turn to wealth dynamics. +Here's a function to simulate the time series of wealth for an individual household -The model we will study is +```{code-cell} ipython3 +@numba.jit +def wealth_time_series(model, w_0, sim_length): + """ + Generate a single time series of length sim_length for wealth given initial + value w_0. The function generates its own aggregate shock sequence. -```{math} -:label: wealth_dynam_ah + """ + # Unpack + household_params, aggregate_params = model + a, b, σ_z, z_mean, z_var = aggregate_params + # Initialize and update + z = generate_aggregate_state_sequence(aggregate_params, + length=sim_length) + w = np.empty(sim_length) + w[0] = w_0 + for t in range(sim_length-1): + w[t+1] = update_wealth(household_params, w[t], z[t+1]) + return w +``` + +Let's look at the wealth dynamics of an individual household + +```{code-cell} ipython3 +model = create_wealth_model() +household_params, aggregate_params = model +w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, y_mean = household_params +a, b, σ_z, z_mean, z_var = aggregate_params +ts_length = 200 +w = wealth_time_series(model, y_mean, ts_length) +``` -w_{t+1} = (1 + r_{t+1}) s(w_t) + y_{t+1} +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(w) +plt.show() ``` -where +Notice the large spikes in wealth over time. -- $w_t$ is wealth at time $t$ for a given household, -- $r_t$ is the rate of return of financial assets, -- $y_t$ is current non-financial (e.g., labor) income and -- $s(w_t)$ is current wealth net of consumption +Such spikes are related to heavy tails in the wealth distribution, which we +discuss below. -## Implementation using JAX +Here's a function to simulate a cross section of households forward in time. -Let's define a model to represent the wealth dynamics. +Note the use of parallelization to speed up computation. ```{code-cell} ipython3 -# NamedTuple Model +@numba.jit(parallel=True) +def update_cross_section(model, w_distribution, z_sequence): + """ + Shifts a cross-section of households forward in time -Model = namedtuple("Model", ("w_hat", "s_0", "c_y", "μ_y", - "σ_y", "c_r", "μ_r", "σ_r", "a", - "b", "σ_z", "z_mean", "z_var", "y_mean")) -``` + Takes -Here's a function to create the Model with the given parameters + * a current distribution of wealth values as w_distribution and + * an aggregate shock sequence z_sequence -```{code-cell} ipython3 -def create_wealth_model(w_hat=1.0, - s_0=0.75, - c_y=1.0, - μ_y=1.0, - σ_y=0.2, - c_r=0.05, - μ_r=0.1, - σ_r=0.5, - a=0.5, - b=0.0, - σ_z=0.1): - """ - Create a wealth model with given parameters and return - and instance of NamedTuple Model. - """ - z_mean = b / (1 - a) - z_var = σ_z**2 / (1 - a**2) - exp_z_mean = jnp.exp(z_mean + z_var / 2) - R_mean = c_r * exp_z_mean + jnp.exp(μ_r + σ_r**2 / 2) - y_mean = c_y * exp_z_mean + jnp.exp(μ_y + σ_y**2 / 2) - # Test a stability condition that ensures wealth does not diverge - # to infinity. - α = R_mean * s_0 - if α >= 1: - raise ValueError("Stability condition failed.") - return Model(w_hat=w_hat, s_0=s_0, c_y=c_y, μ_y=μ_y, - σ_y=σ_y, c_r=c_r, μ_r=μ_r, σ_r=σ_r, a=a, - b=b, σ_z=σ_z, z_mean=z_mean, z_var=z_var, y_mean=y_mean) -``` + and updates each w_t in w_distribution to w_{t+j}, where + j = len(z_sequence). -The following function updates one period with the given current wealth and persistent state. + Returns the new distribution. -```{code-cell} ipython3 -def update_states_jax(arrays, wdy, size, rand_key): - """ - Update one period, given current wealth w and persistent - state z. They are stored in the form of tuples under the arrays argument """ - # Unpack w and z - w, z = arrays + # Unpack + household_params, aggregate_params = model + + num_households = len(w_distribution) + new_distribution = np.empty_like(w_distribution) + z = z_sequence + + # Update each household + for i in numba.prange(num_households): + w = w_distribution[i] + for t in range(sim_length): + w = update_wealth(household_params, w, z[t]) + new_distribution[i] = w + return new_distribution +``` - rand_key, *subkey = jax.random.split(rand_key, 3) - zp = wdy.a * z + wdy.b + wdy.σ_z * jax.random.normal(rand_key, shape=size) +Parallelization works in the function above because the time path of each +household can be calculated independently once the path for the aggregate state +is known. - # Update wealth - y = wdy.c_y * jnp.exp(zp) + jnp.exp( - wdy.μ_y + wdy.σ_y * jax.random.normal(subkey[0], shape=size)) - wp = y +Let's see how long it takes to shift a large cross-section of households forward +200 periods - R = wdy.c_r * jnp.exp(zp) + jnp.exp( - wdy.μ_r + wdy.σ_r * jax.random.normal(subkey[1], shape=size)) - wp += (w >= wdy.w_hat) * R * wdy.s_0 * w - return wp, zp +```{code-cell} ipython3 +sim_length = 200 +num_households = 10_000_000 +ψ_0 = np.full(num_households, y_mean) # Initial distribution +z_sequence = generate_aggregate_state_sequence(aggregate_params, + length=sim_length) +print("Generating cross-section using Numba") +start_time = time() +ψ_star = update_cross_section(model, ψ_0, z_sequence) +numba_time = time() - start_time +print(f"Generated cross-section in {numba_time} seconds.\n") ``` -Here’s function to simulate the time series of wealth for individual households using a `for` loop and JAX. +### JAX implementation -```{code-cell} ipython3 -# Using JAX and for loop +Let's redo some of the preceding calculations using JAX and see how execution +speed compares -def wealth_time_series_for_loop_jax(w_0, n, wdy, size, rand_seed=1): +```{code-cell} ipython3 +def update_cross_section_jax(model, w_distribution, z_sequence, key): """ - Generate a single time series of length n for wealth given - initial value w_0. + Shifts a cross-section of households forward in time + + Takes - * This implementation uses a `for` loop. + * a current distribution of wealth values as w_distribution and + * an aggregate shock sequence z_sequence + + and updates each w_t in w_distribution to w_{t+j}, where + j = len(z_sequence). - The initial persistent state z_0 for each household is drawn from - the stationary distribution of the AR(1) process. + Returns the new distribution. - * wdy: NamedTuple Model - * w_0: scalar/vector - * n: int - * size: size/shape of the w_0 - * rand_seed: int (Used to generate PRNG key) """ - rand_key = jax.random.PRNGKey(rand_seed) - rand_key, *subkey = jax.random.split(rand_key, n) + # Unpack, simplify names + household_params, aggregate_params = model + w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, y_mean = household_params + w = w_distribution + n = len(w) - w_0 = jax.device_put(w_0).reshape(size) + # Update wealth + for t, z in enumerate(z_sequence): + U = jax.random.normal(key, (2, n)) + y = c_y * jnp.exp(z) + jnp.exp(μ_y + σ_y * U[0, :]) + R = c_r * jnp.exp(z) + jnp.exp(μ_r + σ_r * U[1, :]) + w = y + jnp.where(w < w_hat, 0.0, R * s_0 * w) + key, subkey = jax.random.split(key) - z = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(rand_key, shape=size) - w = [w_0] - for t in range(n-1): - w_, z = update_states_jax((w[t], z), wdy, size, subkey[t]) - w.append(w_) - return jnp.array(w) + return w ``` -Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution using the above function. +Let's see how long it takes to shift the cross-section of households forward +using JAX ```{code-cell} ipython3 -wdy = create_wealth_model() # default model -ts_length = 200 -size = (1,) +sim_length = 200 +num_households = 10_000_000 +ψ_0 = jnp.full(num_households, y_mean) # Initial distribution +z_sequence = generate_aggregate_state_sequence(aggregate_params, + length=sim_length) +z_sequence = jnp.array(z_sequence) ``` ```{code-cell} ipython3 -%%time - -w_jax_result = wealth_time_series_for_loop_jax(wdy.y_mean, - ts_length, wdy, size).block_until_ready() +print("Generating cross-section using JAX") +key = jax.random.PRNGKey(1234) +start_time = time() +ψ_star = update_cross_section_jax(model, ψ_0, z_sequence, key) +jax_time = time() - start_time +print(f"Generated cross-section in {jax_time} seconds.\n") ``` ```{code-cell} ipython3 -fig, ax = plt.subplots() -ax.plot(w_jax_result) -plt.show() +print("Repeating without compile time.") +key = jax.random.PRNGKey(1234) +start_time = time() +ψ_star = update_cross_section_jax(model, ψ_0, z_sequence, key) +jax_time = time() - start_time +print(f"Generated cross-section in {jax_time} seconds") ``` -We can further try to optimize and speed up the compile time of the above function by replacing `for` loop with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html). +And let's see how long it takes if we compile the loop. ```{code-cell} ipython3 -def wealth_time_series_jax(w_0, n, wdy, size, rand_seed=1): +def update_cross_section_jax_compiled(model, + w_distribution, + w_size, + z_sequence, + key): """ - Generate a single time series of length n for wealth given - initial value w_0. - - * This implementation uses `jax.lax.scan`. + Shifts a cross-section of households forward in time - The initial persistent state z_0 for each household is drawn from - the stationary distribution of the AR(1) process. + Takes - * wdy: NamedTuple Model - * w_0: scalar/vector - * n: int - * size: size/shape of the w_0 - * rand_seed: int (Used to generate PRNG key) - """ - rand_key = jax.random.PRNGKey(rand_seed) - rand_key, *subkey = jax.random.split(rand_key, n) + * a current distribution of wealth values as w_distribution and + * an aggregate shock sequence z_sequence - w_0 = jax.device_put(w_0).reshape(size) - z_init = wdy.z_mean + jnp.sqrt(wdy.z_var) * jax.random.normal(rand_key, shape=size) - arrays = w_0, z_init - rand_sub_keys = jnp.array(subkey) + and updates each w_t in w_distribution to w_{t+j}, where + j = len(z_sequence). - w_final = jnp.array([w_0]) + Returns the new distribution. - # Define the function for each update - def update_w_z(arrays, rand_sub_key): - wp, zp = update_states_jax(arrays, wdy, size, rand_sub_key) - return (wp, zp), wp + """ + # Unpack, simplify names + household_params, aggregate_params = model + w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, y_mean = household_params + w = w_distribution + n = len(w) + z = z_sequence + sim_length = len(z) - arrays_last, w_values = jax.lax.scan(update_w_z, arrays, rand_sub_keys) - return jnp.concatenate((w_final, w_values)) + def body_function(t, state): + key, w = state + key, subkey = jax.random.split(key) + U = jax.random.normal(subkey, (2, n)) + y = c_y * jnp.exp(z[t]) + jnp.exp(μ_y + σ_y * U[0, :]) + R = c_r * jnp.exp(z[t]) + jnp.exp(μ_r + σ_r * U[1, :]) + w = y + jnp.where(w < w_hat, 0.0, R * s_0 * w) + return key, w -# Create the jit function -wealth_time_series_jax = jax.jit(wealth_time_series_jax, static_argnums=(1,3,)) + key, w = jax.lax.fori_loop(0, sim_length, body_function, (key, w)) + return w ``` -Let's try simulating the model at different parameter values and investigate the implications for the wealth distribution and also observe the difference in time between `wealth_time_series_jax` and `wealth_time_series_for_loop_jax`. - ```{code-cell} ipython3 -wdy = create_wealth_model() # default model -ts_length = 200 -size = (1,) +update_cross_section_jax_compiled = jax.jit( + update_cross_section_jax_compiled, static_argnums=(2,) +) ``` ```{code-cell} ipython3 -%%time - -w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready() +print("Generating cross-section using JAX with compiled loop") +key = jax.random.PRNGKey(1234) +start_time = time() +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) +jax_fori_time = time() - start_time +print(f"Generated cross-section in {jax_fori_time} seconds.\n") ``` -Running the above function again will be even faster because of JAX's JIT. - ```{code-cell} ipython3 -%%time - -# 2nd time is expected to be very fast because of JIT -w_jax_result = wealth_time_series_jax(wdy.y_mean, ts_length, wdy, size).block_until_ready() +print("Repeating without compile time") +key = jax.random.PRNGKey(1234) +start_time = time() +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) +jax_fori_time = time() - start_time +print(f"Generated cross-section in {jax_fori_time} seconds") ``` ```{code-cell} ipython3 -fig, ax = plt.subplots() -ax.plot(w_jax_result) -plt.show() +print(f"JAX is {numba_time / jax_fori_time:.4f} times faster.\n") ``` -Now here’s function to simulate a cross section of households forward in time. +### Pareto tails -```{code-cell} ipython3 -def update_cross_section_jax(w_distribution, shift_length, wdy, size, rand_seed=2): - """ - Shifts a cross-section of household forward in time +In most countries, the cross-sectional distribution of wealth exhibits a Pareto +tail (power law). - * wdy: NamedTuple Model - * w_distribution: array_like, represents current cross-section +Let's see if our model can replicate this stylized fact by running a simulation +that generates a cross-section of wealth and generating a suitable rank-size plot. - Takes a current distribution of wealth values as w_distribution - and updates each w_t in w_distribution to w_{t+j}, where - j = shift_length. +We will use the function `rank_size` from `quantecon` library. - Returns the new distribution. - """ - new_dist = wealth_time_series_jax(w_distribution, shift_length, wdy, size, rand_seed) - new_distribution = new_dist[-1, :] - return new_distribution +In the limit, data that obeys a power law generates a straight line. + +```{code-cell} ipython3 +model = create_wealth_model() +key = jax.random.PRNGKey(1234) +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) +fig, ax = plt.subplots() +rank_data, size_data = qe.rank_size(ψ_star, c=0.001) +ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) +ax.set_xlabel("log rank") +ax.set_ylabel("log size") -# Create the jit function -update_cross_section_jax = jax.jit(update_cross_section_jax, static_argnums=(1,3,)) +plt.show() ``` -## Applications +### Lorenz curves and Gini coefficients -Let's try simulating the model at different parameter values and investigate -the implications for the wealth distribution. +To study the impact of parameters on inequality, we examine Lorenz curves +and the Gini coefficients at different parameters. +QuantEcon provides functions to compute Lorenz curves and Gini coefficients that are accelerated using Numba. -### Inequality Measures +Here we provide JAX-based functions that do the same job and are faster for large data sets on parallel hardware. -Let's look at how inequality varies with returns on financial assets. -The next function generates a cross section and then computes the Lorenz -curve and Gini coefficient. +#### Lorenz curve + +Recall that, for sorted data $w_1, \ldots, w_n$, the Lorenz curve +generates data points $(x_i, y_i)_{i=0}^n$ according to + +$$ + x_0 = y_0 = 0 + \qquad \text{and, for $i \geq 1$,} \quad + x_i = \frac{i}{n}, + \qquad + y_i = + \frac{\sum_{j \leq i} w_j}{\sum_{j \leq n} w_j} +$$ ```{code-cell} ipython3 -def generate_lorenz_and_gini_jax(wdy, num_households=100_000, T=500): - """ - Generate the Lorenz curve data and gini coefficient corresponding to a - WealthDynamics mode by simulating num_households forward to time T. - """ - size = (num_households, ) - ψ_0 = jnp.full(size, wdy.y_mean) - ψ_star = update_cross_section_jax(ψ_0, T, wdy, size) - return gini_jax(ψ_star), lorenz_curve_jax(ψ_star) +def _lorenz_curve_jax(w, w_size): + n = w.shape[0] + w = jnp.sort(w) + x = jnp.arange(n + 1) / n + s = jnp.concatenate((jnp.zeros(1), jnp.cumsum(w))) + y = s / s[n] + return x, y -# Create the jit function -generate_lorenz_and_gini_jax = jax.jit(generate_lorenz_and_gini_jax, - static_argnums=(1,2,)) +lorenz_curve_jax = jax.jit(_lorenz_curve_jax, static_argnums=(1,)) ``` -Now we investigate how the Lorenz curves associated with the wealth distribution change as return to savings varies. - -The code below plots Lorenz curves for three different values of $\mu_r$. +Let's test ```{code-cell} ipython3 -%%time - -fig, ax = plt.subplots() -μ_r_vals = (0.0, 0.025, 0.05) -gini_vals = [] - -for μ_r in μ_r_vals: - wdy = create_wealth_model(μ_r=μ_r) - gv, (f_vals, l_vals) = generate_lorenz_and_gini_jax(wdy) - ax.plot(f_vals, l_vals, label=f'$\psi^*$ at $\mu_r = {μ_r:0.2}$') - gini_vals.append(gv) +sim_length = 200 +num_households = 1_000_000 +ψ_0 = jnp.full(num_households, y_mean) # Initial distribution +z_sequence = generate_aggregate_state_sequence(aggregate_params, + length=sim_length) +z_sequence = jnp.array(z_sequence) +``` -ax.plot(f_vals, f_vals, label='equality') -ax.legend(loc="upper left") -plt.show() +```{code-cell} ipython3 +key = jax.random.PRNGKey(1234) +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) ``` -The Lorenz curve shifts downwards as returns on financial income rise, indicating a rise in inequality. +```{code-cell} ipython3 +%time _ = lorenz_curve_jax(ψ_star, num_households) +``` -Now let's check the Gini coefficient. +```{code-cell} ipython3 +%time x, y = lorenz_curve_jax(ψ_star, num_households) +``` ```{code-cell} ipython3 fig, ax = plt.subplots() -ax.plot(μ_r_vals, gini_vals, label='gini coefficient') -ax.set_xlabel("$\mu_r$") +ax.plot(x, y, label="Lorenz curve at defaults") +ax.plot(x, x, 'k-', lw=1) ax.legend() plt.show() ``` -Once again, we see that inequality increases as returns on financial income -rise. +#### Gini Coefficient -Let's finish this section by investigating what happens when we change the -volatility term $\sigma_r$ in financial returns. +Recall that, for sorted data $w_1, \ldots, w_n$, the Gini coefficient takes the form -```{code-cell} ipython3 -%%time -fig, ax = plt.subplots() -σ_r_vals = (0.35, 0.45, 0.52) -gini_vals = [] +$$ +G := +\frac + {\sum_{i=1}^n \sum_{j = 1}^n |w_j - w_i|} + {2n\sum_{i=1}^n w_i}. +$$ (eq:gini) -for σ_r in σ_r_vals: - wdy = create_wealth_model(σ_r=σ_r) - gv, (f_vals, l_vals) = generate_lorenz_and_gini_jax(wdy) - ax.plot(f_vals, l_vals, label=f'$\psi^*$ at $\sigma_r = {σ_r:0.2}$') - gini_vals.append(gv) +Here's a function that computes the Gini coefficient using vectorization. -ax.plot(f_vals, f_vals, label='equality') -ax.legend(loc="upper left") -plt.show() +```{code-cell} ipython3 +def _gini_jax(w, w_size): + w_1 = jnp.reshape(w, (w_size, 1)) + w_2 = jnp.reshape(w, (1, w_size)) + g_sum = jnp.sum(jnp.abs(w_1 - w_2)) + return g_sum / (2 * w_size * jnp.sum(w)) + +gini_jax = jax.jit(_gini_jax, static_argnums=(1,)) +``` + +```{code-cell} ipython3 +%time gini = gini_jax(ψ_star, num_households).block_until_ready() ``` -We see that greater volatility has the effect of increasing inequality in this model. +```{code-cell} ipython3 +%time gini = gini_jax(ψ_star, num_households).block_until_ready() +gini +``` ## Exercises ```{exercise} :label: wd_ex1 -For a wealth or income distribution with Pareto tail, a higher tail index suggests lower inequality. - -Indeed, it is possible to prove that the Gini coefficient of the Pareto -distribution with tail index $a$ is $1/(2a - 1)$. - -To the extent that you can, confirm this by simulation. +In this exercise, write an alternative version of `gini_jax` that uses `vmap` instead of reshaping and broadcasting. -In particular, generate a plot of the Gini coefficient against the tail index -using both the theoretical value just given and the value computed from a sample via `gini_jax`. - -For the values of the tail index, use `a_vals = jnp.linspace(1, 10, 25)`. - -Use sample of size 1,000 for each $a$ and the sampling method for generating Pareto draws employed in the discussion of Lorenz curves for the Pareto distribution. - -To the extent that you can, interpret the monotone relationship between the -Gini index and $a$. +Test with the same array to see if you can obtain the same output ``` ```{solution-start} wd_ex1 :class: dropdown ``` -Here is one solution, which produces a good match between theory and -simulation. +Here's one solution: ```{code-cell} ipython3 -a_vals = jnp.linspace(1, 10, 25) # Pareto tail index -ginis = [] +@jax.jit +def gini_jax_vmap(w): -n = 1000 # size of each sample -fig, ax = plt.subplots() -for i, a in enumerate(a_vals): - rand_key = jax.random.PRNGKey(i*10) - u = jax.random.uniform(rand_key, shape=(n,)) - y = u**(-1/a) - ginis.append(gini_jax(y)) -ax.plot(a_vals, ginis, label='sampled') -ax.plot(a_vals, 1/(2*a_vals - 1), label='theoretical') -ax.legend() -plt.show() + def _inner_sum(x): + return jnp.sum(jnp.abs(x - w)) + + inner_sum = jax.vmap(_inner_sum) + + full_sum = jnp.sum(inner_sum(w)) + return full_sum / (2 * len(w) * jnp.sum(w)) ``` -In general, for a Pareto distribution, a higher tail index implies less weight -in the right hand tail. - -This means less extreme values for wealth and hence more equality. +```{code-cell} ipython3 +%time gini = gini_jax_vmap(ψ_star).block_until_ready() +gini +``` -More equality translates to a lower Gini index. +```{code-cell} ipython3 +%time gini = gini_jax_vmap(ψ_star).block_until_ready() +gini +``` ```{solution-end} ``` + ```{exercise-start} :label: wd_ex2 ``` -When savings is constant, the wealth process has the same quasi-linear -structure as a Kesten process, with multiplicative and additive shocks. +In this exercise we investigate how the parameters determining the rate of return on assets and labor income shape inequality. + +In doing so we recall that -The Kesten--Goldie theorem tells us that Kesten processes have Pareto tails under a range of parameterizations. +$$ + R_t := 1 + r_t = c_r \exp(z_t) + \exp(\mu_r + \sigma_r \xi_t) +$$ -The theorem does not directly apply here, since savings is not always constant and since the multiplicative and additive terms in {eq}`wealth_dynam_ah` are not IID. +while + +$$ + y_t = c_y \exp(z_t) + \exp(\mu_y + \sigma_y \zeta_t) +$$ -At the same time, given the similarities, perhaps Pareto tails will arise. +Investigate how the Lorenz curves and the Gini coefficient associated with the wealth distribution change as return to savings varies. -To test this, run a simulation that generates a cross-section of wealth and -generate a rank-size plot. +In particular, plot Lorenz curves for the following three different values of $\mu_r$ -In viewing the plot, remember that Pareto tails generate a straight line. Is -this what you see? +```{code-cell} ipython3 +μ_r_vals = (0.0, 0.025, 0.05) +``` -For sample size and initial conditions, use +Use the following as your initial cross-sectional distribution ```{code-cell} ipython3 -num_households = 250_000 -T = 500 # Shift forward T periods -ψ_0 = jnp.full((num_households, ), wdy.y_mean) # Initial distribution +num_households = 1_000_000 +ψ_0 = jnp.full(num_households, y_mean) # Initial distribution ``` +Once you have done that, plot the Gini coefficients as well. + +Do the outcomes match your intuition? + ```{exercise-end} ``` @@ -594,41 +656,154 @@ T = 500 # Shift forward T periods :class: dropdown ``` -First let's generate the distribution: +Here is one solution ```{code-cell} ipython3 -num_households = 250_000 -T = 500 # how far to shift forward in time -size = (num_households, ) +key = jax.random.PRNGKey(1234) +fig, ax = plt.subplots() +gini_vals = [] +for μ_r in μ_r_vals: + model = create_wealth_model(μ_r=μ_r) + ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key + ) + x, y = lorenz_curve_jax(ψ_star, num_households) + g = gini_jax(ψ_star, num_households) + ax.plot(x, y, label=f'$\psi^*$ at $\mu_r = {μ_r:0.2}$') + gini_vals.append(g) +ax.plot(x, y, label='equality') +ax.legend(loc="upper left") +plt.show() +``` -wdy = create_wealth_model() -ψ_0 = jnp.full(size, wdy.y_mean) -ψ_star = update_cross_section_jax(ψ_0, T, wdy, size) +The Lorenz curve shifts downwards as returns on financial income rise, indicating a rise in inequality. + +Now let's check the Gini coefficient + +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(μ_r_vals, gini_vals, label='Gini coefficient') +ax.set_xlabel("$\mu_r$") +ax.legend() +plt.show() ``` -Let's define a function to get the rank data +As expected, inequality increases as returns on financial income rise. + +```{solution-end} +``` + +```{exercise-start} +:label: wd_ex3 +``` + +Now investigate what happens when we change the volatility term $\sigma_r$ in financial returns. + +Use the same initial condition as before and the sequence ```{code-cell} ipython3 -def rank_size(data, c=1): - w = -jnp.sort(-data) # Reverse sort - w = w[:int(len(w) * c)] # extract top (c * 100)% - rank_data = jnp.arange(len(w)) + 1 - size_data = w - return rank_data, size_data +σ_r_vals = (0.35, 0.45, 0.52) +``` + +To isolate the role of volatility, set $\mu_r = - \sigma_r^2 / 2$ at each $\sigma_r$. + +(This holds the variance of the idiosyncratic term $\exp(\mu_r + \sigma_r \zeta)$ constant.) + +```{exercise-end} +``` + +```{solution-start} wd_ex3 +:class: dropdown ``` -Now let's see the rank-size plot: +Here's one solution ```{code-cell} ipython3 +key = jax.random.PRNGKey(1234) fig, ax = plt.subplots() -rank_data, size_data = rank_size(ψ_star, c=0.001) -ax.loglog(rank_data, size_data, 'o', markersize=3.0, alpha=0.5) -ax.set_xlabel("log rank") -ax.set_ylabel("log size") +gini_vals = [] +for σ_r in σ_r_vals: + model = create_wealth_model(σ_r=σ_r, μ_r=(-σ_r**2/2)) + ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key + ) + x, y = lorenz_curve_jax(ψ_star, num_households) + g = gini_jax(ψ_star, num_households) + ax.plot(x, y, label=f'$\psi^*$ at $\sigma_r = {σ_r:0.2}$') + gini_vals.append(g) +ax.plot(x, y, label='equality') +ax.legend(loc="upper left") +plt.show() +``` +```{code-cell} ipython3 +fig, ax = plt.subplots() +ax.plot(σ_r_vals, gini_vals, label='Gini coefficient') +ax.set_xlabel("$\sigma_r$") +ax.legend() plt.show() ``` ```{solution-end} ``` + +```{exercise-start} +:label: wd_ex4 +``` +In this exercise, examine which has more impact on inequality: + +- a 5% rise in volatility of the rate of return, +- or a 5% rise in volatility of labor income. + +Test this by + +1. Shifting $\sigma_r$ up 5% from the baseline and plotting the Lorenz curve +1. Shifting $\sigma_y$ up 5% from the baseline and plotting the Lorenz curve + +Plot both on the same figure and examine the result. + +```{exercise-end} +``` + +```{solution-start} wd_ex4 +:class: dropdown +``` + +Here's one solution. + +It shows that increasing volatility in financial income has a greater effect + +```{code-cell} ipython3 +model = create_wealth_model() +household_params, aggregate_params = model +w_hat, s_0, c_y, μ_y, σ_y, c_r, μ_r, σ_r, y_mean = household_params +σ_r_default = σ_r +σ_y_default = σ_y + +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) +x_default, y_default = lorenz_curve_jax(ψ_star, num_households) + +model = create_wealth_model(σ_r=(1.05 * σ_r_default)) +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) +x_financial, y_financial = lorenz_curve_jax(ψ_star, num_households) + +model = create_wealth_model(σ_y=(1.05 * σ_y_default)) +ψ_star = update_cross_section_jax_compiled( + model, ψ_0, num_households, z_sequence, key +) +x_labor, y_labor = lorenz_curve_jax(ψ_star, num_households) + +fig, ax = plt.subplots() +ax.plot(x_default, x_default, 'k-', lw=1, label='equality') +ax.plot(x_financial, y_financial, label=r'higher $\sigma_r$') +ax.plot(x_labor, y_labor, label=r'higher $\sigma_y$') +ax.legend() +plt.show() +``` +```{solution-end} +``` \ No newline at end of file