diff --git a/utils/python/linelast_cwtrain.py b/utils/python/linelast_cwtrain.py index 3e4ea86d..11b0df82 100644 --- a/utils/python/linelast_cwtrain.py +++ b/utils/python/linelast_cwtrain.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import numpy as np +from scipy.interpolate import interp1d import h5py import os import matplotlib.pyplot as plt @@ -153,6 +154,7 @@ def get_results(samples, prefix): def solve_time_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"): plt.plot(samples, res[:,0], label='FOM') plt.plot(samples, res[:,1], label='ROM') + plt.xscale('log') plt.xlabel(scale_prefix) plt.yscale('log') plt.ylabel("Solve time [s]") @@ -164,6 +166,7 @@ def solve_time_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png" def relerr_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"): plt.plot(samples, res[:,2], label='Relative error') + plt.xscale('log') plt.xlabel(scale_prefix) plt.yscale('log') plt.ylabel("Relative error [-]") @@ -174,6 +177,7 @@ def relerr_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"): def speedup_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"): plt.plot(samples, res[:,3]) + plt.xscale('log') plt.xlabel(scale_prefix) plt.ylabel("Speedup factor [-]") @@ -181,11 +185,26 @@ def speedup_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"): plt.savefig("speedup_" + plt_name, dpi=300) plt.clf() +def export_opt_val(filename, opt_nb_a, opt_nb_i, opt_speedup): + f = open(filename, "w") + out_txt = "Optimal number of bases (interpolated) is: " + str(round(float(opt_nb_a), 4)) + " bases\nOptimal number of bases (rounded) is: " + str(opt_nb_i) + " bases\nSpeedup at rounded number of bases is: " + str(round(float(opt_speedup), 4)) + " x" + f.write(out_txt) + f.close() + def create_scaling_plot(samples, res, scale_prefix, plt_name = "plot.png"): plt.rc('axes', labelsize=14) + ferr = interp1d(samples, res[:,2]) # Rel err + ferr_i = interp1d(res[:,2], samples) # Inverse correlation + x_star_a = ferr_i(1e-2) # Analytical x_star + x_star_i = np.ceil(x_star_a) # Next integer + + fspeed = interp1d(samples, res[:,3]) # speedup factor + opt_speedup = fspeed(x_star_i) + solve_time_scaling_plot(samples, res, scale_prefix, plt_name) relerr_scaling_plot(samples, res, scale_prefix, plt_name) speedup_scaling_plot(samples, res, scale_prefix, plt_name) + export_opt_val("opt_vals.txt", x_star_a, x_star_i, opt_speedup) def get_nr(txt, split_txt = 'comparison'): return int(txt.split('.')[0].split(split_txt)[1])