Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to save audio in the original format when exporting to shar #1422

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions lhotse/audio/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,23 @@ def is_placeholder(self) -> bool:
def num_channels(self) -> int:
return len(self.channel_ids)

@property
def source_format(self) -> str:
"""Infer format of the audio sources.
If all sources have the same format, return it.
If sources have different formats, raise an error.
"""
source_formats = list(set([s.format for s in self.sources]))

if len(source_formats) == 1:
# if all sources have the same format, return it
return source_formats[0]
else:
# at the moment, we don't resolve different formats
raise NotImplementedError(
"Sources have different formats. Resolving to a single format not implemented."
)

@staticmethod
def from_file(
path: Pathlike,
Expand Down
28 changes: 28 additions & 0 deletions lhotse/audio/source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import io
import os
import warnings
from dataclasses import dataclass
from io import BytesIO, FileIO
Expand All @@ -6,6 +8,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
import soundfile as sf
import torch

from lhotse.audio.backend import read_audio
Expand Down Expand Up @@ -64,6 +67,10 @@ class AudioSource:
def has_video(self) -> bool:
return self.video is not None

@property
def format(self) -> str:
return self._get_format()

def load_audio(
self,
offset: Seconds = 0.0,
Expand Down Expand Up @@ -316,3 +323,24 @@ def _prepare_for_reading(
)

return source

def _get_format(self) -> str:
"""Get format for the audio source.
If using 'file' or 'url' types, the format is inferred from the file extension, as in soundfile.
If using 'memory' type, the format is inferred from the binary data.
"""
if self.type in ("file", "url"):
# Resolve audio format based on the filename
format = os.path.splitext(self.source)[-1][1:]
return format.lower()
elif self.type == "memory":
sf_info = sf.info(io.BytesIO(self.source))
if sf_info.format == "OGG" and sf_info.subtype == "OPUS":
# soundfile describes opus as ogg container with opus coding
return "opus"
else:
return sf_info.format.lower()
else:
raise NotImplementedError(
f"Getting format not implemented for source type {self.type}"
)
4 changes: 2 additions & 2 deletions lhotse/bin/modes/shar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def shar():
"-a",
"--audio",
default="none",
type=click.Choice(["none", "wav", "flac", "mp3", "opus"]),
help="Format in which to export audio (disabled by default, enabling will make a copy of the data)",
type=click.Choice(["none", "wav", "flac", "mp3", "opus", "original"]),
help="Format in which to export audio. Original will save in the same format as the original audio (disabled by default, enabling will make a copy of the data)",
)
@click.option(
"-f",
Expand Down
18 changes: 15 additions & 3 deletions lhotse/shar/writers/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def close(self):
def output_paths(self) -> List[str]:
return self.tar_writer.output_paths

def resolve_format(self, original_format: str):
if self.format == "original":
# save using the original format of the input audio
return original_format
else:
# save using the format specified at initialization
return self.format

def write_placeholder(self, key: str) -> None:
self.tar_writer.write(f"{key}.nodata", BytesIO())
self.tar_writer.write(f"{key}.nometa", BytesIO(), count=False)
Expand All @@ -76,15 +84,18 @@ def write(
value: np.ndarray,
sampling_rate: int,
manifest: Recording,
original_format: Optional[str] = None,
) -> None:
save_format = self.resolve_format(original_format)

value, manifest, sampling_rate = self._maybe_resample(
value, manifest, sampling_rate
value, manifest, sampling_rate, format=save_format
)

# Write binary data
stream = BytesIO()
save_audio(
dest=stream, src=value, sampling_rate=sampling_rate, format=self.format
dest=stream, src=value, sampling_rate=sampling_rate, format=save_format
)
self.tar_writer.write(f"{key}.{self.format}", stream)

Expand All @@ -103,13 +114,14 @@ def _maybe_resample(
audio: Union[torch.Tensor, np.ndarray],
manifest: Recording,
sampling_rate: int,
format: str,
) -> Tuple[Union[np.ndarray, torch.Tensor], Recording, int]:
# Resampling is required for some versions of OPUS encoders.
# First resample the manifest which only adjusts the metadata;
# then resample the audio array to 48kHz.
OPUS_DEFAULT_SAMPLING_RATE = 48000
if (
self.format == "opus"
format == "opus"
and is_torchaudio_available()
and not isinstance(get_current_audio_backend(), LibsndfileBackend)
and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE
Expand Down
7 changes: 6 additions & 1 deletion lhotse/shar/writers/shar.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def write(self, cut: Cut) -> None:
recording.sources[0].channels = cut_channels
recording.channel_ids = cut_channels
self.writers["recording"].write(
cut.id, data, cut.sampling_rate, manifest=recording
cut.id,
data,
cut.sampling_rate,
manifest=recording,
original_format=cut.recording.source_format,
)
cut = fastcopy(cut, recording=recording)
else:
Expand Down Expand Up @@ -224,6 +228,7 @@ def resolve_writer(name: str) -> Tuple[FieldWriter, str]:
"flac": (partial(AudioTarWriter, format="flac"), ".tar"),
"mp3": (partial(AudioTarWriter, format="mp3"), ".tar"),
"opus": (partial(AudioTarWriter, format="opus"), ".tar"),
"original": (partial(AudioTarWriter, format="original"), ".tar"),
"lilcom": (partial(ArrayTarWriter, compression="lilcom"), ".tar"),
"numpy": (partial(ArrayTarWriter, compression="numpy"), ".tar"),
"jsonl": (JsonlShardWriter, ".jsonl.gz"),
Expand Down
40 changes: 31 additions & 9 deletions lhotse/testing/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def dummy_recording(
duration: float = 1.0,
sampling_rate: int = 16000,
with_data: bool = False,
source_format: str = "wav",
) -> Recording:
num_samples = compute_num_samples(duration, sampling_rate)
return Recording(
Expand All @@ -72,6 +73,7 @@ def dummy_recording(
sampling_rate=sampling_rate,
num_samples=num_samples,
with_data=with_data,
format=source_format,
)
],
sampling_rate=sampling_rate,
Expand All @@ -85,6 +87,7 @@ def dummy_audio_source(
sampling_rate: int = 16000,
channels: Optional[List[int]] = None,
with_data: bool = False,
format: str = "wav",
) -> AudioSource:
if channels is None:
channels = [0]
Expand All @@ -95,21 +98,40 @@ def dummy_audio_source(
else:
import soundfile

# 1kHz sine wave
data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples))
# generate 1kHz sine wave
f_sine = 1000
assert (
f_sine < sampling_rate / 2
), f"Sine wave frequency {f_sine} exceeds Nyquist frequency {sampling_rate/2} for sampling rate {sampling_rate}"
data = torch.sin(2 * np.pi * f_sine / sampling_rate * torch.arange(num_samples))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzelasko, including a bugfix here: frequency was not normalized.


# prepare multichannel data
if len(channels) > 1:
data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1)
# ensure each channel has different data for channel selection testing
mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)])
data = data * mults

