diff --git a/Cargo.lock b/Cargo.lock index b609894..958b766 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -302,6 +302,7 @@ dependencies = [ "ndarray", "oneshot", "thiserror", + "windows", ] [[package]] @@ -603,6 +604,70 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 4ba7f6d..436641b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,30 @@ alsa = "0.9.0" [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies] coreaudio-rs = "0.12.0" +[target.'cfg(target_os = "windows")'.dependencies] +windows = { version = "0.58.0", features = [ + "Win32_Media_Audio", + "Win32_Foundation", + "Win32_Devices_Properties", + "Win32_Media_KernelStreaming", + "Win32_System_Com_StructuredStorage", + "Win32_System_Threading", + "Win32_Security", + "Win32_System_SystemServices", + "Win32_System_Variant", + "Win32_Media_Multimedia", + "Win32_UI_Shell_PropertiesSystem" +]} + [[example]] name = "enumerate_alsa" path = "examples/enumerate_alsa.rs" + +[[example]] +name = "enumerate_coreaudio" +path = "examples/enumerate_coreaudio.rs" + +[[example]] +name = "enumerate_wasapi" +path = "examples/enumerate_wasapi.rs" + diff --git a/build.rs b/build.rs index c248868..23f1985 100644 --- a/build.rs +++ b/build.rs @@ -6,6 +6,8 @@ fn main() { wasm: { any(target_os = "wasm32") }, os_alsa: { any(target_os = "linux", target_os = "dragonfly", target_os = "freebsd", target_os = "netbsd") }, - os_coreaudio: { any (target_os = "macos", target_os = "ios") } + os_coreaudio: { any (target_os = "macos", target_os = "ios") }, + os_wasapi: { target_os = "windows" }, + unsupported: { not(any(os_alsa, os_coreaudio, os_wasapi))} } } diff --git a/examples/enumerate_alsa.rs b/examples/enumerate_alsa.rs index 7aaaa26..378b3b6 100644 --- a/examples/enumerate_alsa.rs +++ b/examples/enumerate_alsa.rs @@ -1,9 +1,7 @@ -use std::error::Error; - mod util; #[cfg(os_alsa)] -fn main() -> Result<(), Box> { +fn main() -> Result<(), Box> { use crate::util::enumerate::enumerate_devices; use interflow::backends::alsa::AlsaDriver; diff --git a/examples/enumerate_wasapi.rs b/examples/enumerate_wasapi.rs new file mode 100644 index 0000000..05c658c --- /dev/null +++ b/examples/enumerate_wasapi.rs @@ -0,0 +1,13 @@ +mod util; + +#[cfg(os_wasapi)] +fn main() -> Result<(), Box> { + use crate::util::enumerate::enumerate_devices; + use interflow::backends::wasapi::WasapiDriver; + enumerate_devices(WasapiDriver) +} + +#[cfg(not(os_wasapi))] +fn main() { + println!("WASAPI driver is not available on this platform"); +} diff --git a/src/audio_buffer.rs b/src/audio_buffer.rs index 2ff0d4d..240f95b 100644 --- a/src/audio_buffer.rs +++ b/src/audio_buffer.rs @@ -173,7 +173,7 @@ impl AudioBufferBase { for (inp, out) in self.as_interleaved().iter().zip(output.iter_mut()) { *out = *inp; } - return true; + true } } @@ -212,8 +212,7 @@ impl AudioBufferBase { pub fn channels_mut(&mut self) -> impl '_ + Iterator> { self.storage.rows_mut().into_iter() } - - /// Return a mutable interleaved 2-D array view, where samples are in rows and channels are in +/// Return a mutable interleaved 2-D array view, where samples are in rows and channels are in /// columns. pub fn as_interleaved_mut(&mut self) -> ArrayViewMut2 { self.storage.view_mut().reversed_axes() diff --git a/src/backends/alsa.rs b/src/backends/alsa.rs index 17552f2..4c5cd2c 100644 --- a/src/backends/alsa.rs +++ b/src/backends/alsa.rs @@ -10,7 +10,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread::JoinHandle; use std::time::Duration; -use std::{borrow::Cow, ffi::CStr}; +use std::borrow::Cow; use alsa::{device_name::HintIter, pcm, PCM}; use thiserror::Error; @@ -53,10 +53,6 @@ impl AudioDriver for AlsaDriver { } fn list_devices(&self) -> Result, Self::Error> { - const C_PCM: &CStr = match CStr::from_bytes_with_nul(b"pcm\0") { - Ok(cstr) => cstr, - Err(_) => unreachable!(), - }; Ok(HintIter::new(None, c"pcm")? .filter_map(|hint| AlsaDevice::new(hint.name.as_ref()?, hint.direction?).ok())) } @@ -210,11 +206,12 @@ impl AlsaDevice { fn default_config(&self) -> Result { let samplerate = 48000.; // Default ALSA sample rate let channel_count = 2; // Stereo stream - let channels = 1 << channel_count - 1; + let channels = 1 << (channel_count - 1); Ok(StreamConfig { samplerate: samplerate as _, channels, buffer_size_range: (None, None), + exclusive: false, }) } } @@ -259,6 +256,7 @@ impl AlsaStream { channels: ChannelMap32::default() .with_indices(std::iter::repeat(1).take(num_channels)), buffer_size_range: (Some(period_size), Some(period_size)), + exclusive: false, }; let mut timestamp = Timestamp::new(samplerate); let mut buffer = vec![0f32; period_size * num_channels]; @@ -274,13 +272,10 @@ impl AlsaStream { } let frames = device.pcm.avail_update()? as usize; let len = frames * num_channels; - match io.readi(&mut buffer[..len]) { - Err(err) => { - log::warn!("ALSA PCM error, trying to recover ..."); - log::debug!("Error: {err}"); - device.pcm.try_recover(err, true)?; - } - _ => {} + if let Err(err) = io.readi(&mut buffer[..len]) { + log::warn!("ALSA PCM error, trying to recover ..."); + log::debug!("Error: {err}"); + device.pcm.try_recover(err, true)?; } let buffer = AudioRef::from_interleaved(&buffer[..len], num_channels).unwrap(); let context = AudioCallbackContext { @@ -333,6 +328,7 @@ impl AlsaStream { channels: ChannelMap32::default() .with_indices(std::iter::repeat(1).take(num_channels)), buffer_size_range: (Some(period_size), Some(period_size)), + exclusive: false, }; let frames = device.pcm.avail_update()? as usize; let mut timestamp = Timestamp::new(samplerate); @@ -358,10 +354,7 @@ impl AlsaStream { }; callback.on_output_data(context, input); timestamp += frames as u64; - match io.writei(&buffer[..len]) { - Err(err) => device.pcm.try_recover(err, true)?, - _ => {} - } + if let Err(err) = io.writei(&buffer[..len]) { device.pcm.try_recover(err, true)? } match device.pcm.state() { pcm::State::Suspended => { if hwp.can_resume() { diff --git a/src/backends/coreaudio.rs b/src/backends/coreaudio.rs index 95263d8..6c09e55 100644 --- a/src/backends/coreaudio.rs +++ b/src/backends/coreaudio.rs @@ -160,12 +160,18 @@ impl AudioDevice for CoreAudioDevice { .iter() .copied() .filter(move |sr| samplerate_range.contains(sr)) - .map(move |sr| { + .flat_map(move |sr| { + [false, true] + .into_iter() + .map(move |exclusive| (sr, exclusive)) + }) + .map(move |(samplerate, exclusive)| { let channels = 1 << asbd.mFormat.mChannelsPerFrame as u32 - 1; StreamConfig { - samplerate: sr, + samplerate, channels, buffer_size_range: (None, None), + exclusive, } }) })) @@ -195,6 +201,7 @@ impl AudioInputDevice for CoreAudioDevice { channels: 0b1, // Hardcoded to mono on non-interleaved inputs samplerate, buffer_size_range: (None, None), + exclusive: false, }) } @@ -226,6 +233,7 @@ impl AudioOutputDevice for CoreAudioDevice { samplerate, buffer_size_range: (None, None), channels: 0b11, + exclusive: false, }) } diff --git a/src/backends/mod.rs b/src/backends/mod.rs index 906acb8..baccd85 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -4,7 +4,13 @@ //! //! Each backend is provided in its own submodule. Types should be public so that the user isn't //! limited to going through the main API if they want to choose a specific backend. -use crate::{AudioDriver, AudioInputDevice, AudioOutputDevice, DeviceType}; + +use crate::{ + AudioDriver, AudioInputDevice, AudioOutputDevice, DeviceType, +}; + +#[cfg(unsupported)] +compile_error!("Unsupported platform (supports ALSA, CoreAudio, and WASAPI)"); #[cfg(os_alsa)] pub mod alsa; @@ -12,6 +18,9 @@ pub mod alsa; #[cfg(os_coreaudio)] pub mod coreaudio; +#[cfg(os_wasapi)] +pub mod wasapi; + /// Returns the default driver. /// /// "Default" here means that it is a supported driver that is available on the platform. @@ -25,11 +34,17 @@ pub mod coreaudio; /// | **Platform** | **Driver** | /// |:------------:|:----------:| /// | Linux | ALSA | +/// | macOS | CoreAudio | +/// | Windows | WASAPI | +#[cfg(any(os_alsa, os_coreaudio, os_wasapi))] +#[allow(clippy::needless_return)] pub fn default_driver() -> impl AudioDriver { #[cfg(os_alsa)] return alsa::AlsaDriver; #[cfg(os_coreaudio)] return coreaudio::CoreAudioDriver; + #[cfg(os_wasapi)] + return wasapi::WasapiDriver; } /// Returns the default input device for the given audio driver. @@ -51,11 +66,15 @@ where /// "Default" here means both in terms of platform support but also can include runtime selection. /// Therefore, it is better to use this method directly rather than first getting the default /// driver from [`default_driver`]. +#[cfg(any(os_alsa, os_coreaudio, os_wasapi))] +#[allow(clippy::needless_return)] pub fn default_input_device() -> impl AudioInputDevice { #[cfg(os_alsa)] return default_input_device_from(&alsa::AlsaDriver); #[cfg(os_coreaudio)] return default_input_device_from(&coreaudio::CoreAudioDriver); + #[cfg(os_wasapi)] + return default_input_device_from(&wasapi::WasapiDriver); } /// Returns the default input device for the given audio driver. @@ -77,9 +96,13 @@ where /// "Default" here means both in terms of platform support but also can include runtime selection. /// Therefore, it is better to use this method directly rather than first getting the default /// driver from [`default_driver`]. +#[cfg(any(os_alsa, os_coreaudio, os_wasapi))] +#[allow(clippy::needless_return)] pub fn default_output_device() -> impl AudioOutputDevice { #[cfg(os_alsa)] return default_output_device_from(&alsa::AlsaDriver); #[cfg(os_coreaudio)] return default_output_device_from(&coreaudio::CoreAudioDriver); + #[cfg(os_wasapi)] + return default_output_device_from(&wasapi::WasapiDriver); } diff --git a/src/backends/wasapi/device.rs b/src/backends/wasapi/device.rs new file mode 100644 index 0000000..f29bcdb --- /dev/null +++ b/src/backends/wasapi/device.rs @@ -0,0 +1,152 @@ +use super::{error, stream}; +use crate::backends::wasapi::stream::WasapiStream; +use crate::channel_map::Bitset; +use crate::prelude::wasapi::util::WasapiMMDevice; +use crate::{AudioDevice, AudioInputCallback, AudioInputDevice, AudioOutputCallback, AudioOutputDevice, Channel, DeviceType, StreamConfig}; +use std::borrow::Cow; +use windows::Win32::Media::Audio; + +/// Type of devices available from the WASAPI driver. +#[derive(Debug, Clone)] +pub struct WasapiDevice { + device: WasapiMMDevice, + device_type: DeviceType, +} + +impl WasapiDevice { + pub(crate) fn new(device: Audio::IMMDevice, device_type: DeviceType) -> Self { + WasapiDevice { + device: WasapiMMDevice::new(device), + device_type, + } + } +} + +impl AudioDevice for WasapiDevice { + type Error = error::WasapiError; + + fn name(&self) -> Cow { + match self.device.name() { + Some(std) => Cow::Owned(std), + None => { + eprintln!("Cannot get audio device name"); + Cow::Borrowed("") + } + } + } + + fn device_type(&self) -> DeviceType { + self.device_type + } + + fn channel_map(&self) -> impl IntoIterator { + [] + } + + fn is_config_supported(&self, config: &StreamConfig) -> bool { + match self.device_type { + DeviceType::Output => { + stream::is_output_config_supported(self.device.clone(), config) + } + _ => false, + } + } + + fn enumerate_configurations(&self) -> Option> { + None::<[StreamConfig; 0]> + } +} + + +impl AudioInputDevice for WasapiDevice { + type StreamHandle = WasapiStream; + + fn default_input_config(&self) -> Result { + let audio_client = self.device.activate::()?; + let format = unsafe { + audio_client.GetMixFormat()?.read_unaligned() }; + let frame_size = unsafe { audio_client.GetBufferSize() }.map(|i| i as usize).ok(); + Ok(StreamConfig { + channels: 0u32.with_indices(0..format.nChannels as _), + exclusive: false, + samplerate: format.nSamplesPerSec as _, + buffer_size_range: (frame_size, frame_size), + }) + } + + fn create_input_stream( + &self, + stream_config: StreamConfig, + callback: Callback, + ) -> Result, Self::Error> { + Ok(WasapiStream::new_input( + self.device.clone(), + stream_config, + callback, + )) + } +} + +impl AudioOutputDevice for WasapiDevice { + type StreamHandle = WasapiStream; + + fn default_output_config(&self) -> Result { + let audio_client = self.device.activate::()?; + let format = unsafe { + audio_client.GetMixFormat()?.read_unaligned() }; + let frame_size = unsafe { audio_client.GetBufferSize() }.map(|i| i as usize).ok(); + Ok(StreamConfig { + channels: 0u32.with_indices(0..format.nChannels as _), + exclusive: false, + samplerate: format.nSamplesPerSec as _, + buffer_size_range: (frame_size, frame_size), + }) + } + + fn create_output_stream( + &self, + stream_config: StreamConfig, + callback: Callback, + ) -> Result, Self::Error> { + Ok(WasapiStream::new_output( + self.device.clone(), + stream_config, + callback, + )) + } +} + +/// An iterable collection WASAPI devices. +pub struct WasapiDeviceList { + pub(crate) collection: Audio::IMMDeviceCollection, + pub(crate) total_count: u32, + pub(crate) next_item: u32, + pub(crate) device_type: DeviceType, +} + +unsafe impl Send for WasapiDeviceList {} + +unsafe impl Sync for WasapiDeviceList {} + +impl Iterator for WasapiDeviceList { + type Item = WasapiDevice; + + fn next(&mut self) -> Option { + if self.next_item >= self.total_count { + return None; + } + + unsafe { + let device = self.collection.Item(self.next_item).unwrap(); + self.next_item += 1; + Some(WasapiDevice::new(device, self.device_type)) + } + } + + fn size_hint(&self) -> (usize, Option) { + let rest = (self.total_count - self.next_item) as usize; + (rest, Some(rest)) + } +} + +impl ExactSizeIterator for WasapiDeviceList {} diff --git a/src/backends/wasapi/driver.rs b/src/backends/wasapi/driver.rs new file mode 100644 index 0000000..e49eba3 --- /dev/null +++ b/src/backends/wasapi/driver.rs @@ -0,0 +1,113 @@ +use std::borrow::Cow; +use windows::Win32::System::Com; +use windows::Win32::Media::Audio; +use std::sync::OnceLock; +use crate::backends::wasapi::device::{WasapiDevice, WasapiDeviceList}; + +use super::{error, util}; + +use crate::{AudioDriver, DeviceType}; + +/// The WASAPI driver. +#[derive(Debug, Clone, Default)] +pub struct WasapiDriver; + +impl AudioDriver for WasapiDriver { + type Error = error::WasapiError; + type Device = WasapiDevice; + + const DISPLAY_NAME: &'static str = "WASAPI"; + + fn version(&self) -> Result, Self::Error> { + Ok(Cow::Borrowed("unknown")) + } + + fn default_device(&self, device_type: DeviceType) -> Result, Self::Error> { + audio_device_enumerator().get_default_device(device_type) + } + + fn list_devices(&self) -> Result, Self::Error> { + audio_device_enumerator().get_device_list() + } +} + +pub fn audio_device_enumerator() -> &'static AudioDeviceEnumerator { + ENUMERATOR.get_or_init(|| { + // Make sure COM is initialised. + util::com_initializer(); + + unsafe { + let enumerator = Com::CoCreateInstance::<_, Audio::IMMDeviceEnumerator>( + &Audio::MMDeviceEnumerator, + None, + Com::CLSCTX_ALL, + ) + .unwrap(); + + AudioDeviceEnumerator(enumerator) + } + }) +} + +static ENUMERATOR: OnceLock = OnceLock::new(); + +/// Send/Sync wrapper around `IMMDeviceEnumerator`. +pub struct AudioDeviceEnumerator(Audio::IMMDeviceEnumerator); + +impl AudioDeviceEnumerator { + // Returns the default output device. + fn get_default_device( + &self, + device_type: DeviceType, + ) -> Result, error::WasapiError> { + let data_flow = match device_type { + DeviceType::Input => Audio::eCapture, + DeviceType::Output => Audio::eRender, + _ => return Ok(None), + }; + + unsafe { + let device = self.0.GetDefaultAudioEndpoint(data_flow, Audio::eConsole)?; + + Ok(Some(WasapiDevice::new(device, DeviceType::Output))) + } + } + + // Returns a chained iterator of output and input devices. + fn get_device_list(&self) -> Result, error::WasapiError> { + // Create separate collections for output and input devices and then chain them. + unsafe { + let output_collection = self + .0 + .EnumAudioEndpoints(Audio::eRender, Audio::DEVICE_STATE_ACTIVE)?; + + let count = output_collection.GetCount()?; + + let output_device_list = WasapiDeviceList { + collection: output_collection, + total_count: count, + next_item: 0, + device_type: DeviceType::Output, + }; + + let input_collection = self + .0 + .EnumAudioEndpoints(Audio::eCapture, Audio::DEVICE_STATE_ACTIVE)?; + + let count = input_collection.GetCount()?; + + let input_device_list = WasapiDeviceList { + collection: input_collection, + total_count: count, + next_item: 0, + device_type: DeviceType::Input, + }; + + Ok(output_device_list.chain(input_device_list)) + } + } +} + +unsafe impl Send for AudioDeviceEnumerator {} + +unsafe impl Sync for AudioDeviceEnumerator {} \ No newline at end of file diff --git a/src/backends/wasapi/error.rs b/src/backends/wasapi/error.rs new file mode 100644 index 0000000..7ce2d8d --- /dev/null +++ b/src/backends/wasapi/error.rs @@ -0,0 +1,16 @@ +use thiserror::Error; + +/// Type of errors from the WASAPI backend. +#[derive(Debug, Error)] +#[error("WASAPI error: ")] +pub enum WasapiError { + /// Error originating from WASAPI. + #[error("{} (code {})", .0.message(), .0.code())] + BackendError(#[from] windows::core::Error), + /// Requested WASAPI device configuration is not available + #[error("Configuration not available")] + ConfigurationNotAvailable, + /// Windows Foundation error + #[error("Win32 error: {0}")] + FoundationError(String), +} \ No newline at end of file diff --git a/src/backends/wasapi/mod.rs b/src/backends/wasapi/mod.rs new file mode 100644 index 0000000..3673fd1 --- /dev/null +++ b/src/backends/wasapi/mod.rs @@ -0,0 +1,10 @@ +mod util; + +mod error; + +pub(crate) mod driver; +mod device; +mod stream; +pub mod prelude; + +pub use prelude::*; \ No newline at end of file diff --git a/src/backends/wasapi/prelude.rs b/src/backends/wasapi/prelude.rs new file mode 100644 index 0000000..3fa8663 --- /dev/null +++ b/src/backends/wasapi/prelude.rs @@ -0,0 +1,6 @@ +pub use super::{ + device::WasapiDevice, + driver::WasapiDriver, + error::WasapiError, + stream::WasapiStream, +}; diff --git a/src/backends/wasapi/stream.rs b/src/backends/wasapi/stream.rs new file mode 100644 index 0000000..783c591 --- /dev/null +++ b/src/backends/wasapi/stream.rs @@ -0,0 +1,521 @@ +use super::error; +use crate::audio_buffer::AudioMut; +use crate::backends::wasapi::util::WasapiMMDevice; +use crate::channel_map::Bitset; +use crate::prelude::{AudioRef, Timestamp}; +use crate::{ + AudioCallbackContext, AudioInput, AudioInputCallback, AudioOutput, AudioOutputCallback, + AudioStreamHandle, StreamConfig, +}; +use duplicate::duplicate_item; +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread::JoinHandle; +use std::time::Duration; +use std::{ops, ptr, slice}; +use windows::core::imp::CoTaskMemFree; +use windows::core::Interface; +use windows::Win32::Foundation; +use windows::Win32::Foundation::{CloseHandle, HANDLE}; +use windows::Win32::Media::{Audio, KernelStreaming, Multimedia}; +use windows::Win32::System::Threading; + +type EjectSignal = Arc; + +#[duplicate_item( +name ty; +[AudioCaptureBuffer] [IAudioCaptureClient]; +[AudioRenderBuffer] [IAudioRenderClient]; +)] +struct name<'a, T> { + interface: &'a Audio::ty, + data: NonNull, + frame_size: usize, + channels: usize, + __type: PhantomData, +} + +#[duplicate_item( +name; +[AudioCaptureBuffer]; +[AudioRenderBuffer]; +)] +impl<'a, T> ops::Deref for name<'a, T> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + unsafe { slice::from_raw_parts(self.data.cast().as_ptr(), self.channels * self.frame_size) } + } +} + +#[duplicate_item( +name; +[AudioCaptureBuffer]; +[AudioRenderBuffer]; +)] +impl<'a, T> ops::DerefMut for name<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { + slice::from_raw_parts_mut(self.data.cast().as_ptr(), self.channels * self.frame_size) + } + } +} + +impl Drop for AudioCaptureBuffer<'_, T> { + fn drop(&mut self) { + unsafe { self.interface.ReleaseBuffer(self.frame_size as _).unwrap() }; + } +} + +impl Drop for AudioRenderBuffer<'_, T> { + fn drop(&mut self) { + unsafe { + self.interface + .ReleaseBuffer(self.frame_size as _, 0) + .unwrap(); + } + } +} + +impl<'a, T> AudioRenderBuffer<'a, T> { + fn from_client( + render_client: &'a Audio::IAudioRenderClient, + channels: usize, + frame_size: usize, + ) -> Result { + let data = NonNull::new(unsafe { render_client.GetBuffer(frame_size as _) }?) + .expect("Audio buffer data is null"); + Ok(Self { + interface: render_client, + data, + frame_size, + channels, + __type: PhantomData, + }) + } +} +impl<'a, T> AudioCaptureBuffer<'a, T> { + fn from_client( + capture_client: &'a Audio::IAudioCaptureClient, + channels: usize, + ) -> Result, error::WasapiError> { + let mut buf_ptr = ptr::null_mut(); + let mut frame_size = 0; + let mut flags = 0; + unsafe { + capture_client.GetBuffer(&mut buf_ptr, &mut frame_size, &mut flags, None, None) + }?; + let Some(data) = NonNull::new(buf_ptr as _) else { return Ok(None); }; + Ok(Some(Self { + interface: capture_client, + data, + frame_size: frame_size as _, + channels, + __type: PhantomData, + })) + } +} + +struct AudioThread { + audio_client: Audio::IAudioClient, + interface: Interface, + audio_clock: Audio::IAudioClock, + stream_config: StreamConfig, + eject_signal: EjectSignal, + frame_size: usize, + callback: Callback, + event_handle: HANDLE, + clock_start: Duration, +} + +impl AudioThread { + fn finalize(self) -> Result { + if !self.event_handle.is_invalid() { + unsafe { CloseHandle(self.event_handle) }?; + } + let _ = unsafe { + self.audio_client + .Stop() + .inspect_err(|err| eprintln!("Cannot stop audio thread: {err}")) + }; + Ok(self.callback) + } +} + +impl AudioThread { + fn new( + device: WasapiMMDevice, + eject_signal: EjectSignal, + mut stream_config: StreamConfig, + callback: Callback, + ) -> Result { + unsafe { + let audio_client: Audio::IAudioClient = device.activate()?; + let sharemode = if stream_config.exclusive { + Audio::AUDCLNT_SHAREMODE_EXCLUSIVE + } else { + Audio::AUDCLNT_SHAREMODE_SHARED + }; + let format = { + let mut format = config_to_waveformatextensible(&stream_config); + let mut actual_format = ptr::null_mut(); + audio_client + .IsFormatSupported( + sharemode, + &format.Format, + (!stream_config.exclusive).then_some(&mut actual_format), + ) + .ok()?; + if !stream_config.exclusive { + assert!(!actual_format.is_null()); + format.Format = actual_format.read_unaligned(); + CoTaskMemFree(actual_format.cast()); + let sample_rate = format.Format.nSamplesPerSec; + stream_config.channels = 0u32.with_indices(0..format.Format.nChannels as _); + stream_config.samplerate = sample_rate as _; + } + format + }; + let frame_size = stream_config + .buffer_size_range + .0 + .or(stream_config.buffer_size_range.1); + let buffer_duration = frame_size + .map(|frame_size| { + buffer_size_to_duration(frame_size, stream_config.samplerate as _) + }) + .unwrap_or(0); + audio_client.Initialize( + sharemode, + Audio::AUDCLNT_STREAMFLAGS_EVENTCALLBACK + | Audio::AUDCLNT_STREAMFLAGS_AUTOCONVERTPCM, + buffer_duration, + 0, + &format.Format, + None, + )?; + let buffer_size = audio_client.GetBufferSize()? as usize; + let event_handle = { + let event_handle = + Threading::CreateEventA(None, false, false, windows::core::PCSTR(ptr::null()))?; + audio_client.SetEventHandle(event_handle)?; + event_handle + }; + let interface = audio_client.GetService::()?; + let audio_clock = audio_client.GetService::()?; + let frame_size = buffer_size; + Ok(Self { + audio_client, + interface, + audio_clock, + event_handle, + frame_size, + eject_signal, + stream_config: StreamConfig { + buffer_size_range: (Some(frame_size), Some(frame_size)), + ..stream_config + }, + clock_start: Duration::ZERO, + callback, + }) + } + } + + fn await_frame(&mut self) -> Result<(), error::WasapiError> { + let _ = unsafe { + let result = Threading::WaitForSingleObject(self.event_handle, Threading::INFINITE); + if result == Foundation::WAIT_FAILED { + let err = Foundation::GetLastError(); + let description = format!("Waiting for event handle failed: {:?}", err); + return Err(error::WasapiError::FoundationError(description)); + } + result + }; + Ok(()) + } + + fn output_timestamp(&self) -> Result { + let clock = stream_instant(&self.audio_clock)?; + let diff = clock - self.clock_start; + Ok(Timestamp::from_duration( + self.stream_config.samplerate, + diff, + )) + } +} + +impl AudioThread { + fn run(mut self) -> Result { + set_thread_priority(); + unsafe { + self.audio_client.Start()?; + } + self.clock_start = stream_instant(&self.audio_clock)?; + loop { + if self.eject_signal.load(Ordering::Relaxed) { + break self.finalize(); + } + self.await_frame()?; + self.process()?; + } + .inspect_err(|err| eprintln!("Render thread process error: {err}")) + } + + fn process(&mut self) -> Result<(), error::WasapiError> { + let frames_available = unsafe { + self.interface.GetNextPacketSize()? as usize + }; + if frames_available == 0 { + return Ok(()); + } + let Some(mut buffer) = AudioCaptureBuffer::::from_client( + &self.interface, + self.stream_config.channels.count(), + )? else { + eprintln!("Null buffer from WASAPI"); + return Ok(()); + }; + let timestamp = self.output_timestamp()?; + let context = AudioCallbackContext { + stream_config: self.stream_config, + timestamp, + }; + let buffer = + AudioRef::from_interleaved(&mut buffer, self.stream_config.channels.count()).unwrap(); + let output = AudioInput { timestamp, buffer }; + self.callback.on_input_data(context, output); + Ok(()) + } +} + +impl AudioThread { + fn run(mut self) -> Result { + set_thread_priority(); + unsafe { + self.audio_client.Start()?; + } + self.clock_start = stream_instant(&self.audio_clock)?; + loop { + if self.eject_signal.load(Ordering::Relaxed) { + break self.finalize(); + } + self.await_frame()?; + self.process()?; + } + .inspect_err(|err| eprintln!("Render thread process error: {err}")) + } + + fn process(&mut self) -> Result<(), error::WasapiError> { + let frames_available = unsafe { + let padding = self.audio_client.GetCurrentPadding()? as usize; + self.frame_size - padding + }; + if frames_available == 0 { + return Ok(()); + } + let frames_requested = if let Some(max_frames) = self.stream_config.buffer_size_range.1 { + frames_available.min(max_frames) + } else { + frames_available + }; + let mut buffer = AudioRenderBuffer::::from_client( + &self.interface, + self.stream_config.channels.count(), + frames_requested, + )?; + let timestamp = self.output_timestamp()?; + let context = AudioCallbackContext { + stream_config: self.stream_config, + timestamp, + }; + let buffer = + AudioMut::from_interleaved_mut(&mut buffer, self.stream_config.channels.count()) + .unwrap(); + let output = AudioOutput { timestamp, buffer }; + self.callback.on_output_data(context, output); + Ok(()) + } +} + +/// Type representing a WASAPI audio stream. +pub struct WasapiStream { + join_handle: JoinHandle>, + eject_signal: EjectSignal, +} + +impl AudioStreamHandle for WasapiStream { + type Error = error::WasapiError; + + fn eject(self) -> Result { + self.eject_signal.store(true, Ordering::Relaxed); + self.join_handle + .join() + .expect("Audio output thread panicked") + } +} + +impl WasapiStream { + pub(crate) fn new_input( + device: WasapiMMDevice, + stream_config: StreamConfig, + callback: Callback, + ) -> Self { + let eject_signal = EjectSignal::default(); + let join_handle = std::thread::Builder::new() + .name("interflow_wasapi_output_stream".to_string()) + .spawn({ + let eject_signal = eject_signal.clone(); + move || { + let inner: AudioThread = + AudioThread::new(device, eject_signal, stream_config, callback) + .inspect_err(|err| { + eprintln!("Failed to create render thread: {err}") + })?; + inner.run() + } + }) + .expect("Cannot spawn audio output thread"); + Self { + join_handle, + eject_signal, + } + } +} + +impl WasapiStream { + pub(crate) fn new_output( + device: WasapiMMDevice, + stream_config: StreamConfig, + callback: Callback, + ) -> Self { + let eject_signal = EjectSignal::default(); + let join_handle = std::thread::Builder::new() + .name("interflow_wasapi_output_stream".to_string()) + .spawn({ + let eject_signal = eject_signal.clone(); + move || { + let inner: AudioThread = + AudioThread::new(device, eject_signal, stream_config, callback) + .inspect_err(|err| { + eprintln!("Failed to create render thread: {err}") + })?; + inner.run() + } + }) + .expect("Cannot spawn audio output thread"); + Self { + join_handle, + eject_signal, + } + } +} + +fn set_thread_priority() { + unsafe { + let thread_id = Threading::GetCurrentThreadId(); + + let _ = Threading::SetThreadPriority( + HANDLE(thread_id as isize as _), + Threading::THREAD_PRIORITY_TIME_CRITICAL, + ); + } +} + +pub fn buffer_size_to_duration(buffer_size: usize, sample_rate: u32) -> i64 { + (buffer_size as i64 / sample_rate as i64) * (1_000_000_000 / 100) +} + +fn stream_instant(audio_clock: &Audio::IAudioClock) -> Result { + let mut position: u64 = 0; + let mut qpc_position: u64 = 0; + unsafe { + audio_clock.GetPosition(&mut position, Some(&mut qpc_position))?; + }; + // The `qpc_position` is in 100 nanosecond units. Convert it to nanoseconds. + let qpc_nanos = qpc_position * 100; + let instant = Duration::from_nanos(qpc_nanos); + Ok(instant) +} + +pub(crate) fn config_to_waveformatextensible(config: &StreamConfig) -> Audio::WAVEFORMATEXTENSIBLE { + let format_tag = KernelStreaming::WAVE_FORMAT_EXTENSIBLE; + let channels = config.channels as u16; + let sample_rate = config.samplerate as u32; + let sample_bytes = size_of::() as u16; + let avg_bytes_per_sec = u32::from(channels) * sample_rate * u32::from(sample_bytes); + let block_align = channels * sample_bytes; + let bits_per_sample = 8 * sample_bytes; + + let cb_size = { + let extensible_size = size_of::(); + let ex_size = size_of::(); + (extensible_size - ex_size) as u16 + }; + + let waveformatex = Audio::WAVEFORMATEX { + wFormatTag: format_tag as u16, + nChannels: channels, + nSamplesPerSec: sample_rate, + nAvgBytesPerSec: avg_bytes_per_sec, + nBlockAlign: block_align, + wBitsPerSample: bits_per_sample, + cbSize: cb_size, + }; + + let channel_mask = KernelStreaming::KSAUDIO_SPEAKER_DIRECTOUT; + + let sub_format = Multimedia::KSDATAFORMAT_SUBTYPE_IEEE_FLOAT; + + let waveformatextensible = Audio::WAVEFORMATEXTENSIBLE { + Format: waveformatex, + Samples: Audio::WAVEFORMATEXTENSIBLE_0 { + wSamplesPerBlock: bits_per_sample, + }, + dwChannelMask: channel_mask, + SubFormat: sub_format, + }; + + waveformatextensible +} + +pub(crate) fn is_output_config_supported( + device: WasapiMMDevice, + stream_config: &StreamConfig, +) -> bool { + let mut try_ = || unsafe { + let audio_client: Audio::IAudioClient = device.activate()?; + let sharemode = if stream_config.exclusive { + Audio::AUDCLNT_SHAREMODE_EXCLUSIVE + } else { + Audio::AUDCLNT_SHAREMODE_SHARED + }; + let mut format = config_to_waveformatextensible(&stream_config); + let mut actual_format = ptr::null_mut(); + audio_client + .IsFormatSupported( + sharemode, + &format.Format, + (!stream_config.exclusive).then_some(&mut actual_format), + ) + .ok()?; + if !stream_config.exclusive { + assert!(!actual_format.is_null()); + format.Format = actual_format.read_unaligned(); + CoTaskMemFree(actual_format.cast()); + let sample_rate = format.Format.nSamplesPerSec; + let new_channels = 0u32.with_indices(0..format.Format.nChannels as _); + let new_samplerate = sample_rate as f64; + if stream_config.samplerate != new_samplerate + || stream_config.channels.count() != new_channels.count() + { + return Ok(false); + } + } + Ok::<_, error::WasapiError>(true) + }; + try_() + .inspect_err(|err| eprintln!("Error while checking configuration is valid: {err}")) + .unwrap_or(false) +} diff --git a/src/backends/wasapi/util.rs b/src/backends/wasapi/util.rs new file mode 100644 index 0000000..d45eaf1 --- /dev/null +++ b/src/backends/wasapi/util.rs @@ -0,0 +1,130 @@ +use crate::prelude::wasapi::error; +use std::marker::PhantomData; +use windows::core::Interface; +use windows::Win32::Foundation::RPC_E_CHANGED_MODE; +use windows::Win32::Media::Audio; +use windows::Win32::System::Com; +use windows::Win32::System::Com::{CoInitializeEx, CoUninitialize, StructuredStorage, COINIT_APARTMENTTHREADED, STGM_READ}; +use windows::Win32::Devices::Properties; +use windows::Win32::System::Variant::VT_LPWSTR; +use std::ffi::OsString; +use std::os::windows::ffi::OsStringExt; + +thread_local!(static COM_INITIALIZER: ComInitializer = { + unsafe { + // Try to initialize COM with STA by default to avoid compatibility issues with the ASIO + // backend (where CoInitialize() is called by the ASIO SDK) or winit (where drag and drop + // requires STA). + // This call can fail with RPC_E_CHANGED_MODE if another library initialized COM with MTA. + // That's OK though since COM ensures thread-safety/compatibility through marshalling when + // necessary. + let result = CoInitializeEx(None, COINIT_APARTMENTTHREADED); + if result.is_ok() || result == RPC_E_CHANGED_MODE { + ComInitializer { + result, + _ptr: PhantomData, + } + } else { + // COM initialization failed in another way, something is really wrong. + panic!( + "Failed to initialize COM: {}", + std::io::Error::from_raw_os_error(result.0) + ); + } + } +}); + +/// RAII object that guards the fact that COM is initialized. +/// +// We store a raw pointer because it's the only way at the moment to remove `Send`/`Sync` from the +// object. +struct ComInitializer { + result: windows::core::HRESULT, + _ptr: PhantomData<*mut ()>, +} + +impl Drop for ComInitializer { + #[inline] + fn drop(&mut self) { + // Need to avoid calling CoUninitialize() if CoInitializeEx failed since it may have + // returned RPC_E_MODE_CHANGED - which is OK, see above. + if self.result.is_ok() { + unsafe { CoUninitialize() }; + } + } +} + +/// Ensures that COM is initialized in this thread. +#[inline] +pub fn com_initializer() { + COM_INITIALIZER.with(|_| {}); +} + +#[derive(Debug, Clone)] +pub struct WasapiMMDevice(Audio::IMMDevice); + +unsafe impl Send for WasapiMMDevice {} + +impl WasapiMMDevice { + pub(crate) fn new(device: Audio::IMMDevice) -> Self { + Self(device) + } + + pub(crate) fn activate(&self) -> Result { + unsafe { + self.0 + .Activate::(Com::CLSCTX_ALL, None) + .map_err(|err| error::WasapiError::BackendError(err)) + } + } + + pub(crate) fn name(&self) -> Option { + get_device_name(&self.0) + } +} + +fn get_device_name(device: &Audio::IMMDevice) -> Option { + unsafe { + // Open the device's property store. + let property_store = device + .OpenPropertyStore(STGM_READ) + .expect("could not open property store"); + + // Get the endpoint's friendly-name property, else the interface's friendly-name, else the device description. + let mut property_value = property_store + .GetValue(&Properties::DEVPKEY_Device_FriendlyName as *const _ as *const _) + .or(property_store.GetValue( + &Properties::DEVPKEY_DeviceInterface_FriendlyName as *const _ as *const _, + )) + .or(property_store + .GetValue(&Properties::DEVPKEY_Device_DeviceDesc as *const _ as *const _)) + .ok()?; + + let prop_variant = &property_value.as_raw().Anonymous.Anonymous; + + // Read the friendly-name from the union data field, expecting a *const u16. + if prop_variant.vt != VT_LPWSTR.0 { + return None; + } + + let ptr_utf16 = *(&prop_variant.Anonymous as *const _ as *const *const u16); + + // Find the length of the friendly name. + let mut len = 0; + while *ptr_utf16.offset(len) != 0 { + len += 1; + } + + // Convert to a string. + let name_slice = std::slice::from_raw_parts(ptr_utf16, len as usize); + let name_os_string: OsString = OsStringExt::from_wide(name_slice); + let name = name_os_string + .into_string() + .unwrap_or_else(|os_string| os_string.to_string_lossy().into()); + + // Clean up. + StructuredStorage::PropVariantClear(&mut property_value).ok()?; + + Some(name) + } +} \ No newline at end of file diff --git a/src/channel_map.rs b/src/channel_map.rs index 94674d7..c360184 100644 --- a/src/channel_map.rs +++ b/src/channel_map.rs @@ -20,8 +20,7 @@ pub trait Bitset: Sized { fn indices(&self) -> impl IntoIterator { (0..self.capacity()).filter_map(|i| self.get_index(i).then_some(i)) } - - /// Count the number of `true` elements in this bit set. +/// Count the number of `true` elements in this bit set. fn count(&self) -> usize { self.indices().into_iter().count() } @@ -31,8 +30,7 @@ pub trait Bitset: Sized { self.set_index(index, value); self } - - /// Builder-like method for setting all provided indices to `. +/// Builder-like method for setting all provided indices to `. fn with_indices(mut self, indices: impl IntoIterator) -> Self { for ix in indices { self.set_index(ix, true); diff --git a/src/lib.rs b/src/lib.rs index b034b61..9693d73 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,6 +65,9 @@ pub struct StreamConfig { /// honoring this setting, and in future versions may provide additional buffering to ensure /// it, but for now you should not make assumptions on buffer sizes based on this setting. pub buffer_size_range: (Option, Option), + /// Whether the device should be exclusively held (meaning no other application can open the + /// same device). + pub exclusive: bool, } /// Audio channel description. diff --git a/src/prelude.rs b/src/prelude.rs index bea0c8c..77d64d6 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,2 +1,4 @@ pub use crate::backends::*; +#[cfg(os_wasapi)] +pub use crate::backends::wasapi::prelude::*; pub use crate::*;