Skip to content

Commit

Permalink
Resolve merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Justinezgh committed Mar 5, 2024
2 parents 1e3dad2 + 6a6d717 commit 43b9272
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 41 deletions.
Binary file not shown.
Binary file not shown.
89 changes: 63 additions & 26 deletions sbi_lens/normflow/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,94 @@
import jax
import jax.numpy as jnp
import optax
import tensorflow_probability as tfp

tfp = tfp.experimental.substrates.jax
tfb = tfp.bijectors
tfd = tfp.distributions

class TrainModel:
def loss_mse(self, params, theta, x, state_resnet):
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)

loss = jnp.mean(jnp.sum((y - theta) ** 2, axis=1))

return loss, opt_state_resnet

def loss_vmim(self, params, theta, x, state_resnet):
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)
log_prob = self.nf.apply(params, theta, y)

return -jnp.mean(log_prob), opt_state_resnet

def loss_nll(self, params, theta, x, _):
y, _ = self.compressor.apply(
self.info_compressor[0], self.info_compressor[1], None, x
)
log_prob = self.nf.apply(params, theta, y)

return -jnp.mean(log_prob), _

class TrainModel:
def __init__(
self,
compressor,
nf,
optimizer,
loss_name,
nb_pixels,
nb_bins,
dim=None,
info_compressor=None,
):
self.compressor = compressor
self.nf = nf
self.optimizer = optimizer
self.nb_pixels = nb_pixels
self.nb_bins = nb_bins
self.dim = dim # summary statistic dimension

if loss_name == "train_compressor_mse":
self.loss = self.loss_mse
elif loss_name == "train_compressor_vmim":
self.loss = self.loss_vmim
elif loss_name == "train_compressor_gnll":
self.loss = self.loss_gnll
if self.dim is None:
raise ValueError("dim should be specified when using gnll compressor")
elif loss_name == "loss_for_sbi":
if info_compressor is None:
raise NotImplementedError
raise ValueError("sbi loss needs compressor informations")
else:
self.info_compressor = info_compressor
self.loss = self.loss_nll

def loss_mse(self, params, theta, x, state_resnet):
"""Compute the Mean Squared Error loss"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)

loss = jnp.mean(jnp.sum((y - theta) ** 2, axis=1))

return loss, opt_state_resnet

def loss_mae(self, params, theta, x, state_resnet):
"""Compute the Mean Absolute Error loss"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)

loss = jnp.mean(jnp.sum(jnp.absolute(y - theta), axis=1))

return loss, opt_state_resnet

def loss_vmim(self, params, theta, x, state_resnet):
"""Compute the Variational Mutual Information Maximization loss"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)
log_prob = self.nf.apply(params, theta, y)

return -jnp.mean(log_prob), opt_state_resnet

def loss_gnll(self, params, theta, x, state_resnet):
"""Compute the Gaussian Negative Log Likelihood loss"""
y, opt_state_resnet = self.compressor.apply(params, state_resnet, None, x)
y_mean = y[..., : self.dim]
y_var = y[..., self.dim :]
y_var = tfb.FillScaleTriL(diag_bijector=tfb.Softplus(low=1e-3)).forward(y_var)

@jax.jit
@jax.vmap
def _get_log_prob(y_mean, y_var, theta):
likelihood = tfd.MultivariateNormalTriL(y_mean, y_var)
return likelihood.log_prob(theta)

loss = -jnp.mean(_get_log_prob(y_mean, y_var, theta))

return loss, opt_state_resnet

def loss_nll(self, params, theta, x, _):
"""Compute the Negative Log Likelihood loss.
This loss is for inference so it requires to have a trained compressor.
"""
y, _ = self.compressor.apply(
self.info_compressor[0], self.info_compressor[1], None, x
)
log_prob = self.nf.apply(params, theta, y)

return -jnp.mean(log_prob), _

@partial(jax.jit, static_argnums=(0,))
def update(self, model_params, opt_state, theta, x, state_resnet=None):
(loss, opt_state_resnet), grads = jax.value_and_grad(self.loss, has_aux=True)(
Expand Down
243 changes: 243 additions & 0 deletions sbi_lens/simulator/Lpt_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import astropy.units as u
import jax
import jax.numpy as jnp
import jax_cosmo as jc
import jax_cosmo.constants as constants
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax.scipy.ndimage import map_coordinates
from jax_cosmo.scipy.integrate import simps
from jaxpm.kernels import fftk
from jaxpm.painting import cic_paint, compensate_cic
from jaxpm.pm import growth_factor, pm_forces

from sbi_lens.simulator.redshift import subdivide

# code from François Lanusse


def linear_field(mesh_shape, box_size, pk, field):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
kmesh = (
sum((kk / box_size[i] * mesh_shape[i]) ** 2 for i, kk in enumerate(kvec)) ** 0.5
)
pkmesh = (
pk(kmesh)
* (mesh_shape[0] * mesh_shape[1] * mesh_shape[2])
/ (box_size[0] * box_size[1] * box_size[2])
)

field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field


def lpt_lightcone(cosmo, initial_conditions, positions, a, mesh_shape):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta=initial_conditions).reshape(
mesh_shape + [3]
)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a).reshape([1, 1, -1, 1]) * initial_force
return dx.reshape([-1, 3])


def convergence_Born(cosmo, density_planes, r, a, dx, dz, coords, z_source):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz),
lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with
shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
"""
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c) ** 2
# Compute comoving distance of source galaxies
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))

convergence = 0
n_planes = len(r)

def scan_fn(carry, i):
density_planes, a, r = carry

