Skip to content

Commit

Permalink
Add need_repeat_flag in phone based ctc graph compiler (#727)
Browse files Browse the repository at this point in the history
* Fix is_repeat_token in icefall

* Fix phone based recipe

* Update egs/librispeech/ASR/conformer_ctc3/train.py

Co-authored-by: Fangjun Kuang <[email protected]>

* Fix black

Co-authored-by: Fangjun Kuang <[email protected]>
  • Loading branch information
pkufool and csukuangfj authored Dec 4, 2022
1 parent e6a6727 commit c25c8c6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions egs/librispeech/ASR/conformer_ctc3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,7 @@ def run(rank, world_size, args):
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
need_repeat_flag=params.delay_penalty > 0,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.
Expand Down
18 changes: 14 additions & 4 deletions icefall/graph_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
lexicon: Lexicon,
device: torch.device,
oov: str = "<UNK>",
need_repeat_flag: bool = False,
):
"""
Args:
Expand All @@ -39,6 +40,13 @@ def __init__(
oov:
Out of vocabulary word. When a word in the transcript
does not exist in the lexicon, it is replaced with `oov`.
need_repeat_flag:
If True, will add an attribute named `_is_repeat_token_` to ctc_topo
indicating whether this token is a repeat token in ctc graph.
This attribute is needed to implement delay-penalty for phone-based
ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more
details. Note: The above change MUST be included in k2 to open this
flag.
"""
L_inv = lexicon.L_inv.to(device)
assert L_inv.requires_grad is False
Expand All @@ -53,6 +61,12 @@ def __init__(
ctc_topo = k2.ctc_topo(max_token_id, modified=False)

self.ctc_topo = ctc_topo.to(device)

if need_repeat_flag:
self.ctc_topo._is_repeat_token_ = (
self.ctc_topo.labels != self.ctc_topo.aux_labels
)

self.device = device

def compile(self, texts: List[str]) -> k2.Fsa:
Expand All @@ -79,10 +93,6 @@ def compile(self, texts: List[str]) -> k2.Fsa:

fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)

self.ctc_topo._is_repeat_token_ = (
self.ctc_topo.labels != self.ctc_topo.aux_labels
).int()

decoding_graph = k2.compose(
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
)
Expand Down

0 comments on commit c25c8c6

Please sign in to comment.