From 2600564144fd475fe206fbe7eee9e9c85376a4a4 Mon Sep 17 00:00:00 2001 From: Wilder Wohns Date: Fri, 11 Feb 2022 14:27:13 -0500 Subject: [PATCH] make cluster centroids --- tspyro/cluster.py | 80 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/tspyro/cluster.py b/tspyro/cluster.py index 1532af5..b67be48 100644 --- a/tspyro/cluster.py +++ b/tspyro/cluster.py @@ -1,8 +1,12 @@ +import collections +import math from typing import Tuple +import networkx as nx import numpy as np import torch import tqdm +from community import community_louvain def check_sparse_genotypes(data: dict): @@ -119,7 +123,7 @@ def naive_encoder(ts): offsets = [] index = [] values = [] - for haplo in tqdm(ts.haplotypes(), total=ts.num_samples): + for haplo in tqdm.tqdm(ts.haplotypes(), total=ts.num_samples): begin = len(index) for i, g in enumerate(haplo): assert g in "-01" @@ -148,7 +152,7 @@ def variant_encoder(ts): offsets = [] index = [] values = [] - for var in tqdm(ts.variants(), total=ts.num_sites): + for var in tqdm.tqdm(ts.variants(), total=ts.num_sites): geno = var.genotypes begin = len(index) @@ -251,6 +255,78 @@ def make_clustering_gibbs( return assignment, clusters +# Functions borrowed from https://github.com/tskit-dev/what-is-an-arg-paper + + +def convert_nx(ts, mutation_rate=1e-8): + """ + Returns the specified tree sequence as an networkx directed graph. + Weights are the edge's total span * e^(-edge length * mutation_rate). + + Mutation rate is in units: probability of mutation per base pair per generation. + """ + G = nx.Graph() + edges = collections.defaultdict(list) + times = ts.tables.nodes.time + for edge in ts.edges(): + edges[(edge.child, edge.parent)].append( + (edge.left, edge.right, times[edge.parent], times[edge.child]) + ) + for node in ts.nodes(): + G.add_node(node.id, time=node.time, flags=node.flags) + for edge, intervals in edges.items(): + G.add_edge( + *edge, + weight=sum( + (right - left) * math.e ** (-(top - bottom) * mutation_rate) + for (left, right, top, bottom) in intervals + ), + ) + + return G + + +def make_clustering_louvain(ts, resolution, max_cluster_size): + """ + Clusters a tree sequence using the Louvain community detection algorithm. + + :param tskit.TreeSequence ts: The input tree sequence + :param float resolution: Passed to community_louvain: "Will change the size of + the communities, default to 1. represents the time described in + “Laplacian Dynamics and Multiscale Modular Structure in Networks”, R. + Lambiotte, J.-C. Delvenne, M. Barahona" + :param int max_cluster_size: Maximum cluster size to calculate centroids from, + usually 10000 nodes or less to avoid memory issues. + :returns: A ``(num_variants,num_clusters)``-shaped tensor of clusters centroids. + :rtype: torch.Tensor + """ + graph = convert_nx(ts) + + partition = community_louvain.best_partition( + graph, weight="weight", resolution=0.6 + ) # resolution > 1: fewer clusters. < 1: more clusters + assignment = torch.tensor(list(partition.values())) + clusters = torch.unique(assignment) + cluster_centroids = np.zeros((len(clusters), ts.num_sites)) + for cluster in tqdm.tqdm(clusters[:]): + cluster_assignment = torch.tensor(assignment)[assignment == cluster] + if sum(assignment == cluster) > max_cluster_size: + cluster_assignment = torch.randperm(len(cluster_assignment))[ + :max_cluster_size + ] + tables = ts.dump_tables() + arr = np.zeros_like(tables.nodes.flags) + arr[cluster_assignment] = np.ones_like(cluster_assignment) + tables.nodes.flags = arr + all_samples = tables.tree_sequence() + all_samples = all_samples.simplify(filter_sites=False) + geno = all_samples.genotype_matrix() + non_missing = geno == -1 + mx = np.ma.masked_array(geno, mask=non_missing) + cluster_centroids[cluster] = mx.sum(axis=1) / mx.count(axis=1) + return torch.as_tensor(cluster_centroids) + + def make_reproduction_tensor( clusters: torch.Tensor, *,