From 98a73c48881ba2fda8988968ed33a3f0a24ac591 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 11 Nov 2024 01:48:11 +0100 Subject: [PATCH] FFT module revamp. - Added several new functionnalities to the FFT module by changing the dependency from FFTPACK to POCKETFTT. - new optionnal parameters for the API (nthreads, norm, ...) - new functions for cosine and sine transforms (dct, dst, ...) - Switched from dune 2.0 to dune 3.16 (this was required as I ran throught issues with linking while using 2.0) --- .gitmodules | 3 + dune-project | 2 +- examples/dune | 17 +- src/base/core/owl_graph.ml | 2 +- src/owl/dune | 17 +- src/owl/fftpack/fftpack.h | 34 - src/owl/fftpack/fftpack_impl.h | 1450 ------------------------ src/owl/fftpack/owl_fft_d.mli | 65 +- src/owl/fftpack/owl_fft_generic.ml | 109 +- src/owl/fftpack/owl_fft_generic.mli | 112 +- src/owl/fftpack/owl_fft_s.mli | 65 +- src/owl/fftpack/owl_fftpack.ml | 218 +++- src/owl/fftpack/owl_fftpack_float32.c | 41 - src/owl/fftpack/owl_fftpack_float32.cc | 39 + src/owl/fftpack/owl_fftpack_float64.c | 41 - src/owl/fftpack/owl_fftpack_float64.cc | 38 + src/owl/fftpack/owl_fftpack_impl.h | 568 ++++++---- src/owl/fftpack/pocketfft | 1 + src/owl/nlp/owl_nlp_corpus.ml | 2 +- src/owl/nlp/owl_nlp_lda.ml | 2 +- src/owl/nlp/owl_nlp_tfidf.ml | 2 +- src/owl/nlp/owl_nlp_vocabulary.ml | 2 +- 22 files changed, 934 insertions(+), 1896 deletions(-) create mode 100644 .gitmodules delete mode 100644 src/owl/fftpack/fftpack.h delete mode 100644 src/owl/fftpack/fftpack_impl.h delete mode 100644 src/owl/fftpack/owl_fftpack_float32.c create mode 100644 src/owl/fftpack/owl_fftpack_float32.cc delete mode 100644 src/owl/fftpack/owl_fftpack_float64.c create mode 100644 src/owl/fftpack/owl_fftpack_float64.cc create mode 160000 src/owl/fftpack/pocketfft diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..775cc6982 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/owl/fftpack/pocketfft"] + path = src/owl/fftpack/pocketfft + url = https://github.com/mreineck/pocketfft diff --git a/dune-project b/dune-project index 04a070d3b..fe4c8519b 100644 --- a/dune-project +++ b/dune-project @@ -1,3 +1,3 @@ -(lang dune 2.0) +(lang dune 3.16) (name owl) diff --git a/examples/dune b/examples/dune index 6d5a16f8e..14e0e23be 100644 --- a/examples/dune +++ b/examples/dune @@ -1,5 +1,5 @@ (executables - (names + (names backprop checkpoint cifar10_vgg @@ -25,6 +25,15 @@ squeezenet test_log tfidf - vgg16 - ) - (libraries owl)) \ No newline at end of file + vgg16) + (libraries owl) + (flags ; in order to make the examples compile correctly even with the warnings. + (:standard + -warn-error + -unused-value-declaration + -warn-error + -unused-var-strict + -warn-error + -unused-var + -warn-error + -unused-field))) diff --git a/src/base/core/owl_graph.ml b/src/base/core/owl_graph.ml index 2c9d544c3..f344b1276 100644 --- a/src/base/core/owl_graph.ml +++ b/src/base/core/owl_graph.ml @@ -13,7 +13,7 @@ type 'a node = mutable next : 'a node array ; (* children of the node *) mutable attr : 'a (* indicate the validity *) - } + } [@@warning "-69"] type order = | BFS diff --git a/src/owl/dune b/src/owl/dune index 0dcc3c817..0bd4c0914 100644 --- a/src/owl/dune +++ b/src/owl/dune @@ -32,6 +32,8 @@ (copy_files# fftpack/*) +(copy_files# fftpack/pocketfft/*.h) + (copy_files# misc/*) (copy_files# nlp/*) @@ -42,6 +44,15 @@ (name owl) (public_name owl) (wrapped false) + (foreign_stubs + (language cxx) + (names + ;; FFTPACK + owl_fftpack_float32 + owl_fftpack_float64) + (flags + :standard + (:include c_flags.sexp))) (foreign_stubs (language c) (names @@ -65,9 +76,6 @@ owl_ndarray_utils_stub owl_slicing_basic_stub owl_slicing_fancy_stub - ;; FFTPACK - owl_fftpack_float32 - owl_fftpack_float64 ;; stats SFMT owl_stats_dist_beta @@ -202,7 +210,8 @@ (:include c_flags.sexp))) (c_library_flags :standard - (:include c_library_flags.sexp)) + (:include c_library_flags.sexp) + -lstdc++) (flags :standard (:include ocaml_flags.sexp)) diff --git a/src/owl/fftpack/fftpack.h b/src/owl/fftpack/fftpack.h deleted file mode 100644 index eba57ad0c..000000000 --- a/src/owl/fftpack/fftpack.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * OWL - OCaml Scientific Computing - * Copyright (c) 2016-2022 Liang Wang - */ - -/* Refer the doc on http://www.netlib.org/fftpack/doc */ - -#ifdef __cplusplus -extern "C" { -#endif - -// Single precision FFT - -extern void float32_fftpack_cffti(int N, const float wsave[]); -extern void float32_fftpack_cfftf(int N, float c[], const float wsave[]); -extern void float32_fftpack_cfftb(int N, float c[], const float wsave[]); - -extern void float32_fftpack_rffti(int N, const float wsave[]); -extern void float32_fftpack_rfftf(int N, float r[], const float wsave[]); -extern void float32_fftpack_rfftb(int N, float r[], const float wsave[]); - -// Double precision FFT - -extern void float64_fftpack_cffti(int N, const double wsave[]); -extern void float64_fftpack_cfftf(int N, double c[], const double wsave[]); -extern void float64_fftpack_cfftb(int N, double c[], const double wsave[]); - -extern void float64_fftpack_rffti(int N, const double wsave[]); -extern void float64_fftpack_rfftf(int N, double r[], const double wsave[]); -extern void float64_fftpack_rfftb(int N, double r[], const double wsave[]); - -#ifdef __cplusplus -} -#endif diff --git a/src/owl/fftpack/fftpack_impl.h b/src/owl/fftpack/fftpack_impl.h deleted file mode 100644 index 9e90375c7..000000000 --- a/src/owl/fftpack/fftpack_impl.h +++ /dev/null @@ -1,1450 +0,0 @@ -/* - * fftpack.c : A set of FFT routines in C. - * Algorithmically based on Fortran-77 FFTPACK by Paul N. Swarztrauber (Version 4, 1985). - * - * Further adapted into Owl from Numpy library. -*/ - -#include -#include -#include - -#define ref(u,a) u[a] - -#define MAXFAC 13 /* maximum number of factors in factorization of n */ -#define NSPECIAL 4 /* number of factors for which we have special-case routines */ - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef Treal - - -/* ---------------------------------------------------------------------- - passf2, passf3, passf4, passf5, passf. Complex FFT passes fwd and bwd. ------------------------------------------------------------------------ */ - -static void passf2(int ido, int l1, const Treal cc[], Treal ch[], const Treal wa1[], int isign) - /* isign==+1 for backward transform */ - { - int i, k, ah, ac; - Treal ti2, tr2; - if (ido <= 2) { - for (k=0; k= l1) { - for (j=1; j idp) idlj -= idp; - war = wa[idlj - 2]; - wai = wa[idlj-1]; - for (ik=0; ik= l1) { - for (j=1; j= l1) { - for (k=0; k= l1) { - for (j=1; j= l1) { - for (k=0; k= l1) { - for (j=1; j= l1) { - for (j=1; j 5) { - wa[i1-1] = wa[i-1]; - wa[i1] = wa[i]; - } - } - l1 = l2; - } - } /* cffti1 */ - - - /* ------------------------------------------------------------------- -rfftf1, rfftb1, owl_fftpack_rfftf, owl_fftpack_rfftb, rffti1, owl_fftpack_rffti. Treal FFTs. ----------------------------------------------------------------------- */ - -static void rfftf1(int n, Treal c[], Treal ch[], const Treal wa[], const int ifac[MAXFAC+2]) - { - int i; - int k1, l1, l2, na, kh, nf, ip, iw, ix2, ix3, ix4, ido, idl1; - Treal *cinput, *coutput; - nf = ifac[1]; - na = 1; - l2 = n; - iw = n-1; - for (k1 = 1; k1 <= nf; ++k1) { - kh = nf - k1; - ip = ifac[kh + 2]; - l1 = l2 / ip; - ido = n / l2; - idl1 = ido*l1; - iw -= (ip - 1)*ido; - na = !na; - if (na) { - cinput = ch; - coutput = c; - } else { - cinput = c; - coutput = ch; - } - switch (ip) { - case 4: - ix2 = iw + ido; - ix3 = ix2 + ido; - radf4(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3]); - break; - case 2: - radf2(ido, l1, cinput, coutput, &wa[iw]); - break; - case 3: - ix2 = iw + ido; - radf3(ido, l1, cinput, coutput, &wa[iw], &wa[ix2]); - break; - case 5: - ix2 = iw + ido; - ix3 = ix2 + ido; - ix4 = ix3 + ido; - radf5(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3], &wa[ix4]); - break; - default: - if (ido == 1) - na = !na; - if (na == 0) { - radfg(ido, ip, l1, idl1, c, ch, &wa[iw]); - na = 1; - } else { - radfg(ido, ip, l1, idl1, ch, c, &wa[iw]); - na = 0; - } - } - l2 = l1; - } - if (na == 1) return; - for (i = 0; i < n; i++) c[i] = ch[i]; - } /* rfftf1 */ - - -static void rfftb1(int n, Treal c[], Treal ch[], const Treal wa[], const int ifac[MAXFAC+2]) - { - int i; - int k1, l1, l2, na, nf, ip, iw, ix2, ix3, ix4, ido, idl1; - Treal *cinput, *coutput; - nf = ifac[1]; - na = 0; - l1 = 1; - iw = 0; - for (k1=1; k1<=nf; k1++) { - ip = ifac[k1 + 1]; - l2 = ip*l1; - ido = n / l2; - idl1 = ido*l1; - if (na) { - cinput = ch; - coutput = c; - } else { - cinput = c; - coutput = ch; - } - switch (ip) { - case 4: - ix2 = iw + ido; - ix3 = ix2 + ido; - radb4(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3]); - na = !na; - break; - case 2: - radb2(ido, l1, cinput, coutput, &wa[iw]); - na = !na; - break; - case 3: - ix2 = iw + ido; - radb3(ido, l1, cinput, coutput, &wa[iw], &wa[ix2]); - na = !na; - break; - case 5: - ix2 = iw + ido; - ix3 = ix2 + ido; - ix4 = ix3 + ido; - radb5(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3], &wa[ix4]); - na = !na; - break; - default: - radbg(ido, ip, l1, idl1, cinput, coutput, &wa[iw]); - if (ido == 1) na = !na; - } - l1 = l2; - iw += (ip - 1)*ido; - } - if (na == 0) return; - for (i=0; i (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t +val fft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -val ifft : ?axis:int -> (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t +val ifft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -val rfft : ?axis:int -> (float, float64_elt) t -> (Complex.t, complex64_elt) t +val rfft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (float, float64_elt) t + -> (Complex.t, complex64_elt) t -val irfft : ?axis:int -> ?n:int -> (Complex.t, complex64_elt) t -> (float, float64_elt) t +val irfft + : ?axis:int + -> ?n:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex64_elt) t + -> (float, float64_elt) t val fft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t val ifft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t + +val dct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t + +val idct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t + +val dst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t + +val idst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index f785e7c46..4d24c219c 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -5,68 +5,143 @@ open Owl_dense_ndarray_generic -let fft ?axis x = +let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftf (kind x) x y axis; + Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads; y -let ifft ?axis x = +let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftb (kind x) x y axis; - let norm = Complex.{ re = float_of_int (shape y).(axis); im = 0. } in - div_scalar_ y norm; + Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads; y -let rfft ?axis ~otyp x = +let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let s = shape x in s.(axis) <- (s.(axis) / 2) + 1; let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftf ityp otyp x y axis; + Owl_fftpack._owl_rfftf ityp otyp x y axis norm nthreads; y -let irfft ?axis ?n ~otyp x = +let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let s = shape x in let _ = match n with | Some n -> s.(axis) <- n - | None -> s.(axis) <- (s.(axis) - 1) * 2 + | None -> s.(axis) <- (s.(axis) - 1) * 2 in let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftb ityp otyp x y axis; - let norm = float_of_int s.(axis) in - div_scalar_ y norm; + Owl_fftpack._owl_rfftb ityp otyp x y axis norm nthreads; y let fft2 x = fft ~axis:0 x |> fft ~axis:1 let ifft2 x = ifft ~axis:0 x |> ifft ~axis:1 + +let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads; + y + + +let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads; + y + + +let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads; + y + + +let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads; + y diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index d805a3b4c..ef0d63032 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -7,45 +7,93 @@ open Owl_dense_ndarray_generic -(** {5 Basic functions} *) +(** {5 Discrete Fourier Transforms functions} *) -val fft : ?axis:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the -highest dimension if not specified. The return is not scaled. - *) +val fft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, 'a) t + -> (Complex.t, 'a) t +(** [fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the + highest dimension if not specified. The return is not scaled. *) -val ifft : ?axis:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis] -indicates the highest dimension by default. - *) +val ifft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, 'a) t + -> (Complex.t, 'a) t +(** [ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis] + indicates the highest dimension by default. *) -val rfft : ?axis:int -> otyp:(Complex.t, 'a) kind -> (float, 'b) t -> (Complex.t, 'a) t -(** -[rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the -[axis]. [otyp] is used to specify the output type, it must be the consistent -precision with input [x]. You can skip this parameter by using a submodule -with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. - *) +val rfft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> otyp:('a, 'b) kind + -> ('c, 'd) t + -> ('a, 'b) t +(** [rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the + [axis]. [otyp] is used to specify the output type, it must be the consistent + precision with input [x]. You can skip this parameter by using a submodule + with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *) - val irfft +val irfft : ?axis:int -> ?n:int - -> otyp:(float, 'a) kind - -> (Complex.t, 'b) t - -> (float, 'a) t -(** -[irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter -is used to specified the size of output. - *) + -> ?norm:int + -> ?nthreads:int + -> otyp:('a, 'b) kind + -> ('c, 'd) t + -> ('a, 'b) t +(** [irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter + is used to specified the size of output. *) val fft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. - *) +(** [fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. *) val ifft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[ifft2 x] performs inverse 2-dimensional FFT on a complex input. - *) +(** [ifft2 x] performs inverse 2-dimensional FFT on a complex input. *) + +(** {5 Discrete Cosine & Sine Transforms functions} *) + +val dct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [dct ~axis ~type x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) + +val idct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [idct ~axis ~type x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) + +val dst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [dst ~axis ~type x] performs 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) + +val idst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [idst ~axis ~type x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) diff --git a/src/owl/fftpack/owl_fft_s.mli b/src/owl/fftpack/owl_fft_s.mli index fe7fab5e6..aa4b0a643 100644 --- a/src/owl/fftpack/owl_fft_s.mli +++ b/src/owl/fftpack/owl_fft_s.mli @@ -6,14 +6,71 @@ open Bigarray open Owl_dense_ndarray_generic -val fft : ?axis:int -> (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t +val fft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -val ifft : ?axis:int -> (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t +val ifft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -val rfft : ?axis:int -> (float, float32_elt) t -> (Complex.t, complex32_elt) t +val rfft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (float, float32_elt) t + -> (Complex.t, complex32_elt) t -val irfft : ?axis:int -> ?n:int -> (Complex.t, complex32_elt) t -> (float, float32_elt) t +val irfft + : ?axis:int + -> ?n:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex32_elt) t + -> (float, float32_elt) t val fft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t val ifft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t + +val dct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t + +val idct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t + +val dst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t + +val idst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t diff --git a/src/owl/fftpack/owl_fftpack.ml b/src/owl/fftpack/owl_fftpack.ml index 205b02db7..4e54c22ee 100644 --- a/src/owl/fftpack/owl_fftpack.ml +++ b/src/owl/fftpack/owl_fftpack.ml @@ -6,10 +6,13 @@ open Bigarray open Owl_core_types +(* Forward Real FFT *) external owl_float32_rfftf : (float, float32_elt) owl_arr -> (Complex.t, complex32_elt) owl_arr -> int + -> int + -> int -> unit = "float32_rfftf" @@ -17,24 +20,37 @@ external owl_float64_rfftf : (float, float64_elt) owl_arr -> (Complex.t, complex64_elt) owl_arr -> int + -> int + -> int -> unit = "float64_rfftf" let _owl_rfftf - : type a b c d. - (a, b) kind -> (c, d) kind -> (a, b) owl_arr -> (c, d) owl_arr -> int -> unit + : type a b c d. + (a, b) kind + -> (c, d) kind + -> (a, b) owl_arr + -> (c, d) owl_arr + -> int + -> int + -> int + -> unit = - fun ityp otyp x y axis -> + fun ityp otyp x y axis norm nthreads -> match ityp, otyp with - | Float32, Complex32 -> owl_float32_rfftf x y axis - | Float64, Complex64 -> owl_float64_rfftf x y axis - | _ -> failwith "_owl_rfftf: unsupported operation" + | Float32, Complex32 -> owl_float32_rfftf x y axis norm nthreads + | Float64, Complex64 -> owl_float64_rfftf x y axis norm nthreads + | _ -> failwith "_owl_rfftf: unsupported operation" +(* Backward Real FFT *) + external owl_float32_rfftb : (Complex.t, complex32_elt) owl_arr -> (float, float32_elt) owl_arr -> int + -> int + -> int -> unit = "float32_rfftb" @@ -42,57 +58,191 @@ external owl_float64_rfftb : (Complex.t, complex64_elt) owl_arr -> (float, float64_elt) owl_arr -> int + -> int + -> int -> unit = "float64_rfftb" let _owl_rfftb - : type a b c d. - (a, b) kind -> (c, d) kind -> (a, b) owl_arr -> (c, d) owl_arr -> int -> unit + : type a b c d. + (a, b) kind + -> (c, d) kind + -> (a, b) owl_arr + -> (c, d) owl_arr + -> int + -> int + -> int + -> unit = - fun ityp otyp x y axis -> + fun ityp otyp x y axis norm nthreads -> match ityp, otyp with - | Complex32, Float32 -> owl_float32_rfftb x y axis - | Complex64, Float64 -> owl_float64_rfftb x y axis - | _ -> failwith "_owl_rfftb: unsupported operation" + | Complex32, Float32 -> owl_float32_rfftb x y axis norm nthreads + | Complex64, Float64 -> owl_float64_rfftb x y axis norm nthreads + | _ -> failwith "_owl_rfftb: unsupported operation" -external owl_complex32_cfftf - : (Complex.t, complex32_elt) owl_arr +external owl_complex32_cfft + : bool -> (Complex.t, complex32_elt) owl_arr + -> (Complex.t, complex32_elt) owl_arr + -> int + -> int -> int -> unit - = "float32_cfftf" + = "float64_cfft_bytecode" "float32_cfft" -external owl_complex64_cfftf - : (Complex.t, complex64_elt) owl_arr +external owl_complex64_cfft + : bool + -> (Complex.t, complex64_elt) owl_arr -> (Complex.t, complex64_elt) owl_arr -> int + -> int + -> int -> unit - = "float64_cfftf" + = "float64_cfft_bytecode" "float64_cfft" -let _owl_cfftf : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> unit +let _owl_cfftf + : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> int -> int -> unit = function - | Complex32 -> owl_complex32_cfftf - | Complex64 -> owl_complex64_cfftf - | _ -> failwith "_owl_cfftf: unsupported operation" + | Complex32 -> true |> owl_complex32_cfft + | Complex64 -> true |> owl_complex64_cfft + | _ -> failwith "_owl_cfftf: unsupported operation" -external owl_complex32_cfftb - : (Complex.t, complex32_elt) owl_arr - -> (Complex.t, complex32_elt) owl_arr +let _owl_cfftb + : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> int -> int -> unit + = function + | Complex32 -> false |> owl_complex32_cfft + | Complex64 -> false |> owl_complex64_cfft + | _ -> failwith "_owl_cfftb: unsupported operation" + + +(* DCT and DST *) + +(* little helper to get the inverse type of DSTs and DCTs *) +let inverse_map = function + | 1 -> 1 + | 2 -> 3 + | 3 -> 2 + | 4 -> 4 + | _ -> failwith "unknown transform type" + + +(* DCT *) + +external owl_float32_dct + : (float, float32_elt) owl_arr + -> (float, float32_elt) owl_arr + -> int + -> int + -> int + -> bool -> int -> unit - = "float32_cfftb" + = "float32_dct_bytecode" "float32_dct" -external owl_complex64_cfftb - : (Complex.t, complex64_elt) owl_arr - -> (Complex.t, complex64_elt) owl_arr +external owl_float64_dct + : (float, float64_elt) owl_arr + -> (float, float64_elt) owl_arr + -> int + -> int + -> int + -> bool -> int -> unit - = "float64_cfftb" + = "float64_dct_bytecode" "float64_dct" + +let _owl_dctf + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = function + | Float32 -> owl_float32_dct + | Float64 -> owl_float64_dct + | _ -> failwith "_owl_dctf: unsupported operation" + + +let _owl_dctb + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = + fun ityp x y ttype axis norm ortho nthreads -> + match ityp with + | Float32 -> owl_float32_dct x y (inverse_map ttype) axis norm ortho nthreads + | Float64 -> owl_float64_dct x y (inverse_map ttype) axis norm ortho nthreads + | _ -> failwith "_owl_dctb: unsupported operation" + -let _owl_cfftb : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> unit +(* DST *) + +external owl_float32_dst + : (float, float32_elt) owl_arr + -> (float, float32_elt) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = "float32_dst_bytecode" "float32_dst" + +external owl_float64_dst + : (float, float64_elt) owl_arr + -> (float, float64_elt) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = "float64_dst_bytecode" "float64_dst" + +let _owl_dstf + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit = function - | Complex32 -> owl_complex32_cfftb - | Complex64 -> owl_complex64_cfftb - | _ -> failwith "_owl_cfftf: unsupported operation" + | Float32 -> owl_float32_dst + | Float64 -> owl_float64_dst + | _ -> failwith "_owl_dstf: unsupported operation" + + +let _owl_dstb + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = + fun ityp x y ttype axis norm ortho nthreads -> + match ityp with + | Float32 -> owl_float32_dst x y (inverse_map ttype) axis norm ortho nthreads + | Float64 -> owl_float64_dst x y (inverse_map ttype) axis norm ortho nthreads + | _ -> failwith "_owl_dstb: unsupported operation" diff --git a/src/owl/fftpack/owl_fftpack_float32.c b/src/owl/fftpack/owl_fftpack_float32.c deleted file mode 100644 index 4bbd337aa..000000000 --- a/src/owl/fftpack/owl_fftpack_float32.c +++ /dev/null @@ -1,41 +0,0 @@ -/* - * OWL - OCaml Scientific Computing - * Copyright (c) 2016-2022 Liang Wang - */ - -#include - -#include "owl_core.h" - - -#define Treal float - -#define REAL_COPY owl_float32_copy -#define COMPLEX_COPY owl_complex32_copy -#define FFTPACK_CFFTI float32_fftpack_cffti -#define FFTPACK_CFFTF float32_fftpack_cfftf -#define FFTPACK_CFFTB float32_fftpack_cfftb -#define FFTPACK_RFFTI float32_fftpack_rffti -#define FFTPACK_RFFTF float32_fftpack_rfftf -#define FFTPACK_RFFTB float32_fftpack_rfftb -#define STUB_CFFTF float32_cfftf -#define STUB_CFFTB float32_cfftb -#define STUB_RFFTF float32_rfftf -#define STUB_RFFTB float32_rfftb - -#include "owl_fftpack_impl.h" - -#undef REAL_COPY -#undef COMPLEX_COPY -#undef FFTPACK_CFFTI -#undef FFTPACK_CFFTF -#undef FFTPACK_CFFTB -#undef FFTPACK_RFFTI -#undef FFTPACK_RFFTF -#undef FFTPACK_RFFTB -#undef STUB_CFFTF -#undef STUB_CFFTB -#undef STUB_RFFTF -#undef STUB_RFFTB - -#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_float32.cc b/src/owl/fftpack/owl_fftpack_float32.cc new file mode 100644 index 000000000..c50778385 --- /dev/null +++ b/src/owl/fftpack/owl_fftpack_float32.cc @@ -0,0 +1,39 @@ +/* + * OWL - OCaml Scientific Computing + * Copyright (c) 2016-2022 Liang Wang + */ + +#include + +#define Treal float + +extern "C" +{ +#include "owl_core.h" +} + +#define REAL_COPY owl_float32_copy +#define COMPLEX_COPY owl_complex32_copy +#define STUB_CFFT float32_cfft +#define STUB_CFFT_bytecode float32_cfft_bytecode +#define STUB_RFFTF float32_rfftf +#define STUB_RFFTB float32_rfftb +#define STUB_RDCT float32_dct +#define STUB_RDCT_bytecode float32_dct_bytecode +#define STUB_RDST float32_dst +#define STUB_RDST_bytecode float32_dst_bytecode + +#include "owl_fftpack_impl.h" + +#undef REAL_COPY +#undef COMPLEX_COPY +#undef STUB_CFFT +#undef STUB_CFFT_bytecode +#undef STUB_RFFTF +#undef STUB_RFFTB +#undef STUB_RDCT +#undef STUB_RDCT_bytecode +#undef STUB_RDST +#undef STUB_RDST_bytecode + +#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_float64.c b/src/owl/fftpack/owl_fftpack_float64.c deleted file mode 100644 index 2a9fb7832..000000000 --- a/src/owl/fftpack/owl_fftpack_float64.c +++ /dev/null @@ -1,41 +0,0 @@ -/* - * OWL - OCaml Scientific Computing - * Copyright (c) 2016-2022 Liang Wang - */ - -#include - -#include "owl_core.h" - - -#define Treal double - -#define REAL_COPY owl_float64_copy -#define COMPLEX_COPY owl_complex64_copy -#define FFTPACK_CFFTI float64_fftpack_cffti -#define FFTPACK_CFFTF float64_fftpack_cfftf -#define FFTPACK_CFFTB float64_fftpack_cfftb -#define FFTPACK_RFFTI float64_fftpack_rffti -#define FFTPACK_RFFTF float64_fftpack_rfftf -#define FFTPACK_RFFTB float64_fftpack_rfftb -#define STUB_CFFTF float64_cfftf -#define STUB_CFFTB float64_cfftb -#define STUB_RFFTF float64_rfftf -#define STUB_RFFTB float64_rfftb - -#include "owl_fftpack_impl.h" - -#undef REAL_COPY -#undef COMPLEX_COPY -#undef FFTPACK_CFFTI -#undef FFTPACK_CFFTF -#undef FFTPACK_CFFTB -#undef FFTPACK_RFFTI -#undef FFTPACK_RFFTF -#undef FFTPACK_RFFTB -#undef STUB_CFFTF -#undef STUB_CFFTB -#undef STUB_RFFTF -#undef STUB_RFFTB - -#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_float64.cc b/src/owl/fftpack/owl_fftpack_float64.cc new file mode 100644 index 000000000..b34050df0 --- /dev/null +++ b/src/owl/fftpack/owl_fftpack_float64.cc @@ -0,0 +1,38 @@ +/* + * OWL - OCaml Scientific Computing + * Copyright (c) 2016-2022 Liang Wang + */ + +#include +#define Treal double + +extern "C" +{ +#include "owl_core.h" +} + +#define REAL_COPY owl_float64_copy +#define COMPLEX_COPY owl_complex64_copy +#define STUB_CFFT float64_cfft +#define STUB_CFFT_bytecode float64_cfft_bytecode +#define STUB_RFFTF float64_rfftf +#define STUB_RFFTB float64_rfftb +#define STUB_RDCT float64_dct +#define STUB_RDCT_bytecode float64_dct_bytecode +#define STUB_RDST float64_dst +#define STUB_RDST_bytecode float64_dst_bytecode + +#include "owl_fftpack_impl.h" + +#undef REAL_COPY +#undef COMPLEX_COPY +#undef STUB_CFFT +#undef STUB_CFFT_bytecode +#undef STUB_RFFTF +#undef STUB_RFFTB +#undef STUB_RDCT +#undef STUB_RDCT_bytecode +#undef STUB_RDST +#undef STUB_RDST_bytecode + +#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_impl.h b/src/owl/fftpack/owl_fftpack_impl.h index f62491c6e..f3a36ebc4 100644 --- a/src/owl/fftpack/owl_fftpack_impl.h +++ b/src/owl/fftpack/owl_fftpack_impl.h @@ -3,264 +3,382 @@ * Copyright (c) 2016-2022 Liang Wang */ - #ifdef Treal -#include "fftpack_impl.h" - - -/** Owl's interface function to FFTPACK **/ - - -void FFTPACK_CFFTI (int n, Treal wsave[]) { - if (n == 1) return; - int iw1 = 2 * n; - int iw2 = iw1 + 2 * n; - cffti1(n, wsave + iw1, (int*) (wsave + iw2)); -} - - -void FFTPACK_CFFTF (int n, Treal c[], Treal wsave[]) { - if (n == 1) return; - int iw1 = 2 * n; - int iw2 = iw1 + 2 * n; - cfftf1(n, c, wsave, wsave + iw1, (int*) (wsave + iw2), -1); -} - - -void FFTPACK_CFFTB (int n, Treal c[], Treal wsave[]) { - if (n == 1) return; - int iw1 = 2 * n; - int iw2 = iw1 + 2 * n; - cfftf1(n, c, wsave, wsave + iw1, (int*) (wsave + iw2), +1); -} - - -void FFTPACK_RFFTI (int n, Treal wsave[]) { - if (n == 1) return; - rffti1(n, wsave + n, (int*) (wsave + 2 * n)); -} - - -void FFTPACK_RFFTF (int n, Treal r[], Treal wsave[]) { - if (n == 1) return; - rfftf1(n, r, wsave, wsave + n, (int*) (wsave + 2 * n)); -} - - -void FFTPACK_RFFTB(int n, Treal r[], Treal wsave[]) { - if (n == 1) return; - rfftb1(n, r, wsave, wsave + n, (int*) (wsave + 2 * n)); -} - - -/** Helper functions **/ - - -// uppack from halfcomplex x to complex y -static OWL_INLINE void halfcomplex_unpack (int n, Treal* x, int ofsx, int incx, _Complex Treal* y, int ofsy, int incy) { - int i; - *(y + ofsy) = *(x + ofsx) + 0 * I; - - for (i = 1; i < n - i; i++) { - ofsx += incx + incx; - ofsy += incy; - Treal re = *(x + ofsx - incx); - Treal im = *(x + ofsx); - *(y + ofsy) = re + im * I; +#include "pocketfft_hdronly.h" + +/** Owl's interface function to pocketfft **/ +/** Adapted from scipy's pypocketfft.cxx **/ + +using namespace pocketfft::detail; + +template +T norm_fct(int inorm, size_t N) +{ + switch (inorm) + { + case 0: // "backward" - no normalization for forward transform + return T(1); + case 1: // "forward" - 1/n normalization for forward transform + return T(1) / T(N); + case 2: // "ortho" - 1/sqrt(n) normalization for both directions + return T(1) / std::sqrt(T(N)); + default: + caml_failwith("invalid value for inorm (must be 0, 1, or 2)"); + // This will never be reached + return T(0); } - - if (i == n - i) - *(y + ofsy + incy) = *(x + ofsx + incx) + 0 * I; } - -// pack from complex x to halfcomplex y -static OWL_INLINE void halfcomplex_pack (int n, _Complex Treal* x, int ofsx, int incx, Treal* y, int ofsy, int incy) { - int i; - *(y + ofsy) = creal(*(x + ofsx)); - - for (i = 1; i < n - i; i++) { - ofsx += incx; - ofsy += incy + incy; - *(y + ofsy - incy) = creal(*(x + ofsx)); - *(y + ofsy) = cimag(*(x + ofsx)); +template +T compute_norm_factor(const shape_t &dims, const shape_t &axes, int inorm, size_t fct = 1, int delta = 0) +{ + if (inorm == 0) + return T(1); + size_t N = 1; + for (auto a : axes) + { + N *= fct * size_t(int64_t(dims[a]) + delta); } - - if (i == n - i) - *(y + ofsy + incy) = creal(*(x + ofsx + incx)); + return norm_fct(inorm, N); } - -/** Owl's stub functions **/ - - -value STUB_CFFTF (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - _Complex Treal *X_data = (_Complex Treal *) X->data; - - struct caml_ba_array *Y = Caml_ba_array_val(vY); - _Complex Treal *Y_data = (_Complex Treal *) Y->data; - - int d = Long_val(vD); - int n = X->dim[d]; - size_t ws_sz = 4 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(2 * n * sizeof(Treal)); - - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; - - FFTPACK_CFFTI(n, wsave); - - int ofsx = 0; - int ofsy = 0; - - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - COMPLEX_COPY(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_CFFTF(n, (Treal*) data, wsave); - COMPLEX_COPY(n, data, 0, 1, Y_data, ofsy + j, stdy); +extern "C" +{ + + /** Owl's stub functions **/ + + /** + * Complex-to-complex FFT + * @param forward: true for forward transform, false for backward transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + std::complex *X_data = reinterpret_cast *>(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + std::complex *Y_data = reinterpret_cast *>(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + int forward = Bool_val(vForward); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); } - ofsx += slcx; - ofsy += slcy; - } - - free(wsave); - free(data); - - return Val_unit; -} - - -value STUB_CFFTB (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - _Complex Treal *X_data = (_Complex Treal *) X->data; - - struct caml_ba_array *Y = Caml_ba_array_val(vY); - _Complex Treal *Y_data = (_Complex Treal *) Y->data; - - int d = Long_val(vD); - int n = X->dim[d]; - size_t ws_sz = 4 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(2 * n * sizeof(Treal)); - - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; - - FFTPACK_CFFTI(n, wsave); - int ofsx = 0; - int ofsy = 0; - - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - COMPLEX_COPY(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_CFFTB(n, (Treal*) data, wsave); - COMPLEX_COPY(n, data, 0, 1, Y_data, ofsy + j, stdy); + size_t multiplier = sizeof(std::complex); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); } - ofsx += slcx; - ofsy += slcy; - } - - free(wsave); - free(data); - return Val_unit; -} + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try + { + pocketfft::detail::c2c(dims, stride_in, stride_out, axes, forward, + X_data, Y_data, norm_factor, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } + return Val_unit; + } -value STUB_RFFTF (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - Treal *X_data = (Treal *) X->data; + /** + * Complex-to-complex FFT + * @param argv: array of arguments + * @param argn: number of arguments + * @see STUB_CFFT, https://ocaml.org/manual/5.2/intfc.html#ss:c-prim-impl + */ + value STUB_CFFT_bytecode(value *argv, int argn) + { + return STUB_CFFT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]); + } - struct caml_ba_array *Y = Caml_ba_array_val(vY); - _Complex Treal *Y_data = (_Complex Treal *) Y->data; + /** + * Forward Real-to-complex FFT + * @param X: input array (real data) + * @param Y: output array (complex data) + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + std::complex *Y_data = reinterpret_cast *>(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } - int d = Long_val(vD); - int n = X->dim[d]; - size_t ws_sz = 2 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(n * sizeof(Treal)); + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + } - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; + multiplier = sizeof(std::complex); + for (int i = 0; i < Y->num_dims; ++i) + { + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - FFTPACK_RFFTI(n, wsave); + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try + { + pocketfft::r2c(dims, stride_in, stride_out, axes, pocketfft::FORWARD, + X_data, Y_data, norm_factor, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } - int ofsx = 0; - int ofsy = 0; + return Val_unit; + } - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - REAL_COPY(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_RFFTF(n, (Treal*) data, wsave); - halfcomplex_unpack(n, data, 0, 1, Y_data, ofsy + j, stdy); + /** + * Backward Real-to-complex FFT + * @param X: input array (complex data) + * @param Y: output array (real data) + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + std::complex *X_data = reinterpret_cast *>(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + + if (Y->dim[d] != (X->dim[d] - 1) * 2) + caml_failwith("Invalid output dimension for inverse real FFT"); + + shape_t dims; + stride_t stride_in, stride_out; + + int ncomplex = X->dim[d]; + int nreal = Y->dim[d]; + + for (int i = 0; i < X->num_dims; ++i) + { + if (i == d) + { + dims.push_back(static_cast(nreal)); + } + else + { + dims.push_back(static_cast(X->dim[i])); + } } - ofsx += slcx; - ofsy += slcy; - } - free(wsave); - free(data); + size_t multiplier = sizeof(std::complex); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + } - return Val_unit; -} + multiplier = sizeof(Treal); + for (int i = 0; i < Y->num_dims; ++i) + { + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try + { + pocketfft::c2r(dims, stride_in, stride_out, axes, pocketfft::BACKWARD, + X_data, Y_data, norm_factor, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } -value STUB_RFFTB (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - _Complex Treal *X_data = (_Complex Treal *) X->data; + return Val_unit; + } - struct caml_ba_array *Y = Caml_ba_array_val(vY); - Treal *Y_data = (Treal *) Y->data; + /** + * Discrete Cosine Transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param type: type of DCT (1, 2, 3, or 4) + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int type = Long_val(vType); + if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side + caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); + int norm = Long_val(vNorm); + bool ortho = Bool_val(vOrtho); + int nthreads = Long_val(vNthreads); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } - int d = Long_val(vD); - int n = Y->dim[d]; - size_t ws_sz = 2 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(n * sizeof(_Complex Treal)); + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; + shape_t axes{static_cast(d)}; + { + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) + : compute_norm_factor(dims, axes, norm, 2); + try + { + pocketfft::detail::dct(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } - FFTPACK_RFFTI(n, wsave); + return Val_unit; + } - int ofsx = 0; - int ofsy = 0; + value STUB_RDCT_bytecode(value *argv, int argn) + { + return STUB_RDCT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); + } - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - halfcomplex_pack(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_RFFTB(n, (Treal*) data, wsave); - REAL_COPY(n, (Treal*) data, 0, 1, Y_data, ofsy + j, stdy); + /** + * Discrete Sine Transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param type: type of DST (1, 2, 3, or 4) + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int type = Long_val(vType); + if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side + caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); + int norm = Long_val(vNorm); + bool ortho = Bool_val(vOrtho); + int nthreads = Long_val(vNthreads); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); } - ofsx += slcx; - ofsy += slcy; - } - free(wsave); - free(data); + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - return Val_unit; -} + shape_t axes{static_cast(d)}; + { + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) + : compute_norm_factor(dims, axes, norm, 2); + try + { + pocketfft::detail::dst(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } + return Val_unit; + } -#endif //Treal + value STUB_RDST_bytecode(value *argv, int argn) + { + return STUB_RDST(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); + } +} // extern "C" +#endif // Treal diff --git a/src/owl/fftpack/pocketfft b/src/owl/fftpack/pocketfft new file mode 160000 index 000000000..bb87ca50d --- /dev/null +++ b/src/owl/fftpack/pocketfft @@ -0,0 +1 @@ +Subproject commit bb87ca50df0478415a12d9011dc374eeed4e9d93 diff --git a/src/owl/nlp/owl_nlp_corpus.ml b/src/owl/nlp/owl_nlp_corpus.ml index 89c42b161..84a7199c0 100644 --- a/src/owl/nlp/owl_nlp_corpus.ml +++ b/src/owl/nlp/owl_nlp_corpus.ml @@ -21,7 +21,7 @@ type t = mutable minlen : int ; (* minimum length of document to save *) mutable docid : int array (* document id, can refer to original data *) - } + } [@@warning "-69"] let _close_if_open = function | Some h -> close_in h diff --git a/src/owl/nlp/owl_nlp_lda.ml b/src/owl/nlp/owl_nlp_lda.ml index 90894f2c0..05cb5dce9 100644 --- a/src/owl/nlp/owl_nlp_lda.ml +++ b/src/owl/nlp/owl_nlp_lda.ml @@ -43,7 +43,7 @@ type model = mutable data : Owl_nlp_corpus.t ; (* training data, tokenised*) mutable vocb : (string, int) Hashtbl.t (* vocabulary, or dictionary if you prefer *) - } + } [@@warning "-69"] let include_token m w d k = m.t__k.(k) <- m.t__k.(k) +. 1.; diff --git a/src/owl/nlp/owl_nlp_tfidf.ml b/src/owl/nlp/owl_nlp_tfidf.ml index d083bffac..6142c285a 100644 --- a/src/owl/nlp/owl_nlp_tfidf.ml +++ b/src/owl/nlp/owl_nlp_tfidf.ml @@ -30,7 +30,7 @@ type t = mutable corpus : Owl_nlp_corpus.t ; (* corpus type *) mutable handle : in_channel option (* file descriptor of the tfidf *) - } + } [@@warning "-69"] (* various types of TF and IDF functions *) diff --git a/src/owl/nlp/owl_nlp_vocabulary.ml b/src/owl/nlp/owl_nlp_vocabulary.ml index 913267dff..195f7c96d 100644 --- a/src/owl/nlp/owl_nlp_vocabulary.ml +++ b/src/owl/nlp/owl_nlp_vocabulary.ml @@ -11,7 +11,7 @@ type t = mutable i2w : (int, string) Hashtbl.t ; (* index -> word *) mutable i2f : (int, int) Hashtbl.t (* index -> freq *) - } + } [@@warning "-69"] let get_w2i d = d.w2i