-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathreadme.py
77 lines (56 loc) · 2.19 KB
/
readme.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
class MyVolumeModel(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# do the neural net magic!
x = x * 2
return x
from audacitorch import WaveformToWaveformBase
class MyVolumeModelWrapper(WaveformToWaveformBase):
def do_forward_pass(self, x: torch.Tensor) -> torch.Tensor:
# do any preprocessing here!
# expect x to be a waveform tensor with shape (n_channels, n_samples)
output = self.model(x)
# do any postprocessing here!
# the return value should be a multichannel waveform tensor with shape (n_channels, n_samples)
return output
metadata = {
'sample_rate': 48000,
'domain_tags': ['music', 'speech', 'environmental'],
'short_description': 'Use me to boost volume by 3dB :).',
'long_description': 'This description can be a max of 280 characters aaaaaaaaaaaaaaaaaaaa.',
'tags': ['volume boost'],
'labels': ['boosted'],
'effect_type': 'waveform-to-waveform',
'multichannel': False,
}
from pathlib import Path
from audacitorch.utils import save_model, validate_metadata, \
get_example_inputs, test_run
# create a root dir for our model
root = Path('booster-net')
root.mkdir(exist_ok=True, parents=True)
# get our model
model = MyVolumeModel()
# wrap it
wrapper = MyVolumeModelWrapper(model)
# serialize it using torch.jit.script, torch.jit.trace,
# or a combination of both.
# option 1: torch.jit.script
# using torch.jit.script is preferred for most cases,
# but may require changing a lot of source code
serialized_model = torch.jit.script(wrapper)
# option 2: torch.jit.trace
# using torch.jit.trace is typically easier, but you
# need to be extra careful that your serialized model behaves
# properly after tracing
example_inputs = get_example_inputs()
serialized_model = torch.jit.trace(wrapper, example_inputs[0],
check_inputs=example_inputs)
# take your model for a test run!
test_run(serialized_model)
# check that we created our metadata correctly
success, msg = validate_metadata(metadata)
assert success
# save!
save_model(serialized_model, metadata, root)