Skip to content

Commit

Permalink
Improvements in opt-example and co
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuellujan committed Jun 21, 2024
1 parent 49a09e4 commit d525197
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions examples/Opt-ACE-aHfO2/fit-opt-ace-ahfo2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ conf_train, conf_test = split(ds, n_train, n_test)

# NEW: utilizty functions #######################################################

function estimate_time(confs, iap; batch_size = 30)
function estimate_time(confs, iap; batch_size = 50)
if length(confs) < batch_size
batch_size = length(confs)
end
Expand Down Expand Up @@ -91,12 +91,16 @@ for (i, species, body_order, polynomial_degree, rcutoff, wL, csp, r0) in ho
csp = csp,
r0 = r0)
iap = LBasisPotential(basis)
## Compute energy and force descriptors
e_descr_new = compute_local_descriptors(conf_train, iap.basis, pbar = false)
f_descr_new = compute_force_descriptors(conf_train, iap.basis, pbar = false)
ds_cur = DataSet(conf_train .+ e_descr_new .+ f_descr_new)
## Learn
learn!(iap, ds_cur, weights, intercept)
## Get true and predicted values
e, e_pred = get_all_energies(ds_cur), get_all_energies(ds_cur, iap)
f, f_pred = get_all_forces(ds_cur), get_all_forces(ds_cur, iap)
## Compute metrics
e_mae, e_rmse, e_rsq = calc_metrics(e_pred, e)
f_mae, f_rmse, f_rsq = calc_metrics(f_pred, f)
time_us = estimate_time(conf_train, iap) * 10^6
Expand All @@ -109,17 +113,20 @@ for (i, species, body_order, polynomial_degree, rcutoff, wL, csp, r0) in ho
:f_rmse => f_rmse,
:f_rsq => f_rsq,
:time_us => time_us)
## Compute multi-objetive loss based on error and time
if e_mae < e_mae_max &&
f_mae < f_mae_max
loss = time_us
else
loss = time_us + error * 10^3
end
## Print results
println("")
print("E_MAE:$(round(e_mae; digits=3)), ")
print("F_MAE:$(round(f_mae; digits=3)), ")
println("Time per force per atom | µs:$(round(time_us; digits=3))")
flush(stdout)
## Return loss
push!(ho.history, (species, body_order, polynomial_degree, rcutoff, wL, csp, r0))
push!(ho.results, (loss, metrics, iap))
end
Expand All @@ -128,16 +135,16 @@ end

# Prnt and save optimization results
results = get_results(ho)
println(results)
@save_dataframe path results
results

# Optimal IAP
opt_iap = ho.minimum[3]
@save_var res_path opt_iap.β
@save_var res_path opt_iap.β0
@save_var res_path opt_iap.basis

# Plot loss vs time
# Plot error vs time
err_time = plot_err_time(ho)
@save_fig res_path err_time
DisplayAs.PNG(err_time)
Expand Down

0 comments on commit d525197

Please sign in to comment.