Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand options to prefix output and CSV logs with the input file name #23

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions spine/ana/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AnaBase(ABC):
_run_modes = ('reco', 'truth', 'both', 'all')

def __init__(self, obj_type=None, run_mode=None, append=False,
overwrite=False, output_prefix=None):
overwrite=False, log_dir=None, prefix=None):
"""Initialize default anlysis script object properties.

Parameters
Expand All @@ -52,7 +52,9 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
If True, appends existing CSV files instead of creating new ones
overwrite : bool, default False
If True and an output CSV file exists, overwrite it
output_prefix : str, default None
log_dir : str
Output CSV file directory (shared with driver log)
prefix : str, default None
Name to prefix every output CSV file with
"""
# Initialize default keys
Expand Down Expand Up @@ -109,7 +111,8 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
self.overwrite_file = overwrite

# Initialize a writer dictionary to be filled by the children classes
self.output_prefix = output_prefix
self.log_dir = log_dir
self.output_prefix = prefix
self.writers = {}

def initialize_writer(self, name):
Expand All @@ -123,8 +126,10 @@ def initialize_writer(self, name):
# Define the name of the file to write to
assert len(name) > 0, "Must provide a non-empty name."
file_name = f'{self.name}_{name}.csv'
if self.output_prefix is not None:
file_name = f'{self.output_prefix}_{file_name}.csv'
if self.output_prefix:
file_name = f'{self.output_prefix}_{file_name}'
if self.log_dir:
file_name = f'{self.log_dir}/{file_name}'

# Initialize the writer
self.writers[name] = CSVWriter(
Expand Down
12 changes: 10 additions & 2 deletions spine/ana/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ANA_DICT.update(**module_dict(module))


def ana_script_factory(name, cfg, parent_path=''):
def ana_script_factory(name, cfg, overwrite=False, log_dir=None, prefix=None):
"""Instantiates an analyzer module from a configuration dictionary.

Parameters
Expand All @@ -22,6 +22,13 @@ def ana_script_factory(name, cfg, parent_path=''):
parent_path : str
Path to the parent directory of the main analysis configuration. This
allows for the use of relative paths in the analyzers.
overwrite : bool, default False
If `True`, overwrite the CSV logs if they already exist
log_dir : str, optional
Output CSV file directory (shared with driver log)
prefix : str, optional
Input file prefix. If requested, it will be used to prefix
all the output CSV files.

Returns
-------
Expand All @@ -32,4 +39,5 @@ def ana_script_factory(name, cfg, parent_path=''):
cfg['name'] = name

# Instantiate the analysis script module
return instantiate(ANA_DICT, cfg)
return instantiate(
ANA_DICT, cfg, overwrite=overwrite, log_dir=log_dir, prefix=prefix)
43 changes: 36 additions & 7 deletions spine/ana/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,50 @@ class AnaManager:
CSV writers needed to store the output of the analysis scripts.
"""

def __init__(self, cfg, parent_path=''):
def __init__(self, cfg, log_dir=None, prefix=None):
"""Initialize the analysis manager.

Parameters
----------
cfg : dict
Analysis script configurations
parent_path : str, optional
Path to the analysis tools configuration file
log_dir : str
Output CSV file directory (shared with driver log)
prefix : str, optional
Input file prefix. If requested, it will be used to prefix
all the output CSV files.
"""
# Parse the analysis block configuration
self.parse_config(log_dir, prefix, **cfg)

def parse_config(self, log_dir, prefix, overwrite=False,
prefix_output=False, **modules):
"""Parse the analysis tool configuration.

Parameters
----------
log_dir : str
Output CSV file directory (shared with driver log)
prefix : str
Input file prefix. If requested, it will be used to prefix
all the output CSV files.
overwrite : bool, default False
If `True`, overwrite the CSV logs if they already exist
prefix_output : bool, optional
If `True`, will prefix the output CSV names with the input file name
**modules : dict
List of analysis script modules
"""
# Loop over the analyzer modules and get their priorities
keys = np.array(list(cfg.keys()))
keys = np.array(list(modules.keys()))
priorities = -np.ones(len(keys), dtype=np.int32)
for i, k in enumerate(keys):
if 'priority' in cfg[k]:
priorities[i] = cfg[k].pop('priority')
if 'priority' in modules[k]:
priorities[i] = modules[k].pop('priority')

# Only use the prefix if the output is to be prefixed
if not prefix_output:
prefix = None

# Add the modules to a processor list in decreasing order of priority
self.watch = StopwatchManager()
Expand All @@ -45,7 +73,8 @@ def __init__(self, cfg, parent_path=''):
self.watch.initialize(k)

# Append
self.modules[k] = ana_script_factory(k, cfg[k], parent_path)
self.modules[k] = ana_script_factory(
k, modules[k], overwrite, log_dir, prefix)

def __call__(self, data):
"""Pass one batch of data through the analysis scripts
Expand Down
49 changes: 40 additions & 9 deletions spine/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, cfg, rank=None):
assert self.model is None or self.unwrap, (
"Must unwrap the model output to run analysis scripts.")
self.watch.initialize('ana')
self.ana = AnaManager(ana)
self.ana = AnaManager(ana, log_dir=self.log_dir, prefix=self.prefix)

def __len__(self):
"""Returns the number of events in the underlying reader object."""
Expand Down Expand Up @@ -357,19 +357,16 @@ def initialize_io(self, loader=None, reader=None, writer=None):
self.reader = reader_factory(reader)
self.iter_per_epoch = len(self.reader)

# Fetch an appropriate common prefix for all input files
self.prefix = self.get_prefix(self.reader.file_paths)

# Initialize the data writer, if provided
self.writer = None
if writer is not None:
assert self.loader is None or self.unwrap, (
"Must unwrap the model output to write it to file.")
self.watch.initialize('write')
self.writer = writer_factory(writer)

# If requested, extract the name of the input file to prefix logs
if self.prefix_log:
assert len(self.reader.file_paths) == 1, (
"To prefix log, there should be a single input file name.")
self.log_prefix = pathlib.Path(self.reader.file_paths[0]).stem
self.writer = writer_factory(writer, prefix=self.prefix)

# Harmonize the iterations and epochs parameters
assert (self.iterations is None) or (self.epochs is None), (
Expand All @@ -381,6 +378,40 @@ def initialize_io(self, loader=None, reader=None, writer=None):
elif self.epochs is not None:
self.iterations = self.epochs*self.iter_per_epoch

@staticmethod
def get_prefix(file_paths):
"""Builds an appropriate output prefix based on the list of input files.

Parameters
----------
file_paths : List[str]
List of input file paths

Returns
-------
str
Shared input summary string to be used to prefix outputs
"""
# Fetch file base names (ignore where they live)
file_names = [os.path.basename(f) for f in file_paths]

# Get the shared prefix of all files in the list
prefix = os.path.splitext(os.path.commonprefix(file_names))[0]

# If there is only one file, done
if len(file_names) == 1:
return prefix

# Otherwise, form the suffix from the first and last file names
first = os.path.splitext(file_names[0][len(prefix):])
last = os.path.splitext(file_names[-1][len(prefix):])
first = first[0] if first[0] and first[0][0] != '.' else ''
last = last[0] if last[0] and last[0][0] != '.' else ''

suffix = f'{first}--{len(file_names)-2}--{last}'

return prefix + suffix

def initialize_log(self):
"""Initialize the output log for this driver process."""
# Make a directory if it does not exist
Expand All @@ -401,7 +432,7 @@ def initialize_log(self):

# If requested, prefix the log name with the input file name
if self.prefix_log:
log_name = f'{self.log_prefix}_{log_name}'
log_name = f'{self.prefix}_{log_name}'

# Initialize the log
log_path = os.path.join(self.log_dir, log_name)
Expand Down
18 changes: 8 additions & 10 deletions spine/io/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def dataset_factory(dataset_cfg, entry_list=None, dtype=None):

def sampler_factory(sampler_cfg, dataset, minibatch_size, distributed=False,
num_replicas=1, rank=0):
"""
Instantiates sampler based on type specified in configuration under
"""Instantiates sampler based on type specified in configuration under
`io.sampler.name`. The name must match the name of a class under
`spine.io.sample`.

Expand Down Expand Up @@ -163,8 +162,7 @@ def sampler_factory(sampler_cfg, dataset, minibatch_size, distributed=False,


def collate_factory(collate_cfg):
"""
Instantiates collate function based on type specified in configuration
"""Instantiates collate function based on type specified in configuration
under `io.collate.name`. The name must match the name of a class
under `spine.io.collates`.

Expand All @@ -183,8 +181,7 @@ def collate_factory(collate_cfg):


def reader_factory(reader_cfg):
"""
Instantiates reader based on type specified in configuration under
"""Instantiates reader based on type specified in configuration under
`io.reader.name`. The name must match the name of a class under
`spine.io.readers`.

Expand All @@ -206,16 +203,17 @@ def reader_factory(reader_cfg):
return instantiate(READER_DICT, reader_cfg)


def writer_factory(writer_cfg):
"""
Instantiates writer based on type specified in configuration under
def writer_factory(writer_cfg, prefix=None):
"""Instantiates writer based on type specified in configuration under
`io.writer.name`. The name must match the name of a class under
`spine.io.writers`.

Parameters
----------
writer_cfg : dict
Writer configuration dictionary
prefix : str, optional
Input file prefix to use as an output name

Returns
-------
Expand All @@ -227,4 +225,4 @@ def writer_factory(writer_cfg):
Currently the choice is limited to `HDF5Writer` only.
"""
# Initialize writer
return instantiate(WRITER_DICT, writer_cfg)
return instantiate(WRITER_DICT, writer_cfg, prefix=prefix)
20 changes: 15 additions & 5 deletions spine/io/write/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class HDF5Writer:
"""
name = 'hdf5'

def __init__(self, file_name='output.h5', keys=None, skip_keys=None,
dummy_ds=None, overwrite=False, append=False):
def __init__(self, file_name=None, keys=None, skip_keys=None,
dummy_ds=None, overwrite=False, append=False, prefix=None):
"""Initializes the basics of the output file.

Parameters
----------
file_name : str, default 'output.h5'
file_name : str, default 'spine.h5'
Name of the output HDF5 file
keys : List[str], optional
List of data product keys to store. If not specified, store everything
Expand All @@ -53,10 +53,20 @@ def __init__(self, file_name='output.h5', keys=None, skip_keys=None,
Keys for which to create placeholder datasets. For each key, specify
the object type it is supposed to represent as a string.
overwrite : bool, default False
If True, overwrite the output file if it already exists
If `True`, overwrite the output file if it already exists
append : bool, default False
If True, add new values to the end of an existing file
If `True`, add new values to the end of an existing file
prefix : str, optional
Input file prefix. It will be use to form the output file name,
provided that no file_name is explicitely provided
"""
# If the output file name is not provided, use the input file prefix
if not file_name:
assert prefix is not None, (
"If the output `file_name` is not provided, must provide"
"the input file `prefix` to build it from.")
file_name = prefix + '_spine.h5'

# Check that output file does not already exist, if requestes
if not overwrite and os.path.isfile(file_name):
raise FileExistsError(f"File with name {file_name} already exists.")
Expand Down
Loading