Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster centroids #21

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions tspyro/cluster.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import collections
import math
from typing import Tuple

import networkx as nx
import numpy as np
import torch
import tqdm
from community import community_louvain


def check_sparse_genotypes(data: dict):
Expand Down Expand Up @@ -119,7 +123,7 @@ def naive_encoder(ts):
offsets = []
index = []
values = []
for haplo in tqdm(ts.haplotypes(), total=ts.num_samples):
for haplo in tqdm.tqdm(ts.haplotypes(), total=ts.num_samples):
begin = len(index)
for i, g in enumerate(haplo):
assert g in "-01"
Expand Down Expand Up @@ -148,7 +152,7 @@ def variant_encoder(ts):
offsets = []
index = []
values = []
for var in tqdm(ts.variants(), total=ts.num_sites):
for var in tqdm.tqdm(ts.variants(), total=ts.num_sites):
geno = var.genotypes
begin = len(index)

Expand Down Expand Up @@ -251,6 +255,78 @@ def make_clustering_gibbs(
return assignment, clusters


# Functions borrowed from https://github.com/tskit-dev/what-is-an-arg-paper


def convert_nx(ts, mutation_rate=1e-8):
"""
Returns the specified tree sequence as an networkx directed graph.
Weights are the edge's total span * e^(-edge length * mutation_rate).

Mutation rate is in units: probability of mutation per base pair per generation.
"""
G = nx.Graph()
edges = collections.defaultdict(list)
times = ts.tables.nodes.time
for edge in ts.edges():
edges[(edge.child, edge.parent)].append(
(edge.left, edge.right, times[edge.parent], times[edge.child])
)
for node in ts.nodes():
G.add_node(node.id, time=node.time, flags=node.flags)
for edge, intervals in edges.items():
G.add_edge(
*edge,
weight=sum(
(right - left) * math.e ** (-(top - bottom) * mutation_rate)
for (left, right, top, bottom) in intervals
),
)

return G


def make_clustering_louvain(ts, resolution, max_cluster_size):
"""
Clusters a tree sequence using the Louvain community detection algorithm.

:param tskit.TreeSequence ts: The input tree sequence
:param float resolution: Passed to community_louvain: "Will change the size of
the communities, default to 1. represents the time described in
“Laplacian Dynamics and Multiscale Modular Structure in Networks”, R.
Lambiotte, J.-C. Delvenne, M. Barahona"
:param int max_cluster_size: Maximum cluster size to calculate centroids from,
usually 10000 nodes or less to avoid memory issues.
:returns: A ``(num_variants,num_clusters)``-shaped tensor of clusters centroids.
:rtype: torch.Tensor
"""
graph = convert_nx(ts)

partition = community_louvain.best_partition(
graph, weight="weight", resolution=0.6
) # resolution > 1: fewer clusters. < 1: more clusters
assignment = torch.tensor(list(partition.values()))
clusters = torch.unique(assignment)
cluster_centroids = np.zeros((len(clusters), ts.num_sites))
for cluster in tqdm.tqdm(clusters[:]):
cluster_assignment = torch.tensor(assignment)[assignment == cluster]
if sum(assignment == cluster) > max_cluster_size:
cluster_assignment = torch.randperm(len(cluster_assignment))[
:max_cluster_size
]
tables = ts.dump_tables()
arr = np.zeros_like(tables.nodes.flags)
arr[cluster_assignment] = np.ones_like(cluster_assignment)
tables.nodes.flags = arr
all_samples = tables.tree_sequence()
all_samples = all_samples.simplify(filter_sites=False)
geno = all_samples.genotype_matrix()
non_missing = geno == -1
mx = np.ma.masked_array(geno, mask=non_missing)
cluster_centroids[cluster] = mx.sum(axis=1) / mx.count(axis=1)
return torch.as_tensor(cluster_centroids)


def make_reproduction_tensor(
clusters: torch.Tensor,
*,
Expand Down