Skip to content

Commit

Permalink
Supporting NumPy >= 1.26.3 which includes NumPy 2.0 versions (#37)
Browse files Browse the repository at this point in the history
* Update to support NumPy >= 1.26.3 which includes 2.0.1

* Updated GitHub workflows to no longer conrstruct the paper.pdf for JOSS as JOSS review is over and article is published.

* Fixed type annotations to also work on MacOS

* Updated README

* Updated test tolerances due to MacOS. Also updated README.

* Updated test tolerances to accomodate MacOS precision.

* Updated test tolerances to accomodate MacOS precision.
  • Loading branch information
sm00thix authored Aug 5, 2024
1 parent e7046e8 commit d11746d
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 87 deletions.
23 changes: 0 additions & 23 deletions .github/workflows/draft-pdf.yml

This file was deleted.

10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

The `ikpls` software package provides fast and efficient tools for PLS (Partial Least Squares) modeling. This package is designed to help researchers and practitioners handle PLS modeling faster than previously possible - particularly on large datasets.

## Citation
If you use the `ikpls` software package for your work, please cite [this Journal of Open Source Software article](https://joss.theoj.org/papers/10.21105/joss.06533). If you use the fast cross-validation algorithm implemented in `ikpls.fast_cross_validation.numpy_ikpls`, please also cite [this arXiv preprint](https://arxiv.org/abs/2401.13185).

## Unlock the Power of Fast and Stable Partial Least Squares Modeling with IKPLS

Dive into cutting-edge Python implementations of the IKPLS (Improved Kernel Partial Least Squares) Algorithms #1 and #2 [[1]](#references) for CPUs, GPUs, and TPUs. IKPLS is both fast [[2]](#references) and numerically stable [[3]](#references) making it optimal for PLS modeling.
Expand Down Expand Up @@ -49,9 +52,10 @@ and scaling can be enabled or disabled independently from eachother and for X an
by setting the parameters `center_X`, `center_Y`, `scale_X`, and `scale_Y`, respectively.
In addition to correctly handling (column-wise) centering and scaling,
the fast cross-validation algorithm **correctly handles row-wise preprocessing**
such as (row-wise) centering and scaling of the X and Y input matrices,
convolution, or other preprocessing. Row-wise preprocessing can safely be
applied before passing the data to the fast cross-validation algorithm.
that operates independently on each sample such as (row-wise) centering and scaling
of the X and Y input matrices, convolution, or other preprocessing. Row-wise
preprocessing can safely be applied before passing the data to the fast
cross-validation algorithm.

## Prerequisites

Expand Down
2 changes: 1 addition & 1 deletion ikpls/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.4"
__version__ = "1.2.5"
54 changes: 27 additions & 27 deletions ikpls/fast_cross_validation/numpy_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
scale_X: bool = True,
scale_Y: bool = True,
algorithm: int = 1,
dtype: np.float_ = np.float64,
dtype: np.floating = np.float64,
) -> None:
self.center_X = center_X
self.center_Y = center_Y
Expand Down Expand Up @@ -139,27 +139,27 @@ def _stateless_fit(
validation_indices: npt.NDArray[np.int_],
) -> Union[
tuple[
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
],
tuple[
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.float_],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
npt.NDArray[np.floating],
],
]:
"""
Expand Down Expand Up @@ -433,13 +433,13 @@ def _stateless_fit(
def _stateless_predict(
self,
indices: npt.NDArray[np.int_],
B: npt.NDArray[np.float_],
training_X_mean: npt.NDArray[np.float_],
training_Y_mean: npt.NDArray[np.float_],
training_X_std: npt.NDArray[np.float_],
training_Y_std: npt.NDArray[np.float_],
B: npt.NDArray[np.floating],
training_X_mean: npt.NDArray[np.floating],
training_Y_mean: npt.NDArray[np.floating],
training_X_std: npt.NDArray[np.floating],
training_Y_std: npt.NDArray[np.floating],
n_components: Union[None, int] = None,
) -> npt.NDArray[np.float_]:
) -> npt.NDArray[np.floating]:
"""
Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using
`n_components` components. If `n_components` is None, then predictions are
Expand Down Expand Up @@ -503,7 +503,7 @@ def _stateless_fit_predict_eval(
self,
validation_indices: npt.NDArray[np.int_],
metric_function: Callable[
[npt.NDArray[np.float_], npt.NDArray[np.float_]], Any
[npt.NDArray[np.floating], npt.NDArray[np.floating]], Any
],
) -> Any:
"""
Expand Down
2 changes: 1 addition & 1 deletion ikpls/jax_ikpls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.X_std = None
self.Y_std = None

def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[np.float_]]):
def _weight_warning(self, arg: Tuple[npt.NDArray[np.int_], npt.NDArray[np.floating]]):
"""
Display a warning message if the weight is close to zero.
Expand Down
4 changes: 2 additions & 2 deletions ikpls/numpy_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
scale_X: bool = True,
scale_Y: bool = True,
copy: bool = True,
dtype: np.float_ = np.float64,
dtype: np.floating = np.float64,
) -> None:
self.algorithm = algorithm
self.center_X = center_X
Expand Down Expand Up @@ -300,7 +300,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None:

