diff --git a/src/owl/fftpack/owl_fft_d.mli b/src/owl/fftpack/owl_fft_d.mli index 093791d42..ea80f18bb 100644 --- a/src/owl/fftpack/owl_fft_d.mli +++ b/src/owl/fftpack/owl_fft_d.mli @@ -5,24 +5,25 @@ open Bigarray open Owl_dense_ndarray_generic +open Owl_fft_generic val fft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t val ifft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t val rfft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (float, float64_elt) t -> (Complex.t, complex64_elt) t @@ -30,7 +31,7 @@ val rfft val irfft : ?axis:int -> ?n:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) t -> (float, float64_elt) t @@ -41,8 +42,8 @@ val ifft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t val dct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t @@ -50,8 +51,8 @@ val dct val idct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t @@ -59,8 +60,8 @@ val idct val dst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t @@ -68,8 +69,8 @@ val dst val idst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index 4d24c219c..377e7a0c7 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -5,7 +5,17 @@ open Owl_dense_ndarray_generic -let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = +type tnorm = + | Backward + | Forward + | Ortho + +let tnorm_to_int = function + | Backward -> 0 + | Forward -> 1 + | Ortho -> 2 + +let fft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a @@ -14,11 +24,11 @@ let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads; + Owl_fftpack._owl_cfftf (kind x) x y axis (tnorm_to_int norm) nthreads; y -let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = +let ifft ?axis ?(norm : tnorm = Forward) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a @@ -27,11 +37,11 @@ let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads; + Owl_fftpack._owl_cfftb (kind x) x y axis (tnorm_to_int norm) nthreads; y -let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = +let rfft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a @@ -43,11 +53,11 @@ let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x s.(axis) <- (s.(axis) / 2) + 1; let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftf ityp otyp x y axis norm nthreads; + Owl_fftpack._owl_rfftf ityp otyp x y axis (tnorm_to_int norm) nthreads; y -let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = +let irfft ?axis ?n ?(norm : tnorm = Forward) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a @@ -63,7 +73,7 @@ let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kin in let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftb ityp otyp x y axis norm nthreads; + Owl_fftpack._owl_rfftb ityp otyp x y axis (tnorm_to_int norm) nthreads; y @@ -71,7 +81,19 @@ let fft2 x = fft ~axis:0 x |> fft ~axis:1 let ifft2 x = ifft ~axis:0 x |> ifft ~axis:1 -let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = +type ttrig_transform = + | I + | II + | III + | IV + +let ttrig_transform_to_int = function + | I -> 1 + | II -> 2 + | III -> 3 + | IV -> 4 + +let dct ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Backward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -82,15 +104,14 @@ let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dctf (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y -let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = +let idct ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Forward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -101,15 +122,14 @@ let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dctb (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y -let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = +let dst ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Backward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -120,15 +140,14 @@ let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dstf (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y -let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = +let idst ?axis ?(ttype = III) ?(norm : tnorm = Forward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -139,9 +158,8 @@ let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dstb (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index ef0d63032..d66528bc9 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -7,48 +7,59 @@ open Owl_dense_ndarray_generic +(** Normalisation options for transforms. *) +type tnorm = + | Backward (** No normalization on Forward and scaling by 1/N on Backward *) + | Forward (** Normalization by 1/N on the Forward transform. *) + | Ortho (** Forward and Backward are scaled by 1/sqrt(N) *) + (** {5 Discrete Fourier Transforms functions} *) val fft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t -(** [fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the - highest dimension if not specified. The return is not scaled. *) +(** [fft ~axis ~norm x] performs 1-dimensional FFT on a complex input. [axis] is the + highest dimension if not specified. [norm] is the normalization option. By default, [norm] is set to [Backward]. + [nthreads] is the desired number of threads used to compute the fft. *) val ifft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t (** [ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis] - indicates the highest dimension by default. *) + indicates the highest dimension by default. [norm] is the normalization option. By default, [norm] is set to [Forward]. + [nthreads] is the desired number of threads used to compute the fft. *) val rfft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> otyp:('a, 'b) kind -> ('c, 'd) t -> ('a, 'b) t (** [rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the - [axis]. [otyp] is used to specify the output type, it must be the consistent - precision with input [x]. You can skip this parameter by using a submodule - with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *) + [axis]. [norm] is the normalization option. By default, [norm] is set to [Backward]. + [nthreads] is the desired number of threads used to compute the fft. + [otyp] is used to specify the output type, it must be the consistent precision with input [x]. + You can skip this parameter by using a submodule with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *) val irfft : ?axis:int -> ?n:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> otyp:('a, 'b) kind -> ('c, 'd) t -> ('a, 'b) t -(** [irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter - is used to specified the size of output. *) +(** [irfft ~axis ~n x] is the inverse function of [rfft]. [norm] is the normalization option. + By default, [norm] is set to [Forward]. + [nthreads] is the desired number of threads used to compute the fft. + Note the [n] parameter is used to specified the size of output. *) val fft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t (** [fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. *) @@ -58,42 +69,61 @@ val ifft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t (** {5 Discrete Cosine & Sine Transforms functions} *) +type ttrig_transform = + | I + | II + | III + | IV +(** Trigonometric (Cosine and Sine) transform types. *) + val dct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [dct ~axis ~type x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) +(** [dct ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. + [ttype] is the DCT type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Backward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *) val idct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [idct ~axis ~type x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) +(** [idct ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. + [ttype] is the DCT type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Forward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *) val dst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [dst ~axis ~type x] performs 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) +(** [dst ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Sine Transform (DCT) on a real input. + [ttype] is the DCT type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Backward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DST. *) val idst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [idst ~axis ~type x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) +(** [idst ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. + [ttype] is the DST type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Forward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DST. *) \ No newline at end of file diff --git a/src/owl/fftpack/owl_fft_s.mli b/src/owl/fftpack/owl_fft_s.mli index aa4b0a643..3914d685c 100644 --- a/src/owl/fftpack/owl_fft_s.mli +++ b/src/owl/fftpack/owl_fft_s.mli @@ -5,24 +5,25 @@ open Bigarray open Owl_dense_ndarray_generic +open Owl_fft_generic val fft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t val ifft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t val rfft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (float, float32_elt) t -> (Complex.t, complex32_elt) t @@ -30,7 +31,7 @@ val rfft val irfft : ?axis:int -> ?n:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) t -> (float, float32_elt) t @@ -41,8 +42,8 @@ val ifft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t val dct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t @@ -50,8 +51,8 @@ val dct val idct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t @@ -59,8 +60,8 @@ val idct val dst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t @@ -68,8 +69,8 @@ val dst val idst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t