Skip to content

Commit

Permalink
Fix multichannel support
Browse files Browse the repository at this point in the history
  • Loading branch information
gregogiudici committed Feb 13, 2025
1 parent cdf2c26 commit 71c7ad3
Show file tree
Hide file tree
Showing 6 changed files with 617 additions and 107 deletions.
445 changes: 380 additions & 65 deletions examples/library_comparison.ipynb

Large diffs are not rendered by default.

38 changes: 36 additions & 2 deletions examples/test_timeit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,40 @@ def test_stretch():
t = timeit.timeit(test_stretch, number=10)
print('Test 2 (stretch): %f' % t)

def test_multichannel():
y,sr = librosa.load('examples/les_bridge_fing01__00000.wav', sr=None, mono=False)
print('Original file',y.shape)
y = y[np.newaxis,:]
print('Original MONO file',y.shape)

ps = pystretch.Signalsmith.Stretch()
ps.preset(y.shape[0],sr)

# Process
y_1 = ps.process(y)
print('Stretched MONO',y_1.shape)

# Copy the first channel to a second channel
y = np.concatenate((y,y),axis=0)
print('Original MONO file',y.shape)

# Process
ps.preset(y.shape[0],sr)
ps.setTransposeSemitones(12)
y_2 = ps.process(y)

print('Stretched STEREO',y_2.shape)

# Copy the first two channels to other two channels
y = np.concatenate((y,y),axis=0)
print('Original Multichannel', y.shape)

ps.preset(y.shape[0],sr)
y_3 = ps.process(y)
print('Stretched Multichannel',y_3.shape)


if __name__ == '__main__':
test_1()
test_2()
# test_1()
# test_2()
test_multichannel()
2 changes: 1 addition & 1 deletion src/python_stretch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is required to make Python treat the directories as containing packages
__doc__ = "A Python Wrapprer of the Signalsmith Stretch C++ library for pitch and time stretching"
__doc__ = "A simple python library for pitch shifting and time stretching"
__version__ = "0.2.0"

from . import Signalsmith
116 changes: 77 additions & 39 deletions src/signalsmith-bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ namespace nb = nanobind;
using namespace nb::literals;


