Skip to content

Commit

Permalink
Merge branch 'yifan'
Browse files Browse the repository at this point in the history
  • Loading branch information
YifanLu2000 committed Sep 30, 2024
2 parents 2485c62 + b7becb3 commit f43f7fb
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 85 deletions.
24 changes: 23 additions & 1 deletion docs/technicals/spatial_transcriptomics_alignment.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
# Spatial transcriptomics alignment

This section describes the technical details behind Spateo's spatial transcriptomics alignment pipeline.
This section describes the technical details behind Spateo's spatial transcriptomics alignment pipeline.

## Background

The sequential slicing and subsequent spatial transcriptomic profiling at the whole embryo level offer us an unprecedented opportunity to reconstruct the molecular hologram of the entire 3D embryo structure. However, conventional sectioning and downstream library preparation can rotate, transform, deform, and introduce missing regions in each profiled tissue section. In addition, with advancements in technology, spatial transcriptomics techniques with single-cell and even subcellular resolution are gradually emerging, and a single slice often contains hundreds of thousands of cells. Therefore, it is in general necessary to develop scalable and robust algorithms to reconstruct 3D structures to recover the relative spatial locations of single cells across different slices while allowing local distortion within the same slice.


## Methodology
Consider a series of spatially-resolved transcriptomics samples, such as consecutive tissue sections from the same embryo, denote as $\mathcal{D} = \{\mathcal{S}^i\}_{i=1}^k$, where $\mathcal{S}^i=$ is the $i$-th section


### Problem formulation

### Generative process

### Transformation model

### Define prior distributions

### Variational Bayesian Inference

## Function Design


1 change: 1 addition & 0 deletions spateo/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
BA_transform_and_assignment,
Mesh_correction,
align_preprocess,
calc_distance,
calc_exp_dissimilarity,
generate_label_transfer_dict,
generate_label_transfer_prior,
Expand Down
1 change: 1 addition & 0 deletions spateo/alignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .methods import (
Mesh_correction,
align_preprocess,
calc_distance,
calc_exp_dissimilarity,
generate_label_transfer_dict,
)
Expand Down
1 change: 1 addition & 0 deletions spateo/alignment/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_power,
_prod,
_unsqueeze,
calc_distance,
check_backend,
check_exp,
con_K,
Expand Down
29 changes: 24 additions & 5 deletions spateo/alignment/methods/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def diag(self, a, k=0):
"""
raise NotImplementedError()

def unique(self, a, return_inverse=False):
def unique(self, a, return_index=False, return_inverse=False):
r"""
Finds unique elements of given tensor.
Expand Down Expand Up @@ -1058,8 +1058,8 @@ def data(self, a, type_as=None):
else:
return np.asarray(a, dtype=type_as.dtype)

def unique(self, a, return_inverse=False, axis=None):
return np.unique(a, return_inverse=return_inverse, axis=axis)
def unique(self, a, return_index, return_inverse=False, axis=None):
return np.unique(a, return_index=return_index, return_inverse=return_inverse, axis=axis)

def unsqueeze(self, a, axis=-1):
return np.expand_dims(a, axis)
Expand Down Expand Up @@ -1325,8 +1325,27 @@ def data(self, a, type_as=None):
else:
return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)

def unique(self, a, return_inverse=False, axis=None):
return torch.unique(a, return_inverse=return_inverse, dim=axis)
def unique(self, a, return_index=False, return_inverse=False, axis=None):
unique_values, inverse_indices = torch.unique(a, sorted=False, return_inverse=True, dim=axis)

result = [unique_values]

if return_index:
x_sort, idx_sorted = torch.sort(inverse_indices)
return_index = idx_sorted[
torch.hstack(
[
torch.where((x_sort[1:] - x_sort[:-1]) != 0)[0],
torch.tensor([len(inverse_indices) - 1], device=x_sort.device),
]
)
]
result.append(return_index)

if return_inverse:
result.append(inverse_indices)

return tuple(result) if len(result) > 1 else result[0]

