From b7867dcb6c30272833a95997667a7cf152bcc5d2 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Tue, 17 Sep 2024 10:17:13 -0700 Subject: [PATCH 1/4] Add option to provite the ana overwrite flag at the base level --- spine/ana/factories.py | 6 ++++-- spine/ana/manager.py | 28 +++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/spine/ana/factories.py b/spine/ana/factories.py index 6b1c4f1c..62a5cb71 100644 --- a/spine/ana/factories.py +++ b/spine/ana/factories.py @@ -10,7 +10,7 @@ ANA_DICT.update(**module_dict(module)) -def ana_script_factory(name, cfg, parent_path=''): +def ana_script_factory(name, cfg, parent_path='', overwrite=False): """Instantiates an analyzer module from a configuration dictionary. Parameters @@ -22,6 +22,8 @@ def ana_script_factory(name, cfg, parent_path=''): parent_path : str Path to the parent directory of the main analysis configuration. This allows for the use of relative paths in the analyzers. + overwrite : bool, default False + If `True`, overwrite the CSV logs if they already exist Returns ------- @@ -32,4 +34,4 @@ def ana_script_factory(name, cfg, parent_path=''): cfg['name'] = name # Instantiate the analysis script module - return instantiate(ANA_DICT, cfg) + return instantiate(ANA_DICT, cfg, overwrite=overwrite) diff --git a/spine/ana/manager.py b/spine/ana/manager.py index 57948708..55d59c3b 100644 --- a/spine/ana/manager.py +++ b/spine/ana/manager.py @@ -27,14 +27,31 @@ def __init__(self, cfg, parent_path=''): cfg : dict Analysis script configurations parent_path : str, optional - Path to the analysis tools configuration file + Path to the parent directory of the main analysis configuration. + This allows for the use of relative paths in the analyzers. + """ + # Parse the analysis block configuration + self.parse_config(parent_path, **cfg) + + def parse_config(self, parent_path, overwrite=False, **modules): + """Parse the analysis tool configuration. + + Parameters + ---------- + parent_path : str, optional + Path to the parent directory of the main analysis configuration. + This allows for the use of relative paths in the analyzers. + overwrite : bool, default False + If `True`, overwrite the CSV logs if they already exist + **modules : dict + List of analysis script modules """ # Loop over the analyzer modules and get their priorities - keys = np.array(list(cfg.keys())) + keys = np.array(list(modules.keys())) priorities = -np.ones(len(keys), dtype=np.int32) for i, k in enumerate(keys): - if 'priority' in cfg[k]: - priorities[i] = cfg[k].pop('priority') + if 'priority' in modules[k]: + priorities[i] = modules[k].pop('priority') # Add the modules to a processor list in decreasing order of priority self.watch = StopwatchManager() @@ -45,7 +62,8 @@ def __init__(self, cfg, parent_path=''): self.watch.initialize(k) # Append - self.modules[k] = ana_script_factory(k, cfg[k], parent_path) + self.modules[k] = ana_script_factory( + k, modules[k], parent_path, overwrite) def __call__(self, data): """Pass one batch of data through the analysis scripts From f0423166164b20dbc66bbd27ea5c2d2c0483453b Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Tue, 17 Sep 2024 11:45:48 -0700 Subject: [PATCH 2/4] Add support to generate a log prefix when using a list of input files --- spine/driver.py | 45 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/spine/driver.py b/spine/driver.py index c91d5883..eb7cba3c 100644 --- a/spine/driver.py +++ b/spine/driver.py @@ -357,6 +357,9 @@ def initialize_io(self, loader=None, reader=None, writer=None): self.reader = reader_factory(reader) self.iter_per_epoch = len(self.reader) + # Fetch an appropriate common prefix for all input files + self.prefix = self.initialize_prefix(self.reader.file_paths) + # Initialize the data writer, if provided self.writer = None if writer is not None: @@ -365,12 +368,6 @@ def initialize_io(self, loader=None, reader=None, writer=None): self.watch.initialize('write') self.writer = writer_factory(writer) - # If requested, extract the name of the input file to prefix logs - if self.prefix_log: - assert len(self.reader.file_paths) == 1, ( - "To prefix log, there should be a single input file name.") - self.log_prefix = pathlib.Path(self.reader.file_paths[0]).stem - # Harmonize the iterations and epochs parameters assert (self.iterations is None) or (self.epochs is None), ( "Must not specify both `iterations` or `epochs` parameters.") @@ -381,6 +378,40 @@ def initialize_io(self, loader=None, reader=None, writer=None): elif self.epochs is not None: self.iterations = self.epochs*self.iter_per_epoch + @staticmethod + def initialize_prefix(file_paths): + """Builds an appropriate output prefix based on the list of input files. + + Parameters + ---------- + file_paths : List[str] + List of input file paths + + Returns + ------- + str + Shared input summary string to be used to prefix outputs + """ + # Fetch file base names (ignore where they live) + file_names = [os.path.basename(f) for f in file_paths] + + # Get the shared prefix of all files in the list + prefix = os.path.splitext(os.path.commonprefix(file_names))[0] + + # If there is only one file, done + if len(file_names) == 1: + return prefix + + # Otherwise, form the suffix from the first and last file names + first = os.path.splitext(file_names[0][len(prefix):]) + last = os.path.splitext(file_names[-1][len(prefix):]) + first = first[0] if first[0] and first[0][0] != '.' else '' + last = last[0] if last[0] and last[0][0] != '.' else '' + + suffix = f'{first}--{len(file_names)-2}--{last}' + + return prefix + suffix + def initialize_log(self): """Initialize the output log for this driver process.""" # Make a directory if it does not exist @@ -401,7 +432,7 @@ def initialize_log(self): # If requested, prefix the log name with the input file name if self.prefix_log: - log_name = f'{self.log_prefix}_{log_name}' + log_name = f'{self.prefix}_{log_name}' # Initialize the log log_path = os.path.join(self.log_dir, log_name) From 3d718c5e21472963cef024f9f6db11d40f32d52c Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Wed, 18 Sep 2024 08:59:51 -0700 Subject: [PATCH 3/4] When no name is provided for an output HDF5 file, use the input data to form it --- spine/driver.py | 6 +++--- spine/io/factories.py | 18 ++++++++---------- spine/io/write/hdf5.py | 20 +++++++++++++++----- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/spine/driver.py b/spine/driver.py index eb7cba3c..4aa0994d 100644 --- a/spine/driver.py +++ b/spine/driver.py @@ -358,7 +358,7 @@ def initialize_io(self, loader=None, reader=None, writer=None): self.iter_per_epoch = len(self.reader) # Fetch an appropriate common prefix for all input files - self.prefix = self.initialize_prefix(self.reader.file_paths) + self.prefix = self.get_prefix(self.reader.file_paths) # Initialize the data writer, if provided self.writer = None @@ -366,7 +366,7 @@ def initialize_io(self, loader=None, reader=None, writer=None): assert self.loader is None or self.unwrap, ( "Must unwrap the model output to write it to file.") self.watch.initialize('write') - self.writer = writer_factory(writer) + self.writer = writer_factory(writer, prefix=self.prefix) # Harmonize the iterations and epochs parameters assert (self.iterations is None) or (self.epochs is None), ( @@ -379,7 +379,7 @@ def initialize_io(self, loader=None, reader=None, writer=None): self.iterations = self.epochs*self.iter_per_epoch @staticmethod - def initialize_prefix(file_paths): + def get_prefix(file_paths): """Builds an appropriate output prefix based on the list of input files. Parameters diff --git a/spine/io/factories.py b/spine/io/factories.py index 2fbd7821..85766a23 100644 --- a/spine/io/factories.py +++ b/spine/io/factories.py @@ -123,8 +123,7 @@ def dataset_factory(dataset_cfg, entry_list=None, dtype=None): def sampler_factory(sampler_cfg, dataset, minibatch_size, distributed=False, num_replicas=1, rank=0): - """ - Instantiates sampler based on type specified in configuration under + """Instantiates sampler based on type specified in configuration under `io.sampler.name`. The name must match the name of a class under `spine.io.sample`. @@ -163,8 +162,7 @@ def sampler_factory(sampler_cfg, dataset, minibatch_size, distributed=False, def collate_factory(collate_cfg): - """ - Instantiates collate function based on type specified in configuration + """Instantiates collate function based on type specified in configuration under `io.collate.name`. The name must match the name of a class under `spine.io.collates`. @@ -183,8 +181,7 @@ def collate_factory(collate_cfg): def reader_factory(reader_cfg): - """ - Instantiates reader based on type specified in configuration under + """Instantiates reader based on type specified in configuration under `io.reader.name`. The name must match the name of a class under `spine.io.readers`. @@ -206,9 +203,8 @@ def reader_factory(reader_cfg): return instantiate(READER_DICT, reader_cfg) -def writer_factory(writer_cfg): - """ - Instantiates writer based on type specified in configuration under +def writer_factory(writer_cfg, prefix=None): + """Instantiates writer based on type specified in configuration under `io.writer.name`. The name must match the name of a class under `spine.io.writers`. @@ -216,6 +212,8 @@ def writer_factory(writer_cfg): ---------- writer_cfg : dict Writer configuration dictionary + prefix : str, optional + Input file prefix to use as an output name Returns ------- @@ -227,4 +225,4 @@ def writer_factory(writer_cfg): Currently the choice is limited to `HDF5Writer` only. """ # Initialize writer - return instantiate(WRITER_DICT, writer_cfg) + return instantiate(WRITER_DICT, writer_cfg, prefix=prefix) diff --git a/spine/io/write/hdf5.py b/spine/io/write/hdf5.py index 19cca10b..6855870a 100644 --- a/spine/io/write/hdf5.py +++ b/spine/io/write/hdf5.py @@ -37,13 +37,13 @@ class HDF5Writer: """ name = 'hdf5' - def __init__(self, file_name='output.h5', keys=None, skip_keys=None, - dummy_ds=None, overwrite=False, append=False): + def __init__(self, file_name=None, keys=None, skip_keys=None, + dummy_ds=None, overwrite=False, append=False, prefix=None): """Initializes the basics of the output file. Parameters ---------- - file_name : str, default 'output.h5' + file_name : str, default 'spine.h5' Name of the output HDF5 file keys : List[str], optional List of data product keys to store. If not specified, store everything @@ -53,10 +53,20 @@ def __init__(self, file_name='output.h5', keys=None, skip_keys=None, Keys for which to create placeholder datasets. For each key, specify the object type it is supposed to represent as a string. overwrite : bool, default False - If True, overwrite the output file if it already exists + If `True`, overwrite the output file if it already exists append : bool, default False - If True, add new values to the end of an existing file + If `True`, add new values to the end of an existing file + prefix : str, optional + Input file prefix. It will be use to form the output file name, + provided that no file_name is explicitely provided """ + # If the output file name is not provided, use the input file prefix + if not file_name: + assert prefix is not None, ( + "If the output `file_name` is not provided, must provide" + "the input file `prefix` to build it from.") + file_name = prefix + '_spine.h5' + # Check that output file does not already exist, if requestes if not overwrite and os.path.isfile(file_name): raise FileExistsError(f"File with name {file_name} already exists.") From 6d3cb50c5f41d51dc768d8ee245abdc533e8b09e Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Wed, 18 Sep 2024 10:00:40 -0700 Subject: [PATCH 4/4] Add option to prefix the CSV outputs of analyzer with the input file name --- spine/ana/base.py | 15 ++++++++++----- spine/ana/factories.py | 10 ++++++++-- spine/ana/manager.py | 31 +++++++++++++++++++++---------- spine/driver.py | 2 +- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/spine/ana/base.py b/spine/ana/base.py index d24312a7..60933737 100644 --- a/spine/ana/base.py +++ b/spine/ana/base.py @@ -37,7 +37,7 @@ class AnaBase(ABC): _run_modes = ('reco', 'truth', 'both', 'all') def __init__(self, obj_type=None, run_mode=None, append=False, - overwrite=False, output_prefix=None): + overwrite=False, log_dir=None, prefix=None): """Initialize default anlysis script object properties. Parameters @@ -52,7 +52,9 @@ def __init__(self, obj_type=None, run_mode=None, append=False, If True, appends existing CSV files instead of creating new ones overwrite : bool, default False If True and an output CSV file exists, overwrite it - output_prefix : str, default None + log_dir : str + Output CSV file directory (shared with driver log) + prefix : str, default None Name to prefix every output CSV file with """ # Initialize default keys @@ -109,7 +111,8 @@ def __init__(self, obj_type=None, run_mode=None, append=False, self.overwrite_file = overwrite # Initialize a writer dictionary to be filled by the children classes - self.output_prefix = output_prefix + self.log_dir = log_dir + self.output_prefix = prefix self.writers = {} def initialize_writer(self, name): @@ -123,8 +126,10 @@ def initialize_writer(self, name): # Define the name of the file to write to assert len(name) > 0, "Must provide a non-empty name." file_name = f'{self.name}_{name}.csv' - if self.output_prefix is not None: - file_name = f'{self.output_prefix}_{file_name}.csv' + if self.output_prefix: + file_name = f'{self.output_prefix}_{file_name}' + if self.log_dir: + file_name = f'{self.log_dir}/{file_name}' # Initialize the writer self.writers[name] = CSVWriter( diff --git a/spine/ana/factories.py b/spine/ana/factories.py index 62a5cb71..558f2d74 100644 --- a/spine/ana/factories.py +++ b/spine/ana/factories.py @@ -10,7 +10,7 @@ ANA_DICT.update(**module_dict(module)) -def ana_script_factory(name, cfg, parent_path='', overwrite=False): +def ana_script_factory(name, cfg, overwrite=False, log_dir=None, prefix=None): """Instantiates an analyzer module from a configuration dictionary. Parameters @@ -24,6 +24,11 @@ def ana_script_factory(name, cfg, parent_path='', overwrite=False): allows for the use of relative paths in the analyzers. overwrite : bool, default False If `True`, overwrite the CSV logs if they already exist + log_dir : str, optional + Output CSV file directory (shared with driver log) + prefix : str, optional + Input file prefix. If requested, it will be used to prefix + all the output CSV files. Returns ------- @@ -34,4 +39,5 @@ def ana_script_factory(name, cfg, parent_path='', overwrite=False): cfg['name'] = name # Instantiate the analysis script module - return instantiate(ANA_DICT, cfg, overwrite=overwrite) + return instantiate( + ANA_DICT, cfg, overwrite=overwrite, log_dir=log_dir, prefix=prefix) diff --git a/spine/ana/manager.py b/spine/ana/manager.py index 55d59c3b..11c03470 100644 --- a/spine/ana/manager.py +++ b/spine/ana/manager.py @@ -19,30 +19,37 @@ class AnaManager: CSV writers needed to store the output of the analysis scripts. """ - def __init__(self, cfg, parent_path=''): + def __init__(self, cfg, log_dir=None, prefix=None): """Initialize the analysis manager. Parameters ---------- cfg : dict Analysis script configurations - parent_path : str, optional - Path to the parent directory of the main analysis configuration. - This allows for the use of relative paths in the analyzers. + log_dir : str + Output CSV file directory (shared with driver log) + prefix : str, optional + Input file prefix. If requested, it will be used to prefix + all the output CSV files. """ # Parse the analysis block configuration - self.parse_config(parent_path, **cfg) + self.parse_config(log_dir, prefix, **cfg) - def parse_config(self, parent_path, overwrite=False, **modules): + def parse_config(self, log_dir, prefix, overwrite=False, + prefix_output=False, **modules): """Parse the analysis tool configuration. Parameters ---------- - parent_path : str, optional - Path to the parent directory of the main analysis configuration. - This allows for the use of relative paths in the analyzers. + log_dir : str + Output CSV file directory (shared with driver log) + prefix : str + Input file prefix. If requested, it will be used to prefix + all the output CSV files. overwrite : bool, default False If `True`, overwrite the CSV logs if they already exist + prefix_output : bool, optional + If `True`, will prefix the output CSV names with the input file name **modules : dict List of analysis script modules """ @@ -53,6 +60,10 @@ def parse_config(self, parent_path, overwrite=False, **modules): if 'priority' in modules[k]: priorities[i] = modules[k].pop('priority') + # Only use the prefix if the output is to be prefixed + if not prefix_output: + prefix = None + # Add the modules to a processor list in decreasing order of priority self.watch = StopwatchManager() self.modules = OrderedDict() @@ -63,7 +74,7 @@ def parse_config(self, parent_path, overwrite=False, **modules): # Append self.modules[k] = ana_script_factory( - k, modules[k], parent_path, overwrite) + k, modules[k], overwrite, log_dir, prefix) def __call__(self, data): """Pass one batch of data through the analysis scripts diff --git a/spine/driver.py b/spine/driver.py index 4aa0994d..3d7e5a39 100644 --- a/spine/driver.py +++ b/spine/driver.py @@ -132,7 +132,7 @@ def __init__(self, cfg, rank=None): assert self.model is None or self.unwrap, ( "Must unwrap the model output to run analysis scripts.") self.watch.initialize('ana') - self.ana = AnaManager(ana) + self.ana = AnaManager(ana, log_dir=self.log_dir, prefix=self.prefix) def __len__(self): """Returns the number of events in the underlying reader object."""