Skip to content

Commit

Permalink
Adding abstract types and documentation.
Browse files Browse the repository at this point in the history
- Added the ttrig_transform type to specify the DCT and DST types
- Added the tnorm type to specify the normalization option for the FFTs.
  • Loading branch information
gabyfle committed Nov 11, 2024
1 parent 98a73c4 commit 6d324e9
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 72 deletions.
25 changes: 13 additions & 12 deletions src/owl/fftpack/owl_fft_d.mli
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,33 @@

open Bigarray
open Owl_dense_ndarray_generic
open Owl_fft_generic

val fft
: ?axis:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t
-> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t

val ifft
: ?axis:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t
-> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t

val rfft
: ?axis:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> (float, float64_elt) t
-> (Complex.t, complex64_elt) t

val irfft
: ?axis:int
-> ?n:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> (Complex.t, complex64_elt) t
-> (float, float64_elt) t
Expand All @@ -41,35 +42,35 @@ val ifft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t

val dct
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> (float, float64_elt) t
-> (float, float64_elt) t

val idct
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> (float, float64_elt) t
-> (float, float64_elt) t

val dst
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> (float, float64_elt) t
-> (float, float64_elt) t

val idst
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> (float, float64_elt) t
Expand Down
66 changes: 42 additions & 24 deletions src/owl/fftpack/owl_fft_generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@

open Owl_dense_ndarray_generic

let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x =
type tnorm =
| Backward
| Forward
| Ortho

let tnorm_to_int = function
| Backward -> 0
| Forward -> 1
| Ortho -> 2

let fft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -14,11 +24,11 @@ let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x =
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads;
Owl_fftpack._owl_cfftf (kind x) x y axis (tnorm_to_int norm) nthreads;
y


let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x =
let ifft ?axis ?(norm : tnorm = Forward) ?(nthreads : int = 1) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -27,11 +37,11 @@ let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x =
let axis = if axis < 0 then num_dims x + axis else axis in
assert (axis < num_dims x);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads;
Owl_fftpack._owl_cfftb (kind x) x y axis (tnorm_to_int norm) nthreads;
y


let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x =
let rfft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -43,11 +53,11 @@ let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x
s.(axis) <- (s.(axis) / 2) + 1;
let y = empty otyp s in
let ityp = kind x in
Owl_fftpack._owl_rfftf ityp otyp x y axis norm nthreads;
Owl_fftpack._owl_rfftf ityp otyp x y axis (tnorm_to_int norm) nthreads;
y


let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x =
let irfft ?axis ?n ?(norm : tnorm = Forward) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -63,15 +73,27 @@ let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kin
in
let y = empty otyp s in
let ityp = kind x in
Owl_fftpack._owl_rfftb ityp otyp x y axis norm nthreads;
Owl_fftpack._owl_rfftb ityp otyp x y axis (tnorm_to_int norm) nthreads;
y


let fft2 x = fft ~axis:0 x |> fft ~axis:1

let ifft2 x = ifft ~axis:0 x |> ifft ~axis:1

let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x =
type ttrig_transform =
| I
| II
| III
| IV

let ttrig_transform_to_int = function
| I -> 1
| II -> 2
| III -> 3
| IV -> 4

