-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasr_utils.py
49 lines (44 loc) · 2.41 KB
/
asr_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import alex_asr.fst as fst
from math import *
def lattice_calibration(lat):
def find_approx(weight):
calibration_table = [
(0.999999999999975, 2.0, 0.88018348094713794),
(0.9999999999961884, 0.999999999999975, 0.80349730129423402),
(0.999999999830834, 0.9999999999961884, 0.73428284555412904),
(0.9999999953944774, 0.999999999830834, 0.67181212520944233),
(0.9999999061531399, 0.9999999953944774, 0.61542808146230965),
(0.9999985463132696, 0.9999999061531399, 0.5645376743766013),
(0.9999801852444311, 0.9999985463132696, 0.51860564536208376),
(0.9997798377111068, 0.9999801852444311, 0.47714888739515909),
(0.9979068550339635, 0.9997798377111068, 0.43973136376290312),
(0.9855103527780578, 0.9979068550339635, 0.40595952188642426),
(0.9107043589029054, 0.9855103527780578, 0.37547815398674106),
(0.6495785364198595, 0.9107043589029054, 0.34796666105620766),
(0.29778636903306327, 0.6495785364198595, 0.32313568084043276),
(0.08195954925619552, 0.29778636903306327, 0.30072404436424938),
(0.017934535127822632, 0.08195954925619552, 0.28049602899087711),
(0.0008663571990016966, 0.017934535127822632, 0.24576056846402694),
(4.3793087001897236e-05, 0.0008663571990016966, 0.21746408465901332),
(2.438732279388375e-06, 4.3793087001897236e-05, 0.19441297252222531),
(3.773110309325966e-08, 2.438732279388375e-06, 0.16759461792800706),
(1.898540256696829e-10, 3.773110309325966e-08, 0.14254056133908768),
(9.034359693017792e-14, 1.898540256696829e-10, 0.11983387028320044),
(3.668374817174373e-20, 9.034359693017792e-14, 0.099474926398142052),
(0.0, 3.668374817174373e-20, 0.096552470524865874),
]
for i, (min, max, p) in enumerate(calibration_table):
if min <= weight < max:
return p
print "Lattice calibration warning: cannot map input score."
return weight
for state in lat.states:
cum = 0.0
for arc in state.arcs:
weight = exp(-float(arc.weight))
aprx = find_approx(weight)
cum +=aprx
arc.weight = fst.LogWeight(-log(aprx))
for arc in state.arcs:
arc.weight = fst.LogWeight( -log(exp(-float(arc.weight)) / cum) )
return lat