Skip to content

Commit

Permalink
Merge pull request #250 from Starlitnightly/main
Browse files Browse the repository at this point in the history
Optimized `SCC` Implementation and Removed `TensorFlow` Dependencies
  • Loading branch information
Xiaojieqiu authored Oct 9, 2024
2 parents 1b4f67d + 882d8b4 commit 3c01986
Show file tree
Hide file tree
Showing 25 changed files with 4,544 additions and 60 deletions.
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
adjustText
anndata>=0.8.0
colorcet>=2.0.1
cvxopt>=1.2.3
# cvxopt>=1.2.3
csbdeep>=0.6.3
descartes
dynamo-release>=1.4.1
Expand All @@ -21,9 +21,9 @@ networkx>=2.6.3
numba>=0.46.0
numpy>=1.18.1
opencv-python>=4.5.4.60
pandana
# pandana
pandas>=0.25.1
paste-bio>=1.4.0
# paste-bio>=1.4.0
plotly>=5.1.0
POT>=0.8.1
psutil>=5.6.3
Expand All @@ -39,7 +39,7 @@ seaborn>=0.9.0
setuptools>=58.0.4
Shapely>=1.8.0
statsmodels>=0.9.0
tensorflow
# tensorflow
tqdm>=4.62.3
torch
trame>=2.2.5
Expand Down
117 changes: 117 additions & 0 deletions spateo/external/CAST/CAST_Mark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from collections import OrderedDict
from timeit import default_timer as timer

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
from tqdm import trange

from .model.aug import random_aug
from .utils import coords2adjacentmat


def train_seq(graphs, args, dump_epoch_list, out_prefix, model):
"""The CAST MARK training function
Args:
graphs (List[Tuple(str, dgl.Graph, torch.Tensor)]): List of 3-member tuples, each tuple represents one tissue sample, containing sample name, a DGL graph object, and a feature matrix in the torch.Tensor format
args (model_GCNII.Args): the Args object contains training parameters
dump_epoch_list (List): A list of epoch id you hope training snapshots to be dumped, for debug use, empty by default
out_prefix (str): file name prefix for the snapshot files
model (model_GCNII.CCA_SSG): the GNN model
Returns:
Tuple(Dict, List, CCA_SSG): returns a 3-member tuple, a dictionary containing the graph embeddings for each sample, a list of every loss value, and the trained model object
"""
model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)

loss_log = []
time_now = timer()

t = trange(args.epochs, desc="", leave=True)
for epoch in t:

with torch.no_grad():
if epoch in dump_epoch_list:
model.eval()
dump_embedding = OrderedDict()
for name, graph, feat in graphs:
# graph = graph.to(args.device)
# feat = feat.to(args.device)
dump_embedding[name] = model.get_embedding(graph, feat)
torch.save(dump_embedding, f"{out_prefix}_embed_dict_epoch{epoch}.pt")
torch.save(loss_log, f"{out_prefix}_loss_log_epoch{epoch}.pt")
print(f"Successfully dumped epoch {epoch}")

losses = dict()
model.train()
optimizer.zero_grad()
# print(f'Epoch: {epoch}')

for name_, graph_, feat_ in graphs:
with torch.no_grad():
N = graph_.number_of_nodes()
graph1, feat1 = random_aug(graph_, feat_, args.dfr, args.der)
graph2, feat2 = random_aug(graph_, feat_, args.dfr, args.der)

graph1 = graph1.add_self_loop()
graph2 = graph2.add_self_loop()

z1, z2 = model(graph1, feat1, graph2, feat2)

c = torch.mm(z1.T, z2)
c1 = torch.mm(z1.T, z1)
c2 = torch.mm(z2.T, z2)

c = c / N
c1 = c1 / N
c2 = c2 / N

loss_inv = -torch.diagonal(c).sum()
iden = torch.eye(c.size(0), device=args.device)
loss_dec1 = (iden - c1).pow(2).sum()
loss_dec2 = (iden - c2).pow(2).sum()
loss = loss_inv + args.lambd * (loss_dec1 + loss_dec2)
loss.backward()
optimizer.step()

# del graph1, feat1, graph2, feat2
loss_log.append(loss.item())
time_step = timer() - time_now
time_now += time_step
# print(f'Loss: {loss.item()} step time={time_step:.3f}s')
t.set_description(f"Loss: {loss.item():.3f} step time={time_step:.3f}s")
t.refresh()

model.eval()
with torch.no_grad():
dump_embedding = OrderedDict()
for name, graph, feat in graphs:
dump_embedding[name] = model.get_embedding(graph, feat)
return dump_embedding, loss_log, model


# graph construction tools
def delaunay_dgl(sample_name, df, output_path, if_plot=True, strategy_t="convex"):
coords = np.column_stack((np.array(df)[:, 0], np.array(df)[:, 1]))
delaunay_graph = coords2adjacentmat(coords, output_mode="raw", strategy_t=strategy_t)
if if_plot:
positions = dict(zip(delaunay_graph.nodes, coords[delaunay_graph.nodes, :]))
_, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
nx.draw(
delaunay_graph,
positions,
ax=ax,
node_size=1,
node_color="#000000",
edge_color="#5A98AF",
alpha=0.6,
)
plt.axis("equal")
plt.savefig(f"{output_path}/delaunay_{sample_name}.png")
import dgl

return dgl.from_networkx(delaunay_graph)
Loading

0 comments on commit 3c01986

Please sign in to comment.