let dct ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Backward) ?(ortho : bool option) ?(nthreads = 1) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -82,15 +104,14 @@ let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads =
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
| None -> if norm = Ortho then true else false
in
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads;
Owl_fftpack._owl_dctf (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads;
y


let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x =
let idct ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Forward) ?(ortho : bool option) ?(nthreads = 1) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -101,15 +122,14 @@ let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
| None -> if norm = Ortho then true else false
in
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads;
Owl_fftpack._owl_dctb (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads;
y


let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x =
let dst ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Backward) ?(ortho : bool option) ?(nthreads = 1) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -120,15 +140,14 @@ let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads =
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
| None -> if norm = Ortho then true else false
in
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads;
Owl_fftpack._owl_dstf (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads;
y


let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x =
let idst ?axis ?(ttype = III) ?(norm : tnorm = Forward) ?(ortho : bool option) ?(nthreads = 1) x =
let axis =
match axis with
| Some a -> a
Expand All @@ -139,9 +158,8 @@ let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads
let ortho =
match ortho with
| Some o -> o
| None -> if norm = 2 then true else false
| None -> if norm = Ortho then true else false
in
assert (ttype > 0 || ttype < 5);
let y = empty (kind x) (shape x) in
Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads;
Owl_fftpack._owl_dstb (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads;
y
78 changes: 54 additions & 24 deletions src/owl/fftpack/owl_fft_generic.mli
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,59 @@

open Owl_dense_ndarray_generic

(** Normalisation options for transforms. *)
type tnorm =
| Backward (** No normalization on Forward and scaling by 1/N on Backward *)
| Forward (** Normalization by 1/N on the Forward transform. *)
| Ortho (** Forward and Backward are scaled by 1/sqrt(N) *)

(** {5 Discrete Fourier Transforms functions} *)

val fft
: ?axis:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> (Complex.t, 'a) t
-> (Complex.t, 'a) t
(** [fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the
highest dimension if not specified. The return is not scaled. *)
(** [fft ~axis ~norm x] performs 1-dimensional FFT on a complex input. [axis] is the
highest dimension if not specified. [norm] is the normalization option. By default, [norm] is set to [Backward].
[nthreads] is the desired number of threads used to compute the fft. *)

val ifft
: ?axis:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> (Complex.t, 'a) t
-> (Complex.t, 'a) t
(** [ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis]
indicates the highest dimension by default. *)
indicates the highest dimension by default. [norm] is the normalization option. By default, [norm] is set to [Forward].
[nthreads] is the desired number of threads used to compute the fft. *)

val rfft
: ?axis:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> otyp:('a, 'b) kind
-> ('c, 'd) t
-> ('a, 'b) t
(** [rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the
[axis]. [otyp] is used to specify the output type, it must be the consistent
precision with input [x]. You can skip this parameter by using a submodule
with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *)
[axis]. [norm] is the normalization option. By default, [norm] is set to [Backward].
[nthreads] is the desired number of threads used to compute the fft.
[otyp] is used to specify the output type, it must be the consistent precision with input [x].
You can skip this parameter by using a submodule with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *)

val irfft
: ?axis:int
-> ?n:int
-> ?norm:int
-> ?norm:tnorm
-> ?nthreads:int
-> otyp:('a, 'b) kind
-> ('c, 'd) t
-> ('a, 'b) t
(** [irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter
is used to specified the size of output. *)
(** [irfft ~axis ~n x] is the inverse function of [rfft]. [norm] is the normalization option.
By default, [norm] is set to [Forward].
[nthreads] is the desired number of threads used to compute the fft.
Note the [n] parameter is used to specified the size of output. *)

val fft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t
(** [fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. *)
Expand All @@ -58,42 +69,61 @@ val ifft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t

(** {5 Discrete Cosine & Sine Transforms functions} *)

type ttrig_transform =
| I
| II
| III
| IV
(** Trigonometric (Cosine and Sine) transform types. *)

val dct
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> ('a, 'b) t
-> ('a, 'b) t
(** [dct ~axis ~type x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *)
(** [dct ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input.
[ttype] is the DCT type to use for this transform. Default value is [II].
[norm] is the normalization option. By default, [norm] is set to [Backward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *)

val idct
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> ('a, 'b) t
-> ('a, 'b) t
(** [idct ~axis ~type x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *)
(** [idct ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input.
[ttype] is the DCT type to use for this transform. Default value is [II].
[norm] is the normalization option. By default, [norm] is set to [Forward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *)

val dst
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> ('a, 'b) t
-> ('a, 'b) t
(** [dst ~axis ~type x] performs 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *)
(** [dst ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Sine Transform (DCT) on a real input.
[ttype] is the DCT type to use for this transform. Default value is [II].
[norm] is the normalization option. By default, [norm] is set to [Backward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DST. *)

val idst
: ?axis:int
-> ?ttype:int
-> ?norm:int
-> ?ttype:ttrig_transform
-> ?norm:tnorm
-> ?ortho:bool
-> ?nthreads:int
-> ('a, 'b) t
-> ('a, 'b) t
(** [idst ~axis ~type x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *)
(** [idst ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input.
[ttype] is the DST type to use for this transform. Default value is [II].
[norm] is the normalization option. By default, [norm] is set to [Forward].
[ortho] constrols whether or not we should use the orthogonalized variant of the DST. *)
Loading

0 comments on commit 6d324e9

Please sign in to comment.