From c0d9ed3d18b708956d9d9d7c43b0c591d66db996 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 9 Dec 2024 16:13:13 +0100 Subject: [PATCH] fix: handle difference in xtts/tortoise attention (#199) --- TTS/tts/layers/tortoise/arch_utils.py | 17 ++-- TTS/tts/layers/tortoise/autoregressive.py | 4 +- TTS/tts/layers/tortoise/classifier.py | 2 +- TTS/tts/layers/tortoise/diffusion_decoder.py | 21 +++-- TTS/tts/layers/xtts/latent_encoder.py | 95 -------------------- 5 files changed, 28 insertions(+), 111 deletions(-) delete mode 100644 TTS/tts/layers/xtts/latent_encoder.py diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index 4c3733e691..1bbf676393 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -70,11 +70,10 @@ def forward(self, qkv, mask=None, rel_pos=None): weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( bs * self.n_heads, weight.shape[-2], weight.shape[-1] ) - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: - # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. - mask = mask.repeat(self.n_heads, 1).unsqueeze(1) - weight = weight * mask + mask = mask.repeat(self.n_heads, 1, 1) + weight[mask.logical_not()] = -torch.inf + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -93,7 +92,9 @@ def __init__( channels, num_heads=1, num_head_channels=-1, + *, relative_pos_embeddings=False, + tortoise_norm=False, ): super().__init__() self.channels = channels @@ -108,6 +109,7 @@ def __init__( self.qkv = nn.Conv1d(channels, channels * 3, 1) # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) + self.tortoise_norm = tortoise_norm self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) if relative_pos_embeddings: @@ -124,10 +126,13 @@ def __init__( def forward(self, x, mask=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) + x_norm = self.norm(x) + qkv = self.qkv(x_norm) h = self.attention(qkv, mask, self.relative_pos_embeddings) h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) + if self.tortoise_norm: + return (x + h).reshape(b, c, *spatial) + return (x_norm + h).reshape(b, c, *spatial) class Upsample(nn.Module): diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 07cf3d542b..00c884e973 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -176,12 +176,14 @@ def __init__( embedding_dim, attn_blocks=6, num_attn_heads=4, + *, + tortoise_norm=False, ): super().__init__() attn = [] self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=tortoise_norm)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim diff --git a/TTS/tts/layers/tortoise/classifier.py b/TTS/tts/layers/tortoise/classifier.py index c72834e9a8..337323db67 100644 --- a/TTS/tts/layers/tortoise/classifier.py +++ b/TTS/tts/layers/tortoise/classifier.py @@ -97,7 +97,7 @@ def __init__( self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) attn = [] for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + attn.append(AttentionBlock(embedding_dim, num_attn_heads, tortoise_norm=True)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim diff --git a/TTS/tts/layers/tortoise/diffusion_decoder.py b/TTS/tts/layers/tortoise/diffusion_decoder.py index 15bbfb7121..cfdeaff8bb 100644 --- a/TTS/tts/layers/tortoise/diffusion_decoder.py +++ b/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -130,7 +130,7 @@ def __init__(self, model_channels, dropout, num_heads): dims=1, use_scale_shift_norm=True, ) - self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True) def forward(self, x, time_emb): y = self.resblk(x, time_emb) @@ -177,17 +177,17 @@ def __init__( # transformer network. self.code_embedding = nn.Embedding(in_tokens, model_channels) self.code_converter = nn.Sequential( - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), ) self.code_norm = normalization(model_channels) self.latent_conditioner = nn.Sequential( nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), - AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, tortoise_norm=True), ) self.contextual_embedder = nn.Sequential( nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), @@ -196,26 +196,31 @@ def __init__( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, + tortoise_norm=True, ), ) self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1)) diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py deleted file mode 100644 index 6becffb8b7..0000000000 --- a/TTS/tts/layers/xtts/latent_encoder.py +++ /dev/null @@ -1,95 +0,0 @@ -# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts - -import math - -import torch -from torch import nn -from torch.nn import functional as F - -from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module - - -def conv_nd(dims, *args, **kwargs): - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class QKVAttention(nn.Module): - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv, mask=None, qk_bias=0): - """ - Apply QKV attention. - - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = weight + qk_bias - if mask is not None: - mask = mask.repeat(self.n_heads, 1, 1) - weight[mask.logical_not()] = -torch.inf - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v) - - return a.reshape(bs, -1, length) - - -class AttentionBlock(nn.Module): - """An attention block that allows spatial positions to attend to each other.""" - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - out_channels=None, - do_activation=False, - ): - super().__init__() - self.channels = channels - out_channels = channels if out_channels is None else out_channels - self.do_activation = do_activation - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, out_channels * 3, 1) - self.attention = QKVAttention(self.num_heads) - - self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) - self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) - - def forward(self, x, mask=None, qk_bias=0): - b, c, *spatial = x.shape - if mask is not None: - if len(mask.shape) == 2: - mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) - if mask.shape[1] != x.shape[-1]: - mask = mask[:, : x.shape[-1], : x.shape[-1]] - - x = x.reshape(b, c, -1) - x = self.norm(x) - if self.do_activation: - x = F.silu(x, inplace=True) - qkv = self.qkv(x) - h = self.attention(qkv, mask=mask, qk_bias=qk_bias) - h = self.proj_out(h) - xp = self.x_proj(x) - return (xp + h).reshape(b, xp.shape[1], *spatial)