diff --git a/dfdx-core/src/tensor/webgpu/mod.rs b/dfdx-core/src/tensor/webgpu/mod.rs index 666ce53e..b22a5619 100644 --- a/dfdx-core/src/tensor/webgpu/mod.rs +++ b/dfdx-core/src/tensor/webgpu/mod.rs @@ -1,8 +1,10 @@ mod allocate; mod device; +mod types; pub use device::Buffer; pub use device::Webgpu; +pub use types::*; #[cfg(test)] mod tests { diff --git a/dfdx-core/src/tensor/webgpu/types.rs b/dfdx-core/src/tensor/webgpu/types.rs new file mode 100644 index 00000000..6f4d147b --- /dev/null +++ b/dfdx-core/src/tensor/webgpu/types.rs @@ -0,0 +1,56 @@ +use crate::shapes::Unit; + +/// A primitive data type natively supported by WebGPU. +/// +/// See: https://www.w3.org/TR/WGSL/#types +/// +/// todo: support packed types +pub trait WebgpuNativeType : Unit { + /// Name of the data type in WGSL. + const NAME: &'static str; +} + +macro_rules! webgpu_type { + ($RustTy:ty) => { + impl WebgpuNativeType for $RustTy { + const NAME: &'static str = stringify!($RustTy); + } + }; + ($RustTy:ty, $WgpuTy:expr) => { + impl WebgpuNativeType for $RustTy { + const NAME: &'static str = $WgpuTy; + } + }; +} + +/* +see: +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16 +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64 +- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16 + */ +#[cfg(feature = "f16")] +webgpu_type!(half::f16, "f16"); +webgpu_type!(f32); +// todo: only enable when f64 feature is enabled +#[cfg(feature = "f64")] +webgpu_type!(f64); + +#[cfg(feature = "i16")] +webgpu_type!(i16); +webgpu_type!(i32); + +webgpu_type!(u32); +webgpu_type!(bool); + +pub(crate) trait HasGlslType { + const TYPE: &'static str; +} + +impl HasGlslType for f32 { + const TYPE: &'static str = "float"; +} + +impl HasGlslType for f64 { + const TYPE: &'static str = "double"; +} diff --git a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs index bd976711..4dd0e1bb 100644 --- a/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs +++ b/dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs @@ -6,6 +6,11 @@ use crate::{ use core::any::TypeId; use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec}; +use wgpu::{ + BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, +}; + + /// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s. macro_rules! webgpu_params { ($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => { @@ -72,6 +77,7 @@ macro_rules! webgpu_unary { } }; } +pub(crate) use webgpu_unary; /// Zero-sized marker type for forward pass TypeId #[derive(Debug, Default)] @@ -85,22 +91,6 @@ pub(crate) struct Backward { _phantom: PhantomData<(E, K)>, } -pub(crate) trait HasGlslType { - const TYPE: &'static str; -} - -impl HasGlslType for f32 { - const TYPE: &'static str = "float"; -} - -impl HasGlslType for f64 { - const TYPE: &'static str = "double"; -} - -pub(crate) use webgpu_unary; -use wgpu::{ - BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages, -}; impl + 'static> UnaryKernel for Webgpu { const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX;