Skip to content

Commit

Permalink
Updated all submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Aug 22, 2024
1 parent 508dc2d commit 25bf43f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 3 deletions.
64 changes: 64 additions & 0 deletions spateo/alignment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ def solve_RT_by_correspondence(
return R, t


##############
# Simulation #
##############


def split_slice(
adata,
spatial_key,
Expand All @@ -350,3 +355,62 @@ def split_slice(
sorted_adata.obs["slice"] = slice_id
split_adata.append(sorted_adata)
return split_adata[:split_num]


def tps_deformation(
adata,
spatial_key,
key_added,
grid_num=2,
tps_noise_scale=25,
add_corner_points=True,
alpha=0.1,
inplace=True,
):
from tps import ThinPlateSpline

spatial = adata.obsm[spatial_key]
# get min max
x_min, x_max = np.min(spatial[:, 0]), np.max(spatial[:, 0])
y_min, y_max = np.min(spatial[:, 1]), np.max(spatial[:, 1])

# define the length of grid
grid_size_x = (x_max - x_min) / grid_num
grid_size_y = (y_max - y_min) / grid_num
# generate the grid
x_grid = np.linspace(x_min, x_max, grid_num + 1)[:-1] + grid_size_x / 2
y_grid = np.linspace(y_min, y_max, grid_num + 1)[:-1] + grid_size_y / 2
xx, yy = np.meshgrid(x_grid, y_grid)
# generate control points
src_points = []
dst_points = []
for i in range(xx.shape[0]):
for j in range(xx.shape[1]):
x_center, y_center = xx[i, j], yy[i, j]
x = x_center
y = y_center
src_points.append(np.column_stack([x, y]))
dst_points.append(
src_points[-1] + np.random.normal(scale=(grid_size_x + grid_size_y) * tps_noise_scale / 2, size=(1, 2))
)
src_points = np.concatenate(src_points, axis=0)
dst_points = np.concatenate(dst_points, axis=0)
if add_corner_points:
src_points = np.concatenate(
[np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]), src_points], 0
)
dst_points = np.concatenate(
[np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]), dst_points], 0
)
# calculate the TPS deformation
tps = ThinPlateSpline(alpha=alpha) # Regularization
tps.fit(src_points, dst_points)
# perform tps deformation
tps_spatial = tps.transform(spatial)
if inplace:
adata.obsm[key_added] = tps_spatial
return lambda x: tps.transform(x)
else:
adata_tps = adata.copy()
adata_tps.obsm[key_added] = tps_spatial
return adata_tps, lambda x: tps.transform(x)

0 comments on commit 25bf43f

Please sign in to comment.