Skip to content

Commit

Permalink
add animation of optimization process
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Oct 25, 2024
1 parent 3182a3d commit e1df648
Showing 1 changed file with 249 additions and 0 deletions.
249 changes: 249 additions & 0 deletions spateo/plotting/static/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Union

import matplotlib as mpl
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
Expand Down Expand Up @@ -671,6 +672,254 @@ def overlay_slices_2d(
)


def optimization_animation(
aligned_slices: List[AnnData],
label_key: Optional[str] = None,
spatial_key: str = "spatial",
key_added: str = "align_spatial",
iter_key_added: Optional[str] = "iter_spatial",
filename: Optional[str] = "Visualization2D",
fps: int = 10,
stepsize: int = 10,
cmap="Set1",
palette: Optional[dict] = None,
point_size: Optional[float] = None,
n_sampling: int = -1,
):
assert len(aligned_slices) == 2, "Input aligned_slices must be 2 slices!"

if label_key is not None:
if [label_key in s.obs.keys() for s in aligned_slices]:
labels = [s.obs[label_key] for s in aligned_slices]
label1 = aligned_slices[0].obs[label_key]
label2 = aligned_slices[1].obs[label_key]
else:
label1 = np.zeros((aligned_slices[0].shape[0],), dtype=np.int32)
label2 = np.ones((aligned_slices[1].shape[0],), dtype=np.int32)

if n_sampling > 0:
sampling_idx1 = (
np.random.choice(aligned_slices[0].shape[0], n_sampling, replace=False)
if n_sampling < aligned_slices[0].shape[0]
else np.arange(aligned_slices[0].shape[0])
)
sampling_idx2 = (
np.random.choice(aligned_slices[1].shape[0], n_sampling, replace=False)
if n_sampling < aligned_slices[1].shape[0]
else np.arange(aligned_slices[1].shape[0])
)
else:
sampling_idx1 = np.arange(aligned_slices[0].shape[0])
sampling_idx2 = np.arange(aligned_slices[1].shape[0])

# generate palette
if (palette is None) and (label_key is not None):
palette = _agenerate_palette(*labels, cmap=cmap)

if label_key is not None:
label1_colors = [palette[cat] for cat in label1[sampling_idx1]]
label2_colors = [palette[cat] for cat in label2[sampling_idx2]]
else:
label1_colors = ["#e41a1c" for cat in label1[sampling_idx1]]
label2_colors = ["#377eb8" for cat in label2[sampling_idx2]]

if point_size is None:
point_size = 500 * 10 / (len(sampling_idx1) + len(sampling_idx2))

coordsB = aligned_slices[0].obsm[spatial_key]
plt.ioff()
fig, ax = plt.subplots(figsize=(10, 5))
ax.axis("equal")
ax.set_xticks([])
ax.set_yticks([])
artists = []
iter_dict = aligned_slices[1].uns[iter_key_added]
iter = len(iter_dict[key_added])
iteration = range(0, iter, stepsize)
ax.scatter(
coordsB[sampling_idx1, 0], coordsB[sampling_idx1, 1], marker="o", s=point_size, c=label1_colors, edgecolors=None
)
for i in iteration:
frame = ax.scatter(
iter_dict[key_added][i][sampling_idx2, 0],
iter_dict[key_added][i][sampling_idx2, 1],
marker="o",
s=point_size,
c=label2_colors,
edgecolors=None,
)
title_text = "Iter: {}, sigma2: {:.3f}.".format(i, iter_dict["sigma2"][i])
tit = ax.text(0.5, 1.02, title_text, ha="center", va="bottom", size=16, weight="bold", transform=ax.transAxes)
artists.append([frame, tit])
ani = animation.ArtistAnimation(fig=fig, artists=artists, interval=4, blit=False)
ani.save(filename + ".gif", fps=fps, dpi=100)
plt.close()


def plot_deformation_grid(
adata,
spatial_key,
origin_spatial_key,
label_key,
predict_func,
ax,
point_size,
grid_num=10,
line_width=0.5,
grid_color="black",
expand_scale=0.1,
palette=None,
title="",
legend=True,
fontsize=8,
fill=False,
):
x = adata.obsm[spatial_key][:, 0]
y = adata.obsm[spatial_key][:, 1]
origin_x = adata.obsm[origin_spatial_key][:, 0]
origin_y = adata.obsm[origin_spatial_key][:, 1]
label = adata.obs[label_key]
if palette is None:
n_colors = len(np.unique(label))
palette = sns.color_palette("Paired", n_colors)

# plot deformation grid

# Generate the grid points
horizontal_lines_data = pd.DataFrame(columns=["x", "y", "type"])
vertical_lines_data = pd.DataFrame(columns=["x", "y", "type"])
x_min, x_max = origin_x.min(), origin_x.max()
y_min, y_max = origin_y.min(), origin_y.max()
# expand the min max
x_length = x_max - x_min
y_length = y_max - y_min
x_min -= x_length * expand_scale
x_max += x_length * expand_scale
y_min -= y_length * expand_scale
y_max += y_length * expand_scale
horizontal_values = np.linspace(y_min, y_max, grid_num)
vertical_values = np.linspace(x_min, x_max, grid_num)
horizontal_lines, vertical_lines = [], []

