Skip to content

Commit

Permalink
upgrade to PASTE 1.4.0
Browse files Browse the repository at this point in the history
fixed compatibility with POT v0.9.0
  • Loading branch information
mrland99 committed May 26, 2023
1 parent f5617b4 commit a9b10b2
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 8 deletions.
Binary file removed dist/paste-bio-1.3.0.tar.gz
Binary file not shown.
Binary file added dist/paste-bio-1.4.0.tar.gz
Binary file not shown.
Binary file removed dist/paste_bio-1.3.0-py3-none-any.whl
Binary file not shown.
Binary file added dist/paste_bio-1.4.0-py3-none-any.whl
Binary file not shown.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
anndata>=0.7.6
scanpy>=1.7.2
POT>=0.9.0
POT=0.9.0
numpy
scipy
scikit-learn>=0.24.0
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = paste-bio
version = 1.3.0
version = 1.4.0
author = Max Land
author_email = [email protected]
description = A computational method to align and integrate spatial transcriptomics experiments.
Expand Down
82 changes: 79 additions & 3 deletions src/paste/PASTE.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def center_NMF(W, H, slices, pis, lmbda, n_components, random_seed, dissimilarit
H_new = model.components_
return W_new, H_new

def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init = None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False,numItermax=200, use_gpu = False, **kwargs):
def my_fused_gromov_wasserstein(M, C1, C2, p, q, G_init = None, loss_fun='square_loss', alpha=0.5, armijo=False, log=False,numItermax=200, tol_rel=1e-9, tol_abs=1e-9, use_gpu = False, **kwargs):
"""
Adapted fused_gromov_wasserstein with the added capability of defining a G_init (inital mapping).
Also added capability of utilizing different POT backends to speed up computation.
Expand All @@ -343,9 +343,19 @@ def f(G):

def df(G):
return ot.gromov.gwggrad(constC, hC1, hC2, G)

if loss_fun == 'kl_loss':
armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs)
else:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M=0., reg=1., nx=nx, **kwargs)

if log:
res, log = ot.optim.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
res, log = ot.optim.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=numItermax, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)

fgw_dist = log['loss'][-1]

Expand All @@ -355,4 +365,70 @@ def df(G):
return res, log

else:
return ot.optim.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
return ot.optim.cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, numItermax=numItermax, stopThr=tol_rel, stopThr2=tol_abs, **kwargs)

def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
alpha_min=None, alpha_max=None, nx=None, **kwargs):
"""
Solve the linesearch in the FW iterations
Parameters
----------
G : array-like, shape(ns,nt)
The transport map at a given iteration of the FW
deltaG : array-like (ns,nt)
Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
cost_G : float
Value of the cost at `G`
C1 : array-like (ns,ns), optional
Structure matrix in the source domain.
C2 : array-like (nt,nt), optional
Structure matrix in the target domain.
M : array-like (ns,nt)
Cost matrix between the features.
reg : float
Regularization parameter.
alpha_min : float, optional
Minimum value for alpha
alpha_max : float, optional
Maximum value for alpha
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
alpha : float
The optimal step size of the FW
fc : int
nb of function call. Useless here
cost_G : float
The value of the cost for the next iteration
.. _references-solve-linesearch:
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if nx is None:
G, deltaG, C1, C2, M = ot.utils.list_to_array(G, deltaG, C1, C2, M)

if isinstance(M, int) or isinstance(M, float):
nx = ot.backend.get_backend(G, deltaG, C1, C2)
else:
nx = ot.backend.get_backend(G, deltaG, C1, C2, M)

dot = nx.dot(nx.dot(C1, deltaG), C2.T)
a = -2 * reg * nx.sum(dot * deltaG)
b = nx.sum(M * deltaG) - 2 * reg * (nx.sum(dot * G) + nx.sum(nx.dot(nx.dot(C1, G), C2.T) * deltaG))

alpha = ot.optim.solve_1d_linesearch_quad(a, b)
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)

# the new cost is deduced from the line search quadratic function
cost_G = cost_G + a * (alpha ** 2) + b * alpha

return alpha, 1, cost_G
4 changes: 1 addition & 3 deletions src/paste_bio.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: paste-bio
Version: 1.3.0
Version: 1.4.0
Summary: A computational method to align and integrate spatial transcriptomics experiments.
Home-page: https://github.com/raphael-group/paste
Author: Max Land
Expand Down Expand Up @@ -31,8 +31,6 @@ You can read full paper [here](https://www.nature.com/articles/s41592-022-01459-

Additional examples and the code to reproduce the paper's analyses can be found [here](https://github.com/raphael-group/paste_reproducibility). Preprocessed datasets used in the paper can be found on [zenodo](https://doi.org/10.5281/zenodo.6334774).

PASTE is actively being worked on with future updates coming.

### Recent News

* PASTE is now published in [Nature Methods](https://www.nature.com/articles/s41592-022-01459-6)!
Expand Down

0 comments on commit a9b10b2

Please sign in to comment.