Skip to content

Commit

Permalink
Add transpose_sparse() op
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Feb 18, 2022
1 parent 36ed437 commit b22b945
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
17 changes: 16 additions & 1 deletion test/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tspyro.cluster import make_clustering_gibbs
from tspyro.cluster import make_fake_data
from tspyro.cluster import make_reproduction_tensor
from tspyro.cluster import transpose_sparse


@pytest.mark.parametrize("num_epochs", [1, 10])
Expand All @@ -16,7 +17,10 @@ def test_clustering(num_variants, num_samples, num_clusters, num_epochs):
data = make_fake_data(num_samples, num_variants)

print(f"Creating {num_clusters} clusters via {num_epochs} epochs")
clusters = make_clustering_gibbs(data, num_clusters, num_epochs=num_epochs)
assignment, clusters = make_clustering_gibbs(
data, num_clusters, num_epochs=num_epochs
)
assert assignment.shape == (num_samples,)
assert clusters.shape == (num_variants, num_clusters)

print("Creating a reproduction tensor")
Expand All @@ -33,6 +37,17 @@ def test_clustering(num_variants, num_samples, num_clusters, num_epochs):
assert torch.allclose(reproduce.sum(-1), torch.ones(num_clusters, num_clusters))


@pytest.mark.parametrize("num_samples", [5, 10, 100])
@pytest.mark.parametrize("num_variants", [4, 8, 100])
def test_transpose(num_variants, num_samples):
data1 = make_fake_data(num_samples, num_variants)
data2 = transpose_sparse(data1)
data3 = transpose_sparse(data2)
for k, v1 in data1.items():
v3 = data3[k]
assert (v1 == v3).all()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="benchmark clustering algorithm")
parser.add_argument("-n", "--num-samples", default=1000, type=int)
Expand Down
44 changes: 43 additions & 1 deletion tspyro/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import numpy as np
import torch
import tqdm
Expand Down Expand Up @@ -33,6 +35,9 @@ def check_sparse_genotypes(data: dict):
assert index.shape == (int(offsets[-1, -1]),)
assert index.shape == values.shape

assert offsets[0, 0] == 0
assert (offsets[1:, 0] == offsets[:-1, 1]).all()

return offsets, index, values


Expand Down Expand Up @@ -66,6 +71,43 @@ def make_fake_data(num_samples, num_variants):
return data


def transpose_sparse(data: dict, num_cols=0) -> dict:
"""
Convert from row-oriented to column-oriented.
This function is its own inverse.
:param dict data: A dict representing sparse genotypes.
:param int num_cols: Optional number of columns.
:returns: A dict representing sparse genotypes.
:rtype: dict
"""
old_offsets, old_index, old_values = check_sparse_genotypes(data)
num_rows = len(old_offsets)
num_cols = max(num_cols, 1 + int(old_index.max()))

# Initialize positions and create offsets.
position = torch.zeros(1 + num_cols, dtype=torch.long)
for r in range(num_rows):
beg, end = old_offsets[r].tolist()
index = old_index[beg:end]
position[1 + index] += 1
position.cumsum_(0)
new_offsets = torch.stack([position[:-1], position[1:]], dim=-1)

# Populate index and values.
new_index = torch.zeros_like(old_index)
new_values = torch.zeros_like(old_values)
for r in range(num_rows):
beg, end = old_offsets[r].tolist()
index = old_index[beg:end]
values = old_values[beg:end]
new_index[position[index]] = r
new_values[position[index]] = values
position[index] += 1

return dict(offsets=new_offsets, index=new_index, values=new_values)


def naive_encoder(ts):
"""
Make an encoding of the sparse genotype data naively: haplotype by haplotype.
Expand Down Expand Up @@ -145,7 +187,7 @@ def make_clustering_gibbs(
num_clusters: int,
*,
num_epochs: int = 10,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Clusters sparse genotypes using subsample-annealed Gibbs sampling.
Expand Down

0 comments on commit b22b945

Please sign in to comment.