Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Oct 27, 2024
1 parent 7f0c050 commit 853ca39
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 159 deletions.
281 changes: 130 additions & 151 deletions stemflow/utils/plot_gif.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
from functools import partial
from typing import Tuple, Union

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from sklearn.preprocessing import LabelEncoder
from matplotlib.colors import Normalize


def make_sample_gif(
data: pd.core.frame.DataFrame,
data: pd.DataFrame,
file_path: str,
col: str = "abundance",
Spatio1: str = "longitude",
Spatio2: str = "latitude",
Temporal1: str = "DOY",
figsize: Tuple[Union[float, int]] = (18, 9),
xlims: Tuple[Union[float, int]] = None,
ylims: Tuple[Union[float, int]] = None,
figsize: Tuple[Union[float, int], Union[float, int]] = (18, 9),
xlims: Tuple[Union[float, int], Union[float, int]] = None,
ylims: Tuple[Union[float, int], Union[float, int]] = None,
grid: bool = True,
lng_size: int = 20,
lat_size: int = 20,
Expand All @@ -33,205 +31,186 @@ def make_sample_gif(
dpi: Union[float, int] = 300,
fps: int = 30,
cmap: str = "plasma",
verbose=1,
verbose: int = 1,
):
"""make GIF with plt.imshow function
A function to generate GIF file of spatio-temporal pattern.
"""
Create a GIF visualizing spatio-temporal data using plt.imshow.
Args:
data:
Input dataframe. Data should be trimed to the target area/time slice before applying this function.
file_path:
Output GIF file path
col:
Column that contain the value to plot
Spatio1:
Spatio variable column 1
Spatio2:
Spatio variable column 2
Temporal1:
Temporal variable column 1
figsize:
Size of the figure. In matplotlib style.
xlims:
xlim of the figure. If None, default to xlim=(data[Spatio1].min(), data[Spatio1].max()). In matplotlib style.
ylims:
ylim of the figure. If None, default to ylim=(data[Spatio2].min(), data[Spatio2].max()). In matplotlib style.
grid:
Whether to add grids.
lng_size:
pixel count to aggregate at longitudinal direction. Larger means finer resolution.
lat_size:
pixel count to aggregate at latitudinal direction. Larger means finer resolution.
xtick_interval:
the size of x tick interval. If None, default to the cloest 10-based value (0.001, 0.01, 1, ... 1000, ..., etc).
ytick_interval:
the size of y tick interval.
log_scale:
log transform the target value or not.
vmin:
vmin of color map.
vmax:
vmax of color map. If None, set to the 0.9 quantile of the upper bound.
lightgrey_under:
Whether to set color as ligthgrey where values are below vmin.
adder:
If log_scale==True, value = np.log(value + adder)
dpi:
dpi of the GIF.
fps:
speed of GIF playing (frames per second).
cmap:
color map
verbose:
Print current frame if verbose >= 1.
data (pd.DataFrame): Input DataFrame, pre-filtered for the target area/time.
file_path (str): Output GIF file path.
col (str): Column containing the values to plot.
Spatio1 (str): First spatial variable column.
Spatio2 (str): Second spatial variable column.
Temporal1 (str): Temporal variable column.
figsize (Tuple[Union[float, int], Union[float, int]]): Figure size.
xlims (Tuple[Union[float, int], Union[float, int]]): x-axis limits.
ylims (Tuple[Union[float, int], Union[float, int]]): y-axis limits.
grid (bool): Whether to display a grid.
lng_size (int): Number of longitudinal pixels (resolution).
lat_size (int): Number of latitudinal pixels (resolution).
xtick_interval (Union[float, int, None]): Interval between x-ticks.
ytick_interval (Union[float, int, None]): Interval between y-ticks.
log_scale (bool): Whether to apply a logarithmic scale to the data.
vmin (Union[float, int]): Minimum value for color scaling.
vmax (Union[float, int, None]): Maximum value for color scaling.
lightgrey_under (bool): Use light grey color for values below vmin.
adder (Union[int, float]): Value to add before log transformation.
dpi (Union[float, int]): Dots per inch for the output GIF.
fps (int): Frames per second for the GIF.
cmap (str): Colormap to use.
verbose (int): Verbosity level.
"""
#
# Sort data by the temporal variable
data = data.sort_values(by=Temporal1)
data.loc[:, "Temporal_indexer"] = LabelEncoder().fit_transform(data[Temporal1])
data["Temporal_indexer"], _ = pd.factorize(data[Temporal1])

#
# Set x and y limits if not provided
if xlims is None:
xlims = (data[Spatio1].min(), data[Spatio1].max())
if ylims is None:
ylims = (data[Spatio2].min(), data[Spatio2].max())

#
lng_gird = np.linspace(xlims[0], xlims[1], lng_size + 1)[1:]
lat_gird = np.linspace(ylims[0], ylims[1], lat_size + 1)[::-1][1:]
# Create spatial grids without slicing
lng_grid = np.linspace(xlims[0], xlims[1], lng_size + 1)
lat_grid = np.linspace(ylims[0], ylims[1], lat_size + 1)[::-1]

# xtick_interval & ytick_interval
closest_set = [10.0 ** (i) for i in np.arange(-15, 15, 1)] + [10.0 ** (i) / 2 for i in np.arange(-15, 15, 1)]
spatio1_base = (data[Spatio1].max() - data[Spatio1].min()) / 5
# Determine tick intervals
closest_set = (
[10.0 ** i for i in np.arange(-15, 15, 1)]
+ [10.0 ** i / 2 for i in np.arange(-15, 15, 1)]
)
spatio1_base = (xlims[1] - xlims[0]) / 5
if xtick_interval is None:
xtick_interval = min(closest_set, key=lambda x: np.inf if x - spatio1_base > 0 else abs(x - spatio1_base))
xtick_interval = min(
closest_set,
key=lambda x: np.inf if x - spatio1_base > 0 else abs(x - spatio1_base),
)
if xtick_interval >= 1:
xtick_interval = int(xtick_interval)

spatio2_base = (data[Spatio2].max() - data[Spatio2].min()) / 5
spatio2_base = (ylims[1] - ylims[0]) / 5
if ytick_interval is None:
ytick_interval = min(closest_set, key=lambda x: np.inf if x - spatio2_base > 0 else abs(x - spatio2_base))
ytick_interval = min(
closest_set,
key=lambda x: np.inf if x - spatio2_base > 0 else abs(x - spatio2_base),
)
if ytick_interval >= 1:
ytick_interval = int(ytick_interval)

# Utility function to round numbers to the same decimal places
def round_to_same_decimal_places(A, B):
# Convert B to string to count decimal places
str_B = str(B)
if "." in str_B:
decimal_places = len(str_B.split(".")[1])
else:
decimal_places = 0

# Round A to the same number of decimal places as B
rounded_A = round(A, decimal_places)

if abs(rounded_A) > 1000:
# Use format to convert to scientific notation with the same number of decimal places
formatted_A = format(rounded_A, f".{decimal_places}e")
else:
# Simply convert to string with the required precision
formatted_A = f"{rounded_A:.{decimal_places}f}"

return formatted_A

# Initialize figure and axes
fig, ax = plt.subplots(figsize=figsize)

def animate(i, norm, log_scale=log_scale):
# Set color scaling
if vmax is None:
vmax = (
np.max(np.log(data[col].values + adder))
if log_scale
else np.max(data[col].values)
)
norm = Normalize(vmin=vmin, vmax=vmax)

# Prepare colormap
my_cmap = plt.get_cmap(cmap)
if lightgrey_under:
my_cmap.set_under("lightgrey")

# Initialize the image to set up the colorbar
im = ax.imshow(
np.zeros((lat_size, lng_size)), norm=norm, cmap=my_cmap, animated=True
)
cbar = fig.colorbar(im, ax=ax, shrink=0.5)
cbar.ax.get_yaxis().labelpad = 15
cbar_label = f"log({col})" if log_scale else col
cbar.ax.set_ylabel(cbar_label, rotation=270)

# Precompute tick labels and positions
x_ticks = np.arange(xlims[0], xlims[1] + xtick_interval, xtick_interval)
x_tick_labels = [round_to_same_decimal_places(val, xtick_interval) for val in x_ticks]
# Find positions of x_ticks within lng_grid
x_tick_positions = np.searchsorted(lng_grid, x_ticks, side='left')
ax.set_xticks(x_tick_positions)
ax.set_xticklabels(x_tick_labels)

y_ticks = np.arange(ylims[0], ylims[1] + ytick_interval, ytick_interval)
y_tick_labels = [round_to_same_decimal_places(val, ytick_interval) for val in y_ticks]
# Since lat_grid is reversed, we need to account for that
y_tick_positions = lat_size - np.searchsorted(lat_grid[::-1], y_ticks, side='left') - 1
ax.set_yticks(y_tick_positions)
ax.set_yticklabels(y_tick_labels)

# Animation function
def animate(i):
if verbose >= 1:
print(i, end=".")
print(f"Processing frame {i+1}/{frames}", end="\r")

ax.clear()
sub = data[data["Temporal_indexer"] == i].copy()
temporal_value = np.array(sub[Temporal1].values)[0]

g1 = np.digitize(sub[Spatio1], lng_gird, right=True)
g1 = np.where(g1 >= lng_size, lng_size - 1, g1).astype("int")
sub = data[data["Temporal_indexer"] == i]
if sub.empty:
return []

g2 = np.digitize(sub[Spatio2], lat_gird, right=True)
g2 = np.where(g2 >= lng_size, lng_size - 1, g2).astype("int")
temporal_value = sub[Temporal1].iloc[0]

sub.loc[:, f"{Spatio1}_grid"] = g1
sub.loc[:, f"{Spatio2}_grid"] = g2
sub = sub[(sub[f"{Spatio1}_grid"] <= lng_size - 1) & (sub[f"{Spatio2}_grid"] <= lat_size - 1)]
# Correct digitization with adjusted bins
g1 = np.digitize(sub[Spatio1], lng_grid, right=False) - 1
g1 = np.clip(g1, 0, lng_size - 1).astype(int)

sub = sub.groupby([f"{Spatio1}_grid", f"{Spatio2}_grid"])[[col]].mean().reset_index(drop=False)
g2 = np.digitize(sub[Spatio2], lat_grid, right=False) - 1
g2 = np.clip(g2, 0, lat_size - 1).astype(int)

im = np.array([np.nan] * lat_size * lng_size).reshape(lat_size, lng_size)
sub[f"{Spatio1}_grid"] = g1
sub[f"{Spatio2}_grid"] = g2

if log_scale:
im[sub[f"{Spatio2}_grid"].values, sub[f"{Spatio1}_grid"].values] = np.log(sub[col] + adder)
else:
im[sub[f"{Spatio2}_grid"].values, sub[f"{Spatio1}_grid"].values] = sub[col]

my_cmap = matplotlib.colormaps.get_cmap(cmap)
grouped = sub.groupby(
[f"{Spatio2}_grid", f"{Spatio1}_grid"]
)[col].mean()

if lightgrey_under:
my_cmap.set_under("lightgrey")

scat1 = ax.imshow(im, norm=norm, cmap=my_cmap)
im_data = np.full((lat_size, lng_size), np.nan)
indices = (grouped.index.get_level_values(0), grouped.index.get_level_values(1))
values = np.log(grouped.values + adder) if log_scale else grouped.values
im_data[indices] = values

im = ax.imshow(im_data, norm=norm, cmap=my_cmap, animated=True)
ax.set_title(f"{Temporal1}: {temporal_value}", fontsize=30)

# Reset ticks
xtick_labels = np.arange(xlims[0], xlims[1], xtick_interval)
xtick_labels = [round_to_same_decimal_places(i, xtick_interval) for i in xtick_labels]
xtick_positions = np.linspace(0, im.shape[1] - 1, len(xtick_labels))
ax.set_xticks(xtick_positions)
ax.set_xticklabels(xtick_labels)

ytick_labels = np.arange(ylims[0], ylims[1], ytick_interval)
ytick_labels = [round_to_same_decimal_places(i, ytick_interval) for i in ytick_labels]
ytick_positions = np.linspace(im.shape[0] - 1, 0, len(ytick_labels))
ax.set_yticks(ytick_positions)
ax.set_yticklabels(ytick_labels)
# Re-apply ticks and grid in each frame
ax.set_xticks(x_tick_positions)
ax.set_xticklabels(x_tick_labels)
ax.set_yticks(y_tick_positions)
ax.set_yticklabels(y_tick_labels)

# Grid?
if grid:
plt.grid(alpha=0.5)
plt.tight_layout()

return (scat1,)

# scale the color norm
if vmax is None:
if log_scale:
vmax = np.max(np.log(data[col].values + adder))
else:
vmax = np.max(data[col].values)
ax.grid(alpha=0.5)

norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
return [im]

# for getting the color bar
scat1 = partial(animate, norm=norm, log_scale=log_scale)(0)
frames = data["Temporal_indexer"].nunique()

cbar = fig.colorbar(scat1[0], norm=norm, shrink=0.5)
cbar.ax.get_yaxis().labelpad = 15
if log_scale:
cbar.ax.set_ylabel(f"log({col})", rotation=270)
else:
cbar.ax.set_ylabel(f"{col}", rotation=270)

if grid:
plt.grid(alpha=0.5)
plt.tight_layout()

partial_animate = partial(animate, norm=norm, log_scale=log_scale)

frames = len(data["Temporal_indexer"].unique())

# animate!
# Create animation
ani = FuncAnimation(
fig,
partial_animate,
interval=40,
animate,
frames=frames,
interval=1000 / fps,
blit=True,
repeat=True,
frames=frames,
)

ani.save(file_path, dpi=dpi, writer=PillowWriter(fps=fps))
plt.close()
print()
print("Finish!")
if verbose >= 1:
print("\nAnimation saved successfully!")
16 changes: 8 additions & 8 deletions tests/test_random_state_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def test_random_state_reproducibility():
ensemble_df3['stixel_checklist_count'].values)

# 1 and 2 (with the different random state)
assert np.sum(ensemble_df1['calibration_point_x_jitter'].values -
ensemble_df2['calibration_point_x_jitter'].values) != 0
assert np.sum(ensemble_df1['stixel_checklist_count'].values -
ensemble_df2['stixel_checklist_count'].values) != 0
assert np.sum(ensemble_df1['calibration_point_x_jitter'].values[0] - \
ensemble_df2['calibration_point_x_jitter'].values[0]) != 0
assert np.sum(ensemble_df1['stixel_checklist_count'].values[0] - \
ensemble_df2['stixel_checklist_count'].values[0]) != 0


def test_random_state_reproducibility_completely_random_rotation_angle():
Expand All @@ -59,8 +59,8 @@ def test_random_state_reproducibility_completely_random_rotation_angle():
ensemble_df3['stixel_checklist_count'].values)

# 1 and 2 (with the different random state)
assert np.sum(ensemble_df1['calibration_point_x_jitter'].values -
ensemble_df2['calibration_point_x_jitter'].values) != 0
assert np.sum(ensemble_df1['stixel_checklist_count'].values -
ensemble_df2['stixel_checklist_count'].values) != 0
assert np.sum(ensemble_df1['calibration_point_x_jitter'].values[0] -
ensemble_df2['calibration_point_x_jitter'].values[0]) != 0
assert np.sum(ensemble_df1['stixel_checklist_count'].values[0] -
ensemble_df2['stixel_checklist_count'].values[0]) != 0

0 comments on commit 853ca39

Please sign in to comment.