Skip to content

Commit

Permalink
Implement reproduction operator on clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Feb 16, 2022
1 parent 1953ecd commit 1034dc9
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 33 deletions.
45 changes: 18 additions & 27 deletions test/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,36 @@

import pytest
import torch
from tspyro.cluster import check_sparse_genotypes
from tspyro.cluster import make_clustering_gibbs


def make_fake_data(num_samples, num_variants):
dense_values = torch.full((num_samples, num_variants), 0.5).bernoulli_().bool()
dense_mask = torch.full((num_samples, num_variants), 0.5).bernoulli_().bool()

offsets = []
index = []
values = []

end = 0
for data, mask in zip(dense_values, dense_mask):
index.append(mask.nonzero(as_tuple=True)[0])
values.append(data[mask])
beg, end = end, end + len(index[-1])
offsets.append([beg, end])

offsets = torch.tensor(offsets)
index = torch.cat(index)
values = torch.cat(values)
data = dict(offsets=offsets, index=index, values=values)
check_sparse_genotypes(data)

return data
from tspyro.cluster import make_fake_data
from tspyro.cluster import make_reproduction_tensor


@pytest.mark.parametrize("num_epochs", [1, 10])
@pytest.mark.parametrize("num_clusters", [5])
@pytest.mark.parametrize("num_samples", [5, 10, 100])
@pytest.mark.parametrize("num_variants", [20])
def test_clustering_gibbs(num_variants, num_samples, num_clusters, num_epochs):
def test_clustering(num_variants, num_samples, num_clusters, num_epochs):
print(f"Making fake data of size {num_samples} x {num_variants}")
data = make_fake_data(num_samples, num_variants)

print(f"Creating {num_clusters} clusters via {num_epochs} epochs")
clusters = make_clustering_gibbs(data, num_clusters, num_epochs=num_epochs)
assert clusters.shape == (num_variants, num_clusters)

print("Creating a reproduction tensor")
crossover_rate = torch.rand(num_variants - 1) * 0.01
mutation_rate = torch.rand(num_variants) * 0.001
reproduce = make_reproduction_tensor(
clusters,
crossover_rate=crossover_rate,
mutation_rate=mutation_rate,
)
assert isinstance(reproduce, torch.Tensor)
assert reproduce.shape == (num_clusters, num_clusters, num_clusters)
assert torch.allclose(reproduce, reproduce.transpose(0, 1))
assert torch.allclose(reproduce.sum(-1), torch.ones(num_clusters, num_clusters))


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="benchmark clustering algorithm")
Expand All @@ -50,6 +41,6 @@ def test_clustering_gibbs(num_variants, num_samples, num_clusters, num_epochs):
parser.add_argument("-e", "--num-epochs", default=10, type=int)
args = parser.parse_args()

