diff --git a/src/supervised_hybrid/segment.py b/src/supervised_hybrid/segment.py index 754a30a..103de5f 100644 --- a/src/supervised_hybrid/segment.py +++ b/src/supervised_hybrid/segment.py @@ -90,6 +90,7 @@ def pdac( max_segment_length: float, min_segment_length: float, threshold: float, + not_strict: bool ) -> list[Segment]: """applies the probabilistic Divide-and-Conquer algorithm to split an audio into segments satisfying the max-segment-length and min-segment-length conditions @@ -100,6 +101,7 @@ def pdac( max_segment_length (float): the maximum length of a segment min_segment_length (float): the minimum length of a segment threshold (float): probability threshold + not_strict (bool): whether segments longer than max are allowed Returns: list[Segment]: resulting segmentation @@ -118,7 +120,7 @@ def recusrive_split(sgm): while j < len(sorted_indices): split_idx = sorted_indices[j] split_prob = sgm.probs[split_idx] - if split_prob > threshold: + if not_strict and split_prob > threshold: segments.append(sgm) break @@ -132,7 +134,13 @@ def recusrive_split(sgm): break j += 1 else: - segments.append(sgm) + if not_strict: + segments.append(sgm) + else: + if sgm_a.duration > min_segment_length: + recusrive_split(sgm_a) + if sgm_b.duration > min_segment_length: + recusrive_split(sgm_b) recusrive_split(sgm) @@ -231,6 +239,7 @@ def segment(args): args.dac_max_segment_length, args.dac_min_segment_length, args.dac_threshold, + args.not_strict ) yaml_content = update_yaml_content(yaml_content, segments, wav_path.name) @@ -316,6 +325,13 @@ def segment(args): help="after each split by the algorithm, the resulting segments are trimmed to" "the first and last points that corresponds to a probability above this value", ) + parser.add_argument( + "--not_strict", + action="store_true", + help="whether segments longer than max are allowed." + "If this argument is used, respecting the classification threshold conditions (p > thr)" + "is more important than the length conditions (len < max)." + ) args = parser.parse_args() segment(args)