Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/thoglu/jammy_flows into main
Browse files Browse the repository at this point in the history
  • Loading branch information
thoglu committed Jan 2, 2024
2 parents 1c562d7 + ac6ad90 commit 1e4bf50
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 67 deletions.
173 changes: 154 additions & 19 deletions jammy_flows/helper_fns/plotting/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@

###### plotting functions for sphere (s2), employing a flexible grid to save computing while still having smooth contours

def _transform_to_world(ax_object, coords, projection_type="zen_azi"):
"""
Does the coordinate transformation to world coordaintes given a certain projection_type
"""

if(projection_type=="zen_azi"):
healpy_phi_theta_coords=coords*180.0/numpy.pi

# theta
healpy_phi_theta_coords[:,1]=90.0-healpy_phi_theta_coords[:,0]

# phi
healpy_phi_theta_coords[:,0]=coords[:,1]*180.0/numpy.pi
else:
raise Exception("Unknown projection type ", projection_type)

world_coords=ax_object.wcs.all_world2pix(healpy_phi_theta_coords,1)

return world_coords

#### monkeypatching add function for more flexible ticklabels
def add(self,
axis=None,
Expand Down Expand Up @@ -198,6 +218,45 @@ def __init__(self, *args, **kwargs):
super().__init__(*args,
frame_class = kwargs.pop('frame_class', EllipticalFrame),
**kwargs)

def proj_plot(self, *args, **kwargs):
"""
Calls matplotlib.plot in world coordinates, and does an internal transformation first.
Internally the healpy ax uses dec/ra, so have to take of that here.
"""
assert(len(args)>=2)

projection_type="zen_azi"

if("projection_type" in kwargs):
projection_type=kwargs["projection_type"]

assert(projection_type=="zen_azi"), "For now only zen_azi is supported"

x=args[0]
y=args[1]

new_x=x
new_y=y
if(type(x)==list):
new_x=numpy.array(x)
if(type(y)==list):
new_y=numpy.array(y)

assert(new_x.ndim==1)
assert(new_y.ndim==1)

combined_coords=numpy.concatenate([new_x[:,None], new_y[:,None]], axis=1)

world_coords=_transform_to_world(self, combined_coords, projection_type=projection_type)

further_args=args[2:]

new_kwargs=kwargs.copy()
del new_kwargs["projection_type"]

self.plot(world_coords[:,0], world_coords[:,1], *further_args, **new_kwargs)


register_projection(MollviewAzimuth)

Expand Down Expand Up @@ -322,6 +381,44 @@ def graticule(self,
self.coords[1].set_axislabel("zenith [deg]", minpad=zenith_axislabel_minpad) # shift the label a little to the left


def proj_plot(self, *args, **kwargs):
"""
Calls matplotlib.plot in world coordinates, and does an internal transformation first.
Internally the healpy ax uses dec/ra, so have to take of that here.
"""
assert(len(args)>=2)

projection_type="zen_azi"

if("projection_type" in kwargs):
projection_type=kwargs["projection_type"]

assert(projection_type=="zen_azi"), "For now only zen_azi is supported"

x=args[0]
y=args[1]

new_x=x
new_y=y
if(type(x)==list):
new_x=numpy.array(x)
if(type(y)==list):
new_y=numpy.array(y)

assert(new_x.ndim==1)
assert(new_y.ndim==1)

combined_coords=numpy.concatenate([new_x[:,None], new_y[:,None]], axis=1)

world_coords=_transform_to_world(self, combined_coords, projection_type=projection_type)

further_args=args[2:]

new_kwargs=kwargs.copy()
del new_kwargs["projection_type"]

self.plot(world_coords[:,0], world_coords[:,1], *further_args, **new_kwargs)

register_projection(OrthviewAzimuth)

def get_meshed_positions_and_areas(samples,
Expand Down Expand Up @@ -351,6 +448,7 @@ def get_meshed_positions_and_areas(samples,
return ang_vals, per_pixel_areas, moc_map

def get_multiresolution_evals(pdf,
conditional_input=None,
sub_pdf_index=0,
samplesize=10000,
max_entries_per_pixel=5,
Expand All @@ -365,15 +463,42 @@ def get_multiresolution_evals(pdf,
moc_map (Healpix map): Multiresolution healpix map
"""

samples,_,_,_=pdf.sample(samplesize=samplesize)
data_summary_repeated=None
if(conditional_input is not None):
data_summary_repeated=conditional_input

if(type(conditional_input)==list):
if(conditional_input[0].ndim==2):
assert(conditional_input[0].shape[0]==1), "Only a single conditional input item must be given!"
data_summary_repeated=[ci.repeat_interleave(samplesize, dim=0) if ci.ndim==2 else ci[None,:].repeat_interleave(samplesize, dim=0) for ci in conditional_input]
else:
if(conditional_input.ndim==2):
assert(conditional_input.shape[0]==1), "Only a single conditional input item must be given!"
data_summary_repeated=conditional_input.repeat_interleave(samplesize, dim=0) if conditional_input.ndim==2 else conditional_input[None,:].repeat_interleave(samplesize, dim=0)

samples,_,_,_=pdf.sample(samplesize=samplesize, conditional_input=data_summary_repeated)
eval_positions, eval_areas, moc_map=get_meshed_positions_and_areas(samples,max_entries_per_pixel=max_entries_per_pixel)

assert(pdf.pdf_defs_list[sub_pdf_index]=="s2"), ("Trying to get multiresolution for s2 subdimension, but subdimension %d is of type %s" % (sub_pdf_index, pdf.pdf_defs_list[sub_pdf_index]))

if(use_density_if_possible and (sub_pdf_index==0)):
xyz_positions=pdf.transform_target_into_returnable_params(torch.from_numpy(eval_positions).to(samples))
log_pdf,_,_=pdf(xyz_positions, force_embedding_coordinates=True)
pdf_evals=log_pdf.exp().cpu().detach().numpy()

if(data_summary_repeated is not None):

moc_size=xyz_positions.shape[0]
if(type(data_summary_repeated)==list):
assert(moc_size<=data_summary_repeated[0].shape[0])
data_summary_repeated=[ci[:moc_size] for ci in data_summary_repeated]
else:
assert(moc_size<=data_summary_repeated.shape[0])
data_summary_repeated=data_summary_repeated[:moc_size]

log_pdf,_,_=pdf(xyz_positions, force_embedding_coordinates=True, conditional_input=data_summary_repeated)

log_pdf=log_pdf.cpu().detach().numpy()
pdf_evals=numpy.exp(log_pdf)

else:


Expand All @@ -387,13 +512,17 @@ def get_multiresolution_evals(pdf,
pdf_evals[unique_indices]=counts/float(sum(counts))#eval_areas[unique_indices]
pdf_evals=pdf_evals/eval_areas

return eval_positions, pdf_evals, eval_areas, moc_map
# no log_pdf in sample_based evaluation
log_pdf=None

return eval_positions, log_pdf, pdf_evals, eval_areas, moc_map


def plot_multiresolution_healpy(pdf,
fig=None,
ax_to_plot=None,
samplesize=10000,
samplesize=10000,
conditional_input=None,
sub_pdf_index=0,
max_entries_per_pixel=5,
draw_pixels=True,
Expand All @@ -408,7 +537,8 @@ def plot_multiresolution_healpy(pdf,
contour_colors=None, # None -> pick colors from color scheme
zoom=False,
visualization="zen_azi", # zen_azi or dec_ra
declination_trafo_function=None): # required to transform to dec/ra before plotting
declination_trafo_function=None, # required to transform to dec/ra before plotting
show_grid=False):

"""
Visualizes an S2 pdf, or a certain S2 subpart of a PDF using an adaptive healpix grid from mhealpy. Useful if the PDF
Expand All @@ -418,9 +548,10 @@ def plot_multiresolution_healpy(pdf,
"""
assert("s2" in pdf.pdf_defs_list ), "Requires that at least one s2 sub-manifold exists."

eval_positions, pdf_evals, eval_areas, moc_map=get_multiresolution_evals(pdf,
eval_positions, _, pdf_evals, eval_areas, moc_map=get_multiresolution_evals(pdf,
sub_pdf_index=sub_pdf_index,
samplesize=samplesize,
conditional_input=conditional_input,
max_entries_per_pixel=max_entries_per_pixel,
use_density_if_possible=use_density_if_possible)

Expand All @@ -431,10 +562,7 @@ def plot_multiresolution_healpy(pdf,
moc_map=moc_map,
fig=fig,
ax_to_plot=ax_to_plot,
samplesize=samplesize,
max_entries_per_pixel=max_entries_per_pixel,
draw_pixels=draw_pixels,
use_density_if_possible=use_density_if_possible,
log_scale=log_scale,
cbar=cbar,
cbar_kwargs=cbar_kwargs,
Expand All @@ -445,7 +573,8 @@ def plot_multiresolution_healpy(pdf,
contour_colors=contour_colors, # None -> pick colors from color scheme
zoom=zoom,
visualization=visualization, # zen_azi or dec_ra
declination_trafo_function=declination_trafo_function) # required to transform to dec/ra before plotting )
declination_trafo_function=declination_trafo_function,
show_grid=show_grid)

return ax

Expand All @@ -455,10 +584,7 @@ def _plot_multiresolution_healpy(eval_positions,
moc_map=None,
fig=None,
ax_to_plot=None,
samplesize=10000,
max_entries_per_pixel=5,
draw_pixels=True,
use_density_if_possible=True,
log_scale=True,
cbar=True,
cbar_kwargs={},
Expand All @@ -468,8 +594,9 @@ def _plot_multiresolution_healpy(eval_positions,
contour_probs=[0.68, 0.95],
contour_colors=None, # None -> pick colors from color scheme
zoom=False,
visualization="zen_azi", # zen_azi or dec_ra
declination_trafo_function=None): # required to transform to dec/ra before plotting
projection_type="zen_azi", # zen_azi or dec_ra
declination_trafo_function=None, # required to transform to dec/ra before plotting
show_grid=False):

"""
Visualizes an S2 pdf, or a certain S2 subpart of a PDF using an adaptive healpix grid from mhealpy. Useful if the PDF
Expand All @@ -485,6 +612,7 @@ def _plot_multiresolution_healpy(eval_positions,
if(moc_map is None):
sample_pix = mhealpy.ang2pix(mhealpy.MAX_NSIDE, eval_positions[:,0], eval_positions[:,1], nest = True)
moc_map = HealpixMap.moc_histogram(mhealpy.MAX_NSIDE, sample_pix, 1, nest=True)
assert(len(eval_positions)==moc_map.npix)

if(zoom):

Expand Down Expand Up @@ -615,23 +743,30 @@ def _plot_multiresolution_healpy(eval_positions,
target=3
smaller_than_target=desirable_dist_between_graticules<target
num_significant_points=0

while(smaller_than_target):
target=target/10.0
num_significant_points+=1
smaller_than_target=desirable_dist_between_graticules<target


grat_format="d"
if(num_significant_points>0):
grat_format=grat_format+"."+"d"*num_significant_points

graticule_default_kwargs["tick_format"]=grat_format

## max out at 60 degrees for full sky
graticule_default_kwargs["dmer"]=min(desirable_dist_between_graticules, 60.0)
graticule_default_kwargs["dpar"]=min(desirable_dist_between_graticules, 60.0)

for extra_kwarg in graticule_kwargs:
graticule_default_kwargs[extra_kwarg]=graticule_kwargs[extra_kwarg]

ax.graticule(**graticule_default_kwargs)

if(show_grid):
moc_map.plot_grid(ax, linewidth = .1, color = 'white');

return ax

18 changes: 17 additions & 1 deletion jammy_flows/layers/intervals/rational_quadratic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def _flow_mapping(self, inputs, extra_inputs=None):

[x, log_det]=inputs

## make sure we stay in range
x=torch.where(x>1.0, 1.0, x)
x=torch.where(x<-1.0, -1.0, x)

if(self.use_permanent_parameters):
widths=self.rel_log_widths.to(x)
heights=self.rel_log_heights.to(x)
Expand Down Expand Up @@ -219,13 +223,21 @@ def _flow_mapping(self, inputs, extra_inputs=None):
)

log_det_new=log_det+log_det_update.sum(axis=-1)

## make sure we stay in range
x=torch.where(x>1.0, 1.0, x)
x=torch.where(x<-1.0, -1.0, x)

return x, log_det_new

def _inv_flow_mapping(self, inputs, extra_inputs=None):

[x, log_det]=inputs

## make sure we stay in range
x=torch.where(x>1.0, 1.0, x)
x=torch.where(x<-1.0, -1.0, x)

if(self.use_permanent_parameters):
widths=self.rel_log_widths.to(x)
heights=self.rel_log_heights.to(x)
Expand Down Expand Up @@ -296,7 +308,11 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):
)

log_det_new=log_det+log_det_update.sum(axis=-1)


## make sure we stay in range
x=torch.where(x>1.0, 1.0, x)
x=torch.where(x<-1.0, -1.0, x)

return x, log_det_new


Expand Down
13 changes: 7 additions & 6 deletions jammy_flows/layers/spheres/fvm_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):

## go to cylinder from angle
prev_ret=torch.cos(x[:,:1])
fw_upd=torch.log(torch.sin(x[:,0]))
fw_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(x[:,0])))

log_det=log_det+fw_upd

## intermediate [-1,1]->[-1,1] transformation
Expand Down Expand Up @@ -252,7 +253,7 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):
### we have to make the angles safe here...TODO: change to external transformation
ret=torch.where(ret<=-1.0, -1.0+1e-7, ret)
ret=torch.where(ret>=1.0, 1.0-1e-7, ret)

angle=x[:,1:]

if(self.boundary_cos_theta_identity_region==0.0):
Expand Down Expand Up @@ -307,8 +308,8 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):

log_det=torch.masked_scatter(input=log_det, mask=contained_mask, source=log_det_contained)

## go back to angle
ret=torch.acos(ret)
## go back to angle in a safe way
ret=sphere_base.return_safe_angle_within_pi(torch.acos(ret))
rev_upd=torch.log(torch.sin(ret))[:,0]
log_det=log_det-rev_upd

Expand Down Expand Up @@ -356,7 +357,7 @@ def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):

## go to cylinder from angle
prev_ret=torch.cos(x[:,:1])
fw_upd=torch.log(torch.sin(x[:,0]))
fw_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(x[:,0])))
log_det=log_det+fw_upd

angle=x[:,1:]
Expand Down Expand Up @@ -446,7 +447,7 @@ def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):

## go back to angle
ret=torch.acos(ret)
rev_upd=torch.log(torch.sin(ret))[:,0]
rev_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(ret)))[:,0]
log_det=log_det-rev_upd

ret=torch.cat([ret, angle], dim=1)
Expand Down
Loading

0 comments on commit 1e4bf50

Please sign in to comment.