Skip to content

Commit

Permalink
Relax mutation_rate parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Feb 18, 2022
1 parent b22b945 commit 94a3af6
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tspyro/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,16 @@ def make_reproduction_tensor(
array of clusters.
:param torch.Tensor crossover_rate: A ``num_variants-1``-long vector of
probabilties of crossover between successive variants.
:param torch.Tensor mutation_rate: A ``num_variants``-long vector of
mutation probabilites of mutation at each variant site.
:param torch.Tensor mutation_rate: A scalar or ``num_variants``-long vector
of mutation probabilites of mutation at each variant site.
:returns: A reproduction tensor of shape ``(C, C, C)`` where ``C`` is the
number of clusters. This tensor is symmetric in the first two axes and
a normalized probability mass function over the third axis.
:rtype: torch.Tensor
"""
P, C = clusters.shape
assert crossover_rate.shape == (P - 1,)
mutation_rate = torch.as_tensor(mutation_rate, dtype=torch.float).expand(P)
assert mutation_rate.shape == (P,)

# Construct a transition matrix.
Expand Down

0 comments on commit 94a3af6

Please sign in to comment.