test_clustering_gibbs(
test_clustering(
args.num_variants, args.num_samples, args.num_clusters, args.num_epochs
)
107 changes: 101 additions & 6 deletions tspyro/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,36 @@ def check_sparse_genotypes(data: dict):
return offsets, index, values


def make_fake_data(num_samples, num_variants):
"""
Make fake sparse genotype data (for testing and benchmarking).
:returns: A dict representing sparse genotypes.
:rtype: dict
"""
dense_values = torch.full((num_samples, num_variants), 0.5).bernoulli_().bool()
dense_mask = torch.full((num_samples, num_variants), 0.5).bernoulli_().bool()

offsets = []
index = []
values = []

end = 0
for data, mask in zip(dense_values, dense_mask):
index.append(mask.nonzero(as_tuple=True)[0])
values.append(data[mask])
beg, end = end, end + len(index[-1])
offsets.append([beg, end])

offsets = torch.tensor(offsets)
index = torch.cat(index)
values = torch.cat(values)
data = dict(offsets=offsets, index=index, values=values)
check_sparse_genotypes(data)

return data


def make_clustering_gibbs(
data: dict,
num_clusters: int,
Expand All @@ -54,14 +84,14 @@ def make_clustering_gibbs(
offsets, index, values = check_sparse_genotypes(data)
N = len(offsets)
P = 1 + int(index.max())
K = num_clusters
C = num_clusters
MISSING = -1
prior = 0.5 # Jeffreys prior

# The single site Gibbs algorithm treats assignment as the latent variable,
# and caches sufficient statistics in the counts tensor.
assignment = torch.full((N,), MISSING, dtype=torch.long)
counts = torch.full((P, K, 2), prior, dtype=torch.float)
counts = torch.full((P, C, 2), prior, dtype=torch.float)
assert N < 2 ** 23, "counts has too little precision"

# Use a shuffled linearly annealed schedule.
Expand All @@ -85,15 +115,15 @@ def make_clustering_gibbs(
logits = torch.where(value_n[:, None], heads, tails)
logits = logits.div_(posterior.sum(-1)).log_().sum(0)
logits -= logits.max()
k = int(logits.exp_().multinomial(1))
assignment[n] = k
c = int(logits.exp_().multinomial(1))
assignment[n] = c
else:
# Remove the datum from the current cluster.
assert assignment[n] != MISSING
k = int(assignment[n])
c = int(assignment[n])
assignment[n] = MISSING

counts[index_n, k, value_n.long()] += sign
counts[index_n, c, value_n.long()] += sign
assert all(assignment >= 0)
assert pending[+1] % N == 0
assert pending[-1] % N == 0
Expand All @@ -102,3 +132,68 @@ def make_clustering_gibbs(

clusters = (counts[..., 1] / counts.sum(-1)).round_().bool()
return clusters


def make_reproduction_tensor(
clusters: torch.Tensor,
*,
crossover_rate: torch.Tensor,
mutation_rate: torch.Tensor,
) -> torch.Tensor:
"""
Computes pairwise conditional probabilities of sexual reproduction
(crossover + mutation) using a pair HMM over genotypes.
:param torch.Tensor clusters: A ``(num_variants, num_clusters)``-shaped
array of clusters.
:param torch.Tensor crossover_rate: A ``num_variants-1``-long vector of
probabilties of crossover between successive variants.
:param torch.Tensor mutation_rate: A ``num_variants``-long vector of
mutation probabilites of mutation at each variant site.
:returns: A reproduction tensor of shape ``(C, C, C)`` where ``C`` is the
number of clusters. This tensor is symmetric in the first two axes and
a normalized probability mass function over the third axis.
:rtype: torch.Tensor
"""
P, C = clusters.shape
assert crossover_rate.shape == (P - 1,)
assert mutation_rate.shape == (P,)

# Construct a transition matrix.
p = crossover_rate.neg().exp().mul(0.5).add(0.5)
transition = torch.zeros(P - 1, 2, 2)
transition[:, 0, 0] = p
transition[:, 0, 1] = 1 - p
transition[:, 1, 0] = 1 - p
transition[:, 1, 1] = p

# Construct an mutation matrix.
p = mutation_rate.neg().exp().mul(0.5).add(0.5)
mutate = torch.zeros(P, 2, 2)
mutate[:, 0, 0] = p
mutate[:, 0, 1] = 1 - p
mutate[:, 1, 0] = 1 - p
mutate[:, 1, 1] = p

# Apply pair HMM along each genotype.
result = torch.zeros(C, C, C)
state = torch.full((C, C, C, 2), 0.5)
for p in tqdm.tqdm(range(P)):
# Update with observation + mutation noise.
c = clusters[p].long()
state[..., 0] *= mutate[p][c[:, None, None], c]
state[..., 1] *= mutate[p][c[None, :, None], c]

# Transition via crossover.
if p < P - 1:
state = state @ transition[p]

# Numerically stabilize by moving norm to the result.
total = state.sum(-1)
state /= total.unsqueeze(-1)
result += total.log()

# Convert to a probability tensor.
result -= result.logsumexp(-1, True)
result.exp_()
return result

0 comments on commit 1034dc9

Please sign in to comment.