if fill:
for i, vertical_value in enumerate(vertical_values):
vertical_line = np.linspace(y_min, y_max, 1000)[:, np.newaxis]
vertical_line = np.concatenate([np.ones_like(vertical_line) * vertical_value, vertical_line], axis=1)
deformed_vertical_line = predict_func(vertical_line)
if (i == 0) or (i == len(vertical_values) - 1):
if i == 0:
edge_points_vert_up = deformed_vertical_line
if i == len(vertical_values) - 1:
edge_points_vert_down = deformed_vertical_line
else:
continue

for i, horizontal_value in enumerate(horizontal_values):
horizontal_line = np.linspace(x_min, x_max, 1000)[:, np.newaxis]
horizontal_line = np.concatenate(
[horizontal_line, np.ones_like(horizontal_line) * horizontal_value], axis=1
)
deformed_horizontal_line = predict_func(horizontal_line)
if (i == 0) or (i == len(horizontal_values) - 1):
if i == 0:
edge_points_hor_right = deformed_horizontal_line
if i == len(vertical_values) - 1:
edge_points_hor_left = deformed_horizontal_line
else:
continue
edge_x = [
edge_points_vert_up[:, 0],
edge_points_hor_right[:, 0],
np.flip(edge_points_vert_down[:, 0]),
np.flip(edge_points_hor_left[:, 0]),
]
edge_y = [
edge_points_vert_up[:, 1],
edge_points_hor_right[:, 1],
np.flip(edge_points_vert_down[:, 1]),
np.flip(edge_points_hor_left[:, 1]),
]
ax.fill(edge_x, edge_y, color=np.array([249, 249, 249]) / 255, alpha=0.5)
sns.scatterplot(x=x, y=y, hue=label, palette=palette, ax=ax, s=point_size, legend=legend)

# Plot horizontal lines
for i, vertical_value in enumerate(vertical_values):
vertical_line = np.linspace(y_min, y_max, 1000)[:, np.newaxis]
vertical_line = np.concatenate([np.ones_like(vertical_line) * vertical_value, vertical_line], axis=1)
deformed_vertical_line = predict_func(vertical_line)
if (i == 0) or (i == len(vertical_values) - 1):
# ax.plot(deformed_vertical_line[:,0], deformed_vertical_line[:,1], color=np.array([233,72,23])/255, linewidth=line_width, alpha=1)
continue
else:
ax.plot(
deformed_vertical_line[:, 0],
deformed_vertical_line[:, 1],
color=grid_color,
linewidth=line_width,
alpha=0.8,
)

# Plot vertical line
for i, horizontal_value in enumerate(horizontal_values):
horizontal_line = np.linspace(x_min, x_max, 1000)[:, np.newaxis]
horizontal_line = np.concatenate([horizontal_line, np.ones_like(horizontal_line) * horizontal_value], axis=1)
deformed_horizontal_line = predict_func(horizontal_line)
if (i == 0) or (i == len(horizontal_values) - 1):
# ax.plot(deformed_horizontal_line[:,0], deformed_horizontal_line[:,1], color=np.array([233,72,23])/255, linewidth=line_width, alpha=1)
continue
else:
ax.plot(
deformed_horizontal_line[:, 0],
deformed_horizontal_line[:, 1],
color=grid_color,
linewidth=line_width,
alpha=0.8,
)

# edge_points_vert_up = []
# edge_points_vert_down = []

for i, vertical_value in enumerate(vertical_values):
vertical_line = np.linspace(y_min, y_max, 1000)[:, np.newaxis]
vertical_line = np.concatenate([np.ones_like(vertical_line) * vertical_value, vertical_line], axis=1)
deformed_vertical_line = predict_func(vertical_line)
if (i == 0) or (i == len(vertical_values) - 1):
ax.plot(
deformed_vertical_line[:, 0],
deformed_vertical_line[:, 1],
color=np.array([91, 139, 200]) / 255,
linewidth=1.5 * line_width,
alpha=1,
)
else:
continue
# ax.plot(deformed_vertical_line[:,0], deformed_vertical_line[:,1], color=grid_color, linewidth=line_width, alpha=1)

for i, horizontal_value in enumerate(horizontal_values):
horizontal_line = np.linspace(x_min, x_max, 1000)[:, np.newaxis]
horizontal_line = np.concatenate([horizontal_line, np.ones_like(horizontal_line) * horizontal_value], axis=1)
deformed_horizontal_line = predict_func(horizontal_line)
if (i == 0) or (i == len(horizontal_values) - 1):
ax.plot(
deformed_horizontal_line[:, 0],
deformed_horizontal_line[:, 1],
color=np.array([91, 139, 200]) / 255,
linewidth=1.5 * line_width,
alpha=1,
)
else:
# ax.plot(deformed_horizontal_line[:,0], deformed_horizontal_line[:,1], color=grid_color, linewidth=line_width, alpha=1)
continue

if legend:
ax.legend().remove()
ax.set_facecolor("white")
ax.axis("off")
if title != "":
ax.set_title(title + " mapping", fontsize=fontsize)
ax.set_aspect("equal")


# def plot_align_correspondence_2d(
# slices: List[AnnData],
# mapping: List[np.ndarray],
Expand Down

0 comments on commit e1df648

Please sign in to comment.