Skip to content
This repository has been archived by the owner on Jul 20, 2023. It is now read-only.

Commit

Permalink
add tests and accuracy-preserving rho calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
ha0ye committed Sep 24, 2019
1 parent f7fbf5d commit 2f45c7b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
4 changes: 2 additions & 2 deletions R/statistical_tests.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
test_nonlinearity <- function(ts, method = "ebisuzaki", num_surr = 200,
T_period = 1, E = 1, ...)
{
compute_stats <- function(ts, ...)
compute_test_stat <- function(ts, ...)
{
results <- s_map(ts, stats_only = TRUE, silent = TRUE, ...)
delta_rho <- max(results$rho) - results$rho[results$theta == 0]
delta_mae <- results$mae[results$theta == 0] - min(results$mae)
return(c(delta_rho = delta_rho, delta_mae = delta_mae))
}

actual_stats <- compute_stats(ts, ...)
actual_stats <- compute_test_stat(ts, ...)
delta_rho <- actual_stats["delta_rho"]
delta_mae <- actual_stats["delta_mae"]
names(delta_rho) <- NULL
Expand Down
22 changes: 9 additions & 13 deletions src/forecast_machine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,14 +755,16 @@ std::vector<size_t> sort_indices(const vec& v, std::vector<size_t> idx)

PredStats compute_stats_internal(const vec& obs, const vec& pred)
{
// Obtain environment containing function
Rcpp::Environment base("package:stats");

// Make function callable from C++
Rcpp::Function cor_r = base["cor"];


size_t num_pred = 0;
double sum_errors = 0;
double sum_squared_errors = 0;
double sum_obs = 0;
double sum_pred = 0;
double sum_squared_obs = 0;
double sum_squared_pred = 0;
double sum_prod = 0;
size_t same_sign = 0;
size_t num_vectors = obs.size();
if(pred.size() < num_vectors)
Expand All @@ -775,11 +777,6 @@ PredStats compute_stats_internal(const vec& obs, const vec& pred)
++ num_pred;
sum_errors += fabs(obs[k] - pred[k]);
sum_squared_errors += (obs[k] - pred[k]) * (obs[k] - pred[k]);
sum_obs += obs[k];
sum_pred += pred[k];
sum_squared_obs += obs[k] * obs[k];
sum_squared_pred += pred[k] * pred[k];
sum_prod += obs[k] * pred[k];
if((obs[k] >= 0 && pred[k] >= 0) ||
(obs[k] <= 0 && pred[k] <= 0))
++ same_sign;
Expand All @@ -788,9 +785,8 @@ PredStats compute_stats_internal(const vec& obs, const vec& pred)

PredStats output;
output.num_pred = num_pred;
output.rho = (sum_prod * num_pred - sum_obs * sum_pred) /
sqrt((sum_squared_obs * num_pred - sum_obs * sum_obs) *
(sum_squared_pred * num_pred - sum_pred * sum_pred));
Rcpp::NumericVector cor_output = cor_r(obs, pred, "pairwise");
output.rho = cor_output[0];
output.mae = sum_errors / double(num_pred);
output.rmse = sqrt(sum_squared_errors / double(num_pred));
output.perc = double(same_sign) / double(num_pred);
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/test_02_helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ test_that("rEDM_warning filters warnings", {
expect_warning(rEDM_warning("test ABC123", silent = TRUE), NA)
})

test_that("compute_stats has acceptable accuracy on degenerate input", {
x <- rep(c(-0.2, 0.36, -0.47, 0.30), 75)
out <- simplex(x, silent = TRUE)
expect_identical(out$rho, rep(1, 10))
})

test_that("check_params_against_lib produces desired output", {
lib <- matrix(c(1, 5), ncol = 2)
expect_true(check_params_against_lib(3, 1, 1, lib))
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test_11_error_checking.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ test_that("beginning of range checking works", {
expect_match(w, "start_of_range = 201, but num_vectors = 200", all = FALSE)
expect_match(w, "start of time_range was greater than the number of vectors; skipping", all = FALSE)
expect_match(w, "no nearest neighbors found; using NA for forecast", all = FALSE)
simplex_out$rho <- NaN
simplex_out$p_val <- NaN
simplex_out$const_pred_rho <- NaN
simplex_out$const_p_val <- NaN
expect_known_hash(round(simplex_out, 4), "e5b3bb5459")
})

Expand Down

0 comments on commit 2f45c7b

Please sign in to comment.