# prepare source with the selected format
binary_data = BytesIO()
soundfile.write(
binary_data,
data.numpy(),
sampling_rate,
format="wav",
closefd=False,
)
if format == "opus":
# workaround for OPUS: soundfile supports OPUS as a subtype of OGG format
soundfile.write(
binary_data,
data.numpy(),
sampling_rate,
format="OGG",
subtype="OPUS",
closefd=False,
)
else:
soundfile.write(
binary_data,
data.numpy(),
sampling_rate,
format=format,
closefd=False,
)
binary_data.seek(0)
return AudioSource(
type="memory", channels=channels, source=binary_data.getvalue()
Expand Down
102 changes: 53 additions & 49 deletions test/shar/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,55 +66,6 @@ def test_tar_writer_pipe(tmp_path: Path):
assert f2.read() == b"test"


@pytest.mark.parametrize(
"format",
[
"wav",
pytest.param(
"flac",
marks=pytest.mark.skipif(
not check_torchaudio_version_gt("0.12.1"),
reason="Torchaudio v0.12.1 or greater is required.",
),
),
# "mp3", # apparently doesn't work in CI, mp3 encoder is missing
pytest.param(
"opus",
marks=pytest.mark.skipif(
not check_torchaudio_version_gt("2.1.0"),
reason="Torchaudio v2.1.0 or greater is required.",
),
),
],
)
def test_audio_tar_writer(tmp_path: Path, format: str):
from lhotse.testing.dummies import dummy_recording

recording = dummy_recording(0, with_data=True)
audio = recording.load_audio()

with AudioTarWriter(
str(tmp_path / "test.tar"), shard_size=None, format=format
) as writer:
writer.write(
key="my-recording",
value=audio,
sampling_rate=recording.sampling_rate,
manifest=recording,
)

(path,) = writer.output_paths

((deserialized_recording, inner_path),) = list(TarIterator(path))

deserialized_audio = deserialized_recording.resample(
recording.sampling_rate
).load_audio()

rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2))
assert rmse < 0.5


@pytest.mark.parametrize(
["format", "backend"],
[
Expand Down Expand Up @@ -175,6 +126,59 @@ def test_audio_tar_writer(tmp_path: Path, format: str, backend: str):
assert rmse < 0.5


@pytest.mark.parametrize(
["original_format", "rmse_threshold"],
[("wav", 0.0), ("flac", 0.0), ("mp3", 0.003), ("opus", 0.3)],
)
def test_audio_tar_writer_original_format(
tmp_path: Path, original_format: str, rmse_threshold: float
):
"""Test using AudioTarWritter to write the audio signal in the exact same format
as it was loaded from the source.
"""
from lhotse.testing.dummies import dummy_recording

backend = "default" # use the default backend for reading the audio
writer_format = "original" # write the audio in the same format as it was loaded

recording = dummy_recording(0, with_data=True, source_format=original_format)
audio = recording.load_audio()

assert (
recording.source_format == original_format
), f"Recording source format ({recording.source_format}) not matching the expected original format ({original_format})"

with audio_backend(backend):
with AudioTarWriter(
str(tmp_path / "test.tar"), shard_size=None, format=writer_format
) as writer:
writer.write(
key="my-recording",
value=audio,
sampling_rate=recording.sampling_rate,
manifest=recording,
original_format=recording.source_format,
)
(path,) = writer.output_paths
((deserialized_recording, inner_path),) = list(TarIterator(path))

# make sure the deserialized audio is in the same format as the original
assert (
deserialized_recording.source_format == original_format
), f"Deserialized recording source format ({deserialized_recording.source_format}) not matching the expected original format ({original_format})"

# load audio
deserialized_audio = deserialized_recording.resample(
recording.sampling_rate
).load_audio()

# check difference between original and deserialized audio
rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2))
assert (
rmse <= rmse_threshold
), f"RMSE between original and deserialized audio is {rmse}, which is above the threshold of {rmse_threshold}"


def test_shar_writer(tmp_path: Path):
# Prepare data
cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)
Expand Down
Loading