From 2f45c7beb41f17dffe743763455137e60db74344 Mon Sep 17 00:00:00 2001 From: Hao Ye Date: Tue, 24 Sep 2019 14:17:14 -0400 Subject: [PATCH] add tests and accuracy-preserving rho calculation --- R/statistical_tests.R | 4 ++-- src/forecast_machine.cpp | 22 +++++++++------------- tests/testthat/test_02_helper_functions.R | 6 ++++++ tests/testthat/test_11_error_checking.R | 4 ++++ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/R/statistical_tests.R b/R/statistical_tests.R index a8192eb..3f6dece 100644 --- a/R/statistical_tests.R +++ b/R/statistical_tests.R @@ -23,7 +23,7 @@ test_nonlinearity <- function(ts, method = "ebisuzaki", num_surr = 200, T_period = 1, E = 1, ...) { - compute_stats <- function(ts, ...) + compute_test_stat <- function(ts, ...) { results <- s_map(ts, stats_only = TRUE, silent = TRUE, ...) delta_rho <- max(results$rho) - results$rho[results$theta == 0] @@ -31,7 +31,7 @@ test_nonlinearity <- function(ts, method = "ebisuzaki", num_surr = 200, return(c(delta_rho = delta_rho, delta_mae = delta_mae)) } - actual_stats <- compute_stats(ts, ...) + actual_stats <- compute_test_stat(ts, ...) delta_rho <- actual_stats["delta_rho"] delta_mae <- actual_stats["delta_mae"] names(delta_rho) <- NULL diff --git a/src/forecast_machine.cpp b/src/forecast_machine.cpp index 909c200..e27bd30 100644 --- a/src/forecast_machine.cpp +++ b/src/forecast_machine.cpp @@ -755,14 +755,16 @@ std::vector sort_indices(const vec& v, std::vector idx) PredStats compute_stats_internal(const vec& obs, const vec& pred) { + // Obtain environment containing function + Rcpp::Environment base("package:stats"); + + // Make function callable from C++ + Rcpp::Function cor_r = base["cor"]; + + size_t num_pred = 0; double sum_errors = 0; double sum_squared_errors = 0; - double sum_obs = 0; - double sum_pred = 0; - double sum_squared_obs = 0; - double sum_squared_pred = 0; - double sum_prod = 0; size_t same_sign = 0; size_t num_vectors = obs.size(); if(pred.size() < num_vectors) @@ -775,11 +777,6 @@ PredStats compute_stats_internal(const vec& obs, const vec& pred) ++ num_pred; sum_errors += fabs(obs[k] - pred[k]); sum_squared_errors += (obs[k] - pred[k]) * (obs[k] - pred[k]); - sum_obs += obs[k]; - sum_pred += pred[k]; - sum_squared_obs += obs[k] * obs[k]; - sum_squared_pred += pred[k] * pred[k]; - sum_prod += obs[k] * pred[k]; if((obs[k] >= 0 && pred[k] >= 0) || (obs[k] <= 0 && pred[k] <= 0)) ++ same_sign; @@ -788,9 +785,8 @@ PredStats compute_stats_internal(const vec& obs, const vec& pred) PredStats output; output.num_pred = num_pred; - output.rho = (sum_prod * num_pred - sum_obs * sum_pred) / - sqrt((sum_squared_obs * num_pred - sum_obs * sum_obs) * - (sum_squared_pred * num_pred - sum_pred * sum_pred)); + Rcpp::NumericVector cor_output = cor_r(obs, pred, "pairwise"); + output.rho = cor_output[0]; output.mae = sum_errors / double(num_pred); output.rmse = sqrt(sum_squared_errors / double(num_pred)); output.perc = double(same_sign) / double(num_pred); diff --git a/tests/testthat/test_02_helper_functions.R b/tests/testthat/test_02_helper_functions.R index 504309d..cef2268 100644 --- a/tests/testthat/test_02_helper_functions.R +++ b/tests/testthat/test_02_helper_functions.R @@ -6,6 +6,12 @@ test_that("rEDM_warning filters warnings", { expect_warning(rEDM_warning("test ABC123", silent = TRUE), NA) }) +test_that("compute_stats has acceptable accuracy on degenerate input", { + x <- rep(c(-0.2, 0.36, -0.47, 0.30), 75) + out <- simplex(x, silent = TRUE) + expect_identical(out$rho, rep(1, 10)) +}) + test_that("check_params_against_lib produces desired output", { lib <- matrix(c(1, 5), ncol = 2) expect_true(check_params_against_lib(3, 1, 1, lib)) diff --git a/tests/testthat/test_11_error_checking.R b/tests/testthat/test_11_error_checking.R index b04b801..5d8ccdb 100644 --- a/tests/testthat/test_11_error_checking.R +++ b/tests/testthat/test_11_error_checking.R @@ -16,6 +16,10 @@ test_that("beginning of range checking works", { expect_match(w, "start_of_range = 201, but num_vectors = 200", all = FALSE) expect_match(w, "start of time_range was greater than the number of vectors; skipping", all = FALSE) expect_match(w, "no nearest neighbors found; using NA for forecast", all = FALSE) + simplex_out$rho <- NaN + simplex_out$p_val <- NaN + simplex_out$const_pred_rho <- NaN + simplex_out$const_p_val <- NaN expect_known_hash(round(simplex_out, 4), "e5b3bb5459") })