Skip to content

Commit

Permalink
Add @info into algorithm comparison scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
zengfung committed Jul 17, 2022
1 parent 5fb19c4 commit f952f67
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/CompareAlgorithmClassificationRates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include("mod/utils.jl")
import .LDBUtils: repeat_experiment, dict2dataframes, aggregate_results

## ========= Setup =========
@info "LDB Object Setup"
# TODO: Change wavelet filter
wt = wavelet(WT.coif2)
# TODO: Change decomposition levels
Expand Down Expand Up @@ -62,6 +63,7 @@ classifiers = Dict("LDA" => LDA(), "CT" => CT())
measures = [MisclassificationRate(), MulticlassPrecision(), MulticlassTruePositiveRate()]

## ========== Run experiments ==========
@info "Run Experiments"
# Set `save_data` to `true` to see the results in csv files.
results_raw = repeat_experiment(ldbs, classifiers, measures; repeats = 100, save_data=false)
results_by_measure = Dict("$((string typeof)(measure))" => dict2dataframes(results_raw, measure; save_data=false) for measure in measures)
Expand Down
23 changes: 14 additions & 9 deletions src/CompareAlgorithmVisualizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ include("mod/utils.jl")
import .LDBUtils: compute_ldb_vectors, plot_coefficients

## ========== Setup ==========
@info "LDB Object Setup"
train_x, train_y = generateclassdata(ClassData(:tri, 33 , 33 , 33 ), false)
test_x , test_y = generateclassdata(ClassData(:tri, 333, 333, 333), true)
# TODO: Change wavelet filter
Expand Down Expand Up @@ -104,6 +105,7 @@ ldbkemd = LocalDiscriminantBasis(wt = wt,
n_features = n_features)

## ========== Display sample dataset ==========
@info "Data Setup"
# TODO: Change indexing if train size is changed above
class1_mean = mean(train_x[:, 1:33], dims=2)
class2_mean = mean(train_x[:,34:66], dims=2)
Expand All @@ -114,6 +116,7 @@ plot!(p01, class3_mean, linestyle=:dot, color=:black, label="Class 3")
plot!(p01, title="Mean training waveform")

## ========== Fit LDB ==========
@info "LDB Transform"
# Fit and transform training data
ldbk_train_features = WaveletsExt.fit_transform(ldbk , train_x, train_y)
ldbkash_train_features = WaveletsExt.fit_transform(ldbkash, train_x, train_y)
Expand Down Expand Up @@ -144,6 +147,7 @@ p09 = plot_coefficients(ldbkash_train_features) |> p -> plot!(p, title="LDBKASH"
p10 = plot_coefficients(ldbkemd_train_features) |> p -> plot!(p, title="LDBKEMD")

## ========== Fit features in Linear Discriminant Analysis (LDA) classifier ==========
@info "Model Training"
# Data wrangling to fit MLJ.jl syntax
original_train_features = DataFrame(train_x', :auto)
ldbk_train_features = DataFrame(ldbk_train_features', :auto)
Expand All @@ -156,7 +160,7 @@ ldbkash_test_features = DataFrame(ldbkash_test_features', :auto)
ldbkemd_test_features = DataFrame(ldbkemd_test_features', :auto)
test_y = coerce(test_y, Multiclass)
# Model fitting
LDA = @load LDA pkg=MultivariateStats
LDA = @load LDA pkg=MultivariateStats verbosity=0
original_classifier = machine(LDA(), original_train_features, train_y)
ldbk_classifier = machine(LDA(), ldbk_train_features, train_y)
ldbkash_classifier = machine(LDA(), ldbkash_train_features, train_y)
Expand All @@ -175,19 +179,20 @@ ldbk_test_ŷ = predict_mode(ldbk_classifier , ldbk_test_features)
ldbkash_test_ŷ = predict_mode(ldbkash_classifier , ldbkash_test_features)
ldbkemd_test_ŷ = predict_mode(ldbkemd_classifier , ldbkemd_test_features)
# Evaluate Model
@info "Model Evaluation"
original_train_accuracy = Accuracy()(original_train_ŷ, train_y)
@info "Original LDB Train Accuracy: $original_train_accuracy"
@info "\tOriginal LDB Train Accuracy: $original_train_accuracy"
ldbk_train_accuracy = Accuracy()(ldbk_train_ŷ , train_y)
@info "LDBK Train Accuracy: $ldbk_train_accuracy"
@info "\tLDBK Train Accuracy: $ldbk_train_accuracy"
ldbkash_train_accuracy = Accuracy()(ldbkash_train_ŷ , train_y)
@info "LDBASH Train Accuracy: $ldbkash_train_accuracy"
@info "\tLDBASH Train Accuracy: $ldbkash_train_accuracy"
ldbkemd_train_accuracy = Accuracy()(ldbkemd_train_ŷ , train_y)
@info "LDBKEMD Train Accuracy: $ldbkemd_train_accuracy"
@info "\tLDBKEMD Train Accuracy: $ldbkemd_train_accuracy"
original_test_accuracy = Accuracy()(original_test_ŷ, test_y)
@info "Original LDB Test Accuracy: $original_test_accuracy"
@info "\tOriginal LDB Test Accuracy: $original_test_accuracy"
ldbk_test_accuracy = Accuracy()(ldbk_test_ŷ , test_y)
@info "LDBK Test Accuracy: $ldbk_test_accuracy"
@info "\tLDBK Test Accuracy: $ldbk_test_accuracy"
ldbkash_test_accuracy = Accuracy()(ldbkash_test_ŷ , test_y)
@info "LDBKASH Test Accuracy: $ldbkash_test_accuracy"
@info "\tLDBKASH Test Accuracy: $ldbkash_test_accuracy"
ldbkemd_test_accuracy = Accuracy()(ldbkemd_test_ŷ , test_y)
@info "LDBKEMD Test Accuracy: $ldbkemd_test_accuracy"
@info "\tLDBKEMD Test Accuracy: $ldbkemd_test_accuracy"

0 comments on commit f952f67

Please sign in to comment.