From d525197712aa77a1e842d0f6c4b08bce616fd95a Mon Sep 17 00:00:00 2001 From: Emmanuel Lujan Date: Fri, 21 Jun 2024 17:45:53 -0400 Subject: [PATCH] Improvements in opt-example and co --- examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl index d84900f1..7106bf25 100644 --- a/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl +++ b/examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl @@ -31,7 +31,7 @@ conf_train, conf_test = split(ds, n_train, n_test) # NEW: utilizty functions ####################################################### -function estimate_time(confs, iap; batch_size = 30) +function estimate_time(confs, iap; batch_size = 50) if length(confs) < batch_size batch_size = length(confs) end @@ -91,12 +91,16 @@ for (i, species, body_order, polynomial_degree, rcutoff, wL, csp, r0) in ho csp = csp, r0 = r0) iap = LBasisPotential(basis) + ## Compute energy and force descriptors e_descr_new = compute_local_descriptors(conf_train, iap.basis, pbar = false) f_descr_new = compute_force_descriptors(conf_train, iap.basis, pbar = false) ds_cur = DataSet(conf_train .+ e_descr_new .+ f_descr_new) + ## Learn learn!(iap, ds_cur, weights, intercept) + ## Get true and predicted values e, e_pred = get_all_energies(ds_cur), get_all_energies(ds_cur, iap) f, f_pred = get_all_forces(ds_cur), get_all_forces(ds_cur, iap) + ## Compute metrics e_mae, e_rmse, e_rsq = calc_metrics(e_pred, e) f_mae, f_rmse, f_rsq = calc_metrics(f_pred, f) time_us = estimate_time(conf_train, iap) * 10^6 @@ -109,17 +113,20 @@ for (i, species, body_order, polynomial_degree, rcutoff, wL, csp, r0) in ho :f_rmse => f_rmse, :f_rsq => f_rsq, :time_us => time_us) + ## Compute multi-objetive loss based on error and time if e_mae < e_mae_max && f_mae < f_mae_max loss = time_us else loss = time_us + error * 10^3 end + ## Print results println("") print("E_MAE:$(round(e_mae; digits=3)), ") print("F_MAE:$(round(f_mae; digits=3)), ") println("Time per force per atom | µs:$(round(time_us; digits=3))") flush(stdout) + ## Return loss push!(ho.history, (species, body_order, polynomial_degree, rcutoff, wL, csp, r0)) push!(ho.results, (loss, metrics, iap)) end @@ -128,8 +135,8 @@ end # Prnt and save optimization results results = get_results(ho) -println(results) @save_dataframe path results +results # Optimal IAP opt_iap = ho.minimum[3] @@ -137,7 +144,7 @@ opt_iap = ho.minimum[3] @save_var res_path opt_iap.β0 @save_var res_path opt_iap.basis -# Plot loss vs time +# Plot error vs time err_time = plot_err_time(ho) @save_fig res_path err_time DisplayAs.PNG(err_time)