def predict(
self, X: npt.ArrayLike, n_components: Union[None, int] = None
) -> npt.NDArray[np.float_]:
) -> npt.NDArray[np.floating]:
"""
Predicts with Improved Kernel PLS Algorithm #1 on `X` with `B` using
`n_components` components. If `n_components` is None, then predictions are
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ikpls"
version = "1.2.4"
version = "1.2.5"
description = "Improved Kernel PLS and Fast Cross-Validation."
authors = ["Sm00thix <[email protected]>"]
maintainers = ["Sm00thix <[email protected]>"]
Expand All @@ -11,7 +11,7 @@ repository = "https://github.com/Sm00thix/IKPLS"

[tool.poetry.dependencies]
python = ">=3.9, <3.13"
numpy = "^1.26.3"
numpy = ">=1.26.3"
jax = "^0.4.20"
jaxlib = "^0.4.20"
scikit-learn = "^1.5.0"
Expand Down
2 changes: 1 addition & 1 deletion tests/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def load_spectra():
resp_byte_array = resp.read()
byte_contents = io.BytesIO(resp_byte_array)
npz_arr = np.load(byte_contents)
spectra = np.row_stack([npz_arr[k] for k in npz_arr.keys()])
spectra = np.vstack([npz_arr[k] for k in npz_arr.keys()])
spectra = spectra.astype(np.float64)
spectra = -np.log10(spectra)
return spectra
54 changes: 27 additions & 27 deletions tests/test_ikpls.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,27 @@ class TestClass:
csv : DataFrame
The CSV data containing target values.
raw_spectra : NDArray[float]
raw_spectra : npt.NDArray[np.float64]
The raw spectral data.
"""

csv = load_data.load_csv()
raw_spectra = load_data.load_spectra()

def load_X(self) -> npt.NDArray[np.float_]:
def load_X(self) -> npt.NDArray[np.float64]:
"""
Description
-----------
Load the raw spectral data.
Returns
-------
npt.NDArray[np.float_]
npt.NDArray[np.float64]
The raw spectral data.
"""
return np.copy(self.raw_spectra)

def load_Y(self, values: list[str]) -> npt.NDArray[np.float_]:
def load_Y(self, values: list[str]) -> npt.NDArray[np.float64]:
"""
Description
-----------
Expand All @@ -80,7 +80,7 @@ def load_Y(self, values: list[str]) -> npt.NDArray[np.float_]:
Returns
-------
NDArray[float]
npt.NDArray[np.float64]
Target values as a NumPy array.
"""
target_values = self.csv[values].to_numpy()
Expand Down Expand Up @@ -895,8 +895,8 @@ def test_pls_1(self) -> None:
jax_pls_alg_2=jax_pls_alg_2,
diff_jax_pls_alg_1=diff_jax_pls_alg_1,
diff_jax_pls_alg_2=diff_jax_pls_alg_2,
atol=1e-8,
rtol=6e-5,
atol=3e-8,
rtol=2e-4,
)

self.check_predictions(
Expand Down Expand Up @@ -963,8 +963,8 @@ def test_pls_1(self) -> None:
jax_pls_alg_2=jax_pls_alg_2,
diff_jax_pls_alg_1=diff_jax_pls_alg_1,
diff_jax_pls_alg_2=diff_jax_pls_alg_2,
atol=1e-8,
rtol=6e-5,
atol=3e-8,
rtol=2e-4,
)

self.check_predictions(
Expand Down Expand Up @@ -3178,26 +3178,26 @@ def test_fast_cross_val_pls_1(self):
splits = self.load_Y(["split"])
assert Y.shape[1] == 1
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7
)

# Remove the singleton dimension and check that the predictions are consistent.
Y = Y.squeeze()
assert Y.ndim == 1
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7
)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
Expand Down Expand Up @@ -3243,13 +3243,13 @@ def test_fast_cross_val_pls_2_m_less_k(self):
assert Y.shape[1] > 1
assert Y.shape[1] < X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-7
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-6
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-7
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-6
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-7
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-6
)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
Expand Down Expand Up @@ -3296,13 +3296,13 @@ def test_fast_cross_val_pls_2_m_eq_k(self):
assert Y.shape[1] > 1
assert Y.shape[1] == X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=0, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=0, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=0, rtol=2e-8
)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
Expand Down Expand Up @@ -3448,13 +3448,13 @@ def test_fast_cross_val_pls_2_m_less_k_loocv(self):
assert Y.shape[1] > 1
assert Y.shape[1] < X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=2e-6, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=1e-4, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=5e-6, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=1e-4, rtol=2e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=3e-6, rtol=1e-8
X, Y, splits, center=True, scale=True, atol=1e-4, rtol=2e-8
)

def test_fast_cross_val_pls_2_m_eq_k_loocv(self):
Expand Down Expand Up @@ -3494,10 +3494,10 @@ def test_fast_cross_val_pls_2_m_eq_k_loocv(self):
assert Y.shape[1] > 1
assert Y.shape[1] == X.shape[1]
self.check_fast_cross_val_pls(
X, Y, splits, center=False, scale=False, atol=1e-7, rtol=1e-8
X, Y, splits, center=False, scale=False, atol=2e-7, rtol=1e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=False, atol=1e-7, rtol=1e-8
X, Y, splits, center=True, scale=False, atol=2e-7, rtol=1e-8
)
self.check_fast_cross_val_pls(
X, Y, splits, center=True, scale=True, atol=1e-7, rtol=1e-8
Expand Down Expand Up @@ -4111,7 +4111,7 @@ def test_center_scale_combinations_pls_2_m_eq_k(self):
splits = self.load_Y(["split"]) # Contains 3 splits of different sizes
assert Y.shape[1] > 1
assert Y.shape[1] == X.shape[1]
self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8)
self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=3e-8)

# JAX will issue a warning if os.fork() is called as JAX is incompatible with
# multi-threaded code. os.fork() is called by the other cross-validation
Expand Down

0 comments on commit d11746d

Please sign in to comment.