-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodels.py
170 lines (140 loc) · 6.05 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class SincConv_fast(nn.Module):
"""Sinc-based convolution
Parameters
----------
in_channels : `int`
Number of input channels. Must be 1.
out_channels : `int`
Number of filters.
kernel_size : `int`
Filter length.
sample_rate : `int`, optional
Sample rate. Defaults to 16000.
Usage
-----
See `torch.nn.Conv1d`
Reference
---------
Mirco Ravanelli, Yoshua Bengio,
"Speaker Recognition from raw waveform with SincNet".
https://arxiv.org/abs/1808.00158
"""
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1,
stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50):
super(SincConv_fast,self).__init__()
if in_channels != 1:
#msg = (f'SincConv only support one input channel '
# f'(here, in_channels = {in_channels:d}).')
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size%2==0:
self.kernel_size=self.kernel_size+1
self.stride = stride
self.padding = padding
self.dilation = dilation
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
# initialize filterbanks such that they are equally spaced in Mel scale
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(self.to_mel(low_hz),
self.to_mel(high_hz),
self.out_channels + 1)
hz = self.to_hz(mel)
# filter lower frequency (out_channels, 1)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
# filter frequency band (out_channels, 1)
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
# Hamming window
#self.window_ = torch.hamming_window(self.kernel_size)
n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window
self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);
# (1, kernel_size/2)
n = (self.kernel_size - 1) / 2.0
self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes
def forward(self, waveforms):
"""
Parameters
----------
waveforms : `torch.Tensor` (batch_size, 1, n_samples)
Batch of waveforms.
Returns
-------
features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
Batch of sinc filters activations.
"""
self.n_ = self.n_.to(waveforms.device)
self.window_ = self.window_.to(waveforms.device)
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2)
band=(high-low)[:,0]
f_times_t_low = torch.matmul(low, self.n_)
f_times_t_high = torch.matmul(high, self.n_)
band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations.
band_pass_center = 2*band.view(-1,1)
band_pass_right= torch.flip(band_pass_left,dims=[1])
band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)
band_pass = band_pass / (2*band[:,None])
self.filters = (band_pass).view(
self.out_channels, 1, self.kernel_size)
return F.conv1d(waveforms, self.filters, stride=self.stride,
padding=self.padding, dilation=self.dilation,
bias=None, groups=1)
class myResnet(nn.Module):
def __init__(self, pretrained=True):
super(myResnet, self).__init__()
self.model = models.resnet18(pretrained=True)
self.model.fc = nn.Linear(512, 10, bias=True)
def forward(self, x):
x = self.model(x)
return x
class MS_SincResNet(nn.Module):
def __init__(self):
super(MS_SincResNet, self).__init__()
self.layerNorm = nn.LayerNorm([1, 48000])
self.sincNet1 = nn.Sequential(
SincConv_fast(out_channels=160, kernel_size=251),
nn.BatchNorm1d(160),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool1d(1024))
self.sincNet2 = nn.Sequential(
SincConv_fast(out_channels=160, kernel_size=501),
nn.BatchNorm1d(160),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool1d(1024))
self.sincNet3 = nn.Sequential(
SincConv_fast(out_channels=160, kernel_size=1001),
nn.BatchNorm1d(160),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool1d(1024))
self.resnet = myResnet(pretrained=True)
def forward(self, x):
x = self.layerNorm(x)
feat1 = self.sincNet1(x)
feat2 = self.sincNet2(x)
feat3 = self.sincNet3(x)
x = torch.cat((feat1.unsqueeze_(dim=1),
feat2.unsqueeze_(dim=1),
feat3.unsqueeze_(dim=1)), dim=1)
x = self.resnet(x)
return x, feat1, feat2, feat3