p = density_planes[:, :, i]
density_normalization = dz * r[i] / a[i]
p = (p - p.mean()) * constant_factor * density_normalization

# Interpolate at the density plane coordinates
im = map_coordinates(p, coords * r[i] / dx - 0.5, order=1, mode="wrap")

return carry, im * jnp.clip(1.0 - (r[i] / r_s), 0, 1000).reshape([-1, 1, 1])

_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r), jnp.arange(n_planes))

return convergence.sum(axis=0)


def make_full_field_model(field_size, field_npix, box_shape, box_size):
def forward_model(cosmo, nz_shear, initial_conditions):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)

def pk_fn(x):
return jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)

# Create initial conditions
lin_field = linear_field(box_shape, box_size, pk_fn, initial_conditions)

# Create particles
particles = jnp.stack(
jnp.meshgrid(*[jnp.arange(s) for s in box_shape]), axis=-1
).reshape([-1, 3])

# Compute the scale factor that corresponds to each slice of the volume
r = (jnp.arange(box_shape[-1]) + 0.5) * box_size[-1] / box_shape[-1]
a = jc.background.a_of_chi(cosmo, r)

cosmo = jc.Cosmology(
Omega_c=cosmo.Omega_c,
sigma8=cosmo.sigma8,
Omega_b=cosmo.Omega_b,
h=cosmo.h,
n_s=cosmo.n_s,
w0=cosmo.w0,
Omega_k=0.0,
wa=0.0,
)

# Initial displacement
dx = lpt_lightcone(cosmo, lin_field, particles, a, box_shape)

# Paint the particles on a new mesh
lightcone = cic_paint(jnp.zeros(box_shape), particles + dx)
# Apply de-cic filter to recover more signal on small scales
lightcone = compensate_cic(lightcone)

dx = box_size[0] / box_shape[0]
dz = box_size[-1] / box_shape[-1]

# Defining the coordinate grid for lensing map
xgrid, ygrid = np.meshgrid(
np.linspace(
0, field_size, box_shape[0], endpoint=False
), # range of X coordinates
np.linspace(0, field_size, box_shape[1], endpoint=False),
) # range of Y coordinates
coords = jnp.array((np.stack([xgrid, ygrid], axis=0) * u.deg).to(u.rad))

# Generate convergence maps by integrating over nz and source planes
convergence_maps = [
simps(
lambda z: nz(z).reshape([-1, 1, 1])
* convergence_Born(cosmo, lightcone, r, a, dx, dz, coords, z),
0.01,
3.0,
N=32,
)
for nz in nz_shear
]

# Reshape the maps to desired resoluton
convergence_maps = [
kmap.reshape(
[
field_npix,
box_shape[0] // field_npix,
field_npix,
box_shape[1] // field_npix,
]
)
.mean(axis=1)
.mean(axis=-1)
for kmap in convergence_maps
]

return convergence_maps, lightcone

return forward_model


# Define the probabilistic model
def lensingLpt(
N,
map_size,
box_size,
box_shape,
gal_per_arcmin2,
sigma_e,
nbins,
a,
b,
z0,
):
"""
This function defines the top-level forward model for our observations
"""

pix_area = (map_size * 60 / N) ** 2

# Sampling initial conditions
initial_conditions = numpyro.sample(
"z", dist.Normal(jnp.zeros(box_shape), jnp.ones(box_shape))
)

omega_c = numpyro.sample("omega_c", dist.TruncatedNormal(0.2664, 0.2, low=0))
omega_b = numpyro.sample("omega_b", dist.Normal(0.0492, 0.006))
sigma_8 = numpyro.sample("sigma_8", dist.Normal(0.831, 0.14))
h_0 = numpyro.sample("h_0", dist.Normal(0.6727, 0.063))
n_s = numpyro.sample("n_s", dist.Normal(0.9645, 0.08))
w_0 = numpyro.sample("w_0", dist.TruncatedNormal(-1.0, 0.9, low=-2.0, high=-0.3))

cosmo = jc.Cosmology(
Omega_c=omega_c,
Omega_b=omega_b,
sigma8=sigma_8,
h=h_0,
n_s=n_s,
w0=w_0,
wa=0.0,
Omega_k=0.0,
)

# Generate random convergence maps
nz = jc.redshift.smail_nz(a, b, z0, gals_per_arcmin2=gal_per_arcmin2)
nz_shear = subdivide(nz, nbins=nbins, zphot_sigma=0.05)

lensing_model = jax.jit(
make_full_field_model(
field_size=map_size,
field_npix=N,
box_size=box_size,
box_shape=box_shape,
)
)

field, _ = lensing_model(cosmo, nz_shear, initial_conditions)
field = jnp.transpose(jnp.array(field), [1, 2, 0])

x = numpyro.sample(
"y",
dist.MultivariateNormal(
loc=field,
covariance_matrix=jnp.diag(
sigma_e**2
/ (jnp.array([b.gals_per_arcmin2 for b in nz_shear]) * pix_area)
),
),
)

return x
7 changes: 7 additions & 0 deletions sbi_lens/simulator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
tfp = tfp.substrates.jax
tfd = tfp.distributions

np.complex = complex
np.float = float

SOURCE_FILE = Path(__file__)
SOURCE_DIR = SOURCE_FILE.parent
ROOT_DIR = SOURCE_DIR.parent.resolve()
DATA_DIR = ROOT_DIR / "data"

class ForwardModelMassMap:
def __init__(self, config, model_type, lognormal_shifts, with_noise):
Expand Down
Loading

0 comments on commit 43b9272

Please sign in to comment.