def unsqueeze(self, a, axis=-1):
return torch.unsqueeze(a, axis)
Expand Down
31 changes: 18 additions & 13 deletions spateo/alignment/methods/morpho_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def run(
Returns:
np.ndarray: The final cell-cell assignment matrix.
"""
# print(f'summin: {self.nx.sum(self.exp_layers_A[0], axis=1, keepdims=True).min()}')
if self.nn_init:
self._coarse_rigid_alignment()

Expand Down Expand Up @@ -560,7 +559,6 @@ def _guidance_pair_preprocess(
self.X_AI = self.nx.from_numpy(self.guidance_pair[1], type_as=self.type_as)
self.V_AI = self.nx.zeros(self.X_AI.shape, type_as=self.type_as)
self.R_AI = self.nx.zeros(self.X_AI.shape, type_as=self.type_as)
# print(self.V_AI)

if self.normalize_c:
# Normalize the guidance pairs
Expand Down Expand Up @@ -789,7 +787,6 @@ def _init_probability_parameters(
Y=exp_B[sub_sample_B],
metric=d_s,
)
# print(exp_A[sub_sample_A])
min_exp_dist = self.nx.min(exp_dist, 1)
self.probability_parameters[i] = self.nx.maximum(
min_exp_dist[self.nx.argsort(min_exp_dist)[int(sub_sample_A.shape[0] * 0.05)]] / 5,
Expand Down Expand Up @@ -821,13 +818,16 @@ def _construct_kernel(
NotImplementedError: If the specified kernel type is not implemented.
"""

unique_spatial_coords = _unique(self.nx, self.coordsA, 0)
# unique_spatial_coords = _unique(self.nx, self.coordsA, 0)
unique_spatial_coords, unique_idx = self.nx.unique(self.coordsA, return_index=True, axis=0)
inducing_variables_idx = (
np.random.choice(unique_spatial_coords.shape[0], inducing_variables_num, replace=False)
if unique_spatial_coords.shape[0] > inducing_variables_num
else np.arange(unique_spatial_coords.shape[0])
)
self.inducing_variables = unique_spatial_coords[inducing_variables_idx, :]
inducing_variables_idx = unique_idx[inducing_variables_idx]
self.inducing_variables = self.coordsA[inducing_variables_idx, :]
# self.inducing_variables = unique_spatial_coords[inducing_variables_idx, :]
# (self.inducing_variables, _) = sample(
# X=unique_spatial_coords, n_sampling=inducing_variables_num, sampling_method=sampling_method
# )
Expand All @@ -839,6 +839,14 @@ def _construct_kernel(
if self.guidance_effect in ["nonrigid", "both"]
else None
)
elif self.kernel_type == "geodist":
pass
# TODO: finish this
# if self.graph is None:
# self.graph = _construct_graph(self.coordsA, self.knn)
# self.GammaSparse = con_K_graph(self.graph, inducing_variables_idx, inducing_variables_idx, beta=self.kernel_bandwidth)
# self.U = con_K_graph(self.graph, np.arange(self.NA), inducing_variables_idx, beta=self.kernel_bandwidth)

else:
raise NotImplementedError(f"Kernel type '{self.kernel_type}' is not implemented.")

Expand Down Expand Up @@ -1097,7 +1105,6 @@ def _update_assignment_P(
exp_layer_dist = calc_distance(
self.exp_layers_A, exp_layer_B_chunk, self.dissimilarity, self.label_transfer
)

P, K_NA_spatial_chunk, K_NA_sigma2_chunk, sigma2_related_chunk = get_P_core(
spatial_dist=spatial_dist, exp_dist=exp_layer_dist, **common_kwargs
)
Expand All @@ -1120,7 +1127,6 @@ def _update_assignment_P(
Y=self.coordsB[self.batch_idx, :] if self.SVI_mode else self.coordsB,
metric="euc",
) # NA x batch_size (SVI_mode) / NA x NB (not SVI_mode)
# print(self.pre_compute_dist)
if self.pre_compute_dist:
exp_layer_dist = (
[exp_layer_d[:, self.batch_idx] for exp_layer_d in self.exp_layer_dist]
Expand Down Expand Up @@ -1163,7 +1169,6 @@ def _update_assignment_P(
self.K_NA_sigma2 = self.K_NA_sigma2.to_dense()

self.sigma2_related = sigma2_related / (self.Dim * self.Sp_sigma2)
# print(self.sigma2_related)

def _update_gamma(
self,
Expand Down Expand Up @@ -1282,7 +1287,6 @@ def _update_rigid(
if self.SVI_mode
else _dot(self.nx)(self.K_NB, self.coordsB)[None, :],
)
# print(self.Sp)
# solve rotation using SVD formula
mu_XB, mu_XA, mu_Vn = PXB, PXA, PVA
mu_X_deno, mu_Vn_deno = _copy(self.nx, self.Sp), _copy(self.nx, self.Sp)
Expand Down Expand Up @@ -1462,11 +1466,11 @@ def _wrap_output(

if not (self.vecfld_key_added is None):
norm_dict = {
"mean_transformed": self.nx.to_numpy(self.normalize_means[1]),
"mean_fixed": self.nx.to_numpy(self.normalize_means[0]),
"mean_transformed": self.nx.to_numpy(self.normalize_means[0]),
"mean_fixed": self.nx.to_numpy(self.normalize_means[1]),
"scale": self.nx.to_numpy(self.normalize_scales[1]),
"scale_transformed": self.nx.to_numpy(self.normalize_scales[1]),
"scale_fixed": self.nx.to_numpy(self.normalize_scales[0]),
"scale_transformed": self.nx.to_numpy(self.normalize_scales[0]),
"scale_fixed": self.nx.to_numpy(self.normalize_scales[1]),
}

self.vecfld = {
Expand All @@ -1488,4 +1492,5 @@ def _wrap_output(
"NA": self.NA,
"sigma2_variance": self.nx.to_numpy(self.sigma2_variance),
"method": "Spateo",
"norm_dict": norm_dict,
}
2 changes: 1 addition & 1 deletion spateo/alignment/methods/morpho_mesh_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
try:
from .libfastpd import fastpd
except ImportError:
print("fastpd is not installed. Please compile the fastpd library.")
print("fastpd is not installed. If you need mesh correction, please compile the fastpd library.")


# TODO: add str as the input type for the models
Expand Down
Loading

0 comments on commit f43f7fb

Please sign in to comment.