Skip to content

Commit

Permalink
integrate optimisation and parameter evolution plot into plot_models …
Browse files Browse the repository at this point in the history
…function

Change-Id: Iafe8a0214c4d783c998e9f624af5016303911e3f
  • Loading branch information
Jaquier Aurélien Tristan committed Nov 5, 2024
1 parent 3be0b7d commit c05e965
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 49 deletions.
51 changes: 3 additions & 48 deletions bluepyemodel/emodel_pipeline/emodel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from bluepyemodel.access_point import get_access_point
from bluepyemodel.efeatures_extraction.efeatures_extraction import extract_save_features_protocols
from bluepyemodel.emodel_pipeline import plotting
from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point
from bluepyemodel.export_emodel.export_emodel import export_emodels_sonata
from bluepyemodel.model.model_configuration import configure_model
from bluepyemodel.optimisation import setup_and_run_optimisation
Expand Down Expand Up @@ -258,57 +257,14 @@ def plot(self, only_validated=False, load_from_local=False, seeds=None):
"""
pp_settings = self.access_point.pipeline_settings

cell_evaluator = get_evaluator_from_access_point(
self.access_point,
include_validation_protocols=True,
record_ions_and_currents=pp_settings.plot_currentscape,
)

chkp_paths = glob.glob("./checkpoints/**/*.pkl", recursive=True)
if not chkp_paths:
raise ValueError("The checkpoints directory is empty, or there are no .pkl files.")

# Filter the checkpoints to plot
checkpoint_paths = []
for chkp_path in chkp_paths:
if self.access_point.emodel_metadata.emodel not in chkp_path.split("/"):
continue
if (
self.access_point.emodel_metadata.iteration
and self.access_point.emodel_metadata.iteration not in chkp_path.split("/")
):
continue
checkpoint_paths.append(chkp_path)

stem = str(pathlib.Path(chkp_path).stem)
seed = int(stem.rsplit("seed=", maxsplit=1)[-1])

plotting.optimisation(
optimiser=pp_settings.optimiser,
emodel=self.access_point.emodel_metadata.emodel,
iteration=self.access_point.emodel_metadata.iteration,
seed=seed,
checkpoint_path=chkp_path,
figures_dir=pathlib.Path("./figures")
/ self.access_point.emodel_metadata.emodel
/ "optimisation",
)

if pp_settings.plot_parameter_evolution:
plotting.evolution_parameters_density(
evaluator=cell_evaluator,
checkpoint_paths=checkpoint_paths,
metadata=self.access_point.emodel_metadata,
figures_dir=pathlib.Path("./figures")
/ self.access_point.emodel_metadata.emodel
/ "parameter_evolution",
)

return plotting.plot_models(
access_point=self.access_point,
mapper=self.mapper,
seeds=seeds,
figures_dir=pathlib.Path("./figures") / self.access_point.emodel_metadata.emodel,
plot_optimisation_progress=True,
optimiser=pp_settings.optimiser,
plot_parameter_evolution=True,
plot_distributions=True,
plot_scores=True,
plot_traces=True,
Expand All @@ -331,7 +287,6 @@ def plot(self, only_validated=False, load_from_local=False, seeds=None):
only_validated=only_validated,
save_recordings=pp_settings.save_recordings,
load_from_local=load_from_local,
cell_evaluator=cell_evaluator,
)

def export_emodels(self, only_validated=False, seeds=None):
Expand Down
36 changes: 36 additions & 0 deletions bluepyemodel/emodel_pipeline/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from bluepyemodel.evaluation.utils import define_EPSP_feature
from bluepyemodel.evaluation.utils import define_EPSP_protocol
from bluepyemodel.model.morphology_utils import get_basal_and_apical_max_radial_distances
from bluepyemodel.tools.utils import existing_checkpoint_paths
from bluepyemodel.tools.utils import get_amplitude_from_feature_key
from bluepyemodel.tools.utils import make_dir
from bluepyemodel.tools.utils import read_checkpoint
Expand Down Expand Up @@ -1773,6 +1774,9 @@ def plot_models(
mapper,
seeds=None,
figures_dir="./figures",
plot_optimisation_progress=True,
optimiser=None,
plot_parameter_evolution=True,
plot_distributions=True,
plot_scores=True,
plot_traces=True,
Expand Down Expand Up @@ -1807,6 +1811,10 @@ def plot_models(
individual in the population.
seeds (list): if not None, filter emodels to keep only the ones with these seeds.
figures_dir (str): path of the directory in which the figures should be saved.
plot_optimisation_progress (bool): True to plot the optimisation progress from checkpoint
optimiser (str): name of the algorithm used for optimisation, can be "IBEA", "SO-CMA"
or "MO-CMA". Is used in optimisation progress plotting.
plot_parameter_evolution (bool): True to plot parameter evolution
plot_distributions (bool): True to plot the parameters distributions
plot_scores (bool): True to plot the scores
plot_traces (bool): True to plot the traces
Expand Down Expand Up @@ -1885,6 +1893,34 @@ def plot_models(
access_point.pipeline_settings.currentscape_config["current"]["names"],
use_fixed_dt_recordings=False,
)

if plot_optimisation_progress or plot_parameter_evolution:
checkpoint_paths = existing_checkpoint_paths(access_point.emodel_metadata)

if plot_optimisation_progress:
if optimiser is None:
logger.warning("Will not plot optimisation progress because optimiser was not given.")
else:
for chkp_path in checkpoint_paths:
stem = str(Path(chkp_path).stem)
seed = int(stem.rsplit("seed=", maxsplit=1)[-1])

optimisation(
optimiser=optimiser,
emodel=access_point.emodel_metadata.emodel,
iteration=access_point.emodel_metadata.iteration,
seed=seed,
checkpoint_path=chkp_path,
figures_dir=figures_dir / "optimisation",
)

if plot_parameter_evolution:
evolution_parameters_density(
evaluator=cell_evaluator,
checkpoint_paths=checkpoint_paths,
metadata=access_point.emodel_metadata,
figures_dir=figures_dir / "parameter_evolution",
)

if any(
(
Expand Down
23 changes: 22 additions & 1 deletion bluepyemodel/tasks/emodel_creation/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from bluepyemodel.tasks.luigi_tools import WorkflowTaskRequiringMechanisms
from bluepyemodel.tasks.luigi_tools import WorkflowWrapperTask
from bluepyemodel.tools.mechanisms import compile_mechs_in_emodel_dir
from bluepyemodel.tools.utils import existing_checkpoint_paths

# pylint: disable=W0235,W0621,W0404,W0611,W0703,E1128
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1250,6 +1251,7 @@ def run(self):
""" """

plot_optimisation = self.access_point.pipeline_settings.plot_optimisation
optimiser = self.access_point.pipeline_settings.optimiser
plot_currentscape = self.access_point.pipeline_settings.plot_currentscape
plot_bAP_EPSP = self.access_point.pipeline_settings.plot_bAP_EPSP
plot_IV_curves = self.access_point.pipeline_settings.plot_IV_curves
Expand All @@ -1269,6 +1271,9 @@ def run(self):
mapper=mapper,
seeds=range(self.seed, self.seed + batch_size),
figures_dir=Path("./figures") / self.emodel,
plot_optimisation_progress=plot_optimisation,
optimiser=optimiser,
plot_parameter_evolution=plot_optimisation,
plot_distributions=plot_optimisation,
plot_traces=plot_optimisation,
plot_scores=plot_optimisation,
Expand All @@ -1289,7 +1294,8 @@ def run(self):
)

if isinstance(self.access_point, NexusAccessPoint):
self.access_point.update_emodel_images(seed=self.seed, keep_old_images=False)
for seed in range(self.seed, self.seed + batch_size):
self.access_point.update_emodel_images(seed=seed, keep_old_images=False)

def output(self):
""" """
Expand All @@ -1299,6 +1305,21 @@ def output(self):

outputs = []
if plot_optimisation:
# optimisation progress
for checkpoint_path in existing_checkpoint_paths(self.access_point.emodel_metadata):
p = Path(checkpoint_path)
fname = p.stem
fname += "__optimisation.pdf"
fpath = Path("./figures") / self.emodel / "optimisation" / fname
outputs.append(luigi.LocalTarget(fpath))

# parameter evolution
for seed in range(self.seed, self.seed + batch_size):
fname = self.access_point.emodel_metadata.as_string(seed=seed)
fname += "__evo_parameter_density.pdf"
fpath = Path("./figures") / self.emodel / "parameter_evolution" / fname
outputs.append(luigi.LocalTarget(fpath))

# distribution
fname = self.access_point.emodel_metadata.as_string()
fname += "__parameters_distribution.pdf"
Expand Down
23 changes: 23 additions & 0 deletions bluepyemodel/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
limitations under the License.
"""

import glob
import logging
import pickle
from pathlib import Path
Expand All @@ -28,6 +29,28 @@
logger = logging.getLogger("__main__")


def existing_checkpoint_paths(emodel_metadata, checkpoint_paths=None):
"""Returns a list of existing checkpoint paths conforming to metadata.
Args:
emodel_metadata (EModelMetadata): contains emodel and iteration
that should be present in each returned checkpoint path
checkpoint_paths (list): list of existing checkpoint paths to be filtered
using metadata. If None, will be created on the spot.
"""
if checkpoint_paths is None:
checkpoint_paths = glob.glob("./checkpoints/**/*.pkl", recursive=True)
if not checkpoint_paths:
raise ValueError("The checkpoints directory is empty, or there are no .pkl files.")

if not emodel_metadata.iteration:
return [chkp for chkp in checkpoint_paths if emodel_metadata.emodel in chkp.split("/")]
return [
chkp for chkp in checkpoint_paths
if emodel_metadata.emodel in chkp.split("/")
and emodel_metadata.iteration in chkp.split("/")
]

def checkpoint_path_exists(checkpoint_path):
"""Returns True if checkpoint path exists, False if not.
Expand Down

0 comments on commit c05e965

Please sign in to comment.