Skip to content

Commit

Permalink
option to allow for segments longer than max
Browse files Browse the repository at this point in the history
  • Loading branch information
johntsi committed Jan 31, 2023
1 parent e514944 commit 418b5e6
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/supervised_hybrid/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def pdac(
max_segment_length: float,
min_segment_length: float,
threshold: float,
not_strict: bool
) -> list[Segment]:
"""applies the probabilistic Divide-and-Conquer algorithm to split an audio
into segments satisfying the max-segment-length and min-segment-length conditions
Expand All @@ -100,6 +101,7 @@ def pdac(
max_segment_length (float): the maximum length of a segment
min_segment_length (float): the minimum length of a segment
threshold (float): probability threshold
not_strict (bool): whether segments longer than max are allowed
Returns:
list[Segment]: resulting segmentation
Expand All @@ -118,7 +120,7 @@ def recusrive_split(sgm):
while j < len(sorted_indices):
split_idx = sorted_indices[j]
split_prob = sgm.probs[split_idx]
if split_prob > threshold:
if not_strict and split_prob > threshold:
segments.append(sgm)
break

Expand All @@ -132,7 +134,13 @@ def recusrive_split(sgm):
break
j += 1
else:
segments.append(sgm)
if not_strict:
segments.append(sgm)
else:
if sgm_a.duration > min_segment_length:
recusrive_split(sgm_a)
if sgm_b.duration > min_segment_length:
recusrive_split(sgm_b)

recusrive_split(sgm)

Expand Down Expand Up @@ -231,6 +239,7 @@ def segment(args):
args.dac_max_segment_length,
args.dac_min_segment_length,
args.dac_threshold,
args.not_strict
)

yaml_content = update_yaml_content(yaml_content, segments, wav_path.name)
Expand Down Expand Up @@ -316,6 +325,13 @@ def segment(args):
help="after each split by the algorithm, the resulting segments are trimmed to"
"the first and last points that corresponds to a probability above this value",
)
parser.add_argument(
"--not_strict",
action="store_true",
help="whether segments longer than max are allowed."
"If this argument is used, respecting the classification threshold conditions (p > thr)"
"is more important than the length conditions (len < max)."
)
args = parser.parse_args()

segment(args)

0 comments on commit 418b5e6

Please sign in to comment.