From f1a60ea5de7e736c88d2a431e9f813973fd750c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ole=20Engstr=C3=B8m?= Date: Sun, 7 Jul 2024 22:26:58 +0200 Subject: [PATCH] Now comparing std to 0 using float epsilon instead of direct comparison. This seems to have been the cause of #33 --- .github/workflows/pull_request_test_workflow.yml | 2 +- .github/workflows/test_workflow.yml | 2 +- ikpls/__init__.py | 2 +- ikpls/fast_cross_validation/numpy_ikpls.py | 11 ++++++----- ikpls/jax_ikpls_base.py | 3 ++- ikpls/numpy_ikpls.py | 11 ++++++----- pyproject.toml | 2 +- 7 files changed, 18 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pull_request_test_workflow.yml b/.github/workflows/pull_request_test_workflow.yml index 31c3547..4a4cfe8 100644 --- a/.github/workflows/pull_request_test_workflow.yml +++ b/.github/workflows/pull_request_test_workflow.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-12] + os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/test_workflow.yml b/.github/workflows/test_workflow.yml index ae61d01..401a686 100644 --- a/.github/workflows/test_workflow.yml +++ b/.github/workflows/test_workflow.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-12] + os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/ikpls/__init__.py b/ikpls/__init__.py index 10aa336..b3f9ac7 100644 --- a/ikpls/__init__.py +++ b/ikpls/__init__.py @@ -1 +1 @@ -__version__ = "1.2.3" +__version__ = "1.2.4" diff --git a/ikpls/fast_cross_validation/numpy_ikpls.py b/ikpls/fast_cross_validation/numpy_ikpls.py index d972b2b..9c4208c 100644 --- a/ikpls/fast_cross_validation/numpy_ikpls.py +++ b/ikpls/fast_cross_validation/numpy_ikpls.py @@ -92,6 +92,7 @@ def __init__( self.scale_Y = scale_Y self.algorithm = algorithm self.dtype = dtype + self.eps = np.finfo(dtype).eps self.name = f"Improved Kernel PLS Algorithm #{algorithm}" if self.algorithm not in [1, 2]: raise ValueError( @@ -269,7 +270,7 @@ def _stateless_fit( + train_sum_sq_X ) ) - training_X_std[training_X_std == 0] = 1 + training_X_std[np.abs(training_X_std) <= self.eps] = 1 # Compute the training set standard deviations for Y if self.scale_Y: @@ -289,7 +290,7 @@ def _stateless_fit( + train_sum_sq_Y ) ) - training_Y_std[training_Y_std == 0] = 1 + training_Y_std[np.abs(training_Y_std) <= self.eps] = 1 # Subtract the validation set's contribution from the total XTY training_XTY = self.XTY - validation_X.T @ validation_Y @@ -341,7 +342,7 @@ def _stateless_fit( # Step 2 if self.M == 1: norm = la.norm(training_XTY, ord=2) - if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + if np.isclose(norm, 0, atol=self.eps, rtol=0): self._weight_warning(i) break w = training_XTY / norm @@ -352,7 +353,7 @@ def _stateless_fit( q = eig_vecs[:, -1:] w = training_XTY @ q norm = la.norm(w) - if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + if np.isclose(norm, 0, atol=self.eps, rtol=0): self._weight_warning(i) break w = w / norm @@ -360,7 +361,7 @@ def _stateless_fit( training_XTYYTX = training_XTY @ training_XTY.T eig_vals, eig_vecs = la.eigh(training_XTYYTX) norm = eig_vals[-1] - if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + if np.isclose(norm, 0, atol=self.eps, rtol=0): self._weight_warning(i) break w = eig_vecs[:, -1:] diff --git a/ikpls/jax_ikpls_base.py b/ikpls/jax_ikpls_base.py index 4a7fe89..f645dfd 100644 --- a/ikpls/jax_ikpls_base.py +++ b/ikpls/jax_ikpls_base.py @@ -101,6 +101,7 @@ def __init__( self.scale_Y = scale_Y self.copy = copy self.dtype = dtype + self.eps = jnp.finfo(self.dtype).eps self.reverse_differentiable = reverse_differentiable self.verbose = verbose self.name = "Improved Kernel PLS Algorithm" @@ -236,7 +237,7 @@ def get_std(self, A: ArrayLike): print(f"get_stds for {self.name} will be JIT compiled...") A_std = jnp.std(A, axis=0, dtype=self.dtype, keepdims=True, ddof=1) - A_std = jnp.where(A_std == 0, 1, A_std) + A_std = jnp.where(jnp.abs(A_std) <= self.eps, 1, A_std) return A_std @partial(jax.jit, static_argnums=(0, 3, 4, 5, 6, 7)) diff --git a/ikpls/numpy_ikpls.py b/ikpls/numpy_ikpls.py index 27d8dca..0ff7109 100644 --- a/ikpls/numpy_ikpls.py +++ b/ikpls/numpy_ikpls.py @@ -88,6 +88,7 @@ def __init__( self.scale_Y = scale_Y self.copy = copy self.dtype = dtype + self.eps = np.finfo(dtype).eps self.name = f"Improved Kernel PLS Algorithm #{algorithm}" if self.algorithm not in [1, 2]: raise ValueError( @@ -207,12 +208,12 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None: if self.scale_X: self.X_std = X.std(axis=0, ddof=1, dtype=self.dtype, keepdims=True) - self.X_std[self.X_std == 0] = 1 + self.X_std[np.abs(self.X_std) <= self.eps] = 1 X /= self.X_std if self.scale_Y: self.Y_std = Y.std(axis=0, ddof=1, dtype=self.dtype, keepdims=True) - self.Y_std[self.Y_std == 0] = 1 + self.Y_std[np.abs(self.Y_std) <= self.eps] = 1 Y /= self.Y_std N, K = X.shape @@ -246,7 +247,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None: # Step 2 if M == 1: norm = la.norm(XTY, ord=2) - if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + if np.isclose(norm, 0, atol=self.eps, rtol=0): self._weight_warning(i) break w = XTY / norm @@ -257,7 +258,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None: q = eig_vecs[:, -1:] w = XTY @ q norm = la.norm(w) - if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + if np.isclose(norm, 0, atol=self.eps, rtol=0): self._weight_warning(i) break w = w / norm @@ -265,7 +266,7 @@ def fit(self, X: npt.ArrayLike, Y: npt.ArrayLike, A: int) -> None: XTYYTX = XTY @ XTY.T eig_vals, eig_vecs = la.eigh(XTYYTX) norm = eig_vals[-1] - if np.isclose(norm, 0, atol=np.finfo(np.float64).eps, rtol=0): + if np.isclose(norm, 0, atol=self.eps, rtol=0): self._weight_warning(i) break w = eig_vecs[:, -1:] diff --git a/pyproject.toml b/pyproject.toml index 3f262ce..00b72b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ikpls" -version = "1.2.3" +version = "1.2.4" description = "Improved Kernel PLS and Fast Cross-Validation." authors = ["Sm00thix "] maintainers = ["Sm00thix "]