diff --git a/arkane/encorr/bac.py b/arkane/encorr/bac.py index 4ee87e9a5d..f9b09f22c3 100644 --- a/arkane/encorr/bac.py +++ b/arkane/encorr/bac.py @@ -1047,20 +1047,21 @@ def fit(self, logging.info(f'RMSE/MAE before fitting: {stats_before.rmse:.2f}/{stats_before.mae:.2f} kcal/mol') logging.info(f'RMSE/MAE after fitting: {stats_after.rmse:.2f}/{stats_after.mae:.2f} kcal/mol') - rmse_before = [test_data.calculate_stats().rmse for test_data in test_data_results] - mae_before = [test_data.calculate_stats().mae for test_data in test_data_results] - rmse_after = [test_data.calculate_stats(for_bac_data=True).rmse for test_data in test_data_results] - mae_after = [test_data.calculate_stats(for_bac_data=True).mae for test_data in test_data_results] + num_test_data = sum(len(test_data) for test_data in test_data_results) + rmse_before = np.sqrt(np.sum([test_data.calculate_stats().rmse**2 * len(test_data) for test_data in test_data_results]) / num_test_data) + mae_before = np.sum([test_data.calculate_stats().mae * len(test_data) for test_data in test_data_results]) / num_test_data + rmse_after = np.sqrt(np.sum([test_data.calculate_stats(for_bac_data=True).rmse**2 * len(test_data) for test_data in test_data_results]) / num_test_data) + mae_after = np.sum([test_data.calculate_stats(for_bac_data=True).mae * len(test_data) for test_data in test_data_results]) / num_test_data logging.info('\nCross-validation results:') - logging.info(f'Testing RMSE before fitting (mean +- 1 std): ' - f'{np.average(rmse_before):.2f} +- {np.std(rmse_before):.2f} kcal/mol') - logging.info(f'Testing MAE before fitting (mean +- 1 std): ' - f'{np.average(mae_before):.2f} +- {np.std(mae_before):.2f} kcal/mol') - logging.info(f'Testing RMSE after fitting (mean +- 1 std): ' - f'{np.average(rmse_after):.2f} +- {np.std(rmse_after):.2f} kcal/mol') - logging.info(f'Testing MAE after fitting (mean +- 1 std): ' - f'{np.average(mae_after):.2f} +- {np.std(mae_after):.2f} kcal/mol') + logging.info(f'Testing RMSE before fitting: ' + f'{rmse_before:.2f} kcal/mol') + logging.info(f'Testing MAE before fitting: ' + f'{mae_before:.2f} kcal/mol') + logging.info(f'Testing RMSE after fitting: ' + f'{rmse_after:.2f} kcal/mol') + logging.info(f'Testing MAE after fitting: ' + f'{mae_after:.2f} kcal/mol') def get_confidence_intervals(x: np.ndarray,