This repository has been archived by the owner on Sep 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathdump_mfcc_feature.py
117 lines (95 loc) · 3.72 KB
/
dump_mfcc_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import os
import sys
import soundfile as sf
import torch
import torchaudio
import tqdm
from npy_append_array import NpyAppendArray
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("dump_mfcc_feature")
class MfccFeatureReader(object):
def __init__(self, sample_rate):
self.sample_rate = sample_rate
def read_audio(self, path, ref_len=None):
wav, sr = sf.read(path)
assert sr == self.sample_rate, sr
if wav.ndim == 2:
wav = wav.mean(-1)
assert wav.ndim == 1, wav.ndim
if ref_len is not None and abs(ref_len - len(wav)) > 160:
logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
return wav
def get_feats(self, path, ref_len=None):
x = self.read_audio(path, ref_len)
with torch.no_grad():
x = torch.from_numpy(x).float()
x = x.view(1, -1)
mfccs = torchaudio.compliance.kaldi.mfcc(
waveform=x,
sample_frequency=self.sample_rate,
use_energy=False,
) # (time, freq)
mfccs = mfccs.transpose(0, 1) # (freq, time)
deltas = torchaudio.functional.compute_deltas(mfccs)
ddeltas = torchaudio.functional.compute_deltas(deltas)
concat = torch.cat([mfccs, deltas, ddeltas], dim=0)
concat = concat.transpose(0, 1).contiguous() # (freq, time)
return concat
def get_path_iterator(tsv, nshard, rank):
with open(tsv, "r") as f:
root = f.readline().rstrip()
lines = [line.rstrip() for line in f]
tot = len(lines)
shard_size = math.ceil(tot / nshard)
start, end = rank * shard_size, min((rank + 1) * shard_size, tot)
assert start < end, "start={start}, end={end}"
logger.info(
f"rank {rank} of {nshard}, process {end-start} "
f"({start}-{end}) out of {tot}"
)
lines = lines[start:end]
def iterate():
for line in lines:
_, video_path, wav_path, nsample_video, nsample_wav = line.split("\t")
yield f"{root}/{wav_path}", int(nsample_wav)
return iterate, len(lines)
def dump_feature(tsv_dir, split, nshard, rank, feat_dir, sample_rate=16_000):
reader = MfccFeatureReader(sample_rate)
generator, num = get_path_iterator(f"{tsv_dir}/{split}.tsv", nshard, rank)
iterator = generator()
feat_path = f"{feat_dir}/{split}_{rank}_{nshard}.npy"
leng_path = f"{feat_dir}/{split}_{rank}_{nshard}.len"
os.makedirs(feat_dir, exist_ok=True)
if os.path.exists(feat_path):
os.remove(feat_path)
feat_f = NpyAppendArray(feat_path)
with open(leng_path, "w") as leng_f:
for path, nsample in tqdm.tqdm(iterator, total=num):
feat = reader.get_feats(path, nsample)
feat_f.append(feat.cpu().numpy())
leng_f.write(f"{len(feat)}\n")
logger.info("finished successfully")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("tsv_dir")
parser.add_argument("split")
parser.add_argument("nshard", type=int)
parser.add_argument("rank", type=int)
parser.add_argument("feat_dir")
parser.add_argument("--sample_rate", type=int, default=16000)
args = parser.parse_args()
logger.info(args)
dump_feature(**vars(args))