// Buffer class for reading audio, with offset reading
// (Buffer class is based on "Wav" class used in https://github.com/Signalsmith-Audio/signalsmith-stretch/blob/example-cmd/cmd/util/wav.h)
// Buffer class for reading audio (with offset reading)
// (Buffer class is based on "Wav" class in https://github.com/Signalsmith-Audio/signalsmith-stretch/blob/main/cmd/util/wav.h)
template<typename Sample=float>
class Buffer{
private:
Expand Down Expand Up @@ -70,7 +70,6 @@ struct Stretch{
Stretch() : stretch_() {}
Stretch(long seed) : stretch_(seed) {}


// === Access to private members ===
void set_sr(Sample value) { sampleRate_ = value; }
Sample sampleRate() const { return sampleRate_;}
Expand Down Expand Up @@ -121,16 +120,16 @@ struct Stretch{
void setTransposeSemitones(Sample semitones, Sample tonalityLimit=0) {
stretch_.setTransposeSemitones(semitones, tonalityLimit);
}
void setFreqMap(std::function<Sample(Sample)> inputToOutput) {
stretch_.setFreqMap(inputToOutput);
}
// void setFreqMap(std::function<Sample(Sample)> inputToOutput) {
// stretch_.setFreqMap(inputToOutput);
// }

void setTimeFactor(Sample timeFactor) {
timeFactor_ = timeFactor;
}

// ====================
// Simple stretch function
// ==================== TO BE REMOVED ====================
// Simple stretch function
void simple_stretch_(const float* inputSignal, size_t inputSize, float* outputSignal, size_t outputSize) {
// Compress by linear interpolation
for (size_t i = 0; i < outputSize; ++i) {
Expand All @@ -155,10 +154,6 @@ struct Stretch{
nb::ndarray<nb::numpy, float, nb::ndim<2>> process(nb::ndarray<nb::numpy, float, nb::ndim<2>> audio_input) {
auto inData = audio_input.data();

if (audio_input.shape(0) > 2) {
throw std::runtime_error("Only mono or stereo audio is supported. The input should have shape (1,samples) or (2,samples).");
}

size_t numChannels = audio_input.shape(0);
size_t inputLength = audio_input.shape(1);

Expand All @@ -177,12 +172,12 @@ struct Stretch{
outputChannels[i] = new float[paddedOutputLength]();
}

// Copy input audio to input_data
// Copy from inData to inputChannels
for (size_t i = 0; i < numChannels; ++i) {
std::copy(inData + i*inputLength , inData + (i+1)*inputLength , inputChannels[i]);
}

// Wrap input/output buffer with Buffer class (for offset reading/writing)
// Wrap input/output channel-buffer with Buffer class (for offset reading/writing)
Buffer<float> inBuffer(inputChannels, paddedInputLength);
Buffer<float> outBuffer(outputChannels, outputLength);

Expand All @@ -204,12 +199,12 @@ struct Stretch{
size_t outShape[2] = {numChannels, outputLength };
float* outData = new float[numChannels * outShape[1]];

// Copy data from outputChannels to outData
// Copy from outputChannels to outData
for (size_t i = 0; i < numChannels; ++i) {
std::copy(outputChannels[i] + tailSamples, outputChannels[i] + paddedOutputLength , outData + i * outputLength );
}

// Reset the stretch processor or we will get an error: free() invalid pointer
// REMEMBER: Reset the stretch processor or we will get an error: free() invalid pointer
stretch_.reset();

// Clean up
Expand All @@ -234,36 +229,79 @@ struct Stretch{
using Sample = float;

NB_MODULE(Signalsmith, m) {
nb::class_<Stretch<Sample>>(m, "Stretch")
.def(nb::init<>())
.def(nb::init<long>(), "seed"_a)
// Getters
.def("blockSamples", &Stretch<Sample>::blockSamples)
.def("intervalSamples", &Stretch<Sample>::intervalSamples)
.def("inputLatency", &Stretch<Sample>::inputLatency)
.def("outputLatency", &Stretch<Sample>::outputLatency)
// Access to timeFactor_, sampleRate_
m.doc() = "Python binding of the Signalsmith Stretch library, providing time-stretching and pitch-shifting capabilities.";

nb::class_<Stretch<Sample>>(m, "Stretch", "Class for Stretch processor.")
.def(nb::init<>(), "Default constructor.")
.def(nb::init<long>(), "seed"_a, "Constructor with seed for deterministic behavior.")

// Attribute getters
.def("blockSamples", &Stretch<Sample>::blockSamples, "Get the block size used in processing.")
.def("intervalSamples", &Stretch<Sample>::intervalSamples, "Get the interval size for overlapping.")
.def("inputLatency", &Stretch<Sample>::inputLatency, "Get the input latency of the processor in samples.")
.def("outputLatency", &Stretch<Sample>::outputLatency, "Get the output latency of the processor in samples.")

// Access to timeFactor_ and sampleRate_
.def_prop_rw("sampleRate",
[](Stretch<Sample> &t) { return t.sampleRate() ; },
[](Stretch<Sample> &t, Sample value) { t.set_sr(value); })
[](Stretch<Sample> &t) { return t.sampleRate(); },
[](Stretch<Sample> &t, Sample value) { t.set_sr(value); },
"Sample rate of the processor in Hz.")
.def_prop_rw("timeFactor",
[](Stretch<Sample> &t) { return t.timeFactor() ; },
[](Stretch<Sample> &t, Sample value) { t.set_tf(value); })
[](Stretch<Sample> &t) { return t.timeFactor(); },
[](Stretch<Sample> &t, Sample value) { t.set_tf(value); },
"Time-stretching factor. A value >1 speeds up the signal, <1 slows it down.")

// Settings
.def("reset", &Stretch<Sample>::reset)
.def("reset", &Stretch<Sample>::reset, "Reset the processor to its initial state.")
.def("preset", &Stretch<Sample>::preset,
"nChannels"_a, "sampleRate"_a, "cheaper"_a=false)
"nChannels"_a, "sampleRate"_a, "cheaper"_a=false,
"Configure the Stretch processor with a preset.\n\n"
"Parameters:\n"
"----------\n"
"- nChannels (int): Number of audio channels.\n"
"- sampleRate (float): Sample rate in Hz.\n"
"- cheaper (bool, optional): If True, uses a lower-quality but more efficient configuration (default: False).")
.def("configure", &Stretch<Sample>::configure,
"nChannels"_a, "blockSamples"_a, "intervalSamples"_a)
"nChannels"_a, "blockSamples"_a, "intervalSamples"_a,
"Manually configure the stretch processor.\n\n"
"Parameters:\n"
"----------\n"
"- nChannels (int): Number of audio channels.\n"
"- blockSamples (int): Block size for processing.\n"
"- intervalSamples (int): Interval size for overlapping.")

.def("setTransposeFactor", &Stretch<Sample>::setTransposeFactor,
"multiplier"_a, "tonalityLimit"_a=0)
"multiplier"_a, "tonalityLimit"_a=0,
"Set the transposition factor for pitch shifting.\n\n"
"Parameters:\n"
"----------\n"
"- multiplier (float): Pitch shift multiplier (e.g., 2.0 for an octave up).\n"
"- tonalityLimit (float, optional): Restriction on tonal adjustments (default: 0).")
.def("setTransposeSemitones", &Stretch<Sample>::setTransposeSemitones,
"semitones"_a, "tonalityLimit"_a=0)
.def("setFreqMap", &Stretch<Sample>::setFreqMap,
"inputToOutput"_a)
"semitones"_a, "tonalityLimit"_a=0,
"Set the pitch shift in semitones.\n\n"
"Parameters:\n"
"----------\n"
"- semitones (float): Number of semitones to shift (e.g., +12 for an octave up).\n"
"- tonalityLimit (float, optional): Restriction on tonal adjustments (default: 0).")
.def("setTimeFactor", &Stretch<Sample>::setTimeFactor,
"timeFactor"_a)
// Processing
"timeFactor"_a,
"Set the time-stretching factor.\n\n"
"Parameters:\n"
"----------\n"
"- timeFactor (float): Factor by which time is stretched or compressed (e.g., 0.5 slows down by half, 2.0 doubles speed).")

// PROCESSING
.def("process", &Stretch<Sample>::process,
"audio_input"_a);
"audio_input"_a,
"Process an input audio buffer and return the stretched or pitch-shifted output.\n\n"
"Parameters:\n"
"----------\n"
"- audio_input (numpy.ndarray): Input audio buffer to be processed.\n\n"
"Returns:\n"
"----------\n"
"- numpy.ndarray: Stretched or pitch-shifted output audio buffer.")
;
// .def("setFreqMap", &Stretch<Sample>::setFreqMap,
// "inputToOutput"_a) // TODO: implement custom frequency mapping
}
35 changes: 35 additions & 0 deletions tests/test_pitch_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import python_stretch as m
import numpy as np

def test_mono():
x1 = np.random.rand(1, 44100)

ps = m.Signalsmith.Stretch()
ps.setTransposeSemitones(12)

y1 = ps.process(x1)
del ps
assert y1.shape == (1, 44100)

def test_stereo():
x1 = np.random.rand(2, 44100)

ps = m.Signalsmith.Stretch()
ps.setTransposeSemitones(12)

y1 = ps.process(x1)
del ps
assert y1.shape == (2, 44100)

def test_multichannel():
nChannels = np.random.randint(3, 10)
x1 = np.random.rand(nChannels, 44100)

ps = m.Signalsmith.Stretch()
ps.setTransposeSemitones(12)

y1 = ps.process(x1)
del ps

assert y1.shape == (nChannels, 44100)

88 changes: 88 additions & 0 deletions tests/test_time_stretch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import python_stretch as m
import numpy as np

def test_mono_double_length():
x1 = np.random.rand(1,44100)
x2 = np.random.rand(1,22050)

ps = m.Signalsmith.Stretch()
ps.setTimeFactor(0.5)

y1 = ps.process(x1)
y2 = ps.process(x2)
del ps

assert y1.shape == (1, 88200)
assert y2.shape == (1, 44100)

def test_mono_half_length():
x1 = np.random.rand(1,44100)
x2 = np.random.rand(1,22050)

ps = m.Signalsmith.Stretch()
ps.setTimeFactor(2.)

y1 = ps.process(x1)
y2 = ps.process(x2)
del ps

assert y1.shape == (1, 22050)
assert y2.shape == (1, 11025)

def test_stereo_double_length():
x1 = np.random.rand(2,44100)
x2 = np.random.rand(2,22050)

ps = m.Signalsmith.Stretch()
ps.setTimeFactor(0.5)

y1 = ps.process(x1)
y2 = ps.process(x2)
del ps

assert y1.shape == (2, 88200)
assert y2.shape == (2, 44100)

def test_stereo_half_length():
x1 = np.random.rand(2,44100)
x2 = np.random.rand(2,22050)

ps = m.Signalsmith.Stretch()
ps.setTimeFactor(2.)

y1 = ps.process(x1)
y2 = ps.process(x2)
del ps

assert y1.shape == (2, 22050)
assert y2.shape == (2, 11025)

def test_multichannel_double_length():
n_channels = np.random.randint(3, 10)
x1 = np.random.rand(n_channels, 44100)
x2 = np.random.rand(n_channels, 22050)

ps = m.Signalsmith.Stretch()
ps.setTimeFactor(0.5)

y1 = ps.process(x1)
y2 = ps.process(x2)
del ps

assert y1.shape == (n_channels, 88200)
assert y2.shape == (n_channels, 44100)

def test_multichannel_half_length():
n_channels = np.random.randint(3, 10)
x1 = np.random.rand(n_channels, 44100)
x2 = np.random.rand(n_channels, 22050)

ps = m.Signalsmith.Stretch()
ps.setTimeFactor(2.)

y1 = ps.process(x1)
y2 = ps.process(x2)
del ps

assert y1.shape == (n_channels, 22050)
assert y2.shape == (n_channels, 11025)

0 comments on commit 71c7ad3

Please sign in to comment.