Skip to content

Commit

Permalink
Channel selection for multi-channel custom recording fields (#1299)
Browse files Browse the repository at this point in the history
* Channel selection for multi-channel custom recording fields

* fix

* fix exporting multicuts to shar
  • Loading branch information
pzelasko authored Mar 7, 2024
1 parent 7cc8fb4 commit b34e805
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 33 deletions.
21 changes: 17 additions & 4 deletions lhotse/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from lhotse import Recording
from lhotse.utils import ifnone
from lhotse.utils import fastcopy, ifnone


class CustomFieldMixin:
Expand Down Expand Up @@ -81,6 +81,14 @@ def __delattr__(self, key: str) -> None:
raise AttributeError(f"No such member: '{key}'")
del self.custom[key]

def with_custom(self, name: str, value: Any):
"""Return a copy of this object with an extra custom field assigned to it."""
cpy = fastcopy(
self, custom=self.custom.copy() if self.custom is not None else {}
)
cpy.custom[name] = value
return cpy

def load_custom(self, name: str) -> np.ndarray:
"""
Load custom data as numpy array. The custom data is expected to have
Expand All @@ -103,9 +111,14 @@ def load_custom(self, name: str) -> np.ndarray:
# TemporalArray supports slicing.
return value.load(start=self.start, duration=self.duration)
elif isinstance(value, Recording):
# Recording supports slicing. Note: we will not slice the channels
# as cut.channels referes to cut.recording and not the custom field.
return value.load_audio(offset=self.start, duration=self.duration)
# Recording supports slicing.
# Note: cut.channels referes to cut.recording and not the custom field.
# We have to use a special channel selector field instead; e.g.:
# if this is "target_recording", we'll look for "target_recording_channel_selector"
channels = self.custom.get(f"{name}_channel_selector")
return value.load_audio(
channels=channels, offset=self.start, duration=self.duration
)
else:
raise ValueError(
f"To load {name}, the cut needs to have field {name} (or cut.custom['{name}']) "
Expand Down
35 changes: 34 additions & 1 deletion lhotse/cut/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial, reduce
from itertools import groupby
from operator import add
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -365,6 +365,39 @@ def merge_supervisions(

return fastcopy(self, supervisions=msups)

def with_channels(self, channels: Union[List[int], int]) -> DataCut:
"""
Select specified channels from this cut.
Supports extending to other channels available in the underlying :class:`Recording`.
If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`,
otherwise we'll return a :class:`~lhotse.cut.MultiCut'.
"""
mono = isinstance(channels, int) or len(channels) == 1
assert set([channels] if mono else channels).issubset(
set(self.recording.channel_ids)
), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}"

if mono:
from .mono import MonoCut

if isinstance(channels, Sequence):
(channels,) = channels
return MonoCut(
id=f"{self.id}-{channels}",
recording=self.recording,
start=self.start,
duration=self.duration,
channel=channels,
supervisions=[
fastcopy(s, channel=channels)
for s in self.supervisions
if is_equal_or_contains(s.channel, channels)
],
custom=self.custom,
)

return fastcopy(self, channel=channels)

@staticmethod
def from_mono(*cuts: DataCut) -> "MultiCut":
"""
Expand Down
73 changes: 46 additions & 27 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,34 +1394,21 @@ def truncate(
:param rng: optional random number generator to be used with a 'random' ``offset_type``.
:return: a new CutSet instance with truncated cuts.
"""
truncated_cuts = []
for cut in self:
if cut.duration <= max_duration:
truncated_cuts.append(cut)
continue

def compute_offset():
if offset_type == "start":
return 0.0
last_offset = cut.duration - max_duration
if offset_type == "end":
return last_offset
if offset_type == "random":
if rng is None:
return random.uniform(0.0, last_offset)
else:
return rng.uniform(0.0, last_offset)
raise ValueError(f"Unknown 'offset_type' option: {offset_type}")

truncated_cuts.append(
cut.truncate(
offset=compute_offset(),
duration=max_duration,
keep_excessive_supervisions=keep_excessive_supervisions,
preserve_id=preserve_id,
)
assert offset_type in (
"start",
"end",
"random",
), f"Unknown offset type: '{offset_type}'"
return self.map(
partial(
_truncate_single,
max_duration=max_duration,
offset_type=offset_type,
keep_excessive_supervisions=keep_excessive_supervisions,
preserve_id=preserve_id,
rng=rng,
)
return CutSet(truncated_cuts)
)

def extend_by(
self,
Expand Down Expand Up @@ -3368,6 +3355,38 @@ def _drop_supervisions(cut, *args, **kwargs):
return cut.drop_supervisions(*args, **kwargs)


def _truncate_single(
cut: Cut,
max_duration: Seconds,
offset_type: str,
keep_excessive_supervisions: bool = True,
preserve_id: bool = False,
rng: Optional[random.Random] = None,
) -> Cut:
if cut.duration <= max_duration:
return cut

def compute_offset():
if offset_type == "start":
return 0.0
last_offset = cut.duration - max_duration
if offset_type == "end":
return last_offset
if offset_type == "random":
if rng is None:
return random.uniform(0.0, last_offset)
else:
return rng.uniform(0.0, last_offset)
raise ValueError(f"Unknown 'offset_type' option: {offset_type}")

return cut.truncate(
offset=compute_offset(),
duration=max_duration,
keep_excessive_supervisions=keep_excessive_supervisions,
preserve_id=preserve_id,
)


def _export_to_shar_single(
cuts: CutSet,
output_dir: Pathlike,
Expand Down
23 changes: 23 additions & 0 deletions lhotse/shar/writers/shar.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def write(self, cut: Cut) -> None:
if cut.has_recording:
data = cut.load_audio()
recording = to_shar_placeholder(cut.recording, cut)
cut_channels = _aslist(cut.channel)
if recording.channel_ids != cut_channels:
# If recording is multi-channel but the cut refers to a subset of them,
# we have to update the recording manifest accordingly
recording.sources[0].channels = cut_channels
recording.channel_ids = cut_channels
self.writers["recording"].write(
cut.id, data, cut.sampling_rate, manifest=recording
)
Expand Down Expand Up @@ -171,13 +177,24 @@ def write(self, cut: Cut) -> None:
else:
data = cut.load_custom(key)
placeholder_obj = to_shar_placeholder(val, cut)
channel_selector_key = f"{key}_channel_selector"
kwargs = {}
if isinstance(val, Recording):
kwargs["sampling_rate"] = val.sampling_rate
if cut.has_custom(channel_selector_key):
# override custom recording channels since the audio was loaded via cut
# and used the channel selector
placeholder_obj.sources[0].channels = cut.custom[
channel_selector_key
]
placeholder_obj.channel_ids = cut.custom[
channel_selector_key
]
self.writers[key].write(
cut.id, data, manifest=placeholder_obj, **kwargs
)
cut = fastcopy(cut, custom=cut.custom.copy())
cut.custom.pop(channel_selector_key, None) # no longer needed
setattr(cut, key, placeholder_obj)
else:
self.writers[key].write_placeholder(cut.id)
Expand Down Expand Up @@ -224,3 +241,9 @@ def _create_cuts_output_url(base_output_url: str, shard_suffix: str) -> str:
base_output_url = base_output_url.replace("pipe:", "pipe:gzip -c | ")

return f"{base_output_url}/cuts{shard_suffix}.jsonl.gz"


def _aslist(x):
if isinstance(x, list):
return x
return [x]
3 changes: 3 additions & 0 deletions lhotse/testing/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def dummy_audio_source(
data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples))
if len(channels) > 1:
data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1)
# ensure each channel has different data for channel selection testing
mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)])
data = data * mults
binary_data = BytesIO()
soundfile.write(
binary_data,
Expand Down
40 changes: 40 additions & 0 deletions test/cut/test_custom_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lhotse.serialization import deserialize_item
from lhotse.testing.dummies import (
dummy_cut,
dummy_multi_channel_recording,
dummy_multi_cut,
dummy_recording,
dummy_supervision,
Expand Down Expand Up @@ -401,3 +402,42 @@ def test_del_attr_mono_cut(cut):
with pytest.raises(AttributeError):
del cut.extra_metadata
assert "extra_metadata" not in cut.custom


def test_multi_cut_custom_multi_recording_channel_selector():
cut = dummy_multi_cut(0, channel=[0, 1, 2, 3], with_data=True)
cut.target_recording = dummy_multi_channel_recording(
1, channel_ids=[0, 1, 2, 3], with_data=True
)

# All input channels
ref_audio = cut.load_audio()
assert ref_audio.shape == (4, 16000)

# Input channel selection
two_channel_in = cut.with_channels([0, 1])
audio = two_channel_in.load_audio()
assert audio.shape == (2, 16000)
np.testing.assert_allclose(ref_audio[:2], audio)

# Input channel selection, different channels
two_channel_in = cut.with_channels([0, 3])
audio = two_channel_in.load_audio()
assert audio.shape == (2, 16000)
np.testing.assert_allclose(ref_audio[::3], audio)

# All output channels
ref_tgt_audio = cut.load_target_recording()
assert ref_tgt_audio.shape == (4, 16000)

# Output channel selection
two_channel_out = cut.with_custom("target_recording_channel_selector", [0, 1])
audio = two_channel_out.load_target_recording()
assert audio.shape == (2, 16000)
np.testing.assert_allclose(ref_tgt_audio[:2], audio)

# Output channel selection, different channels
two_channel_out = cut.with_custom("target_recording_channel_selector", [0, 3])
audio = two_channel_out.load_target_recording()
assert audio.shape == (2, 16000)
np.testing.assert_allclose(ref_tgt_audio[::3], audio)
9 changes: 8 additions & 1 deletion test/cut/test_cut_truncate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lhotse.cut import CutSet, MixedCut, MixTrack, MonoCut, PaddingCut
from lhotse.features import Features
from lhotse.supervision import SupervisionSegment, SupervisionSet
from lhotse.testing.dummies import DummyManifest, dummy_cut, dummy_recording
from lhotse.testing.dummies import DummyManifest, as_lazy, dummy_cut, dummy_recording
from lhotse.testing.random import deterministic_rng


Expand Down Expand Up @@ -238,6 +238,13 @@ def test_truncate_mixed_cut_gap_or_padding(gapped_mixed_cut, offset):
assert audio is not None


def test_truncate_cut_set_lazy_result(cut_set):
with as_lazy(cut_set, ".jsonl") as lazy_cuts:
truncated_cut_set = lazy_cuts.truncate(max_duration=5, offset_type="start")
assert truncated_cut_set.is_lazy
assert all(c.duration == pytest.approx(5.0) for c in truncated_cut_set)


def test_truncate_cut_set_offset_start(cut_set):
truncated_cut_set = cut_set.truncate(max_duration=5, offset_type="start")
cut1, cut2 = truncated_cut_set
Expand Down

0 comments on commit b34e805

Please sign in to comment.