Skip to content

Commit

Permalink
Merge pull request #24 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Add option to split output by input file, add shower track post-processors
  • Loading branch information
francois-drielsma authored Oct 4, 2024
2 parents 841dfb7 + d6de450 commit 25c2998
Show file tree
Hide file tree
Showing 25 changed files with 371 additions and 82 deletions.
2 changes: 0 additions & 2 deletions spine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""Top-level module of the SPICE source code."""

from .data import *
from .driver import Driver
from .version import __version__
33 changes: 24 additions & 9 deletions spine/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import yaml
import psutil
import pathlib
import numpy as np
import torch

Expand Down Expand Up @@ -132,7 +131,8 @@ def __init__(self, cfg, rank=None):
assert self.model is None or self.unwrap, (
"Must unwrap the model output to run analysis scripts.")
self.watch.initialize('ana')
self.ana = AnaManager(ana, log_dir=self.log_dir, prefix=self.prefix)
self.ana = AnaManager(
ana, log_dir=self.log_dir, prefix=self.log_prefix)

def __len__(self):
"""Returns the number of events in the underlying reader object."""
Expand Down Expand Up @@ -232,7 +232,7 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
log_dir='logs', prefix_log=False, overwrite_log=False,
parent_path=None, iterations=None, epochs=None,
unwrap=False, rank=None, log_step=1, distributed=False,
train=None, verbosity='info'):
split_output=False, train=None, verbosity='info'):
"""Initialize the base driver parameters.
Parameters
Expand Down Expand Up @@ -263,6 +263,10 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
Number of iterations before the logging is called (1: every step)
distributed : bool, default False
If `True`, this process is distributed among multiple processes
train : dict, optional
Training configuration dictionary
split_output : bool, default False
Split the output of the process into one file per input file
verbosity : int, default 'info'
Verbosity level to pass to the `logging` module. Pick one of
'debug', 'info', 'warning', 'error', 'critical'.
Expand Down Expand Up @@ -309,6 +313,7 @@ def initialize_base(self, seed, dtype='float32', world_size=0,
self.unwrap = unwrap
self.seed = seed
self.log_step = log_step
self.split_output = split_output

return train

Expand Down Expand Up @@ -358,15 +363,17 @@ def initialize_io(self, loader=None, reader=None, writer=None):
self.iter_per_epoch = len(self.reader)

# Fetch an appropriate common prefix for all input files
self.prefix = self.get_prefix(self.reader.file_paths)
self.log_prefix, self.output_prefix = self.get_prefixes(
self.reader.file_paths, self.split_output)

# Initialize the data writer, if provided
self.writer = None
if writer is not None:
assert self.loader is None or self.unwrap, (
"Must unwrap the model output to write it to file.")
self.watch.initialize('write')
self.writer = writer_factory(writer, prefix=self.prefix)
self.writer = writer_factory(
writer, prefix=self.output_prefix, split=self.split_output)

# Harmonize the iterations and epochs parameters
assert (self.iterations is None) or (self.epochs is None), (
Expand All @@ -379,17 +386,19 @@ def initialize_io(self, loader=None, reader=None, writer=None):
self.iterations = self.epochs*self.iter_per_epoch

@staticmethod
def get_prefix(file_paths):
def get_prefixes(file_paths, split_output):
"""Builds an appropriate output prefix based on the list of input files.
Parameters
----------
file_paths : List[str]
List of input file paths
split_output : bool
Split the output of the process into one file per input file
Returns
-------
str
Union[str, List[str]]
Shared input summary string to be used to prefix outputs
"""
# Fetch file base names (ignore where they live)
Expand All @@ -409,8 +418,14 @@ def get_prefix(file_paths):
last = last[0] if last[0] and last[0][0] != '.' else ''

suffix = f'{first}--{len(file_names)-2}--{last}'
log_prefix = prefix + suffix

return prefix + suffix
# Always provide a single prefix for the log, adapt output prefix
if not split_output:
return log_prefix, log_prefix
else:
return (log_prefix,
[os.path.splitext(name)[0] for name in file_names])

def initialize_log(self):
"""Initialize the output log for this driver process."""
Expand All @@ -432,7 +447,7 @@ def initialize_log(self):

# If requested, prefix the log name with the input file name
if self.prefix_log:
log_name = f'{self.prefix}_{log_name}'
log_name = f'{self.log_prefix}_{log_name}'

# Initialize the log
log_path = os.path.join(self.log_dir, log_name)
Expand Down
3 changes: 2 additions & 1 deletion spine/io/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np

from spine import TensorBatch, IndexBatch, EdgeIndexBatch
from spine.data import TensorBatch, IndexBatch, EdgeIndexBatch

from spine.utils.geo import Geometry

__all__ = ['CollateAll']
Expand Down
6 changes: 4 additions & 2 deletions spine/io/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def reader_factory(reader_cfg):
return instantiate(READER_DICT, reader_cfg)


def writer_factory(writer_cfg, prefix=None):
def writer_factory(writer_cfg, prefix=None, split=False):
"""Instantiates writer based on type specified in configuration under
`io.writer.name`. The name must match the name of a class under
`spine.io.writers`.
Expand All @@ -214,6 +214,8 @@ def writer_factory(writer_cfg, prefix=None):
Writer configuration dictionary
prefix : str, optional
Input file prefix to use as an output name
split : bool, default False
Split the output into one file per input file
Returns
-------
Expand All @@ -225,4 +227,4 @@ def writer_factory(writer_cfg, prefix=None):
Currently the choice is limited to `HDF5Writer` only.
"""
# Initialize writer
return instantiate(WRITER_DICT, writer_cfg, prefix=prefix)
return instantiate(WRITER_DICT, writer_cfg, prefix=prefix, split=split)
3 changes: 2 additions & 1 deletion spine/io/parse/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import numpy as np

from spine import Meta
from spine.data import Meta

from spine.utils.globals import DELTA_SHP
from spine.utils.particles import process_particle_event
from spine.utils.ppn import image_coordinates
Expand Down
2 changes: 1 addition & 1 deletion spine/io/parse/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""


from spine import Meta, RunInfo, Flash, CRTHit, Trigger, ObjectList
from spine.data import Meta, RunInfo, Flash, CRTHit, Trigger, ObjectList

from spine.utils.conditional import larcv

Expand Down
3 changes: 2 additions & 1 deletion spine/io/parse/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import numpy as np

from spine import Meta, Particle, Neutrino, ObjectList
from spine.data import Meta, Particle, Neutrino, ObjectList

from spine.utils.globals import TRACK_SHP, PDG_TO_PID, PID_MASSES
from spine.utils.particles import process_particles
from spine.utils.ppn import get_ppn_labels, image_coordinates
Expand Down
3 changes: 2 additions & 1 deletion spine/io/parse/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import numpy as np

from spine import Meta
from spine.data import Meta

from spine.utils.globals import GHOST_SHP
from spine.utils.ghost import compute_rescaled_charge
from spine.utils.conditional import larcv
Expand Down
112 changes: 77 additions & 35 deletions spine/io/write/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class HDF5Writer:
name = 'hdf5'

def __init__(self, file_name=None, keys=None, skip_keys=None,
dummy_ds=None, overwrite=False, append=False, prefix=None):
dummy_ds=None, overwrite=False, append=False,
prefix=None, split=False):
"""Initializes the basics of the output file.
Parameters
Expand All @@ -59,21 +60,38 @@ def __init__(self, file_name=None, keys=None, skip_keys=None,
prefix : str, optional
Input file prefix. It will be use to form the output file name,
provided that no file_name is explicitely provided
split : bool, default False
If `True`, split the output to produce one file per input file
"""
# If the output file name is not provided, use the input file prefix
# If the output file name is not provided, use the input file prefix(es)
if not file_name:
assert prefix is not None, (
"If the output `file_name` is not provided, must provide"
"the input file `prefix` to build it from.")
file_name = prefix + '_spine.h5'
if not split:
file_name = f'{prefix}_spine.h5'
else:
file_name = [f'{pre}_spine.h5' for pre in prefix]

elif split:
file_name = [f'{file_name}_{i}' for i in range(len(prefix))]

# Check that output file does not already exist, if requestes
if not overwrite and os.path.isfile(file_name):
raise FileExistsError(f"File with name {file_name} already exists.")
# Check that the output file(s) do(es) not already exist, if requested
if not overwrite:
if not split:
if os.path.isfile(file_name):
raise FileExistsError(
f"File with name {file_name} already exists.")
else:
for f in file_name:
if os.path.isfile(f):
raise FileExistsError(
f"File with name {f} already exists.")

# Store persistent attributes
self.file_name = file_name
self.append = append
self.split = split
self.ready = False
self.object_dtypes = [] # TODO: make this a set

Expand Down Expand Up @@ -131,20 +149,22 @@ def create(self, data, cfg=None):
for key in self.keys:
self.register_key(data, key)

# Initialize the output HDF5 file
with h5py.File(self.file_name, 'w') as out_file:
# Initialize the info dataset that stores environment parameters
if cfg is not None:
out_file.create_dataset(
'info', (0,), maxshape=(None,), dtype=None)
out_file['info'].attrs['cfg'] = yaml.dump(cfg)
out_file['info'].attrs['version'] = __version__
# Initialize the output HDF5 file(s)
file_names = [self.file_name] if not self.split else self.file_name
for file_name in file_names:
with h5py.File(file_name, 'w') as out_file:
# Initialize the info dataset that stores environment parameters
if cfg is not None:
out_file.create_dataset(
'info', (0,), maxshape=(None,), dtype=None)
out_file['info'].attrs['cfg'] = yaml.dump(cfg)
out_file['info'].attrs['version'] = __version__

# Initialize the event dataset and their reference array datasets
self.initialize_datasets(out_file)
# Initialize the event dataset and their reference array datasets
self.initialize_datasets(out_file)

# Mark file as ready for use
self.ready = True
# Mark file(s) as ready for use
self.ready = True

def get_stored_keys(self, data):
"""Get the list of data product keys to store.
Expand Down Expand Up @@ -432,23 +452,45 @@ def __call__(self, data, cfg=None):
self.create(data, cfg)
self.ready = True

# Append file
with h5py.File(self.file_name, 'a') as out_file:
# Loop over batch IDs
for batch_id in range(batch_size):
# Initialize a new event
event = np.empty(1, self.event_dtype)

# Initialize a dictionary of references to be passed to the
# event dataset and store the input and result keys
for key in self.keys:
self.append_key(out_file, event, data, key, batch_id)

# Append event
event_id = len(out_file['events'])
event_ds = out_file['events']
event_ds.resize(event_id + 1, axis=0) # pylint: disable=E1101
event_ds[event_id] = event
# Append file(s)
if not self.split:
with h5py.File(self.file_name, 'a') as out_file:
# Loop over batch IDs
for batch_id in range(batch_size):
self.append_entry(out_file, data, batch_id)

else:
file_ids = data['file_index']
for file_id in np.unique(file_ids):
with h5py.File(self.file_name[file_id], 'a') as out_file:
for batch_id in np.where(file_ids == file_id)[0]:
self.append_entry(out_file, data, batch_id)

def append_entry(self, out_file, data, batch_id):
"""Stores one entry.
Parameters
----------
out_file : h5py.File
HDF5 file instance
data : dict
Dictionary of data products
batch_id : int
Batch ID to be stored
"""
# Initialize a new event
event = np.empty(1, self.event_dtype)

# Initialize a dictionary of references to be passed to the
# event dataset and store the input and result keys
for key in self.keys:
self.append_key(out_file, event, data, key, batch_id)

# Append event
event_id = len(out_file['events'])
event_ds = out_file['events']
event_ds.resize(event_id + 1, axis=0) # pylint: disable=E1101
event_ds[event_id] = event

def append_key(self, out_file, event, data, key, batch_id):
"""Stores data key in a specific dataset of an HDF5 file.
Expand Down
3 changes: 2 additions & 1 deletion spine/model/grappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .layer.factories import final_factory
from .layer.gnn.factories import *

from spine import TensorBatch, IndexBatch, EdgeIndexBatch
from spine.data import TensorBatch, IndexBatch, EdgeIndexBatch

from spine.utils.globals import (
BATCH_COL, COORD_COLS, CLUST_COL, GROUP_COL, SHAPE_COL, LOWES_SHP)
from spine.utils.enums import enum_factory
Expand Down
3 changes: 2 additions & 1 deletion spine/model/layer/common/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from sklearn.cluster import DBSCAN as sklearn_dbscan

from spine import TensorBatch, IndexBatch
from spine.data import TensorBatch, IndexBatch

from spine.utils.globals import (
SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP, COORD_COLS, PPN_SHAPE_COL,
COORD_START_COLS, COORD_END_COLS)
Expand Down
2 changes: 1 addition & 1 deletion spine/model/layer/common/final.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch import nn

from spine import TensorBatch
from spine.data import TensorBatch

from .mlp import MLP
from .evidential import EvidentialModel
Expand Down
3 changes: 2 additions & 1 deletion spine/model/layer/gnn/encode/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from spine.model.layer.cnn.encoder import SparseResidualEncoder

from spine import TensorBatch, IndexBatch
from spine.data import TensorBatch, IndexBatch

from spine.utils.globals import BATCH_COL

__all__ = ['ClustCNNNodeEncoder', 'ClustCNNEdgeEncoder', 'ClustCNNGlobalEncoder']
Expand Down
3 changes: 2 additions & 1 deletion spine/model/layer/gnn/encode/geometric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import torch

from spine import TensorBatch
from spine.data import TensorBatch

from spine.utils.torch_local import local_cdist
from spine.utils.globals import COORD_COLS, VALUE_COL, SHAPE_COL
from spine.utils.gnn.cluster import (
Expand Down
Loading

0 comments on commit 25c2998

Please sign in to comment.