diff --git a/test/test_cluster.py b/test/test_cluster.py index 72e0b99..2f16d6d 100644 --- a/test/test_cluster.py +++ b/test/test_cluster.py @@ -5,6 +5,7 @@ from tspyro.cluster import make_clustering_gibbs from tspyro.cluster import make_fake_data from tspyro.cluster import make_reproduction_tensor +from tspyro.cluster import transpose_sparse @pytest.mark.parametrize("num_epochs", [1, 10]) @@ -16,7 +17,10 @@ def test_clustering(num_variants, num_samples, num_clusters, num_epochs): data = make_fake_data(num_samples, num_variants) print(f"Creating {num_clusters} clusters via {num_epochs} epochs") - clusters = make_clustering_gibbs(data, num_clusters, num_epochs=num_epochs) + assignment, clusters = make_clustering_gibbs( + data, num_clusters, num_epochs=num_epochs + ) + assert assignment.shape == (num_samples,) assert clusters.shape == (num_variants, num_clusters) print("Creating a reproduction tensor") @@ -33,6 +37,17 @@ def test_clustering(num_variants, num_samples, num_clusters, num_epochs): assert torch.allclose(reproduce.sum(-1), torch.ones(num_clusters, num_clusters)) +@pytest.mark.parametrize("num_samples", [5, 10, 100]) +@pytest.mark.parametrize("num_variants", [4, 8, 100]) +def test_transpose(num_variants, num_samples): + data1 = make_fake_data(num_samples, num_variants) + data2 = transpose_sparse(data1) + data3 = transpose_sparse(data2) + for k, v1 in data1.items(): + v3 = data3[k] + assert (v1 == v3).all() + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="benchmark clustering algorithm") parser.add_argument("-n", "--num-samples", default=1000, type=int) diff --git a/tspyro/cluster.py b/tspyro/cluster.py index a900aa4..fdb6351 100644 --- a/tspyro/cluster.py +++ b/tspyro/cluster.py @@ -1,3 +1,5 @@ +from typing import Tuple + import numpy as np import torch import tqdm @@ -33,6 +35,9 @@ def check_sparse_genotypes(data: dict): assert index.shape == (int(offsets[-1, -1]),) assert index.shape == values.shape + assert offsets[0, 0] == 0 + assert (offsets[1:, 0] == offsets[:-1, 1]).all() + return offsets, index, values @@ -66,6 +71,43 @@ def make_fake_data(num_samples, num_variants): return data +def transpose_sparse(data: dict, num_cols=0) -> dict: + """ + Convert from row-oriented to column-oriented. + This function is its own inverse. + + :param dict data: A dict representing sparse genotypes. + :param int num_cols: Optional number of columns. + :returns: A dict representing sparse genotypes. + :rtype: dict + """ + old_offsets, old_index, old_values = check_sparse_genotypes(data) + num_rows = len(old_offsets) + num_cols = max(num_cols, 1 + int(old_index.max())) + + # Initialize positions and create offsets. + position = torch.zeros(1 + num_cols, dtype=torch.long) + for r in range(num_rows): + beg, end = old_offsets[r].tolist() + index = old_index[beg:end] + position[1 + index] += 1 + position.cumsum_(0) + new_offsets = torch.stack([position[:-1], position[1:]], dim=-1) + + # Populate index and values. + new_index = torch.zeros_like(old_index) + new_values = torch.zeros_like(old_values) + for r in range(num_rows): + beg, end = old_offsets[r].tolist() + index = old_index[beg:end] + values = old_values[beg:end] + new_index[position[index]] = r + new_values[position[index]] = values + position[index] += 1 + + return dict(offsets=new_offsets, index=new_index, values=new_values) + + def naive_encoder(ts): """ Make an encoding of the sparse genotype data naively: haplotype by haplotype. @@ -145,7 +187,7 @@ def make_clustering_gibbs( num_clusters: int, *, num_epochs: int = 10, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Clusters sparse genotypes using subsample-annealed Gibbs sampling.