diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index ec0f605f2..cd8a41fb2 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -168,6 +168,23 @@ def is_placeholder(self) -> bool: def num_channels(self) -> int: return len(self.channel_ids) + @property + def source_format(self) -> str: + """Infer format of the audio sources. + If all sources have the same format, return it. + If sources have different formats, raise an error. + """ + source_formats = list(set([s.format for s in self.sources])) + + if len(source_formats) == 1: + # if all sources have the same format, return it + return source_formats[0] + else: + # at the moment, we don't resolve different formats + raise NotImplementedError( + "Sources have different formats. Resolving to a single format not implemented." + ) + @staticmethod def from_file( path: Pathlike, diff --git a/lhotse/audio/source.py b/lhotse/audio/source.py index 459881ea2..88bb9743d 100644 --- a/lhotse/audio/source.py +++ b/lhotse/audio/source.py @@ -1,3 +1,5 @@ +import io +import os import warnings from dataclasses import dataclass from io import BytesIO, FileIO @@ -6,6 +8,7 @@ from typing import List, Optional, Tuple, Union import numpy as np +import soundfile as sf import torch from lhotse.audio.backend import read_audio @@ -64,6 +67,10 @@ class AudioSource: def has_video(self) -> bool: return self.video is not None + @property + def format(self) -> str: + return self._get_format() + def load_audio( self, offset: Seconds = 0.0, @@ -316,3 +323,24 @@ def _prepare_for_reading( ) return source + + def _get_format(self) -> str: + """Get format for the audio source. + If using 'file' or 'url' types, the format is inferred from the file extension, as in soundfile. + If using 'memory' type, the format is inferred from the binary data. + """ + if self.type in ("file", "url"): + # Resolve audio format based on the filename + format = os.path.splitext(self.source)[-1][1:] + return format.lower() + elif self.type == "memory": + sf_info = sf.info(io.BytesIO(self.source)) + if sf_info.format == "OGG" and sf_info.subtype == "OPUS": + # soundfile describes opus as ogg container with opus coding + return "opus" + else: + return sf_info.format.lower() + else: + raise NotImplementedError( + f"Getting format not implemented for source type {self.type}" + ) diff --git a/lhotse/bin/modes/shar.py b/lhotse/bin/modes/shar.py index 95b670529..cffdf596b 100644 --- a/lhotse/bin/modes/shar.py +++ b/lhotse/bin/modes/shar.py @@ -27,8 +27,8 @@ def shar(): "-a", "--audio", default="none", - type=click.Choice(["none", "wav", "flac", "mp3", "opus"]), - help="Format in which to export audio (disabled by default, enabling will make a copy of the data)", + type=click.Choice(["none", "wav", "flac", "mp3", "opus", "original"]), + help="Format in which to export audio. Original will save in the same format as the original audio (disabled by default, enabling will make a copy of the data)", ) @click.option( "-f", diff --git a/lhotse/shar/writers/audio.py b/lhotse/shar/writers/audio.py index b3a855f59..22bdf4cf1 100644 --- a/lhotse/shar/writers/audio.py +++ b/lhotse/shar/writers/audio.py @@ -66,6 +66,14 @@ def close(self): def output_paths(self) -> List[str]: return self.tar_writer.output_paths + def resolve_format(self, original_format: str): + if self.format == "original": + # save using the original format of the input audio + return original_format + else: + # save using the format specified at initialization + return self.format + def write_placeholder(self, key: str) -> None: self.tar_writer.write(f"{key}.nodata", BytesIO()) self.tar_writer.write(f"{key}.nometa", BytesIO(), count=False) @@ -76,15 +84,18 @@ def write( value: np.ndarray, sampling_rate: int, manifest: Recording, + original_format: Optional[str] = None, ) -> None: + save_format = self.resolve_format(original_format) + value, manifest, sampling_rate = self._maybe_resample( - value, manifest, sampling_rate + value, manifest, sampling_rate, format=save_format ) # Write binary data stream = BytesIO() save_audio( - dest=stream, src=value, sampling_rate=sampling_rate, format=self.format + dest=stream, src=value, sampling_rate=sampling_rate, format=save_format ) self.tar_writer.write(f"{key}.{self.format}", stream) @@ -103,13 +114,14 @@ def _maybe_resample( audio: Union[torch.Tensor, np.ndarray], manifest: Recording, sampling_rate: int, + format: str, ) -> Tuple[Union[np.ndarray, torch.Tensor], Recording, int]: # Resampling is required for some versions of OPUS encoders. # First resample the manifest which only adjusts the metadata; # then resample the audio array to 48kHz. OPUS_DEFAULT_SAMPLING_RATE = 48000 if ( - self.format == "opus" + format == "opus" and is_torchaudio_available() and not isinstance(get_current_audio_backend(), LibsndfileBackend) and sampling_rate != OPUS_DEFAULT_SAMPLING_RATE diff --git a/lhotse/shar/writers/shar.py b/lhotse/shar/writers/shar.py index 229a073b6..2c1a88442 100644 --- a/lhotse/shar/writers/shar.py +++ b/lhotse/shar/writers/shar.py @@ -135,7 +135,11 @@ def write(self, cut: Cut) -> None: recording.sources[0].channels = cut_channels recording.channel_ids = cut_channels self.writers["recording"].write( - cut.id, data, cut.sampling_rate, manifest=recording + cut.id, + data, + cut.sampling_rate, + manifest=recording, + original_format=cut.recording.source_format, ) cut = fastcopy(cut, recording=recording) else: @@ -224,6 +228,7 @@ def resolve_writer(name: str) -> Tuple[FieldWriter, str]: "flac": (partial(AudioTarWriter, format="flac"), ".tar"), "mp3": (partial(AudioTarWriter, format="mp3"), ".tar"), "opus": (partial(AudioTarWriter, format="opus"), ".tar"), + "original": (partial(AudioTarWriter, format="original"), ".tar"), "lilcom": (partial(ArrayTarWriter, compression="lilcom"), ".tar"), "numpy": (partial(ArrayTarWriter, compression="numpy"), ".tar"), "jsonl": (JsonlShardWriter, ".jsonl.gz"), diff --git a/lhotse/testing/dummies.py b/lhotse/testing/dummies.py index aec6a7581..0999906aa 100644 --- a/lhotse/testing/dummies.py +++ b/lhotse/testing/dummies.py @@ -63,6 +63,7 @@ def dummy_recording( duration: float = 1.0, sampling_rate: int = 16000, with_data: bool = False, + source_format: str = "wav", ) -> Recording: num_samples = compute_num_samples(duration, sampling_rate) return Recording( @@ -72,6 +73,7 @@ def dummy_recording( sampling_rate=sampling_rate, num_samples=num_samples, with_data=with_data, + format=source_format, ) ], sampling_rate=sampling_rate, @@ -85,6 +87,7 @@ def dummy_audio_source( sampling_rate: int = 16000, channels: Optional[List[int]] = None, with_data: bool = False, + format: str = "wav", ) -> AudioSource: if channels is None: channels = [0] @@ -95,21 +98,40 @@ def dummy_audio_source( else: import soundfile - # 1kHz sine wave - data = torch.sin(2 * np.pi * 1000 * torch.arange(num_samples)) + # generate 1kHz sine wave + f_sine = 1000 + assert ( + f_sine < sampling_rate / 2 + ), f"Sine wave frequency {f_sine} exceeds Nyquist frequency {sampling_rate/2} for sampling rate {sampling_rate}" + data = torch.sin(2 * np.pi * f_sine / sampling_rate * torch.arange(num_samples)) + + # prepare multichannel data if len(channels) > 1: data = data.unsqueeze(0).expand(len(channels), -1).transpose(0, 1) # ensure each channel has different data for channel selection testing mults = torch.tensor([1 / idx for idx in range(1, len(channels) + 1)]) data = data * mults + + # prepare source with the selected format binary_data = BytesIO() - soundfile.write( - binary_data, - data.numpy(), - sampling_rate, - format="wav", - closefd=False, - ) + if format == "opus": + # workaround for OPUS: soundfile supports OPUS as a subtype of OGG format + soundfile.write( + binary_data, + data.numpy(), + sampling_rate, + format="OGG", + subtype="OPUS", + closefd=False, + ) + else: + soundfile.write( + binary_data, + data.numpy(), + sampling_rate, + format=format, + closefd=False, + ) binary_data.seek(0) return AudioSource( type="memory", channels=channels, source=binary_data.getvalue() diff --git a/test/shar/test_write.py b/test/shar/test_write.py index cee35de1c..7d88bd3d0 100644 --- a/test/shar/test_write.py +++ b/test/shar/test_write.py @@ -66,55 +66,6 @@ def test_tar_writer_pipe(tmp_path: Path): assert f2.read() == b"test" -@pytest.mark.parametrize( - "format", - [ - "wav", - pytest.param( - "flac", - marks=pytest.mark.skipif( - not check_torchaudio_version_gt("0.12.1"), - reason="Torchaudio v0.12.1 or greater is required.", - ), - ), - # "mp3", # apparently doesn't work in CI, mp3 encoder is missing - pytest.param( - "opus", - marks=pytest.mark.skipif( - not check_torchaudio_version_gt("2.1.0"), - reason="Torchaudio v2.1.0 or greater is required.", - ), - ), - ], -) -def test_audio_tar_writer(tmp_path: Path, format: str): - from lhotse.testing.dummies import dummy_recording - - recording = dummy_recording(0, with_data=True) - audio = recording.load_audio() - - with AudioTarWriter( - str(tmp_path / "test.tar"), shard_size=None, format=format - ) as writer: - writer.write( - key="my-recording", - value=audio, - sampling_rate=recording.sampling_rate, - manifest=recording, - ) - - (path,) = writer.output_paths - - ((deserialized_recording, inner_path),) = list(TarIterator(path)) - - deserialized_audio = deserialized_recording.resample( - recording.sampling_rate - ).load_audio() - - rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2)) - assert rmse < 0.5 - - @pytest.mark.parametrize( ["format", "backend"], [ @@ -175,6 +126,59 @@ def test_audio_tar_writer(tmp_path: Path, format: str, backend: str): assert rmse < 0.5 +@pytest.mark.parametrize( + ["original_format", "rmse_threshold"], + [("wav", 0.0), ("flac", 0.0), ("mp3", 0.003), ("opus", 0.3)], +) +def test_audio_tar_writer_original_format( + tmp_path: Path, original_format: str, rmse_threshold: float +): + """Test using AudioTarWritter to write the audio signal in the exact same format + as it was loaded from the source. + """ + from lhotse.testing.dummies import dummy_recording + + backend = "default" # use the default backend for reading the audio + writer_format = "original" # write the audio in the same format as it was loaded + + recording = dummy_recording(0, with_data=True, source_format=original_format) + audio = recording.load_audio() + + assert ( + recording.source_format == original_format + ), f"Recording source format ({recording.source_format}) not matching the expected original format ({original_format})" + + with audio_backend(backend): + with AudioTarWriter( + str(tmp_path / "test.tar"), shard_size=None, format=writer_format + ) as writer: + writer.write( + key="my-recording", + value=audio, + sampling_rate=recording.sampling_rate, + manifest=recording, + original_format=recording.source_format, + ) + (path,) = writer.output_paths + ((deserialized_recording, inner_path),) = list(TarIterator(path)) + + # make sure the deserialized audio is in the same format as the original + assert ( + deserialized_recording.source_format == original_format + ), f"Deserialized recording source format ({deserialized_recording.source_format}) not matching the expected original format ({original_format})" + + # load audio + deserialized_audio = deserialized_recording.resample( + recording.sampling_rate + ).load_audio() + + # check difference between original and deserialized audio + rmse = np.sqrt(np.mean((audio - deserialized_audio) ** 2)) + assert ( + rmse <= rmse_threshold + ), f"RMSE between original and deserialized audio is {rmse}, which is above the threshold of {rmse_threshold}" + + def test_shar_writer(tmp_path: Path): # Prepare data cuts = DummyManifest(CutSet, begin_id=0, end_id=20, with_data=True)