Skip to content

Commit

Permalink
Fix the bugs with cpu + sparse calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Oct 25, 2024
1 parent e1df648 commit 45aa2c1
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 61 deletions.
2 changes: 2 additions & 0 deletions spateo/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
generate_label_transfer_prior,
get_optimal_mapping_relationship,
grid_deformation,
group_pca,
morpho_align,
morpho_align_ref,
paste_align,
paste_align_ref,
paste_transform,
solve_RT_by_correspondence,
split_slice,
tps_deformation,
)
2 changes: 2 additions & 0 deletions spateo/alignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
generate_label_transfer_prior,
get_labels_based_on_coords,
get_optimal_mapping_relationship,
group_pca,
mapping_aligned_coords,
mapping_center_coords,
solve_RT_by_correspondence,
split_slice,
tps_deformation,
)
23 changes: 20 additions & 3 deletions spateo/alignment/methods/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import scipy
import scipy.linalg
import scipy.sparse as sp
import scipy.special as special
from scipy.sparse import coo_matrix, csr_matrix, issparse

Expand Down Expand Up @@ -1044,10 +1045,18 @@ def log(self, a):
return np.log(a)

def concatenate(self, arrays, axis=0):
return np.concatenate(arrays, axis)
if all(issparse(arr) for arr in arrays):
return sp.vstack(arrays) if axis == 0 else sp.hstack(arrays)
elif all(isinstance(arr, np.ndarray) for arr in arrays):
return np.concatenate(arrays, axis)
else:
raise ValueError("All arrays should be of the same type")

def sum(self, a, axis=None, keepdims=False):
return np.sum(a, axis, keepdims=keepdims)
if issparse(a):
return a.sum(axis=axis)
else:
return np.sum(a, axis, keepdims=keepdims)

def arange(self, stop, start=0, step=1, type_as=None):
return np.arange(start, stop, step)
Expand All @@ -1071,7 +1080,15 @@ def power(self, a, exponents):
return np.power(a, exponents)

def dot(self, a, b):
return np.dot(a, b)
if sp.issparse(a):
if sp.issparse(b):
return a.dot(b)
else:
return a.dot(b)
elif sp.issparse(b):
return b.T.dot(a.T).T
else:
return np.dot(a, b)

def prod(self, a, axis=0):
return np.prod(a, axis=axis)
Expand Down
115 changes: 65 additions & 50 deletions spateo/alignment/methods/morpho_class.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import random

import networkx
import numpy as np
import ot
import scipy.sparse as sp
import torch
from anndata import AnnData

