diff --git a/.github/actions/build/action.yml b/.github/actions/build/action.yml index efb2c29..7d5532b 100644 --- a/.github/actions/build/action.yml +++ b/.github/actions/build/action.yml @@ -24,7 +24,7 @@ runs: # stop the build if there are Python syntax errors or undefined names flake8 ikpls/ --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 ikpls/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 ikpls/ --count --exit-zero --max-complexity=10 --max-line-length=88 --statistics shell: bash - name: Install IKPLS dependencies diff --git a/examples/fast_cross_val_numpy.py b/examples/fast_cross_val_numpy.py index 5de4566..ea3aa84 100644 --- a/examples/fast_cross_val_numpy.py +++ b/examples/fast_cross_val_numpy.py @@ -107,7 +107,8 @@ def mse_for_each_target(Y_true, Y_pred): best_num_components = np.asarray( [ [ - np_pls_alg_1_fast_cv_results[split][f'num_components_lowest_mse_target_{i}'] + np_pls_alg_1_fast_cv_results[split] + [f'num_components_lowest_mse_target_{i}'] for split in unique_splits ] for i in range(M) diff --git a/ikpls/fast_cross_validation/numpy_ikpls.py b/ikpls/fast_cross_validation/numpy_ikpls.py index ddc647b..d972b2b 100644 --- a/ikpls/fast_cross_validation/numpy_ikpls.py +++ b/ikpls/fast_cross_validation/numpy_ikpls.py @@ -34,18 +34,18 @@ class PLS: Whether to center `X` before fitting by subtracting its row of column-wise means from each row. The row of column-wise means is computed on the training set for each fold to avoid data leakage. - + center_Y : bool, optional default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. The row of column-wise means is computed on the training set for each fold to avoid data leakage. - + scale_X : bool, optional default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased estimate of the sample standard deviation is used. The row of column-wise standard deviations is computed on the training set for each fold to avoid data leakage. - + scale_Y : bool, optional default=True Whether to scale `Y` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased estimate @@ -64,7 +64,7 @@ class PLS: ------ ValueError If `algorithm` is not 1 or 2. - + Notes ----- Any centering and scaling is undone before returning predictions to ensure that @@ -482,7 +482,7 @@ def _stateless_predict( predictions for that specific number of components is used. If `n_components` is None, returns a prediction for each number of components up to `A`. - + Notes ----- """ diff --git a/ikpls/jax_ikpls_alg_1.py b/ikpls/jax_ikpls_alg_1.py index 719bbe2..4ae0a15 100644 --- a/ikpls/jax_ikpls_alg_1.py +++ b/ikpls/jax_ikpls_alg_1.py @@ -31,11 +31,11 @@ class PLS(PLSBase): center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased estimate @@ -413,11 +413,11 @@ def stateless_fit( center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased diff --git a/ikpls/jax_ikpls_alg_2.py b/ikpls/jax_ikpls_alg_2.py index 628c833..a57376b 100644 --- a/ikpls/jax_ikpls_alg_2.py +++ b/ikpls/jax_ikpls_alg_2.py @@ -31,11 +31,11 @@ class PLS(PLSBase): center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased estimate @@ -389,11 +389,11 @@ def stateless_fit( center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased diff --git a/ikpls/jax_ikpls_base.py b/ikpls/jax_ikpls_base.py index 8d76683..e7d16f7 100644 --- a/ikpls/jax_ikpls_base.py +++ b/ikpls/jax_ikpls_base.py @@ -5,7 +5,7 @@ Implementations of concrete classes exist for both Improved Kernel PLS Algorithm #1 and Improved Kernel PLS Algorithm #2. -For more details, refer to the paper: +For more details, refer to the paper: "Improved Kernel Partial Least Squares Regression" by Dayal and MacGregor. Author: Ole-Christian Galbo Engstrøm @@ -669,11 +669,11 @@ def stateless_fit( center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased @@ -949,11 +949,11 @@ def stateless_fit_predict_eval( center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased @@ -1167,11 +1167,11 @@ def _inner_cross_validate( center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased diff --git a/ikpls/numpy_ikpls.py b/ikpls/numpy_ikpls.py index e0c2581..27d8dca 100644 --- a/ikpls/numpy_ikpls.py +++ b/ikpls/numpy_ikpls.py @@ -29,15 +29,15 @@ class PLS(BaseEstimator): ---------- algorithm : int, default=1 Whether to use Improved Kernel PLS Algorithm #1 or #2. - + center_X : bool, default=True Whether to center `X` before fitting by subtracting its row of column-wise means from each row. - + center_Y : bool, default=True Whether to center `Y` before fitting by subtracting its row of column-wise means from each row. - + scale_X : bool, default=True Whether to scale `X` before fitting by dividing each row with the row of `X`'s column-wise standard deviations. Bessel's correction for the unbiased estimate diff --git a/tests/test_ikpls.py b/tests/test_ikpls.py index 51c2ed1..ca206e8 100644 --- a/tests/test_ikpls.py +++ b/tests/test_ikpls.py @@ -3613,7 +3613,7 @@ def rmse_per_component(Y_true: npt.NDArray, Y_pred: npt.NDArray) -> npt.NDArray: mse = np.mean(se, axis=-2) rmse = np.sqrt(mse) return rmse - + def jax_rmse_per_component( Y_true: npt.NDArray, Y_pred: npt.NDArray ) -> npt.NDArray: @@ -3624,7 +3624,7 @@ def jax_rmse_per_component( mse = np.mean(se, axis=-2) rmse = jnp.sqrt(mse) return rmse - + jnp_splits = jnp.array(splits) def cv_splitter(splits: npt.NDArray): @@ -3753,14 +3753,15 @@ def cv_splitter(splits: npt.NDArray): # Convert the results from dict to list for easier comparison fast_cv_np_pls_alg_1_results = [ - np.asarray(value) for value in fast_cv_np_pls_alg_1_results.values() + np.asarray(value) for value + in fast_cv_np_pls_alg_1_results.values() ] fast_cv_np_pls_alg_2_results = [ np.asarray(value) for value in fast_cv_np_pls_alg_2_results.values() ] - # Sort fast cv results according to the unique splits for comparison with the - # other algorithms + # Sort fast cv results according to the unique splits for comparison with + # the other algorithms unique_splits, sort_indices = np.unique(splits, return_index=True) unique_splits = unique_splits.astype(int) fast_cv_order = np.argsort(sort_indices) @@ -3810,15 +3811,21 @@ def cv_splitter(splits: npt.NDArray): ["RMSE"], ) - # Get the best number of components in terms of minimizing validation RMSE for - # each split is equal among all algorithms + # Get the best number of components in terms of minimizing validation RMSE + # for each split is equal among all algorithms unique_splits = np.unique(splits).astype(int) np_pls_alg_1_best_num_components = [ - [np.argmin(np_pls_alg_1_rmses[split][..., i]) for split in unique_splits] + [ + np.argmin(np_pls_alg_1_rmses[split][..., i]) + for split in unique_splits + ] for i in range(M) ] np_pls_alg_2_best_num_components = [ - [np.argmin(np_pls_alg_2_rmses[split][..., i]) for split in unique_splits] + [ + np.argmin(np_pls_alg_2_rmses[split][..., i]) + for split in unique_splits + ] for i in range(M) ] fast_cv_np_pls_alg_1_best_num_components = [ @@ -3865,14 +3872,16 @@ def cv_splitter(splits: npt.NDArray): ] np_pls_alg_1_best_rmses = [ [ - np_pls_alg_1_rmses[split][np_pls_alg_1_best_num_components[i][split], i] + np_pls_alg_1_rmses[split] + [np_pls_alg_1_best_num_components[i][split], i] for split in unique_splits ] for i in range(M) ] np_pls_alg_2_best_rmses = [ [ - np_pls_alg_2_rmses[split][np_pls_alg_2_best_num_components[i][split], i] + np_pls_alg_2_rmses[split] + [np_pls_alg_2_best_num_components[i][split], i] for split in unique_splits ] for i in range(M) @@ -3932,22 +3941,49 @@ def cv_splitter(splits: npt.NDArray): for i in range(M) ] - assert_allclose(np_pls_alg_2_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol) assert_allclose( - fast_cv_np_pls_alg_1_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol + np_pls_alg_2_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol + ) + assert_allclose( + fast_cv_np_pls_alg_1_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol + ) + assert_allclose( + fast_cv_np_pls_alg_2_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol ) assert_allclose( - fast_cv_np_pls_alg_2_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol + jax_pls_alg_1_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol ) - assert_allclose(jax_pls_alg_1_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol) - assert_allclose(jax_pls_alg_2_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol) assert_allclose( - diff_jax_pls_alg_1_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol + jax_pls_alg_2_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol ) assert_allclose( - diff_jax_pls_alg_2_best_rmses, np_pls_alg_1_best_rmses, atol=atol, rtol=rtol + diff_jax_pls_alg_1_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol ) - + assert_allclose( + diff_jax_pls_alg_2_best_rmses, + np_pls_alg_1_best_rmses, + atol=atol, + rtol=rtol + ) + # JAX will issue a warning if os.fork() is called as JAX is incompatible with # multi-threaded code. os.fork() is called by the other cross-validation # algorithms. However, there is no interaction between the JAX and the other @@ -3965,17 +4001,17 @@ def test_center_scale_combinations_pls_1(self): Description ----------- This test loads input predictor variables, a single target variable, and split - indices for cross-validation. It then calls the 'check_center_scale_combinations' - method to validate the cross-validation results for all possible combinations of - centering and scaling. + indices for cross-validation. It then calls the + `check_center_scale_combinations` method to validate the cross-validation + results for all possible combinations of centering and scaling. Returns: None """ X = self.load_X() - X = X[..., :10] # Decrease the amount of features in the interest of time. + X = X[..., :10] # Decrease the amount of features in the interest of time. Y = self.load_Y(["Protein"]) - splits = self.load_Y(["split"]) # Contains 3 splits of different sizes + splits = self.load_Y(["split"]) # Contains 3 splits of different sizes assert Y.shape[1] == 1 self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8) @@ -4002,14 +4038,14 @@ def test_center_scale_combinations_pls_2_m_less_k(self): ----------- This test loads input predictor variables, multiple target variables (where M is less than K), and split indices for cross-validation. It then calls the - 'check_center_scale_combinations' method to validate the cross-validation results - for all possible combinations of centering and scaling. + `check_center_scale_combinations` method to validate the cross-validation + results for all possible combinations of centering and scaling. Returns: None """ X = self.load_X() - X = X[..., :11] # Decrease the amount of features in the interest of time. + X = X[..., :11] # Decrease the amount of features in the interest of time. Y = self.load_Y( [ "Rye_Midsummer", @@ -4024,11 +4060,11 @@ def test_center_scale_combinations_pls_2_m_less_k(self): "Protein", ] ) - splits = self.load_Y(["split"]) # Contains 3 splits of different sizes + splits = self.load_Y(["split"]) # Contains 3 splits of different sizes assert Y.shape[1] > 1 assert Y.shape[1] < X.shape[1] self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-7) - + # JAX will issue a warning if os.fork() is called as JAX is incompatible with # multi-threaded code. os.fork() is called by the other cross-validation # algorithms. However, there is no interaction between the JAX and the other @@ -4047,8 +4083,8 @@ def test_center_scale_combinations_pls_2_m_eq_k(self): ----------- This test loads input predictor variables, multiple target variables (where M is equal to K), and split indices for cross-validation. It then calls the - 'check_center_scale_combinations' method to validate the cross-validation results - for all possible combinations of centering and scaling. + `check_center_scale_combinations` method to validate the cross-validation + results for all possible combinations of centering and scaling. Returns: None @@ -4069,11 +4105,11 @@ def test_center_scale_combinations_pls_2_m_eq_k(self): ] ) X = X[..., :10] - splits = self.load_Y(["split"]) # Contains 3 splits of different sizes + splits = self.load_Y(["split"]) # Contains 3 splits of different sizes assert Y.shape[1] > 1 assert Y.shape[1] == X.shape[1] self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8) - + # JAX will issue a warning if os.fork() is called as JAX is incompatible with # multi-threaded code. os.fork() is called by the other cross-validation # algorithms. However, there is no interaction between the JAX and the other @@ -4092,8 +4128,8 @@ def test_center_scale_combinations_pls_2_m_greater_k(self): ----------- This test loads input predictor variables, multiple target variables (where M is greater than K), and split indices for cross-validation. It then calls the - 'check_center_scale_combinations' method to validate the cross-validation results - for all possible combinations of centering and scaling. + `check_center_scale_combinations` method to validate the cross-validation + results for all possible combinations of centering and scaling. Returns: None @@ -4114,7 +4150,7 @@ def test_center_scale_combinations_pls_2_m_greater_k(self): ] ) X = X[..., :9] - splits = self.load_Y(["split"]) # Contains 3 splits of different sizes + splits = self.load_Y(["split"]) # Contains 3 splits of different sizes assert Y.shape[1] > 1 assert Y.shape[1] > X.shape[1] - self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8) \ No newline at end of file + self.check_center_scale_combinations(X, Y, splits, atol=0, rtol=1e-8)