Skip to content

Commit

Permalink
fix: handle difference in xtts/tortoise attention (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard authored Dec 9, 2024
1 parent b545ab8 commit c0d9ed3
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 111 deletions.
17 changes: 11 additions & 6 deletions TTS/tts/layers/tortoise/arch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ def forward(self, qkv, mask=None, rel_pos=None):
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)

return a.reshape(bs, -1, length)
Expand All @@ -93,7 +92,9 @@ def __init__(
channels,
num_heads=1,
num_head_channels=-1,
*,
relative_pos_embeddings=False,
tortoise_norm=False,
):
super().__init__()
self.channels = channels
Expand All @@ -108,6 +109,7 @@ def __init__(
self.qkv = nn.Conv1d(channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.tortoise_norm = tortoise_norm

self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
Expand All @@ -124,10 +126,13 @@ def __init__(
def forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
x_norm = self.norm(x)
qkv = self.qkv(x_norm)
h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
if self.tortoise_norm:
return (x + h).reshape(b, c, *spatial)
return (x_norm + h).reshape(b, c, *spatial)


class Upsample(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/tortoise/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,14 @@ def __init__(
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
*,
tortoise_norm=False,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=tortoise_norm))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/tortoise/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=True))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

Expand Down
21 changes: 13 additions & 8 deletions TTS/tts/layers/tortoise/diffusion_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, model_channels, dropout, num_heads):
dims=1,
use_scale_shift_norm=True,
)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True)

def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
Expand Down Expand Up @@ -177,17 +177,17 @@ def __init__(
# transformer network.
self.code_embedding = nn.Embedding(in_tokens, model_channels)
self.code_converter = nn.Sequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
)
self.code_norm = normalization(model_channels)
self.latent_conditioner = nn.Sequential(
nn.Conv1d(in_latent_channels, model_channels, 3, padding=1),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True),
)
self.contextual_embedder = nn.Sequential(
nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
Expand All @@ -196,26 +196,31 @@ def __init__(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
tortoise_norm=True,
),
)
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
Expand Down
95 changes: 0 additions & 95 deletions TTS/tts/layers/xtts/latent_encoder.py

This file was deleted.

0 comments on commit c0d9ed3

Please sign in to comment.