Expand Down Expand Up @@ -36,6 +38,8 @@
check_rep_layer,
check_spatial_coords,
con_K,
con_K_graph,
construct_knn_graph,
filter_common_genes,
get_P_core,
get_rep,
Expand Down Expand Up @@ -94,7 +98,7 @@ class Morpho_pairwise:
K (Union[int, float]): Number of sparse inducing points used for Nyström approximation for the kernel. Default is 15.
kernel_type (str): Type of kernel used. Default is "euc".
sigma2_init_scale (Optional[Union[int, float]]): Initial value for the spatial dispersion level. Default is 0.1.
partial_robust_level (float): Robust level of partial alignment. Default is 25.
partial_robust_level (float): Robust level of partial alignment. Default is 10.
normalize_c (bool): Whether to normalize spatial coordinates. Default is True.
normalize_g (bool): Whether to normalize gene expression. Default is True.
Expand Down Expand Up @@ -138,6 +142,8 @@ def __init__(
beta: Union[int, float] = 0.01,
K: Union[int, float] = 15,
kernel_type: str = "euc",
graph: Optional[networkx.Graph] = None,
graph_knn: int = 10,
sigma2_init_scale: Optional[Union[int, float]] = 0.1,
sigma2_end: Optional[Union[int, float]] = None,
gamma_a: float = 1.0,
Expand Down Expand Up @@ -193,6 +199,8 @@ def __init__(
self.K = K
self.kernel_type = kernel_type
self.kernel_bandwidth = beta
self.graph = graph
self.graph_knn = graph_knn
self.sigma2_init_scale = sigma2_init_scale
self.sigma2_end = sigma2_end
self.partial_robust_level = partial_robust_level
Expand Down Expand Up @@ -840,13 +848,12 @@ def _construct_kernel(
else None
)
elif self.kernel_type == "geodist":
pass
# TODO: finish this
# if self.graph is None:
# self.graph = _construct_graph(self.coordsA, self.knn)
# self.GammaSparse = con_K_graph(self.graph, inducing_variables_idx, inducing_variables_idx, beta=self.kernel_bandwidth)
# self.U = con_K_graph(self.graph, np.arange(self.NA), inducing_variables_idx, beta=self.kernel_bandwidth)

if self.graph is None:
self.graph = construct_knn_graph(self.coordsA, self.graph_knn)
self.U = con_K_graph(self.graph, inducing_variables_idx, beta=self.kernel_bandwidth)
self.GammaSparse = self.U[inducing_variables_idx, :]
self.U_I = None # currently not support as the gudiance points is not in the graph
else:
raise NotImplementedError(f"Kernel type '{self.kernel_type}' is not implemented.")

Expand Down Expand Up @@ -1163,10 +1170,17 @@ def _update_assignment_P(
self.Sp_sigma2 = Sp_sigma2

if self.sparse_calculation_mode:
self.K_NA = self.K_NA.to_dense()
self.K_NB = self.K_NB.to_dense()
self.K_NA_spatial = self.K_NA_spatial.to_dense()
self.K_NA_sigma2 = self.K_NA_sigma2.to_dense()
if nx_torch(self.nx):
self.K_NA = self.K_NA.to_dense()
self.K_NB = self.K_NB.to_dense()
self.K_NA_spatial = self.K_NA_spatial.to_dense()
self.K_NA_sigma2 = self.K_NA_sigma2.to_dense()

else:
self.K_NA = self.K_NA.A.squeeze(-1)
self.K_NB = self.K_NB.A.squeeze(0)
self.K_NA_spatial = self.K_NA_spatial
self.K_NA_sigma2 = self.K_NA_sigma2

self.sigma2_related = sigma2_related / (self.Dim * self.Sp_sigma2)

Expand Down Expand Up @@ -1234,38 +1248,38 @@ def _update_nonrigid(
"""

SigmaInv = self.sigma2 * self.lambdaVF * self.GammaSparse + _dot(self.nx)(
SigmaInv = self.sigma2 * self.lambdaVF * self.GammaSparse + self.nx.dot(
self.U.T, self.nx.einsum("ij,i->ij", self.U, self.K_NA)
)
if self.SVI_mode:
PXB_term = _dot(self.nx)(self.P, self.coordsB[self.batch_idx, :]) - self.nx.einsum(
PXB_term = self.nx.dot(self.P, self.coordsB[self.batch_idx, :]) - self.nx.einsum(
"ij,i->ij", self.RnA, self.K_NA
)
self.SigmaInv = self.step_size * SigmaInv + (1 - self.step_size) * self.SigmaInv
self.PXB_term = self.step_size * PXB_term + (1 - self.step_size) * self.PXB_term
else:
self.PXB_term = _dot(self.nx)(self.P, self.coordsB) - self.nx.einsum("ij,i->ij", self.RnA, self.K_NA)
self.PXB_term = self.nx.dot(self.P, self.coordsB) - self.nx.einsum("ij,i->ij", self.RnA, self.K_NA)
self.SigmaInv = SigmaInv

UPXB_term = _dot(self.nx)(self.U.T, self.PXB_term)
UPXB_term = self.nx.dot(self.U.T, self.PXB_term)

# TODO: can we store these kernel multiple results? They are fixed
if self.guidance and ((self.guidance_effect == "nonrigid") or (self.guidance_effect == "both")):
self.SigmaInv += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * _dot(self.nx)(
self.SigmaInv += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * self.nx.dot(
self.U_I.T, self.U_I
)
UPXB_term += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * _dot(self.nx)(
UPXB_term += (self.sigma2 * self.guidance_weight * self.Sp / self.U_I.shape[0]) * self.nx.dot(
self.U_I.T, self.X_BI - self.R_AI
)

Sigma = _pinv(self.nx)(self.SigmaInv)
self.Coff = _dot(self.nx)(Sigma, UPXB_term)
self.Coff = self.nx.dot(Sigma, UPXB_term)

self.VnA = _dot(self.nx)(self.U, self.Coff)
self.VnA = self.nx.dot(self.U, self.Coff)
if self.guidance and ((self.guidance_effect == "nonrigid") or (self.guidance_effect == "both")):
self.V_AI = _dot(self.nx)(self.U_I, self.Coff)
self.V_AI = self.nx.dot(self.U_I, self.Coff)
self.SigmaDiag = self.sigma2 * self.nx.einsum(
"ij->i", self.nx.einsum("ij,ji->ij", self.U, _dot(self.nx)(Sigma, self.U.T))
"ij->i", self.nx.einsum("ij,ji->ij", self.U, self.nx.dot(Sigma, self.U.T))
)

def _update_rigid(
Expand All @@ -1281,11 +1295,11 @@ def _update_rigid(
"""

PXA, PVA, PXB = (
_dot(self.nx)(self.K_NA, self.coordsA)[None, :],
_dot(self.nx)(self.K_NA, self.VnA)[None, :],
_dot(self.nx)(self.K_NB, self.coordsB[self.batch_idx, :])[None, :]
self.nx.dot(self.K_NA, self.coordsA)[None, :],
self.nx.dot(self.K_NA, self.VnA)[None, :],
self.nx.dot(self.K_NB, self.coordsB[self.batch_idx, :])[None, :]
if self.SVI_mode
else _dot(self.nx)(self.K_NB, self.coordsB)[None, :],
else self.nx.dot(self.K_NB, self.coordsB)[None, :],
)
# solve rotation using SVD formula
mu_XB, mu_XA, mu_Vn = PXB, PXA, PVA
Expand All @@ -1297,10 +1311,10 @@ def _update_rigid(
mu_X_deno += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.X_BI.shape[0]
mu_Vn_deno += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.X_BI.shape[0]
if self.nn_init:
mu_XB += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
mu_XB += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
self.inlier_P.T, self.inlier_B
)
mu_XA += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
mu_XA += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
self.inlier_P.T, self.inlier_A
)
mu_X_deno += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.sum(
Expand All @@ -1323,44 +1337,43 @@ def _update_rigid(
if self.nn_init:
inlier_A_hat = self.inlier_A - mu_XA
inlier_B_hat = self.inlier_B - mu_XB

A = -(
_dot(self.nx)(XA_hat.T, self.nx.einsum("ij,i->ij", VnA_hat, self.K_NA))
- _dot(self.nx)(_dot(self.nx)(XA_hat.T, self.P), XB_hat)
self.nx.dot(XA_hat.T, self.nx.einsum("ij,i->ij", VnA_hat, self.K_NA))
- self.nx.dot(self.nx.dot(XA_hat.T, self.P), XB_hat)
).T

if self.guidance_effect in ("rigid", "both"):
A -= (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * _dot(self.nx)(
A -= (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.nx.dot(
X_AI_hat.T, V_AI_hat - X_BI_hat
).T

if self.nn_init:
A -= (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
A -= (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
(inlier_A_hat * self.inlier_P).T, -inlier_B_hat
).T

svdU, svdS, svdV = _linalg(self.nx).svd(A)
self.C[-1, -1] = _linalg(self.nx).det(_dot(self.nx)(svdU, svdV))
self.C[-1, -1] = _linalg(self.nx).det(self.nx.dot(svdU, svdV))

R = _dot(self.nx)(_dot(self.nx)(svdU, self.C), svdV)
R = self.nx.dot(self.nx.dot(svdU, self.C), svdV)
if self.SVI_mode and self.step_size < 1:
self.R = self.step_size * R + (1 - self.step_size) * self.R
else:
self.R = R

# solve translation using SVD formula
t_numerator = PXB - PVA - _dot(self.nx)(PXA, self.R.T)
t_numerator = PXB - PVA - self.nx.dot(PXA, self.R.T)
t_deno = _copy(self.nx, self.Sp)

if self.guidance and (self.guidance_effect in ("rigid", "both")):
t_numerator += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.nx.sum(
self.X_BI - self.V_AI - _dot(self.nx)(self.X_AI, self.R.T), axis=0
self.X_BI - self.V_AI - self.nx.dot(self.X_AI, self.R.T), axis=0
)
t_deno += (self.sigma2 * self.guidance_weight * self.Sp / self.X_BI.shape[0]) * self.X_BI.shape[0]

if self.nn_init:
t_numerator += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * _dot(self.nx)(
self.inlier_P.T, self.inlier_B - _dot(self.nx)(self.inlier_A, self.R.T)
t_numerator += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.dot(
self.inlier_P.T, self.inlier_B - self.nx.dot(self.inlier_A, self.R.T)
)
t_deno += (self.sigma2 * self.nn_init_weight * self.Sp / self.nx.sum(self.inlier_P)) * self.nx.sum(
self.inlier_P
Expand All @@ -1372,11 +1385,11 @@ def _update_rigid(
else:
self.t = t

self.RnA = _dot(self.nx)(self.coordsA, self.R.T) + self.t
self.RnA = self.nx.dot(self.coordsA, self.R.T) + self.t
if self.nn_init:
self.inlier_R = _dot(self.nx)(self.inlier_A, self.R.T) + self.t
self.inlier_R = self.nx.dot(self.inlier_A, self.R.T) + self.t
if self.guidance:
self.R_AI = _dot(self.nx)(self.R_AI, self.R.T) + self.t
self.R_AI = self.nx.dot(self.R_AI, self.R.T) + self.t

def _update_sigma2(
self,
Expand Down Expand Up @@ -1420,23 +1433,24 @@ def _get_optimal_R(
"""

mu_XnA, mu_XnB = (
_dot(self.nx)(self.K_NA, self.coordsA) / self.Sp,
_dot(self.nx)(self.K_NB, self.coordsB[self.batch_idx, :]) / self.Sp
self.nx.dot(self.K_NA, self.coordsA) / self.Sp,
self.nx.dot(self.K_NB, self.coordsB[self.batch_idx, :]) / self.Sp
if self.SVI_mode
else _dot(self.nx)(self.K_NB, self.coordsB) / self.Sp,
else self.nx.dot(self.K_NB, self.coordsB) / self.Sp,
)
XnABar, XnBBar = (
self.coordsA - mu_XnA,
self.coordsB[self.batch_idx, :] - mu_XnB if self.SVI_mode else self.coordsB - mu_XnB,
)
A = _dot(self.nx)(_dot(self.nx)(self.P, XnBBar).T, XnABar)
A = self.nx.dot(self.nx.dot(self.P, XnBBar).T, XnABar)

# get the optimal rotation matrix R
svdU, svdS, svdV = _linalg(self.nx).svd(A)
self.C[-1, -1] = _linalg(self.nx).det(_dot(self.nx)(svdU, svdV))
self.optimal_R = _dot(self.nx)(_dot(self.nx)(svdU, self.C), svdV)
self.optimal_t = mu_XnB - _dot(self.nx)(mu_XnA, self.optimal_R.T)
self.optimal_RnA = _dot(self.nx)(self.coordsA, self.optimal_R.T) + self.optimal_t

self.C[-1, -1] = _linalg(self.nx).det(self.nx.dot(svdU, svdV))
self.optimal_R = self.nx.dot(self.nx.dot(svdU, self.C), svdV)
self.optimal_t = mu_XnB - self.nx.dot(mu_XnA, self.optimal_R.T)
self.optimal_RnA = self.nx.dot(self.coordsA, self.optimal_R.T) + self.optimal_t

def _wrap_output(
self,
Expand Down Expand Up @@ -1468,7 +1482,7 @@ def _wrap_output(
norm_dict = {
"mean_transformed": self.nx.to_numpy(self.normalize_means[0]),
"mean_fixed": self.nx.to_numpy(self.normalize_means[1]),
"scale": self.nx.to_numpy(self.normalize_scales[1]),
"scale": self.nx.to_numpy(self.normalize_scales[0]),
"scale_transformed": self.nx.to_numpy(self.normalize_scales[0]),
"scale_fixed": self.nx.to_numpy(self.normalize_scales[1]),
}
Expand All @@ -1493,4 +1507,5 @@ def _wrap_output(
"sigma2_variance": self.nx.to_numpy(self.sigma2_variance),
"method": "Spateo",
"norm_dict": norm_dict,
"kernel_type": self.kernel_type,
}
Loading

0 comments on commit 45aa2c1

Please sign in to comment.