diff --git a/bin/larcv_check_valid.py b/bin/larcv_check_valid.py
index e95ea52c..34ecd84f 100644
--- a/bin/larcv_check_valid.py
+++ b/bin/larcv_check_valid.py
@@ -64,7 +64,7 @@ def main(source, source_list, output):
(set(keys_list[idx]) != set(all_keys))):
print(f"- Bad file: {file_path}")
out_file.write(f'{file_path}\n')
- bad_files += file_path
+ bad_files.append(file_path)
suffix = ':' if len(bad_files) > 0 else '.'
print(f"\nFound {len(bad_files)} bad files{suffix}")
diff --git a/spine/ana/base.py b/spine/ana/base.py
index 60933737..db1d3a18 100644
--- a/spine/ana/base.py
+++ b/spine/ana/base.py
@@ -25,11 +25,19 @@ class AnaBase(ABC):
units : str
Units in which the coordinates are expressed
"""
+
+ # Name of the analysis script (as specified in the configuration)
name = None
+
+ # Alternative allowed names of the analysis script
aliases = ()
- keys = None
+
+ # Units in which the analysis script expects objects to be expressed in
units = 'cm'
+ # Set of data keys needed for this analysis script to operate
+ _keys = ()
+
# List of recognized object types
_obj_types = ('fragment', 'particle', 'interaction')
@@ -58,9 +66,7 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
Name to prefix every output CSV file with
"""
# Initialize default keys
- if self.keys is None:
- self.keys = {}
- self.keys.update({
+ self.update_keys({
'index': True, 'file_index': True,
'file_entry_index': False, 'run_info': False
})
@@ -104,7 +110,9 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
self.obj_keys = (self.fragment_keys
+ self.particle_keys
+ self.interaction_keys)
- self.keys.update({k:True for k in self.obj_keys})
+
+ # Update underlying keys, if needed
+ self.update_keys({k:True for k in self.obj_keys})
# Store the append flag
self.append_file = append
@@ -136,6 +144,42 @@ def initialize_writer(self, name):
file_name, append=self.append_file,
overwrite=self.overwrite_file)
+ @property
+ def keys(self):
+ """Dictionary of (key, necessity) pairs which determine which data keys
+ are needed/optional for the post-processor to run.
+
+ Returns
+ -------
+ Dict[str, bool]
+ Dictionary of (key, necessity) pairs to be used
+ """
+ return dict(self._keys)
+
+ @keys.setter
+ def keys(self, keys):
+ """Converts a dictionary of keys to an immutable tuple.
+
+ Parameters
+ ----------
+ Dict[str, bool]
+ Dictionary of (key, necessity) pairs to be used
+ """
+ self._keys = tuple(keys.items())
+
+ def update_keys(self, update_dict):
+ """Update the underlying set of keys and their necessity in place.
+
+ Parameters
+ ----------
+ update_dict : Dict[str, bool]
+ Dictionary of (key, necessity) pairs to update the keys with
+ """
+ if len(update_dict) > 0:
+ keys = self.keys
+ keys.update(update_dict)
+ self._keys = tuple(keys.items())
+
def get_base_dict(self, data):
"""Builds the entry information dictionary.
diff --git a/spine/ana/diag/__init__.py b/spine/ana/diag/__init__.py
new file mode 100644
index 00000000..257eaec4
--- /dev/null
+++ b/spine/ana/diag/__init__.py
@@ -0,0 +1,10 @@
+'''Diagnostic analaysis scripts.
+
+This submodule is use to run basic diagnostics analyses such as:
+- Track dE/dx profile
+- Track energy reconstruction
+- Shower start dE/dx
+- ...
+'''
+
+from .shower import *
diff --git a/spine/ana/diag/shower.py b/spine/ana/diag/shower.py
new file mode 100644
index 00000000..4d708280
--- /dev/null
+++ b/spine/ana/diag/shower.py
@@ -0,0 +1,64 @@
+'''Module to evaluate diagnostic metrics on showers.'''
+
+from spine.ana.base import AnaBase
+
+__all__ = ['ShowerStartDEdxAna']
+
+
+class ShowerStartDEdxAna(AnaBase):
+ """This analysis script computes the dE/dx value within some distance
+ from the start point of an EM shower object.
+
+ This is a useful diagnostic tool to evaluate the calorimetric separability
+ of different EM shower types (electron vs photon), which are expected to
+ have different dE/dx patterns near their start point.
+ """
+
+ # Name of the analysis script (as specified in the configuration)
+ name = 'shower_start_dedx'
+
+ def __init__(self, radius, obj_type='particle', run_mode='both',
+ truth_point_mode='points', truth_dep_mode='depositions',
+ **kwargs):
+ """Initialize the analysis script.
+
+ Parameters
+ ----------
+ radius : Union[float, List[float]]
+ Radius around the start point for which evaluate dE/dx
+ **kwargs : dict, optional
+ Additional arguments to pass to :class:`AnaBase`
+ """
+ # Initialize the parent class
+ super().__init__(obj_type, run_mode, **kwargs)
+
+ # Store parameters
+ self.radius = radius
+
+ # Initialize the CSV writer(s) you want
+ for obj in self.obj_type:
+ self.initialize_writer(obj)
+
+ def process(self, data):
+ """Evaluate shower start dE/dx for one entry.
+
+ Parameters
+ ----------
+ data : dict
+ Dictionary of data products
+ """
+ # Fetch the keys you want
+ data = data['prod']
+
+ # Loop over all requested object types
+ for key in self.obj_keys:
+ # Loop over all objects of that type
+ for obj in data[key]:
+ # Do something with the object
+ disp = p.end_point - p.start_point
+
+ # Make a dictionary of integer out of it
+ out = {'disp_x': disp[0], 'disp_y': disp[1], 'disp_z': disp[2]}
+
+ # Write the row to file
+ self.append('template', **out)
diff --git a/spine/ana/factories.py b/spine/ana/factories.py
index 558f2d74..f00b4625 100644
--- a/spine/ana/factories.py
+++ b/spine/ana/factories.py
@@ -10,7 +10,7 @@
ANA_DICT.update(**module_dict(module))
-def ana_script_factory(name, cfg, overwrite=False, log_dir=None, prefix=None):
+def ana_script_factory(name, cfg, overwrite=None, log_dir=None, prefix=None):
"""Instantiates an analyzer module from a configuration dictionary.
Parameters
@@ -22,7 +22,7 @@ def ana_script_factory(name, cfg, overwrite=False, log_dir=None, prefix=None):
parent_path : str
Path to the parent directory of the main analysis configuration. This
allows for the use of relative paths in the analyzers.
- overwrite : bool, default False
+ overwrite : bool, optional
If `True`, overwrite the CSV logs if they already exist
log_dir : str, optional
Output CSV file directory (shared with driver log)
@@ -39,5 +39,9 @@ def ana_script_factory(name, cfg, overwrite=False, log_dir=None, prefix=None):
cfg['name'] = name
# Instantiate the analysis script module
- return instantiate(
- ANA_DICT, cfg, overwrite=overwrite, log_dir=log_dir, prefix=prefix)
+ if overwrite is not None:
+ return instantiate(
+ ANA_DICT, cfg, overwrite=overwrite, log_dir=log_dir, prefix=prefix)
+ else:
+ return instantiate(
+ ANA_DICT, cfg, log_dir=log_dir, prefix=prefix)
diff --git a/spine/ana/manager.py b/spine/ana/manager.py
index 11c03470..fd0bdb21 100644
--- a/spine/ana/manager.py
+++ b/spine/ana/manager.py
@@ -1,5 +1,6 @@
"""Manages the operation of analysis scripts."""
+from copy import deepcopy
from collections import defaultdict, OrderedDict
import numpy as np
@@ -35,7 +36,7 @@ def __init__(self, cfg, log_dir=None, prefix=None):
# Parse the analysis block configuration
self.parse_config(log_dir, prefix, **cfg)
- def parse_config(self, log_dir, prefix, overwrite=False,
+ def parse_config(self, log_dir, prefix, overwrite=None,
prefix_output=False, **modules):
"""Parse the analysis tool configuration.
@@ -46,7 +47,7 @@ def parse_config(self, log_dir, prefix, overwrite=False,
prefix : str
Input file prefix. If requested, it will be used to prefix
all the output CSV files.
- overwrite : bool, default False
+ overwrite : bool, optional
If `True`, overwrite the CSV logs if they already exist
prefix_output : bool, optional
If `True`, will prefix the output CSV names with the input file name
@@ -54,6 +55,7 @@ def parse_config(self, log_dir, prefix, overwrite=False,
List of analysis script modules
"""
# Loop over the analyzer modules and get their priorities
+ modules = deepcopy(modules)
keys = np.array(list(modules.keys()))
priorities = -np.ones(len(keys), dtype=np.int32)
for i, k in enumerate(keys):
diff --git a/spine/ana/metric/cluster.py b/spine/ana/metric/cluster.py
index 502d1011..fa37d45c 100644
--- a/spine/ana/metric/cluster.py
+++ b/spine/ana/metric/cluster.py
@@ -22,13 +22,16 @@ class ClusterAna(AnaBase):
- particles
- interactions
"""
+
+ # Name of the analysis script (as specified in the configuration)
name = 'cluster_eval'
# Label column to use for each clustering label_col
- _label_cols = {
- 'fragment': CLUST_COL, 'particle': GROUP_COL,
- 'interaction': INTER_COL
- }
+ _label_cols = (
+ ('fragment', CLUST_COL),
+ ('particle', GROUP_COL),
+ ('interaction', INTER_COL)
+ )
def __init__(self, obj_type=None, use_objects=False, per_object=True,
per_shape=True, metrics=('pur', 'eff', 'ari'),
@@ -53,7 +56,7 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
label_key : str, default 'clust_label_adapt'
Name of the tensor which contains the cluster labels, when
using the raw reconstruction output
- label_col : str
+ label_col : str, optional
Column name in the label tensor specifying the aggregation label_col
**kwargs : dict, optional
Additional arguments to pass to :class:`AnaBase`
@@ -71,9 +74,9 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
# Initialize the parent class
super().__init__(obj_type, 'both', **kwargs)
- if not use_objects:
- for key in self.obj_keys:
- del self.keys[key]
+
+
+ # If the clustering is not done per object, fix target
if not per_object:
self.obj_type = [label_col]
@@ -90,27 +93,47 @@ def __init__(self, obj_type=None, use_objects=False, per_object=True,
# Convert metric strings to functions
self.metrics = {m: getattr(spine.utils.metrics, m) for m in metrics}
- # List the necessary data products
+ # If objects are not used, remove them from the required keys
+ keys = self.keys
+ if not use_objects:
+ for key in self.obj_keys:
+ del keys[key]
+
+ # List other necessary data products
if self.per_object:
if not self.use_objects:
# Store the labels and the clusters output by the reco chain
- self.keys[label_key] = True
+ keys[label_key] = True
for obj in self.obj_type:
- self.keys[f'{obj}_clusts'] = True
- self.keys[f'{obj}_shapes'] = True
+ keys[f'{obj}_clusts'] = True
+ keys[f'{obj}_shapes'] = True
else:
- self.keys['points'] = True
+ keys['points'] = True
else:
- self.keys[label_key] = True
- self.keys['clusts'] = True
- self.keys['group_pred'] = True
+ keys[label_key] = True
+ keys['clusts'] = True
+ keys['group_pred'] = True
+
+ self.keys = keys
# Initialize the output
for obj in self.obj_type:
self.initialize_writer(obj)
+ @property
+ def label_cols(self):
+ """Dictionary of (key, column_id) pairs which determine which column
+ in the label tensor corresponds to a specific clustering target.
+
+ Returns
+ -------
+ Dict[str, int]
+ Dictionary of (key, column_id) mapping from name to label column
+ """
+ return dict(self._label_cols)
+
def process(self, data):
"""Store the clustering metrics for one entry.
@@ -124,7 +147,7 @@ def process(self, data):
# Build the cluster labels for this object type
if not self.use_objects:
# Fetch the right label column
- label_col = self.label_col or self._label_cols[obj_type]
+ label_col = self.label_col or self.label_cols[obj_type]
num_points = len(data[self.label_key])
labels = data[self.label_key][:, label_col]
shapes = data[self.label_key][:, SHAPE_COL]
diff --git a/spine/ana/metric/point.py b/spine/ana/metric/point.py
index ec9f25b3..aac631a5 100644
--- a/spine/ana/metric/point.py
+++ b/spine/ana/metric/point.py
@@ -25,8 +25,13 @@ class PointProposalAna(AnaBase):
- Point type classification accuracy
- Point end classification accuracy
"""
+
+ # Name of the analysis script (as specified in the configuration)
name = 'point_eval'
+ # Set of data keys needed for this analysis script to operate
+ _keys = (('ppn_pred', True),)
+
def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label',
endpoints=False, **kwargs):
"""Initialize the analysis script.
@@ -51,8 +56,7 @@ def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label',
self.endpoints = endpoints
# Append other required key
- self.keys['ppn_pred'] = True
- self.keys[self.label_key] = True
+ self.update_keys({self.label_key: True})
# Initialize the output
self.initialize_writer('truth_to_reco')
diff --git a/spine/ana/metric/segment.py b/spine/ana/metric/segment.py
index 05dd0496..218d797e 100644
--- a/spine/ana/metric/segment.py
+++ b/spine/ana/metric/segment.py
@@ -16,6 +16,8 @@ class SegmentAna(AnaBase):
"""Class which computes and stores the necessary data to build a
semantic segmentation confusion matrix.
"""
+
+ # Name of the analysis script (as specified in the configuration)
name = 'segment_eval'
def __init__(self, summary=True, num_classes=GHOST_SHP, ghost=False,
@@ -70,19 +72,22 @@ def __init__(self, summary=True, num_classes=GHOST_SHP, ghost=False,
"Cannot produce ghost metrics from fragments/particles.")
# List the necessary data products
+ keys = self.keys
if not use_fragments and not use_particles:
self.obj_source = None
- self.keys[label_key] = True
- self.keys['segmentation'] = True
+ keys[label_key] = True
+ keys['segmentation'] = True
if ghost:
self.num_classes += 1
- self.keys['ghost'] = True
+ keys['ghost'] = True
else:
- self.keys['points'] = True
+ keys['points'] = True
self.obj_type = 'fragments' if use_fragments else 'particles'
for prefix in ['reco', 'truth']:
- self.keys[f'{prefix}_{self.obj_type}'] = True
+ keys[f'{prefix}_{self.obj_type}'] = True
+
+ self.keys = keys
# Initialize the output
if summary:
diff --git a/spine/ana/script/event.py b/spine/ana/script/event.py
index 7bb2fad6..cb8b73d4 100644
--- a/spine/ana/script/event.py
+++ b/spine/ana/script/event.py
@@ -14,9 +14,14 @@
class EventAna(AnaBase):
"""Class which saves basic event information (and their matches)."""
+
+ # Name of the analysis script (as specified in the configuration)
name = 'event'
- keys = {'index': True, 'reco_particles': True,
- 'truth_particles': True, 'run_info': False}
+
+ # Set of data keys needed for this analysis script to operate
+ _keys = (
+ ('reco_particles', True), ('truth_particles', True)
+ )
def __init__(self, **kwargs):
"""Initialize the CSV event logging class.
diff --git a/spine/ana/script/save.py b/spine/ana/script/save.py
index 39841e38..b4f4ce7c 100644
--- a/spine/ana/script/save.py
+++ b/spine/ana/script/save.py
@@ -19,17 +19,17 @@ class SaveAna(AnaBase):
_match_modes = (None, 'reco_to_truth', 'truth_to_reco', 'both', 'all')
# Default object types when a match is not found
- _default_objs = {
- 'reco_fragments': RecoFragment(),
- 'truth_fragments': TruthFragment(),
- 'reco_particles': RecoParticle(),
- 'truth_particles': TruthParticle(),
- 'reco_interactions': RecoInteraction(),
- 'truth_interactions': TruthInteraction()
- }
+ _default_objs = (
+ ('reco_fragments', RecoFragment()),
+ ('truth_fragments', TruthFragment()),
+ ('reco_particles', RecoParticle()),
+ ('truth_particles', TruthParticle()),
+ ('reco_interactions', RecoInteraction()),
+ ('truth_interactions', TruthInteraction())
+ )
def __init__(self, obj_type, fragment=None, particle=None, interaction=None,
- run_mode='both', match_mode='both', **kwargs):
+ lengths=None, run_mode='both', match_mode='both', **kwargs):
"""Initialize the CSV logging class.
If any of the `fragments`, `particles` or `interactions` are specified
@@ -40,8 +40,14 @@ def __init__(self, obj_type, fragment=None, particle=None, interaction=None,
----------
obj_type : Union[str, List[str]], default ['particle', 'interaction']
Objects to build files from
- attrs : List[str]
- List of object attributes to store
+ fragment : List[str], optional
+ List of fragment attributes to store
+ particle : List[str], optional
+ List of particle attributes to store
+ interaction : List[str], optional
+ List of interaction attributes to store
+ lengths : Dict[str, int], optional
+ Lengths to use for variable-length object attributes
match_mode : str, default 'both'
If reconstructed and truth are available, specified which matching
direction(s) should be saved to the log file.
@@ -60,6 +66,9 @@ def __init__(self, obj_type, fragment=None, particle=None, interaction=None,
"When storing matches, you must load both reco and truth "
f"objects, i.e. set `run_mode` to `True`. Got {run_mode}.")
+ # Store default objects as a dictionary
+ self.default_objs = dict(self._default_objs)
+
# Store the list of attributes to store for each object type
attrs = {
'fragments': fragment,
@@ -80,7 +89,7 @@ def __init__(self, obj_type, fragment=None, particle=None, interaction=None,
for run_mode in ['reco', 'truth']:
key = f'{run_mode}_{obj_t}'
if attrs[obj_t] is not None:
- all_keys = self._default_objs[key].as_dict().keys()
+ all_keys = self.default_objs[key].as_dict().keys()
self.attrs[key] = set(attrs[obj_t]) & set(all_keys)
leftover -= (leftover & self.attrs[key])
@@ -92,16 +101,22 @@ def __init__(self, obj_type, fragment=None, particle=None, interaction=None,
"The following keys were not found in either the reco "
f"or the truth {obj_t} : {leftover}")
+ # Store the list of variable-length array lengths
+ self.lengths = lengths
+
# Add the necessary keys associated with matching, if needed
+ keys = {}
if match_mode is not None:
for prefix in self.prefixes:
for obj_name in obj_type:
if prefix == 'reco' and match_mode != 'truth_to_reco':
- self.keys[f'{obj_name}_matches_r2t'] = True
- self.keys[f'{obj_name}_matches_r2t_overlap'] = True
+ keys[f'{obj_name}_matches_r2t'] = True
+ keys[f'{obj_name}_matches_r2t_overlap'] = True
if prefix == 'truth' and match_mode != 'reco_to_truth':
- self.keys[f'{obj_name}_matches_t2r'] = True
- self.keys[f'{obj_name}_matches_t2r_overlap'] = True
+ keys[f'{obj_name}_matches_t2r'] = True
+ keys[f'{obj_name}_matches_t2r_overlap'] = True
+
+ self.update_keys(keys)
# Initialize one CSV writer per object type
for key in self.obj_keys:
@@ -125,11 +140,13 @@ def process(self, data):
other = other_prefix[prefix]
attrs = self.attrs[key]
attrs_other = self.attrs[f'{other}_{obj_type}']
+ lengths = self.lengths
+ lengths_other = self.lengths
if (self.match_mode is None or
self.match_mode == f'{other}_to_{prefix}'):
# If there is no matches, save objects by themselves
for i, obj in enumerate(data[key]):
- self.append(key, **obj.scalar_dict(attrs))
+ self.append(key, **obj.scalar_dict(attrs, lengths))
else:
# If there are matches, combine the objects with their best
@@ -137,12 +154,12 @@ def process(self, data):
match_suffix = f'{prefix[0]}2{other[0]}'
match_key = f'{obj_type[:-1]}_matches_{match_suffix}'
for idx, (obj_i, obj_j) in enumerate(data[match_key]):
- src_dict = obj_i.scalar_dict(attrs)
+ src_dict = obj_i.scalar_dict(attrs, lengths)
if obj_j is not None:
- tgt_dict = obj_j.scalar_dict(attrs_other)
+ tgt_dict = obj_j.scalar_dict(attrs_other, lengths_other)
else:
- default_obj = self._default_objs[f'{other}_{obj_type}']
- tgt_dict = default_obj.scalar_dict(attrs_other)
+ default_obj = self.default_objs[f'{other}_{obj_type}']
+ tgt_dict = default_obj.scalar_dict(attrs_other, lengths_other)
src_dict = {f'{prefix}_{k}':v for k, v in src_dict.items()}
tgt_dict = {f'{other}_{k}':v for k, v in tgt_dict.items()}
diff --git a/spine/ana/template.py b/spine/ana/template.py
index 361c1f45..9e6712ab 100644
--- a/spine/ana/template.py
+++ b/spine/ana/template.py
@@ -19,7 +19,9 @@
class TemplateAna(AnaBase):
"""Description of what the analysis script is supposed to be doing."""
- name = 'template' # Name used to call the analysis script in the config
+
+ # Name of the analysis script (as specified in the configuration)
+ name = 'template'
def __init__(self, arg0, arg1, obj_type, run_mode, append_file,
overwrite_file, output_prefix):
@@ -56,7 +58,7 @@ def __init__(self, arg0, arg1, obj_type, run_mode, append_file,
self.initialize_writer('template')
# Add additional required data products
- self.keys['prod'] = True # Means we must have 'data' in the dictionary
+ self.update_keys({'prod': True}) # Means we must have 'prod' in the dictionary
def process(self, data):
"""Pass data products corresponding to one entry through the analysis.
diff --git a/spine/build/base.py b/spine/build/base.py
index 7769c218..2d69119f 100644
--- a/spine/build/base.py
+++ b/spine/build/base.py
@@ -12,58 +12,42 @@ class BuilderBase(ABC):
A Builder class takes input data and full chain result dictionaries
and processes them into human-readable data structures.
-
- Attributes
- ----------
- name : str
- Name of the builder (to call it from a configuration file)
- reco_type : object
- Data class representation object of a reconstructed object
- truth_type : object
- Data class representation object of a truth object
- build_reco_keys : Dict[str, bool]
- Dictionary of keys used to build the reconstructed objects from the
- data products and whether they are essential or not
- build_truth_keys : Dict[str, bool]
- Dictionary of keys used to build the truth objects from the
- data products and whether they are essential or not
- load_reco_keys : Dict[str, bool]
- Dictionary of keys used to load the reconstructed objects from
- existing stored objects and whether they are essential or not
- load_truth_keys : Dict[str, bool]
- Dictionary of keys used to load the truth objects from
- existing stored objects and whether they are essential or not
"""
- name = None
-
- reco_type = None
- truth_type = None
-
- build_reco_keys = {
- 'points': True, 'depositions': True, 'sources': False
- }
-
- build_truth_keys = {
- 'label_tensor': True, 'label_adapt_tensor': True,
- 'label_g4_tensor': False, 'points': True, 'points_label': True,
- 'points_g4': False, 'depositions': True, 'depositions_label': True,
- 'depositions_q_label': False, 'depositions_g4': False,
- 'sources': False, 'sources_label': False
- }
- load_reco_keys = {
- 'points': True, 'depositions': True, 'sources': False
- }
-
- load_truth_keys = {
- 'points': True, 'points_label': True, 'points_g4': False,
- 'depositions': True, 'depositions_label': True,
- 'depositions_q_label': False, 'depositions_g4': False,
- 'sources': False, 'sources_label': False
- }
+ # Builder name
+ name = None
- # List of recognized run modes
- _run_modes = ['reco', 'truth', 'both', 'all']
+ # Types of objects constructed by the builder
+ _reco_type = None
+ _truth_type = None
+
+ # Necessary/optional data products to build a reconstructed object
+ _build_reco_keys = (
+ ('points', True), ('depositions', True), ('sources', False)
+ )
+
+ # Necessary/optional data products to build a truth object
+ _build_truth_keys = (
+ ('label_tensor', True), ('label_adapt_tensor', True),
+ ('label_g4_tensor', False), ('points', True),
+ ('points_label', True), ('points_g4', False), ('depositions', True),
+ ('depositions_label', True), ('depositions_q_label', False),
+ ('depositions_g4', False), ('sources', False),
+ ('sources_label', False)
+ )
+
+ # Necessary/optional data products to load a reconstructed object
+ _load_reco_keys = (
+ ('points', True), ('depositions', True), ('sources', False)
+ )
+
+ # Necessary/optional data products to load a truth object
+ _load_truth_keys = (
+ ('points', True), ('points_label', True), ('points_g4', False),
+ ('depositions', True), ('depositions_label', True),
+ ('depositions_q_label', False), ('depositions_g4', False),
+ ('sources', False), ('sources_label', False)
+ )
def __init__(self, mode, units):
"""Initializes the builder.
@@ -77,13 +61,8 @@ def __init__(self, mode, units):
Units in which the position arguments of the constructed objects
should be expressed (one of 'cm' or 'px')
"""
- # Check on the mode, store it
- assert mode in self._run_modes, (
- f"Run mode not recognized: {mode}. Must be one of 'reco', "
- "'truth', 'both' or 'all'.")
+ # Store the mode and units
self.mode = mode
-
- # Store the target units
self.units = units
def __call__(self, data):
@@ -158,7 +137,7 @@ def check_units(self, data, key, entry=None):
def construct(self, func, data, entry=None):
"""Prepares the input based on the required data and runs constructor.
-
+
Parameters
----------
func : str
@@ -176,8 +155,8 @@ def construct(self, func, data, entry=None):
# Get the description of the fields needed by this source object
input_data = {}
method, dtype = func.split('_')
- keys = getattr(self, f'{func}_keys')
- for key, req in keys.items():
+ keys = getattr(self, f'_{func}_keys')
+ for key, req in keys:
# If the field has no default value, must be provided
if req and key not in data:
raise KeyError(
@@ -191,7 +170,7 @@ def construct(self, func, data, entry=None):
input_data[key] = data[key]
obj_list = getattr(self, func)(input_data)
- default = getattr(self, f'{dtype}_type')()
+ default = getattr(self, f'_{dtype}_type')()
return ObjectList(obj_list, default)
diff --git a/spine/build/fragment.py b/spine/build/fragment.py
index 9b97e1a3..c78ba808 100644
--- a/spine/build/fragment.py
+++ b/spine/build/fragment.py
@@ -1,4 +1,4 @@
-"""Classes in charge of constructing FragmentBase objects."""
+"""Classes in charge of constructing Fragment objects."""
from typing import List
from dataclasses import dataclass
@@ -24,32 +24,39 @@ class FragmentBuilder(BuilderBase):
necessary information and builds :class:`RecoFragment` and
:class:`TruthFragment` objects from it.
"""
+
+ # Builder name
name = 'fragment'
- reco_type = RecoFragment
- truth_type = TruthFragment
-
- build_reco_keys = {
- 'fragment_clusts': True, 'fragment_shapes': True,
- 'fragment_start_points': False, 'fragment_end_points': False,
- 'fragment_group_pred': False, 'fragment_node_pred': False,
- **BuilderBase.build_reco_keys
- }
-
- build_truth_keys = {
- 'particles': False,
- **BuilderBase.build_truth_keys
- }
-
- load_reco_keys = {
- 'reco_fragments': True,
- **BuilderBase.load_reco_keys
- }
-
- load_truth_keys = {
- 'truth_fragments': True,
- **BuilderBase.load_truth_keys
- }
+ # Types of objects constructed by the builder
+ _reco_type = RecoFragment
+ _truth_type = TruthFragment
+
+ # Necessary/optional data products to build a reconstructed object
+ _build_reco_keys = (
+ ('fragment_clusts', True), ('fragment_shapes', True),
+ ('fragment_start_points', False), ('fragment_end_points', False),
+ ('fragment_group_pred', False), ('fragment_node_pred', False),
+ *BuilderBase._build_reco_keys
+ )
+
+ # Necessary/optional data products to build a truth object
+ _build_truth_keys = (
+ ('particles', False),
+ *BuilderBase._build_truth_keys
+ )
+
+ # Necessary/optional data products to load a reconstructed object
+ _load_reco_keys = (
+ ('reco_fragments', True),
+ *BuilderBase._load_reco_keys
+ )
+
+ # Necessary/optional data products to load a truth object
+ _load_truth_keys = (
+ ('truth_fragments', True),
+ *BuilderBase._load_truth_keys
+ )
def build_reco(self, data):
"""Builds :class:`RecoFragment` objects from the full chain output.
diff --git a/spine/build/interaction.py b/spine/build/interaction.py
index a7ac4a3a..a4eb315a 100644
--- a/spine/build/interaction.py
+++ b/spine/build/interaction.py
@@ -1,4 +1,4 @@
-"""Class in charge of constructing *Interaction objects."""
+"""Class in charge of constructing Interaction objects."""
from collections import defaultdict
from warnings import warn
@@ -19,26 +19,33 @@ class InteractionBuilder(BuilderBase):
necessary information and builds :class:`RecoInteraction` and
:class:`TruthInteraction` objects from it.
"""
- name = 'interaction'
-
- reco_type = RecoInteraction
- truth_type = TruthInteraction
-
- build_reco_keys = {
- 'reco_particles': True
- }
- build_truth_keys = {
- 'truth_particles': True, 'neutrinos': False
- }
-
- load_reco_keys = {
- 'reco_interactions': True, 'reco_particles': True
- }
+ # Builder name
+ name = 'interaction'
- load_truth_keys = {
- 'truth_interactions': True, 'truth_particles': True
- }
+ # Types of objects constructed by the builder
+ _reco_type = RecoInteraction
+ _truth_type = TruthInteraction
+
+ # Necessary/optional data products to build a reconstructed object
+ _build_reco_keys = (
+ ('reco_particles', True),
+ )
+
+ # Necessary/optional data products to build a truth object
+ _build_truth_keys = (
+ ('truth_particles', True), ('neutrinos', False)
+ )
+
+ # Necessary/optional data products to load a reconstructed object
+ _load_reco_keys = (
+ ('reco_interactions', True), ('reco_particles', True)
+ )
+
+ # Necessary/optional data products to load a truth object
+ _load_truth_keys = (
+ ('truth_interactions', True), ('truth_particles', True)
+ )
def build_reco(self, data):
"""Builds :class:`RecoInteraction` objects from the full chain output.
diff --git a/spine/build/manager.py b/spine/build/manager.py
index 8fda8005..9b94ade3 100644
--- a/spine/build/manager.py
+++ b/spine/build/manager.py
@@ -19,19 +19,28 @@ class BuildManager:
- Interpret the raw output of the reconstruction chain
- Load up existing objects stored as dictionaries
"""
+
+ # List of recognized run modes
+ _run_modes = ('reco', 'truth', 'both', 'all')
+
+ # List of recognized units
+ _units = ('cm', 'px')
+
# Name of input data products needed to build representations. These names
- # are not set in stone, so they can be set in the configuration
- sources = {
- 'data_tensor': ['data_adapt', 'data'],
- 'label_tensor': 'clust_label',
- 'label_adapt_tensor': ['clust_label_adapt', 'clust_label'],
- 'label_g4_tensor': 'clust_label_g4',
- 'depositions_q_label': 'charge_label',
- 'sources': ['sources_adapt', 'sources'],
- 'sources_label': 'sources_label',
- 'particles': 'particles',
- 'neutrinos': 'neutrinos'
- }
+ # are not set in stone; they can be set in the configuration
+ _sources = (
+ ('data_tensor', ('data_adapt', 'data')),
+ ('label_tensor', ('clust_label',)),
+ ('label_adapt_tensor', ('clust_label_adapt', 'clust_label')),
+ ('label_g4_tensor', ('clust_label_g4',)),
+ ('depositions_q_label', ('charge_label',)),
+ ('sources', ('sources_adapt', 'sources')),
+ ('sources_label', ('sources_label',)),
+ ('particles', ('particles',)),
+ ('neutrinos', ('neutrinos',)),
+ ('flashes', ('flashes',)),
+ ('crthits', ('crthits',))
+ )
def __init__(self, fragments, particles, interactions,
mode='both', units='cm', sources=None):
@@ -51,26 +60,29 @@ def __init__(self, fragments, particles, interactions,
Dictionary which maps the necessary data products onto a name
in the input/output dictionary of the reconstruction chain.
"""
- # Check on the mode, store desired units
- assert mode in ['reco', 'truth', 'both', 'all'], (
- f"Run mode not recognized: {mode}. Must be one of 'reco', "
- "'truth', 'both' or 'all'.")
+ # Check on the mode, store it
+ assert mode in self._run_modes, (
+ f"Run mode not recognized: {mode}. Must be one {self._run_modes}")
self.mode = mode
+
+ # Check on the units, store them
+ assert units in self._units, (
+ f"Units not recognized: {units}. Must be one {self._units}")
self.units = units
- # Parse the build sources based on defaults
+ # If custom sources are provided, update the tuple
if sources is not None:
+ sources_dict = dict(self._sources)
for key, value in sources.items():
- assert key in self.sources, (
+ assert key in sources_dict, (
"Unexpected data product specified in `sources`: "
- f"{key}. Should be one of {list(self.sources.keys())}.")
- self.sources.update(**sources)
+ f"{key}. Should be one of {list(sources_dict.keys())}.")
+ if isinstance(value, str):
+ sources_dict[key] = (value,)
+ else:
+ sources_dict[key] = tuple(value)
- for key, value in self.sources.items():
- if isinstance(value, str):
- self.sources[key] = [value]
- else:
- self.sources[key] = value
+ self._sources = tuple(sources_dict.items())
# Initialize the builders
self.builders = OrderedDict()
@@ -79,10 +91,9 @@ def __init__(self, fragments, particles, interactions,
if particles:
self.builders['particle'] = ParticleBuilder(mode, units)
if interactions:
- assert particles, (
- "Interactions are built from particles. If "
- "`interactions` is True, so must "
- "`particles` be.")
+ assert particles is not None, (
+ "Interactions are built from particles. If `interactions` "
+ "is True, so must `particles` be.")
self.builders['interaction'] = InteractionBuilder(mode, units)
assert len(self.builders), (
@@ -150,7 +161,7 @@ def build_sources(self, data, entry=None):
"""
# Fetch the orginal sources
sources = {}
- for key, alt_keys in self.sources.items():
+ for key, alt_keys in self._sources:
for alt in alt_keys:
if alt in data:
sources[key] = data[alt]
diff --git a/spine/build/particle.py b/spine/build/particle.py
index f4a34368..1d3744c4 100644
--- a/spine/build/particle.py
+++ b/spine/build/particle.py
@@ -1,4 +1,4 @@
-"""Classes in charge of constructing ParticleBase objects."""
+"""Classes in charge of constructing Particle objects."""
import numpy as np
@@ -19,34 +19,41 @@ class ParticleBuilder(BuilderBase):
necessary information and builds :class:`RecoParticle` and
:class:`TruthParticle` objects from it.
"""
+
+ # Builder name
name = 'particle'
- reco_type = RecoParticle
- truth_type = TruthParticle
-
- build_reco_keys = {
- 'particle_clusts': True, 'particle_shapes': True,
- 'particle_start_points': True, 'particle_end_points': True,
- 'particle_group_pred': True, 'particle_node_type_pred': True,
- 'particle_node_primary_pred': True,
- 'particle_node_orient_pred': False, 'reco_fragments': False,
- **BuilderBase.build_reco_keys
- }
-
- build_truth_keys = {
- 'particles': False,
- **BuilderBase.build_truth_keys
- }
-
- load_reco_keys = {
- 'reco_particles': True,
- **BuilderBase.load_reco_keys
- }
-
- load_truth_keys = {
- 'truth_particles': True,
- **BuilderBase.load_truth_keys
- }
+ # Types of objects constructed by the builder
+ _reco_type = RecoParticle
+ _truth_type = TruthParticle
+
+ # Necessary/optional data products to build a reconstructed object
+ _build_reco_keys = (
+ ('particle_clusts', True), ('particle_shapes', True),
+ ('particle_start_points', True), ('particle_end_points', True),
+ ('particle_group_pred', True), ('particle_node_type_pred', True),
+ ('particle_node_primary_pred', True),
+ ('particle_node_orient_pred', False), ('reco_fragments', False),
+ *BuilderBase._build_reco_keys
+ )
+
+ # Necessary/optional data products to build a truth object
+ _build_truth_keys = (
+ ('particles', False),
+ *BuilderBase._build_truth_keys
+ )
+
+ # Necessary/optional data products to load a reconstructed object
+ _load_reco_keys = (
+ ('reco_particles', True),
+ *BuilderBase._load_reco_keys
+ )
+
+ # Necessary/optional data products to load a truth object
+ _load_truth_keys = (
+ ('truth_particles', True),
+ *BuilderBase._load_truth_keys
+ )
def build_reco(self, data):
"""Builds :class:`RecoParticle` objects from the full chain output.
@@ -209,7 +216,7 @@ def _build_truth(self, label_tensor, label_adapt_tensor, particles, points,
particle = TruthParticle(**particles[group_id].as_dict())
assert particle.id == group_id, (
"The ordering of the true particle is wrong.")
-
+
# Override the index of the particle but preserve it
particle.orig_id = group_id
particle.id = i
diff --git a/spine/data/base.py b/spine/data/base.py
index a7d7d270..de1b7e91 100644
--- a/spine/data/base.py
+++ b/spine/data/base.py
@@ -13,37 +13,37 @@ class DataBase:
"""
# Enumerated attributes
- _enum_attrs = {}
+ _enum_attrs = ()
- # Fixed-length attributes as (key, size) pairs
- _fixed_length_attrs = {}
+ # Fixed-length attributes as (key, size) or (key, (size, dtype)) pairs
+ _fixed_length_attrs = ()
- # Variable-length attributes as (key, dtype) pairs
- _var_length_attrs = {}
-
- # Attributes to be binarized to form an integer from a variable-length array
- _binarize_attrs = []
+ # Variable-length attributes as (key, dtype) or (key, (width, dtype)) pairs
+ _var_length_attrs = ()
# Attributes specifying coordinates
- _pos_attrs = []
+ _pos_attrs = ()
# Attributes specifying vector components
- _vec_attrs = []
+ _vec_attrs = ()
# String attributes
- _str_attrs = []
+ _str_attrs = ()
# Boolean attributes
- _bool_attrs = []
+ _bool_attrs = ()
# Attributes to concatenate when merging objects
- _cat_attrs = []
+ _cat_attrs = ()
+
+ # Attributes that must never be stored to file
+ _skip_attrs = ()
- # Attributes that should not be stored to file (long-form attributes)
- _skip_attrs = []
+ # Attributes that must not be stored to file when storing lite files
+ _lite_skip_attrs = ()
# Euclidean axis labels
- _axes = ['x', 'y', 'z']
+ _axes = ('x', 'y', 'z')
def __post_init__(self):
"""Immediately called after building the class attributes.
@@ -56,7 +56,7 @@ def __post_init__(self):
format one gets when loading string from HDF5 files.
"""
# Provide default values to the variable-length array attributes
- for attr, dtype in self._var_length_attrs.items():
+ for attr, dtype in self._var_length_attrs:
if getattr(self, attr) is None:
if not isinstance(dtype, tuple):
setattr(self, attr, np.empty(0, dtype=dtype))
@@ -65,7 +65,7 @@ def __post_init__(self):
setattr(self, attr, np.empty((0, width), dtype=dtype))
# Provide default values to the fixed-length array attributes
- for attr, size in self._fixed_length_attrs.items():
+ for attr, size in self._fixed_length_attrs:
if getattr(self, attr) is None:
if not isinstance(size, tuple):
dtype = np.float32
@@ -103,7 +103,7 @@ def __eq__(self, other):
if self.__class__ != other.__class__:
return False
- # Check that all attributes are identical
+ # Check that all base attributes are identical
for k, v in self.__dict__.items():
if np.isscalar(v):
# For scalars, regular comparison will do
@@ -133,17 +133,28 @@ def set_precision(self, precision):
dtype = f'{val.dtype.str[:-1]}{precision}'
setattr(self, attr, val.astype(dtype))
- def as_dict(self):
+ def as_dict(self, lite=False):
"""Returns the data class as dictionary of (key, value) pairs.
+ Parameters
+ ----------
+ lite : bool, default False
+ If `True`, the `_lite_skip_attrs` are dropped
+
Returns
-------
dict
Dictionary of attribute names and their values
"""
- return {k: v for k, v in asdict(self).items() if not k in self._skip_attrs}
+ # Build a list of attributes to skip
+ if not lite:
+ skip_attrs = self._skip_attrs
+ else:
+ skip_attrs = (*self._skip_attrs, *self._lite_skip_attrs)
- def scalar_dict(self, attrs=None):
+ return {k: v for k, v in asdict(self).items() if not k in skip_attrs}
+
+ def scalar_dict(self, attrs=None, lengths=None, lite=False):
"""Returns the data class attributes as a dictionary of scalars.
This is useful when storing data classes in CSV files, which expect
@@ -154,40 +165,54 @@ def scalar_dict(self, attrs=None):
attrs : List[str], optional
List of attribute names to include in the dictionary. If not
specified, all the keys are included.
+ lengths : Dict[str, int], optional
+ Specifies the length of variable-length attributes
+ lite : bool, default False
+ If `True`, the `_lite_skip_attrs` are dropped
"""
# Loop over the attributes of the data class
+ lengths = lengths or {}
scalar_dict, found = {}, []
- for attr, value in self.as_dict().items():
+ for attr, value in self.as_dict(lite).items():
# If the attribute is not requested, skip
if attrs is not None and attr not in attrs:
continue
else:
found.append(attr)
- # If the attribute is long-form attribute, skip it
- if (attr not in self._binarize_attrs and
- (attr in self._skip_attrs or attr in self._var_length_attrs)):
- continue
-
# Dispatch
if np.isscalar(value):
# If the attribute is a scalar, store as is
scalar_dict[attr] = value
- elif attr in self._binarize_attrs:
- # If the list is to be binarized, do it
- scalar_dict[attr] = int(np.sum(2**value))
-
elif attr in (self._pos_attrs + self._vec_attrs):
# If the attribute is a position or vector, expand with axis
for i, v in enumerate(value):
scalar_dict[f'{attr}_{self._axes[i]}'] = v
- elif attr in self._fixed_length_attrs:
- # If the attribute is a fixed length array, expand with index
+ elif attr in self.fixed_length_attrs:
+ # If the attribute is a fixed-length array, expand with index
for i, v in enumerate(value):
scalar_dict[f'{attr}_{i}'] = v
+ elif attr in self.var_length_attrs:
+ if attr in lengths:
+ # If the attribute is a variable-length array with a length
+ # provided, resize it to match that length and store it
+ for i in range(lengths[attr]):
+ if i < len(value):
+ scalar_dict[f'{attr}_{i}'] = value[i]
+ else:
+ scalar_dict[f'{attr}_{i}'] = None
+
+ else:
+ # If the attribute is a variable-length array of
+ # indeterminate length, do not store it
+ assert attrs is None or attr not in attrs, (
+ f"Cannot cast {attr} to scalars. To cast a variable-"
+ "length array, must provide a fixed length.")
+ continue
+
else:
raise ValueError(
f"Cannot expand the `{attr}` attribute of "
@@ -203,36 +228,36 @@ def scalar_dict(self, attrs=None):
@property
def fixed_length_attrs(self):
- """Fetches the dictionary of fixed-length array attributes.
+ """Fetches the dictionary of fixed-length array attributes as a dictionary.
Returns
-------
Dict[str, int]
- Dictioary which maps fixed-length attributes onto their length
+ Dictionary which maps fixed-length attributes onto their length
"""
- return self._fixed_length_attrs
+ return dict(self._fixed_length_attrs)
@property
def var_length_attrs(self):
- """Fetches the list of variable-length array attributes.
+ """Fetches the list of variable-length array attributes as a dictionary.
Returns
-------
Dict[str, type]
Dictionary which maps variable-length attributes onto their type
"""
- return self._fixed_length_attrs
+ return dict(self._var_length_attrs)
@property
def enum_attrs(self):
- """Fetches the list of enumerated attributes.
+ """Fetches the list of enumerated attributes as a dictionary.
Returns
-------
Dict[int, Dict[int, str]]
Dictionary which maps names onto enumerator descriptors
"""
- return self._enum_attrs
+ return {k: dict(v) for k, v in self._enum_attrs}
@property
def skip_attrs(self):
@@ -245,6 +270,17 @@ def skip_attrs(self):
"""
return self._skip_attrs
+ @property
+ def lite_skip_attrs(self):
+ """Fetches the list of attributes to not store to lite file.
+
+ Returns
+ -------
+ List[str]
+ List of attributes to exclude from the storage process
+ """
+ return self._lite_skip_attrs
+
@dataclass(eq=False)
class PosDataBase(DataBase):
diff --git a/spine/data/crt.py b/spine/data/crt.py
index 0cafd1bb..683f2d21 100644
--- a/spine/data/crt.py
+++ b/spine/data/crt.py
@@ -63,19 +63,19 @@ class CRTHit(PosDataBase):
units: str = 'cm'
# Fixed-length attributes
- _fixed_length_attrs = {'center': 3, 'width': 3}
+ _fixed_length_attrs = (('center', 3), ('width', 3))
# Variable-length attributes
- _var_length_attrs = {'feb_id': np.ubyte}
+ _var_length_attrs = (('feb_id', np.ubyte),)
# Attributes specifying coordinates
- _pos_attrs = ['position']
+ _pos_attrs = ('position',)
# Attributes specifying vector components
- _vec_attrs = ['width']
+ _vec_attrs = ('width',)
# String attributes
- _str_attrs = ['tagger', 'units']
+ _str_attrs = ('tagger', 'units')
@classmethod
def from_larcv(cls, crthit):
@@ -92,7 +92,7 @@ def from_larcv(cls, crthit):
CRT hit object
"""
# Get the physical center and width of the CRT hit
- axes = ['x', 'y', 'z']
+ axes = ('x', 'y', 'z')
center = np.array([getattr(crthit, f'{a}_pos')() for a in axes])
width = np.array([getattr(crthit, f'{a}_err')() for a in axes])
@@ -100,7 +100,8 @@ def from_larcv(cls, crthit):
feb_id = np.array([ord(c) for c in crthit.feb_id()], dtype=np.ubyte)
# Get the number of PEs per FEB channel
- # TODO: This is a dictionary of dictionaries, hard to store
+ # TODO: This is a dictionary of dictionaries, need to figure out
+ # how to unpack in a sensible manner
return cls(id=crthit.id(), plane=crthit.plane(),
tagger=crthit.tagger(), feb_id=feb_id, ts0_s=crthit.ts0_s(),
diff --git a/spine/data/meta.py b/spine/data/meta.py
index b64a8120..0ef39487 100644
--- a/spine/data/meta.py
+++ b/spine/data/meta.py
@@ -34,12 +34,12 @@ class Meta(DataBase):
count: np.ndarray = None
# Fixed-length attributes
- _fixed_length_attrs = {
- 'lower': 3, 'upper': 3, 'size': 3, 'count': (3, np.int64)
- }
+ _fixed_length_attrs = (
+ ('lower', 3), ('upper', 3), ('size', 3), ('count', (3, np.int64))
+ )
# Attributes specifying vector components
- _vec_attrs = ['lower', 'upper', 'size', 'count']
+ _vec_attrs = ('lower', 'upper', 'size', 'count')
@property
def dimension(self):
diff --git a/spine/data/neutrino.py b/spine/data/neutrino.py
index 67b2e3fd..e66532e1 100644
--- a/spine/data/neutrino.py
+++ b/spine/data/neutrino.py
@@ -63,6 +63,8 @@ class Neutrino(PosDataBase):
Energy transfer (Q0) in GeV
lepton_p : float
Absolute momentum of the lepton
+ distance_travel : float
+ True amount of distance traveled by the neutrino before interacting
theta : float
Angle between incoming and outgoing leptons in radians
creation_process : str
@@ -96,6 +98,7 @@ class Neutrino(PosDataBase):
momentum_transfer_mag: float = -1.
energy_transfer: float = -1.
lepton_p: float = -1.
+ distance_travel: float = -1.
theta: float = -1.
creation_process: str = ''
position: np.ndarray = None
@@ -103,23 +106,23 @@ class Neutrino(PosDataBase):
units: str = 'cm'
# Fixed-length attributes
- _fixed_length_attrs = {'position': 3, 'momentum': 3}
+ _fixed_length_attrs = (('position', 3), ('momentum', 3))
# Attributes specifying coordinates
- _pos_attrs = ['position']
+ _pos_attrs = ('position',)
# Attributes specifying vector components
- _vec_attrs = ['momentum']
+ _vec_attrs = ('momentum',)
# Enumerated attributes
- _enum_attrs = {
- 'current_type': {v : k for k, v in NU_CURR_TYPE.items()},
- 'interaction_mode': {v : k for k, v in NU_INT_TYPE.items()},
- 'interaction_type': {v : k for k, v in NU_INT_TYPE.items()}
- }
+ _enum_attrs = (
+ ('current_type', tuple((v, k) for k, v in NU_CURR_TYPE.items())),
+ ('interaction_mode', tuple((v, k) for k, v in NU_INT_TYPE.items())),
+ ('interaction_type', tuple((v, k) for k, v in NU_INT_TYPE.items()))
+ )
# String attributes
- _str_attrs = ['creation_process']
+ _str_attrs = ('creation_process',)
@classmethod
def from_larcv(cls, neutrino):
@@ -139,13 +142,14 @@ def from_larcv(cls, neutrino):
obj_dict = {}
# Load the scalar attributes
- for key in ['id', 'interaction_id', 'mct_index', 'nu_track_id',
+ for key in ('id', 'interaction_id', 'mct_index', 'nu_track_id',
'lepton_track_id', 'pdg_code', 'lepton_pdg_code',
'current_type', 'interaction_mode', 'interaction_type',
'target', 'nucleon', 'quark', 'energy_init',
'hadronic_invariant_mass', 'bjorken_x', 'inelasticity',
'momentum_transfer', 'momentum_transfer_mag',
- 'energy_transfer', 'lepton_p', 'theta', 'creation_process']:
+ 'energy_transfer', 'lepton_p', 'distance_travel',
+ 'theta', 'creation_process'):
if not hasattr(neutrino, key):
warn(f"The LArCV Neutrino object is missing the {key} "
"attribute. It will miss from the Neutrino object.")
@@ -163,7 +167,7 @@ def from_larcv(cls, neutrino):
[getattr(vector, a)() for a in pos_attrs], dtype=np.float32)
# Load the momentum attribute (special care needed)
- mom_attrs = ['px', 'py', 'pz']
+ mom_attrs = ('px', 'py', 'pz')
if not hasattr(neutrino, 'momentum'):
warn("The LArCV Neutrino object is missing the momentum "
"attribute. It will miss from the Neutrino object.")
diff --git a/spine/data/optical.py b/spine/data/optical.py
index 2eb090a1..0515b552 100644
--- a/spine/data/optical.py
+++ b/spine/data/optical.py
@@ -20,6 +20,8 @@ class Flash(PosDataBase):
----------
id : int
Index of the flash in the list
+ volume_id : int
+ Index of the optical volume in which the flahs was recorded
time : float
Time with respect to the trigger in microseconds
time_width : float
@@ -46,6 +48,7 @@ class Flash(PosDataBase):
Units in which the position coordinates are expressed
"""
id: int = -1
+ volume_id: int = -1
frame: int = -1
in_beam_frame: bool = False
on_beam_time: bool = False
@@ -60,16 +63,16 @@ class Flash(PosDataBase):
units: str = 'cm'
# Fixed-length attributes
- _fixed_length_attrs = {'center': 3, 'width': 3}
+ _fixed_length_attrs = (('center', 3), ('width', 3))
# Variable-length attributes
- _var_length_attrs = {'pe_per_ch': np.float32}
+ _var_length_attrs = (('pe_per_ch', np.float32),)
# Attributes specifying coordinates
- _pos_attrs = ['center']
+ _pos_attrs = ('center',)
# Attributes specifying vector components
- _vec_attrs = ['width']
+ _vec_attrs = ('width',)
@classmethod
def from_larcv(cls, flash):
@@ -86,14 +89,20 @@ def from_larcv(cls, flash):
Flash object
"""
# Get the physical center and width of the flash
- axes = ['x', 'y', 'z']
+ axes = ('x', 'y', 'z')
center = np.array([getattr(flash, f'{a}Center')() for a in axes])
width = np.array([getattr(flash, f'{a}Width')() for a in axes])
# Get the number of PEs per optical channel
pe_per_ch = np.array(list(flash.PEPerOpDet()), dtype=np.float32)
- return cls(id=flash.id(), frame=flash.frame(),
+ # Get the volume ID, if it is filled (TODO: simplify with update)
+ volume_id = -1
+ for attr in ('tpc', 'volume_id'):
+ if hasattr(flash, attr):
+ volume_id = getattr(flash, attr)()
+
+ return cls(id=flash.id(), volume_id=volume_id, frame=flash.frame(),
in_beam_frame=flash.inBeamFrame(),
on_beam_time=flash.onBeamTime(), time=flash.time(),
time_abs=flash.absTime(), time_width=flash.timeWidth(),
diff --git a/spine/data/out/base.py b/spine/data/out/base.py
index 6a4df0ca..7a7688ba 100644
--- a/spine/data/out/base.py
+++ b/spine/data/out/base.py
@@ -68,24 +68,26 @@ class OutBase(PosDataBase):
units: str = 'cm'
# Variable-length attribtues
- _var_length_attrs = {
- 'index': np.int64, 'depositions': np.float32,
- 'match_ids': np.int64, 'match_overlaps': np.float32,
- 'points': (3, np.float32), 'sources': (2, np.int64),
- 'module_ids': np.int64
- }
-
- # Attributes to be binarized to form an integer from a variable-length array
- _binarize_attrs = ['module_ids']
+ _var_length_attrs = (
+ ('index', np.int64), ('depositions', np.float32),
+ ('match_ids', np.int64), ('match_overlaps', np.float32),
+ ('points', (3, np.float32)), ('sources', (2, np.int64)),
+ ('module_ids', np.int64)
+ )
# Boolean attributes
- _bool_attrs = ['is_contained', 'is_matched', 'is_cathode_crosser', 'is_truth']
+ _bool_attrs = (
+ 'is_contained', 'is_matched', 'is_cathode_crosser', 'is_truth'
+ )
# Attributes to concatenate when merging objects
- _cat_attrs = ['index', 'points', 'depositions', 'sources']
+ _cat_attrs = ('index', 'points', 'depositions', 'sources')
+
+ # Attributes that must never be stored to file
+ _skip_attrs = ('points', 'depositions', 'sources')
- # Attributes that should not be stored
- _skip_attrs = ['points', 'depositions', 'sources']
+ # Attributes that must not be stored to file when storing lite files
+ _lite_skip_attrs = ('index',)
@property
def size(self):
@@ -205,24 +207,30 @@ class TruthBase(OutBase):
is_truth: bool = True
# Variable-length attribtues
- _var_length_attrs = {
- 'depositions_q': np.float32, 'index_adapt': np.int64,
- 'depositions_adapt': np.float32, 'depositions_adapt_q': np.float32,
- 'index_g4': np.int64, 'depositions_g4': np.int64,
- 'points_adapt': (3, np.float32), 'sources_adapt': (2, np.int64),
- 'points_g4': (3, np.float32), **OutBase._var_length_attrs
- }
+ _var_length_attrs = (
+ ('depositions_q', np.float32), ('index_adapt', np.int64),
+ ('depositions_adapt', np.float32), ('depositions_adapt_q', np.float32),
+ ('index_g4', np.int64), ('depositions_g4', np.int64),
+ ('points_adapt', (3, np.float32)), ('sources_adapt', (2, np.int64)),
+ ('points_g4', (3, np.float32)), *OutBase._var_length_attrs
+ )
# Attributes to concatenate when merging objects
- _cat_attrs = ['depositions_q', 'index_adapt', 'points_adapt',
- 'depositions_adapt', 'depositions_adapt_q', 'sources_adapt',
- 'index_g4', 'points_g4', 'depositions_g4',
- *OutBase._cat_attrs]
-
- # Attributes that should not be stored
- _skip_attrs = ['depositions_q', 'points_adapt', 'depositions_adapt',
- 'depositions_adapt_q', 'sources_adapt', 'depositions_g4',
- 'points_g4', 'depositions_g4', *OutBase._skip_attrs]
+ _cat_attrs = (
+ 'depositions_q', 'index_adapt', 'points_adapt', 'depositions_adapt',
+ 'depositions_adapt_q', 'sources_adapt', 'index_g4', 'points_g4',
+ 'depositions_g4', *OutBase._cat_attrs
+ )
+
+ # Attributes that must never be stored to file
+ _skip_attrs = (
+ 'depositions_q', 'points_adapt', 'depositions_adapt',
+ 'depositions_adapt_q', 'sources_adapt', 'depositions_g4',
+ 'points_g4', 'depositions_g4', *OutBase._skip_attrs
+ )
+
+ # Attributes that must not be stored to file when storing lite files
+ _lite_skip_attrs = ('index_adapt', 'index_g4', *OutBase._lite_skip_attrs)
@property
def size_adapt(self):
diff --git a/spine/data/out/fragment.py b/spine/data/out/fragment.py
index 1c096cfb..32b3faf3 100644
--- a/spine/data/out/fragment.py
+++ b/spine/data/out/fragment.py
@@ -52,22 +52,24 @@ class FragmentBase:
end_dir: np.ndarray = None
# Fixed-length attributes
- _fixed_length_attrs = {'start_point': 3, 'end_point': 3, 'start_dir': 3,
- 'end_dir': 3}
+ _fixed_length_attrs = (
+ ('start_point', 3), ('end_point', 3), ('start_dir', 3),
+ ('end_dir', 3)
+ )
# Attributes specifying coordinates
- _pos_attrs = ['start_point', 'end_point']
+ _pos_attrs = ('start_point', 'end_point')
# Attributes specifying vector components
- _vec_attrs = ['start_dir', 'end_dir']
+ _vec_attrs = ('start_dir', 'end_dir')
# Boolean attributes
- _bool_attrs = ['is_primary']
+ _bool_attrs = ('is_primary',)
# Enumerated attributes
- _enum_attrs = {
- 'shape': {v : k for k, v in SHAPE_LABELS.items()}
- }
+ _enum_attrs = (
+ ('shape', tuple((v, k) for k, v in SHAPE_LABELS.items())),
+ )
def __str__(self):
"""Human-readable string representation of the fragment object.
@@ -96,12 +98,16 @@ class RecoFragment(FragmentBase, RecoBase):
primary_scores: np.ndarray = None
# Fixed-length attributes
- _fixed_length_attrs = {
- 'primary_scores': 2,
- **FragmentBase._fixed_length_attrs}
+ _fixed_length_attrs = (
+ ('primary_scores', 2),
+ *FragmentBase._fixed_length_attrs
+ )
# Boolean attributes
- _bool_attrs = [*RecoBase._bool_attrs, *FragmentBase._bool_attrs]
+ _bool_attrs = (
+ *RecoBase._bool_attrs,
+ *FragmentBase._bool_attrs
+ )
def __str__(self):
"""Human-readable string representation of the fragment object.
@@ -143,30 +149,37 @@ class TruthFragment(Particle, FragmentBase, TruthBase):
reco_end_dir: np.ndarray = None
# Fixed-length attributes
- _fixed_length_attrs = {
- **FragmentBase._fixed_length_attrs,
- **Particle._fixed_length_attrs,
- 'reco_start_dir': 3, 'reco_end_dir': 3
- }
+ _fixed_length_attrs = (
+ ('reco_start_dir', 3), ('reco_end_dir', 3),
+ *FragmentBase._fixed_length_attrs,
+ *Particle._fixed_length_attrs
+ )
# Variable-length attributes
- _var_length_attrs = {
- **TruthBase._var_length_attrs,
- **Particle._var_length_attrs,
- 'children_counts': np.int32
- }
+ _var_length_attrs = (
+ ('children_counts', np.int32),
+ *TruthBase._var_length_attrs,
+ *Particle._var_length_attrs
+ )
# Attributes specifying coordinates
- _pos_attrs = [*FragmentBase._pos_attrs, *Particle._pos_attrs]
+ _pos_attrs = (
+ *FragmentBase._pos_attrs,
+ *Particle._pos_attrs
+ )
# Attributes specifying vector components
- _vec_attrs = [
- *FragmentBase._vec_attrs, *Particle._vec_attrs,
- 'reco_start_dir', 'reco_end_dir'
- ]
+ _vec_attrs = (
+ 'reco_start_dir', 'reco_end_dir',
+ *FragmentBase._vec_attrs,
+ *Particle._vec_attrs
+ )
# Boolean attributes
- _bool_attrs = [*TruthBase._bool_attrs, *FragmentBase._bool_attrs]
+ _bool_attrs = (
+ *TruthBase._bool_attrs,
+ *FragmentBase._bool_attrs
+ )
def __str__(self):
"""Human-readable string representation of the fragment object.
diff --git a/spine/data/out/interaction.py b/spine/data/out/interaction.py
index c90bb3da..7cf4993a 100644
--- a/spine/data/out/interaction.py
+++ b/spine/data/out/interaction.py
@@ -38,10 +38,12 @@ class InteractionBase:
Whether this interaction vertex is inside the fiducial volume
is_flash_matched : bool
True if the interaction was matched to an optical flash
- flash_id : int
- Index of the optical flash the interaction was matched to
- flash_time : float
- Time at which the flash occurred in nanoseconds
+ flash_ids : np.ndarray
+ (F) Indices of the optical flashes the interaction is matched to
+ flash_volume_ids : np.ndarray
+ (F) Indices of the optical volumes the flashes where recorded in
+ flash_times : np.ndarray
+ (F) Times at which the flashes occurred in microseconds
flash_total_pe : float
Total number of photoelectrons associated with the flash
flash_hypo_pe : float
@@ -57,31 +59,34 @@ class InteractionBase:
vertex: np.ndarray = None
is_fiducial: bool = False
is_flash_matched: bool = False
- flash_id: int = -1
- flash_time: float = -np.inf
+ flash_ids: np.ndarray = None
+ flash_volume_ids: np.ndarray = None
+ flash_times: np.ndarray = None
flash_total_pe: float = -1.
flash_hypo_pe: float = -1.
topology: str = None
# Fixed-length attributes
- _fixed_length_attrs = {
- 'vertex': 3, 'particle_counts': len(PID_LABELS) - 1,
- 'primary_particle_counts': len(PID_LABELS) - 1
- }
+ _fixed_length_attrs = (
+ ('vertex', 3), ('particle_counts', len(PID_LABELS) - 1),
+ ('primary_particle_counts', len(PID_LABELS) - 1)
+ )
# Variable-length attributes as (key, dtype) pairs
- _var_length_attrs = {
- 'particles': object, 'particle_ids': np.int32
- }
+ _var_length_attrs = (
+ ('particles', object), ('particle_ids', np.int32),
+ ('flash_ids', np.int32), ('flash_volume_ids', np.int32),
+ ('flash_times', np.int32)
+ )
# Attributes specifying coordinates
- _pos_attrs = ['vertex']
+ _pos_attrs = ('vertex',)
# Boolean attributes
- _bool_attrs = ['is_fiducial', 'is_flash_matched']
+ _bool_attrs = ('is_fiducial', 'is_flash_matched')
- # Attributes that should not be stored
- _skip_attrs = ['particles']
+ # Attributes that must never be stored to file
+ _skip_attrs = ('particles',)
def __str__(self):
"""Human-readable string representation of the interaction object.
@@ -217,16 +222,23 @@ def from_particles(cls, particles):
class RecoInteraction(InteractionBase, RecoBase):
"""Reconstructed interaction information."""
- # Attributes that should not be stored
- _skip_attrs = [*RecoBase._skip_attrs, *InteractionBase._skip_attrs]
+ # Attributes that must never be stored to file
+ _skip_attrs = (
+ *RecoBase._skip_attrs,
+ *InteractionBase._skip_attrs
+ )
# Variable-length attributes
- _var_length_attrs = {
- **RecoBase._var_length_attrs, **InteractionBase._var_length_attrs
- }
+ _var_length_attrs = (
+ *RecoBase._var_length_attrs,
+ *InteractionBase._var_length_attrs
+ )
# Boolean attributes
- _bool_attrs = [*RecoBase._bool_attrs, *InteractionBase._bool_attrs]
+ _bool_attrs = (
+ *RecoBase._bool_attrs,
+ *InteractionBase._bool_attrs
+ )
def __str__(self):
"""Human-readable string representation of the interaction object.
@@ -258,28 +270,36 @@ class TruthInteraction(Neutrino, InteractionBase, TruthBase):
reco_vertex: np.ndarray = None
# Fixed-length attributes
- _fixed_length_attrs = {
- **Neutrino._fixed_length_attrs,
- **InteractionBase._fixed_length_attrs,
- 'reco_vertex': 3,
- }
+ _fixed_length_attrs = (
+ ('reco_vertex', 3),
+ *Neutrino._fixed_length_attrs,
+ *InteractionBase._fixed_length_attrs
+ )
# Variable-length attributes
- _var_length_attrs = {
- **TruthBase._var_length_attrs,
- **InteractionBase._var_length_attrs
- }
+ _var_length_attrs = (
+ *TruthBase._var_length_attrs,
+ *InteractionBase._var_length_attrs
+ )
# Attributes specifying coordinates
- _pos_attrs = [
- *InteractionBase._pos_attrs, *Neutrino._pos_attrs, 'reco_vertex'
- ]
+ _pos_attrs = (
+ 'reco_vertex',
+ *InteractionBase._pos_attrs,
+ *Neutrino._pos_attrs
+ )
# Boolean attributes
- _bool_attrs = [*TruthBase._bool_attrs, *InteractionBase._bool_attrs]
-
- # Attributes that should not be stored
- _skip_attrs = [*TruthBase._skip_attrs, *InteractionBase._skip_attrs]
+ _bool_attrs = (
+ *TruthBase._bool_attrs,
+ *InteractionBase._bool_attrs
+ )
+
+ # Attributes that must never be stored to file
+ _skip_attrs = (
+ *TruthBase._skip_attrs,
+ *InteractionBase._skip_attrs
+ )
def __str__(self):
"""Human-readable string representation of the interaction object.
diff --git a/spine/data/out/particle.py b/spine/data/out/particle.py
index 9b9d5b38..9b5599b9 100644
--- a/spine/data/out/particle.py
+++ b/spine/data/out/particle.py
@@ -60,8 +60,12 @@ class ParticleBase:
Kinetic energy reconstructed from the energy depositions alone in MeV
csda_ke : float
Kinetic energy reconstructed from the particle range in MeV
+ csda_ke_per_pid : np.ndarray
+ Same as `csda_ke` but for every available track PID hypothesis
mcs_ke : float
Kinetic energy reconstructed using the MCS method in MeV
+ mcs_ke_per_pid : np.ndarray
+ Same as `mcs_ke` but for every available track PID hypothesis
momentum : np.ndarray
3-momentum of the particle at the production point in MeV/c
p : float
@@ -87,39 +91,43 @@ class ParticleBase:
ke: float = -1.
calo_ke: float = -1.
csda_ke: float = -1.
+ csda_ke_per_pid: np.ndarray = None
mcs_ke: float = -1.
+ mcs_ke_per_pid: np.ndarray = None
momentum: np.ndarray = None
p: float = None
is_valid: bool = True
# Fixed-length attributes
- _fixed_length_attrs = {
- 'start_point': 3, 'end_point': 3, 'start_dir': 3, 'end_dir': 3,
- 'momentum': 3
- }
+ _fixed_length_attrs = (
+ ('start_point', 3), ('end_point', 3), ('start_dir', 3),
+ ('end_dir', 3), ('momentum', 3),
+ ('csda_ke_per_pid', len(PID_LABELS) - 1),
+ ('mcs_ke_per_pid', len(PID_LABELS) - 1)
+ )
# Variable-length attributes as (key, dtype) pairs
- _var_length_attrs = {
- 'fragments': object, 'fragment_ids': np.int32
- }
+ _var_length_attrs = (
+ ('fragments', object), ('fragment_ids', np.int32)
+ )
# Attributes specifying coordinates
- _pos_attrs = ['start_point', 'end_point']
+ _pos_attrs = ('start_point', 'end_point')
# Attributes specifying vector components
- _vec_attrs = ['start_dir', 'end_dir', 'momentum']
+ _vec_attrs = ('start_dir', 'end_dir', 'momentum')
# Boolean attributes
- _bool_attrs = ['is_primary', 'is_valid']
+ _bool_attrs = ('is_primary', 'is_valid')
# Enumerated attributes
- _enum_attrs = {
- 'shape': {v : k for k, v in SHAPE_LABELS.items()},
- 'pid': {v : k for k, v in PID_LABELS.items()}
- }
+ _enum_attrs = (
+ ('shape', tuple((v, k) for k, v in SHAPE_LABELS.items())),
+ ('pid', tuple((v, k) for k, v in PID_LABELS.items()))
+ )
- # Attributes that should not be stored
- _skip_attrs = ['fragments', 'ppn_points']
+ # Attributes that must never be stored to file
+ _skip_attrs = ('fragments', 'ppn_points')
def __str__(self):
"""Human-readable string representation of the particle object.
@@ -196,29 +204,45 @@ class RecoParticle(ParticleBase, RecoBase):
(M) List of indexes of PPN points associated with this particle
ppn_points : np.ndarray
(M, 3) List of PPN points tagged to this particle
+ vertex_distance: float
+ Set-to-point distance between all particle points and the parent
+ interaction vertex. (untis of cm)
+ shower_split_angle: float
+ Estimate of the opening angle of the shower. If particle is not a
+ shower, then this is set to -1. (units of degrees)
"""
pid_scores: np.ndarray = None
primary_scores: np.ndarray = None
ppn_ids: np.ndarray = None
ppn_points: np.ndarray = None
+ vertex_distance: float = -1.
+ shower_split_angle: float = -1.
# Fixed-length attributes
- _fixed_length_attrs = {
- 'pid_scores': len(PID_LABELS) - 1,
- 'primary_scores': 2,
- **ParticleBase._fixed_length_attrs}
+ _fixed_length_attrs = (
+ ('pid_scores', len(PID_LABELS) - 1), ('primary_scores', 2),
+ *ParticleBase._fixed_length_attrs
+ )
# Variable-length attributes
- _var_length_attrs = {
- **RecoBase._var_length_attrs, **ParticleBase._var_length_attrs,
- 'ppn_ids': np.int32, 'ppn_points': (3, np.float32)
- }
+ _var_length_attrs = (
+ ('ppn_ids', np.int32), ('ppn_points', (3, np.float32)),
+ *RecoBase._var_length_attrs,
+ *ParticleBase._var_length_attrs
+ )
# Boolean attributes
- _bool_attrs = [*RecoBase._bool_attrs, *ParticleBase._bool_attrs]
-
- # Attributes that should not be stored
- _skip_attrs = [*RecoBase._skip_attrs, *ParticleBase._skip_attrs, 'ppn_points']
+ _bool_attrs = (
+ *RecoBase._bool_attrs,
+ *ParticleBase._bool_attrs
+ )
+
+ # Attributes that must never be stored to file
+ _skip_attrs = (
+ 'ppn_points',
+ *RecoBase._skip_attrs,
+ *ParticleBase._skip_attrs
+ )
def __str__(self):
"""Human-readable string representation of the particle object.
@@ -351,6 +375,30 @@ def momentum(self):
def momentum(self, momentum):
pass
+ @property
+ def reco_ke(self):
+ """Alias for `ke`, to match nomenclature in truth."""
+ return self.ke
+
+ @property
+ def reco_momentum(self):
+ """Alias for `momentum`, to match nomenclature in truth."""
+ return self.momentum
+
+ @property
+ def reco_length(self):
+ """Alias for `length`, to match nomenclature in truth."""
+ return self.length
+
+ @property
+ def reco_start_dir(self):
+ """Alias for `start_dir`, to match nomenclature in truth."""
+ return self.start_dir
+
+ @property
+ def reco_end_dir(self):
+ """Alias for `end_dir`, to match nomenclature in truth."""
+ return self.end_dir
@dataclass(eq=False)
@inherit_docstring(TruthBase, ParticleBase)
@@ -373,42 +421,58 @@ class TruthParticle(Particle, ParticleBase, TruthBase):
reco_end_dir : np.ndarray
(3) Particle direction estimate w.r.t. the end point (only assigned
to track objects)
+ reco_ke : float
+ Best-guess reconstructed KE of the particle
+ reco_momentum : np.ndarray
+ Best-guess reconstructed momentum of the particle
"""
orig_interaction_id: int = -1
children_counts: np.ndarray = None
reco_length: float = -1.
reco_start_dir: np.ndarray = None
reco_end_dir: np.ndarray = None
+ reco_ke: float = -1.
+ reco_momentum: np.ndarray = None
# Fixed-length attributes
- _fixed_length_attrs = {
- **ParticleBase._fixed_length_attrs,
- **Particle._fixed_length_attrs,
- 'reco_start_dir': 3, 'reco_end_dir': 3
- }
+ _fixed_length_attrs = (
+ ('reco_start_dir', 3), ('reco_end_dir', 3), ('reco_momentum', 3),
+ *ParticleBase._fixed_length_attrs,
+ *Particle._fixed_length_attrs
+ )
# Variable-length attributes
- _var_length_attrs = {
- **TruthBase._var_length_attrs,
- **ParticleBase._var_length_attrs,
- **Particle._var_length_attrs,
- 'children_counts': np.int32
- }
+ _var_length_attrs = (
+ ('children_counts', np.int32),
+ *TruthBase._var_length_attrs,
+ *ParticleBase._var_length_attrs,
+ *Particle._var_length_attrs
+ )
# Attributes specifying coordinates
- _pos_attrs = [*ParticleBase._pos_attrs, *Particle._pos_attrs]
+ _pos_attrs = (
+ *ParticleBase._pos_attrs,
+ *Particle._pos_attrs
+ )
# Attributes specifying vector components
- _vec_attrs = [
- *ParticleBase._vec_attrs, *Particle._vec_attrs,
- 'reco_start_dir', 'reco_end_dir'
- ]
+ _vec_attrs = (
+ 'reco_start_dir', 'reco_end_dir', 'reco_momentum',
+ *ParticleBase._vec_attrs,
+ *Particle._vec_attrs
+ )
# Boolean attributes
- _bool_attrs = [*TruthBase._bool_attrs, *ParticleBase._bool_attrs]
+ _bool_attrs = (
+ *TruthBase._bool_attrs,
+ *ParticleBase._bool_attrs
+ )
- # Attributes that should not be stored
- _skip_attrs = [*TruthBase._skip_attrs, *ParticleBase._skip_attrs]
+ # Attributes that must never be stored to file
+ _skip_attrs = (
+ *TruthBase._skip_attrs,
+ *ParticleBase._skip_attrs
+ )
def __str__(self):
"""Human-readable string representation of the particle object.
@@ -483,3 +547,57 @@ def ke(self):
@ke.setter
def ke(self, ke):
pass
+
+ @property
+ def reco_ke(self):
+ """Best-guess reconstructed kinetic energy in MeV.
+
+ Uses calorimetry for EM activity and this order for track:
+ - CSDA-based estimate if it is available
+ - MCS-based estimate if it is available
+ - Calorimetry if all else fails
+
+ Returns
+ -------
+ float
+ Best-guess kinetic energy
+ """
+ if self.shape != TRACK_SHP:
+ # If a particle is not a track, can only use calorimetry
+ return self.calo_ke
+
+ else:
+ # If a particle is a track, pick CSDA for contained tracks and
+ # pick MCS for uncontained tracks, unless specified otherwise
+ if self.is_contained and self.csda_ke > 0.0:
+ return self.csda_ke
+ elif not self.is_contained and self.mcs_ke > 0.0:
+ return self.mcs_ke
+ else:
+ return self.calo_ke
+
+ @reco_ke.setter
+ def reco_ke(self, reco_ke):
+ pass
+
+ @property
+ def reco_momentum(self):
+ """Best-guess reconstructed momentum in MeV/c.
+
+ Returns
+ -------
+ np.ndarray
+ (3) Momentum vector
+ """
+ ke = self.reco_ke
+ if ke >= 0.0 and self.reco_start_dir[0] != -np.inf and self.pid in PID_MASSES:
+ mass = PID_MASSES[self.pid]
+ mom = np.sqrt(ke**2 + 2 * ke * mass)
+ return mom * self.reco_start_dir
+
+ else:
+ return np.full(3, -np.inf, dtype=np.float32)
+
+ @reco_momentum.setter
+ def reco_momentum(self, reco_momentum):
+ pass
diff --git a/spine/data/particle.py b/spine/data/particle.py
index 3631a208..e3630a74 100644
--- a/spine/data/particle.py
+++ b/spine/data/particle.py
@@ -151,30 +151,33 @@ class Particle(PosDataBase):
units: str = 'cm'
# Fixed-length attributes
- _fixed_length_attrs = {'position': 3, 'end_position': 3,
- 'parent_position': 3, 'ancestor_position': 3,
- 'first_step': 3, 'last_step': 3, 'momentum': 3,
- 'end_momentum': 3}
+ _fixed_length_attrs = (
+ ('position', 3), ('end_position', 3), ('parent_position', 3),
+ ('ancestor_position', 3), ('first_step', 3), ('last_step', 3),
+ ('momentum', 3), ('end_momentum', 3)
+ )
# Variable-length attributes
- _var_length_attrs = {'children_id': np.int64}
+ _var_length_attrs = (('children_id', np.int64),)
# Attributes specifying coordinates
- _pos_attrs = ['position', 'end_position', 'parent_position',
- 'ancestor_position', 'first_step', 'last_step']
+ _pos_attrs = (
+ 'position', 'end_position', 'parent_position', 'ancestor_position',
+ 'first_step', 'last_step'
+ )
# Attributes specifying vector components
- _vec_attrs = ['momentum', 'end_momentum']
+ _vec_attrs = ('momentum', 'end_momentum')
# Enumerated attributes
- _enum_attrs = {
- 'shape': {v : k for k, v in SHAPE_LABELS.items()},
- 'pid': {v : k for k, v in PID_LABELS.items()}
- }
+ _enum_attrs = (
+ ('shape', tuple((v, k) for k, v in SHAPE_LABELS.items())),
+ ('pid', tuple((v, k) for k, v in PID_LABELS.items()))
+ )
# String attributes
- _str_attrs = ['creation_process', 'parent_creation_process',
- 'ancestor_creation_process']
+ _str_attrs = ('creation_process', 'parent_creation_process',
+ 'ancestor_creation_process')
@property
def p(self):
@@ -242,12 +245,12 @@ def from_larcv(cls, particle):
obj_dict = {}
# Load the scalar attributes
- for prefix in ['', 'parent_', 'ancestor_']:
- for key in ['track_id', 'pdg_code', 'creation_process', 't']:
+ for prefix in ('', 'parent_', 'ancestor_'):
+ for key in ('track_id', 'pdg_code', 'creation_process', 't'):
obj_dict[prefix+key] = getattr(particle, prefix+key)()
- for key in ['id', 'gen_id', 'group_id', 'interaction_id', 'parent_id',
+ for key in ('id', 'gen_id', 'group_id', 'interaction_id', 'parent_id',
'mct_index', 'mcst_index', 'num_voxels', 'shape',
- 'energy_init', 'energy_deposit', 'distance_travel']:
+ 'energy_init', 'energy_deposit', 'distance_travel'):
if not hasattr(particle, key):
warn(f"The LArCV Particle object is missing the {key} "
"attribute. It will miss from the Particle object.")
@@ -266,8 +269,8 @@ def from_larcv(cls, particle):
# Load the other array attributes (special care needed)
obj_dict['children_id'] = np.asarray(particle.children_id(), dtype=int)
- mom_attrs = ['px', 'py', 'pz']
- for prefix in ['', 'end_']:
+ mom_attrs = ('px', 'py', 'pz')
+ for prefix in ('', 'end_'):
key = prefix + 'momentum'
if not hasattr(particle, key):
warn(f"The LArCV Particle object is missing the {key} "
diff --git a/spine/driver.py b/spine/driver.py
index 2468d3d9..40e9ec54 100644
--- a/spine/driver.py
+++ b/spine/driver.py
@@ -134,10 +134,6 @@ def __init__(self, cfg, rank=None):
self.ana = AnaManager(
ana, log_dir=self.log_dir, prefix=self.log_prefix)
- def __len__(self):
- """Returns the number of events in the underlying reader object."""
- return len(self.reader)
-
def process_config(self, io, base=None, model=None, build=None,
post=None, ana=None, rank=None):
"""Reads the configuration and dumps it to the logger.
@@ -457,6 +453,51 @@ def initialize_log(self):
log_path = os.path.join(self.log_dir, log_name)
self.logger = CSVWriter(log_path, overwrite=self.overwrite_log)
+ def __len__(self):
+ """Returns the number of events in the underlying reader object.
+
+ Returns
+ -------
+ int
+ Number of elements in the underlying loader/reader.
+ """
+ return len(self.reader)
+
+ def __iter__(self):
+ """Resets the counter and returns itself.
+
+ Returns
+ -------
+ object
+ The Driver itself
+ """
+ # If a loader is used, reinitialize it. Otherwise set an entry counter
+ if self.loader is not None:
+ self.loader_iter = iter(self.loader)
+ self.counter = None
+ else:
+ self.counter = 0
+
+ return self
+
+ def __next__(self):
+ """Defines how to process the next entry in the iterator.
+
+ Returns
+ -------
+ Union[dict, List[dict]]
+ Either one combined data dictionary, or one per entry in the batch
+ """
+ # If there are more iterations to go through, return data
+ if self.counter < len(self):
+ data = self.process(self.counter)
+ if self.counter is not None:
+ self.counter += 1
+
+ return data
+
+ raise StopIteration
+
def run(self):
"""Loop over the requested number of iterations, process them."""
# To run the loop, must know how many times it must be done
diff --git a/spine/io/collate.py b/spine/io/collate.py
index bb956fa9..26e63b21 100644
--- a/spine/io/collate.py
+++ b/spine/io/collate.py
@@ -26,7 +26,7 @@ class CollateAll:
name = 'all'
def __init__(self, split=False, target_id=0, detector=None,
- boundary=None, overlay=None, source=None):
+ geometry_file=None, overlay=None, source=None):
"""Initialize the collation parameters.
Parameters
@@ -38,8 +38,8 @@ def __init__(self, split=False, target_id=0, detector=None,
If split is `True`, specifies where to relocate the points
detector : str, optional
Name of a recognized detector to the geometry from
- boundary : str, optional
- Path to a `.npy` boundary file to load the boundaries from
+ geometry_file : str, optional
+ Path to a `.yaml` geometry file to load the geometry from
overlay : dict, optional
Image overlay configuration
source : dict, optional
@@ -50,11 +50,11 @@ def __init__(self, split=False, target_id=0, detector=None,
self.split = split
self.source = None
if split:
- assert (detector is not None) or (boundary is not None), (
+ assert (detector is not None) or (geoemtry_file is not None), (
"If splitting the input per module, must provide detector")
self.target_id = target_id
- self.geo = Geometry(detector, boundary)
+ self.geo = Geometry(detector, geometry_file)
self.source = source
if overlay is not None:
@@ -127,7 +127,7 @@ def __call__(self, batch):
# one batch ID per [batch, volume] pair
voxels_v, features_v, batch_ids_v = [], [], []
counts = np.empty(
- batch_size*self.geo.num_modules, dtype=np.int64)
+ batch_size*self.geo.tpc.num_modules, dtype=np.int64)
for s, sample in enumerate(batch):
# Identify which point belongs to which module
voxels, features, meta = sample[key]
@@ -152,7 +152,7 @@ def __call__(self, batch):
for m, module_index in enumerate(module_indexes):
voxels_v.append(voxels[module_index])
features_v.append(features[module_index])
- idx = self.geo.num_modules * s + m
+ idx = self.geo.tpc.num_modules * s + m
batch_ids_v.append(np.full(len(module_index),
idx, dtype=voxels.dtype))
counts[idx] = len(module_index)
@@ -192,13 +192,13 @@ def __call__(self, batch):
else:
features_v = []
counts = np.empty(
- batch_size*self.geo.num_modules, dtype=np.int64)
+ batch_size*self.geo.tpc.num_modules, dtype=np.int64)
for s, sample in enumerate(batch):
features = sample[key]
- for m in range(self.geo.num_modules):
+ for m in range(self.geo.tpc.num_modules):
module_index = np.where(sources[s][:, 0] == m)[0]
features_v.append(features[module_index])
- idx = self.geo.num_modules * s + m
+ idx = self.geo.tpc.num_modules * s + m
counts[idx] = len(module_index)
tensor = np.vstack(features_v)
diff --git a/spine/io/parse/base.py b/spine/io/parse/base.py
index 0fef7a5e..ecfdb88c 100644
--- a/spine/io/parse/base.py
+++ b/spine/io/parse/base.py
@@ -20,8 +20,12 @@ class ParserBase(ABC):
tree_keys : List[str]
List of file data product name
"""
+
+ # Name of the parser (as specified in the configuration)
name = None
- aliases = []
+
+ # Alternative allowed names of the parser
+ aliases = ()
def __init__(self, dtype, **kwargs):
"""Loops over data product names, stores them.
diff --git a/spine/io/parse/cluster.py b/spine/io/parse/cluster.py
index caea2b72..a5546143 100644
--- a/spine/io/parse/cluster.py
+++ b/spine/io/parse/cluster.py
@@ -36,6 +36,8 @@ class Cluster2DParser(ParserBase):
parser: cluster2d
cluster_event: cluster2d_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'cluster2d'
def __init__(self, dtype, cluster_event, projection_id):
@@ -133,6 +135,8 @@ class Cluster3DParser(ParserBase):
primary_include_mpr: true
break_clusters: false
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'cluster3d'
def __init__(self, dtype, particle_event=None, add_particle_info=False,
@@ -187,7 +191,7 @@ def __init__(self, dtype, particle_event=None, add_particle_info=False,
if self.add_particle_info:
assert particle_event is not None, (
"If `add_particle_info` is `True`, must provide the "
- "`particle_event` argument")
+ "`particle_event` argument.")
def __call__(self, trees):
"""Parse one entry.
@@ -380,6 +384,8 @@ class Cluster3DChargeRescaledParser(Cluster3DParser):
"""Identical to :class:`Cluster3DParser`, but computes rescaled charges
on the fly.
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'cluster3d_rescale_charge'
def __init__(self, sparse_value_event_list, collection_only=False,
@@ -460,6 +466,8 @@ class Cluster3DMultiModuleParser(Cluster3DParser):
"""Identical to :class:`Cluster3DParser`, but fetches charge information
from multiple detector modules independantly.
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'cluster3d_multi_module'
def __call__(self, trees):
diff --git a/spine/io/parse/misc.py b/spine/io/parse/misc.py
index 06aad44f..39a5e642 100644
--- a/spine/io/parse/misc.py
+++ b/spine/io/parse/misc.py
@@ -35,8 +35,12 @@ class MetaParser(ParserBase):
parser: meta
sparse_event: sparse3d_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'meta'
- aliases = ['meta2d', 'meta3d']
+
+ # Alternative allowed names of the parser
+ aliases = ('meta2d', 'meta3d')
def __call__(self, trees):
"""Parse one entry.
@@ -83,7 +87,7 @@ def process(self, sparse_event=None, cluster_event=None):
"""
# Check on the input, pick a source for the metadata
assert (sparse_event is not None) ^ (cluster_event is not None), (
- "Must specify either `sparse_event` or `cluster_event`")
+ "Must specify either `sparse_event` or `cluster_event`.")
ref_event = sparse_event if sparse_event is not None else cluster_event
# Fetch a specific projection, if needed
@@ -104,6 +108,8 @@ class RunInfoParser(ParserBase):
parser: run_info
sparse_event: sparse3d_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'run_info'
def __call__(self, trees):
@@ -135,7 +141,7 @@ def process(self, sparse_event=None, cluster_event=None):
"""
# Check on the input, pick a source for the run information
assert (sparse_event is not None) ^ (cluster_event is not None), (
- "Must specify either `sparse_event` or `cluster_event`")
+ "Must specify either `sparse_event` or `cluster_event`.")
ref_event = sparse_event if sparse_event is not None else cluster_event
return RunInfo.from_larcv(ref_event)
@@ -144,15 +150,25 @@ def process(self, sparse_event=None, cluster_event=None):
class FlashParser(ParserBase):
"""Copy construct Flash and return an array of `Flash`.
+ This parser also takes care of flashes that have been split between their
+ respective optical volumes, provided a `flash_event_list`. This parser
+ assumes that the trees are provided in order of the volume ID they
+ correspond to.
+
.. code-block. yaml
schema:
- flashes_cryoE:
+ flashes:
parser: flash
- flash_event: flash_cryoE
-
+ flash_event_list:
+ - flash_cryoE
+ - flash_cryoW
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'flash'
- aliases = ['opflash']
+
+ # Alternative allowed names of the parser
+ aliases = ('opflash',)
def __call__(self, trees):
"""Parse one entry.
@@ -179,19 +195,32 @@ def process(self, flash_event=None, flash_event_list=None):
List[Flash]
List of optical flash objects
"""
- # Check on the input, aggregate the sources for the optical flashes
+ # Check on the input
assert ((flash_event is not None) ^
(flash_event_list is not None)), (
- "Must specify either `flash_event` or `flash_event_list`")
+ "Must specify either `flash_event` or `flash_event_list`.")
+
+ # Parse flash objects
if flash_event is not None:
+ # If there is a single flash event, parse it as is
flash_list = flash_event.as_vector()
- else:
- flash_list = []
- for flash_event in flash_event_list:
- flash_list.extend(flash_event.as_vector())
+ flashes = [Flash.from_larcv(larcv.Flash(f)) for f in flash_list]
- # Output as a list of LArCV optical flash objects
- flashes = [Flash.from_larcv(larcv.Flash(f)) for f in flash_list]
+ else:
+ # Otherwise, set the volume ID of the flash to the source index
+ # and count the flash index from 0 to the largest number
+ flashes = []
+ idx = 0
+ for volume_id, flash_event in enumerate(flash_event_list):
+ for f in flash_event.as_vector():
+ # Cast and update attributes
+ flash = Flash.from_larcv(f)
+ flash.id = idx
+ flash.volume_id = volume_id
+
+ # Append, increment counter
+ flashes.append(flash)
+ idx += 1
return ObjectList(flashes, Flash())
@@ -205,6 +234,8 @@ class CRTHitParser(ParserBase):
parser: crthit
crthit_event: crthit_crthit
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'crthit'
def __call__(self, trees):
@@ -245,6 +276,8 @@ class TriggerParser(ParserBase):
parser: trigger
trigger_event: trigger_base
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'trigger'
def __call__(self, trees):
diff --git a/spine/io/parse/particle.py b/spine/io/parse/particle.py
index a2a06f72..0b33752c 100644
--- a/spine/io/parse/particle.py
+++ b/spine/io/parse/particle.py
@@ -16,13 +16,13 @@
from spine.utils.globals import TRACK_SHP, PDG_TO_PID, PID_MASSES
from spine.utils.particles import process_particles
-from spine.utils.ppn import get_ppn_labels, image_coordinates
+from spine.utils.ppn import get_ppn_labels, get_vertex_labels, image_coordinates
from spine.utils.conditional import larcv
from .base import ParserBase
__all__ = ['ParticleParser', 'NeutrinoParser', 'ParticlePointParser',
- 'ParticleCoordinateParser', 'ParticleGraphParser',
+ 'ParticleCoordinateParser', 'VertexPointParser', 'ParticleGraphParser',
'SingleParticlePIDParser', 'SingleParticleEnergyParser']
@@ -40,6 +40,8 @@ class ParticleParser(ParserBase):
pixel_coordinates: True
post_process: True
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'particle'
def __init__(self, pixel_coordinates=True, post_process=True,
@@ -164,6 +166,8 @@ class NeutrinoParser(ParserBase):
pixel_coordinates: True
asis: False
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'neutrino'
def __init__(self, pixel_coordinates=True, asis=False, **kwargs):
@@ -257,6 +261,8 @@ class ParticlePointParser(ParserBase):
sparse_event: sparse3d_pcluster
include_point_tagging: True
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'particle_points'
def __init__(self, include_point_tagging=True, **kwargs):
@@ -316,9 +322,9 @@ def process(self, particle_event, sparse_event=None, cluster_event=None):
meta = ref_event.meta()
# Get the point labels
- particles_v = particle_event.as_vector()
+ particle_v = particle_event.as_vector()
point_labels = get_ppn_labels(
- particles_v, meta, self.ftype,
+ particle_v, meta, self.ftype,
include_point_tagging=self.include_point_tagging)
return point_labels[:, :3], point_labels[:, 3:], Meta.from_larcv(meta)
@@ -337,6 +343,8 @@ class ParticleCoordinateParser(ParserBase):
particle_event: particle_pcluster
sparse_event: sparse3d_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'particle_coords'
def __call__(self, trees):
@@ -381,11 +389,11 @@ def process(self, particle_event, sparse_event=None, cluster_event=None):
meta = ref_event.meta()
# Scale particle coordinates to image size
- particles_v = particle_event.as_vector()
+ particle_v = particle_event.as_vector()
# Make features
- features = np.empty((len(particles_v), 8), dtype=self.ftype)
- for i, p in enumerate(particles_v):
+ features = np.empty((len(particle_v), 8), dtype=self.ftype)
+ for i, p in enumerate(particle_v):
start_point = last_point = image_coordinates(meta, p.first_step())
if p.shape() == TRACK_SHP: # End point only meaningful for tracks
last_point = image_coordinates(meta, p.last_step())
@@ -395,6 +403,86 @@ def process(self, particle_event, sparse_event=None, cluster_event=None):
return features[:, :6], features[:, 6:], Meta.from_larcv(meta)
+class VertexPointParser(ParserBase):
+ """Class that retrieves the vertices.
+
+ It provides the coordinates of points where multiple particles originate:
+ - If the `neutrino_event` is provided, it simply uses the coordinates of
+ the neutrino interaction points.
+ - If the `particle_event` is provided instead, it looks for ancestor point
+ positions shared by at least two particles.
+
+ .. code-block. yaml
+
+ schema:
+ vertices:
+ parser: vertex_points
+ particle_event: particle_pcluster
+ #neutrino_event: neutrino_mpv
+ sparse_event: sparse3d_pcluster
+ include_point_tagging: True
+ """
+
+ # Name of the parser (as specified in the configuration)
+ name = 'vertex_points'
+
+ def __call__(self, trees):
+ """Parse one entry.
+
+ Parameters
+ ----------
+ trees : dict
+ Dictionary which maps each data product name to a LArCV object
+ """
+ return self.process(**self.get_input_data(trees))
+
+ def process(self, particle_event=None, neutrino_event=None,
+ sparse_event=None, cluster_event=None):
+ """Fetch the list of label vertex points.
+
+ Parameters
+ ----------
+ particle_event : larcv.EventParticle
+ Particle event which contains the list of true particles
+ neutrino_event : larcv.EventNeutrino
+ Neutrino event which contains the list of true neutrinos
+ sparse_event : larcv.EventSparseTensor3D, optional
+ Tensor which contains the metadata needed to convert the
+ positions in voxel coordinates
+ cluster_event : larcv.EventClusterVoxel3D, optional
+ Cluster which contains the metadata needed to convert the
+ positions in voxel coordinates
+
+ Returns
+ -------
+ np_voxels : np.ndarray
+ (N, 3) array of [x, y, z] coordinates
+ np_features : np.ndarray
+ (N, 1) array of [vertex ID]
+ meta : Meta
+ Metadata of the parsed image
+ """
+ # Check that only one source of vertex is provided
+ assert (particle_event is not None) ^ (neutrino_event is not None), (
+ "Must provide either `particle_event` or `sparse_event` to "
+ "get the vertex points, not both.")
+
+ # Fetch the metadata
+ assert (sparse_event is not None) ^ (cluster_event is not None), (
+ "Must provide either `sparse_event` or `cluster_event` to "
+ "get the metadata and convert positions to voxel units.")
+ ref_event = sparse_event if sparse_event is not None else cluster_event
+ meta = ref_event.meta()
+
+ # Get the vertex labels
+ particle_v = particle_event.as_vector() if particle_event else None
+ neutrino_v = neutrino_event.as_vector() if neutrino_event else None
+ point_labels = get_vertex_labels(
+ particle_v, neutrino_v, meta, self.ftype)
+
+ return point_labels[:, :3], point_labels[:, 3:], Meta.from_larcv(meta)
+
+
class ParticleGraphParser(ParserBase):
"""Class that uses larcv.EventParticle to construct edges
between particles (i.e. clusters).
@@ -408,6 +496,8 @@ class ParticleGraphParser(ParserBase):
cluster_event: cluster3d_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'particle_graph'
def __call__(self, trees):
@@ -518,6 +608,8 @@ class SingleParticlePIDParser(ParserBase):
parser: single_particle_pid
particle_event: particle_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'single_particle_pid'
def __call__(self, trees):
@@ -564,6 +656,8 @@ class SingleParticleEnergyParser(ParserBase):
parser: single_particle_energy
particle_event: particle_pcluster
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'single_particle_energy'
def __call__(self, trees):
diff --git a/spine/io/parse/sparse.py b/spine/io/parse/sparse.py
index c80a3446..e428e203 100644
--- a/spine/io/parse/sparse.py
+++ b/spine/io/parse/sparse.py
@@ -35,6 +35,8 @@ class Sparse2DParser(ParserBase):
- ...
projection_id: 0
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'parse_sparse2d'
def __init__(self, dtype, projection_id, sparse_event=None,
@@ -60,9 +62,9 @@ def __init__(self, dtype, projection_id, sparse_event=None,
# Get the number of features in the output tensor
assert (sparse_event is not None) ^ (sparse_event_list is not None), (
- "Must provide either `sparse_event` or `sparse_event_list`")
+ "Must provide either `sparse_event` or `sparse_event_list`.")
assert sparse_event_list is None or len(sparse_event_list), (
- "Must provide as least 1 sparse_event in the list")
+ "Must provide as least 1 sparse_event in the list.")
self.num_features = 1
if sparse_event_list is not None:
@@ -116,9 +118,9 @@ def process(self, sparse_event=None, sparse_event_list=None):
larcv.fill_2d_voxels(tensor, np_voxels)
else:
assert meta == tensor.meta(), (
- "The metadata must match between tensors")
+ "The metadata must match between tensors.")
assert num_points == tensor.as_vector().size(), (
- "The number of pixels must match between tensors")
+ "The number of pixels must match between tensors.")
# Get the feature vector for this tensor
np_data = np.empty((num_points, 1), dtype=self.ftype)
@@ -141,6 +143,8 @@ class Sparse3DParser(ParserBase):
- sparse3d_pcluster_1
- ...
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'sparse3d'
def __init__(self, dtype, sparse_event=None, sparse_event_list=None,
@@ -190,9 +194,9 @@ def __init__(self, dtype, sparse_event=None, sparse_event_list=None,
# Get the number of features in the output tensor
assert (sparse_event is not None) ^ (sparse_event_list is not None), (
- "Must provide either `sparse_event` or `sparse_event_list`")
+ "Must provide either `sparse_event` or `sparse_event_list`.")
assert sparse_event_list is None or len(sparse_event_list), (
- "Must provide as least 1 sparse_event in the list")
+ "Must provide as least 1 sparse_event in the list.")
num_tensors = 1 if sparse_event is not None else len(sparse_event_list)
if self.num_features is not None:
@@ -256,7 +260,7 @@ def process(self, sparse_event=None, sparse_event_list=None):
meta = sparse_event.meta()
else:
assert meta == sparse_event.meta(), (
- "The metadata must match between tensors")
+ "The metadata must match between tensors.")
if num_points is None:
num_points = sparse_event.as_vector().size()
@@ -265,7 +269,7 @@ def process(self, sparse_event=None, sparse_event_list=None):
larcv.fill_3d_voxels(sparse_event, np_voxels)
else:
assert num_points == sparse_event.as_vector().size(), (
- "The number of pixels must match between tensors")
+ "The number of pixels must match between tensors.")
# Get the feature vector for this tensor
np_data = np.empty((num_points, 1), dtype=self.ftype)
@@ -309,8 +313,9 @@ class Sparse3DGhostParser(Sparse3DParser):
parser: sparse3d
sparse_event_semantics: sparse3d_semantics
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'sparse3d_ghost'
- aliases = []
def __call__(self, trees):
"""Parse one entry.
@@ -356,8 +361,12 @@ class Sparse3DChargeRescaledParser(Sparse3DParser):
parser: sparse3d_charge_rescaled
sparse_event_semantics: sparse3d_semantics
"""
+
+ # Name of the parser (as specified in the configuration)
name = 'parse_sparse3d_rescale_charge'
- aliases = ['parse_sparse3d_charge_rescaled']
+
+ # Alternative allowed names of the parser
+ aliases = ('parse_sparse3d_charge_rescaled',)
def __init__(self, collection_only=False, collection_id=2, **kwargs):
"""Initialize the parser.
diff --git a/spine/io/read/base.py b/spine/io/read/base.py
index f0209007..25a3c71b 100644
--- a/spine/io/read/base.py
+++ b/spine/io/read/base.py
@@ -119,14 +119,10 @@ def process_file_paths(self, file_keys, limit_num_files=None,
f"File key {file_key} yielded no compatible path.")
for path in file_paths:
if (limit_num_files is not None and
- len(self.file_paths) > limit_num_files):
+ len(self.file_paths) >= limit_num_files):
break
self.file_paths.append(path)
- if (limit_num_files is not None and
- len(self.file_paths) >= limit_num_files):
- break
-
self.file_paths = sorted(self.file_paths)
# Print out the list of loaded files
diff --git a/spine/io/write/hdf5.py b/spine/io/write/hdf5.py
index c83c1749..4ed92bda 100644
--- a/spine/io/write/hdf5.py
+++ b/spine/io/write/hdf5.py
@@ -38,12 +38,13 @@ class HDF5Writer:
name = 'hdf5'
def __init__(self, file_name=None, keys=None, skip_keys=None, dummy_ds=None,
- overwrite=False, append=False, prefix=None, split=False):
+ overwrite=False, append=False, prefix=None, split=False,
+ lite=False):
"""Initializes the basics of the output file.
Parameters
----------
- file_name : str, default 'spine.h5'
+ file_name : str, optional
Name of the output HDF5 file
keys : List[str], optional
List of data product keys to store. If not specified, store everything
@@ -61,6 +62,8 @@ def __init__(self, file_name=None, keys=None, skip_keys=None, dummy_ds=None,
provided that no file_name is explicitely provided
split : bool, default False
If `True`, split the output to produce one file per input file
+ lite : bool, default False
+ If `True`, the lite version of objects is stored (drop point indexes)
"""
# If the output file name is not provided, use the input file prefix(es)
if not file_name:
@@ -98,6 +101,7 @@ def __init__(self, file_name=None, keys=None, skip_keys=None, dummy_ds=None,
self.file_name = file_name
self.append = append
self.split = split
+ self.lite = lite
self.ready = False
self.object_dtypes = [] # TODO: make this a set
@@ -335,7 +339,7 @@ def get_object_dtype(self, obj):
List of (key, dtype) pairs
"""
object_dtype = []
- for key, val in obj.as_dict().items():
+ for key, val in obj.as_dict(self.lite).items():
# Append the relevant data type
if isinstance(val, str):
# String
@@ -344,7 +348,7 @@ def get_object_dtype(self, obj):
elif hasattr(obj, 'enum_attrs') and key in obj.enum_attrs:
# Recognized enumerated list
enum_dtype = h5py.enum_dtype(
- obj.enum_attrs[key], basetype=type(val))
+ dict(obj.enum_attrs[key]), basetype=type(val))
object_dtype.append((key, enum_dtype))
elif np.isscalar(val):
@@ -529,7 +533,7 @@ def append_key(self, out_file, event, data, key, batch_id):
array = [array]
if val.dtype in self.object_dtypes:
- self.store_objects(out_file, event, key, array, val.dtype)
+ self.store_objects(out_file, event, key, array, val.dtype, self.lite)
else:
self.store(out_file, event, key, array)
@@ -647,7 +651,7 @@ def store_flat(out_file, event, key, array_list):
event[key] = region_ref
@staticmethod
- def store_objects(out_file, event, key, array, obj_dtype):
+ def store_objects(out_file, event, key, array, obj_dtype, lite):
"""Stores a list of objects with understandable attributes in the file
and stores its mapping in the event dataset.
@@ -663,11 +667,13 @@ def store_objects(out_file, event, key, array, obj_dtype):
Array of objects or dictionaries to be stored
obj_dtype : list
List of (key, dtype) pairs which specify what's to store
+ lite : bool
+ If `True`, store the lite version of objects
"""
# Convert list of objects to list of storable objects
objects = np.empty(len(array), obj_dtype)
for i, obj in enumerate(array):
- objects[i] = tuple(obj.as_dict().values())
+ objects[i] = tuple(obj.as_dict(lite).values())
# Extend the dataset, store array
dataset = out_file[key]
diff --git a/spine/model/layer/gnn/encode/cnn.py b/spine/model/layer/gnn/encode/cnn.py
index b41628fb..6f162df8 100644
--- a/spine/model/layer/gnn/encode/cnn.py
+++ b/spine/model/layer/gnn/encode/cnn.py
@@ -13,6 +13,8 @@
class ClustCNNNodeEncoder(torch.nn.Module):
"""Produces cluster node features using a sparse residual CNN encoder."""
+
+ # Name of the node encoder (as specified in the configuration)
name = 'cnn'
def __init__(self, **cfg):
@@ -68,6 +70,8 @@ class ClustCNNEdgeEncoder(torch.nn.Module):
Considers an edge as an image containing both ojbects connected by
the edge in a single image.
"""
+
+ # Name of the edge encoder (as specified in the configuration)
name = 'cnn'
def __init__(self, **cfg):
@@ -143,6 +147,8 @@ class ClustCNNGlobalEncoder(torch.nn.Module):
Considers the whole graph as an image containing all objects in it.
"""
+
+ # Name of the global encoder (as specified in the configuration)
name = 'cnn'
def __init__(self, **cfg):
diff --git a/spine/model/layer/gnn/encode/empty.py b/spine/model/layer/gnn/encode/empty.py
index ab810d2e..af7fba47 100644
--- a/spine/model/layer/gnn/encode/empty.py
+++ b/spine/model/layer/gnn/encode/empty.py
@@ -10,6 +10,8 @@
class EmptyClusterNodeEncoder(torch.nn.Module):
"""Produces empty cluster node features."""
+
+ # Name of the node encoder (as specified in the configuration)
name = 'empty'
def forward(self, data, clusts, **kwargs):
@@ -38,6 +40,8 @@ def forward(self, data, clusts, **kwargs):
class EmptyClusterEdgeEncoder(torch.nn.Module):
"""Produces empty cluster edge features."""
+
+ # Name of the edge encoder (as specified in the configuration)
name = 'empty'
def forward(self, data, clusts, edge_index, **kwargs):
@@ -68,6 +72,8 @@ def forward(self, data, clusts, edge_index, **kwargs):
class EmptyClusterGlobalEncoder(torch.nn.Module):
"""Produces empty global graph features."""
+
+ # Name of the global encoder (as specified in the configuration)
name = 'empty'
def forward(self, data, clusts, **kwargs):
diff --git a/spine/model/layer/gnn/encode/geometric.py b/spine/model/layer/gnn/encode/geometric.py
index f621f117..7ec03a72 100644
--- a/spine/model/layer/gnn/encode/geometric.py
+++ b/spine/model/layer/gnn/encode/geometric.py
@@ -41,8 +41,12 @@ class ClustGeoNodeEncoder(torch.nn.Module):
- Start dEdx (1)
- End dEdx (1)
"""
+
+ # Name of the node encoder (as specified in the configuration)
name = 'geometric'
- aliases = ['geo']
+
+ # Alternative allowed names of the node encoder
+ aliases = ('geo',)
def __init__(self, use_numpy=True, add_value=False, add_shape=False,
add_points=False, add_local_dirs=False, dir_max_dist=5.,
@@ -296,8 +300,12 @@ class ClustGeoEdgeEncoder(torch.nn.Module):
- Length of the displacement vector (1)
- Outer product of the displacement vector (9)
"""
+
+ # Name of the edge encoder (as specified in the configuration)
name = 'geometric'
- aliases = ['geo']
+
+ # Alternative allowed names of the edge encoder
+ aliases = ('geo',)
def __init__(self, use_numpy=True):
"""Initializes the geometric-based node encoder.
diff --git a/spine/model/layer/gnn/encode/mixed.py b/spine/model/layer/gnn/encode/mixed.py
index 64f4ad9d..9c87c901 100644
--- a/spine/model/layer/gnn/encode/mixed.py
+++ b/spine/model/layer/gnn/encode/mixed.py
@@ -16,6 +16,8 @@
class ClustGeoCNNMixNodeEncoder(torch.nn.Module):
"""Produces cluster node features using both geometric and CNN encoders."""
+
+ # Name of the node encoder (as specified in the configuration)
name = 'geo_cnn_mix'
def __init__(self, geo_encoder, cnn_encoder, activation='elu'):
@@ -79,6 +81,8 @@ def forward(self, data, clusts, **kwargs):
class ClustGeoCNNMixEdgeEncoder(torch.nn.Module):
"""Produces cluster edge features using both geometric and CNN encoders."""
+
+ # Name of the edge encoder (as specified in the configuration)
name = 'geo_cnn_mix'
def __init__(self, geo_encoder, cnn_encoder):
diff --git a/spine/model/layer/gnn/graph/base.py b/spine/model/layer/gnn/graph/base.py
index 6a3045f9..265bac5c 100644
--- a/spine/model/layer/gnn/graph/base.py
+++ b/spine/model/layer/gnn/graph/base.py
@@ -14,6 +14,9 @@
class GraphBase:
"""Parent class for all graph constructors."""
+ # Name of the graph constructor (as specified in the configuration)
+ name = None
+
def __init__(self, directed=False, max_length=None, classes=None,
max_count=None, dist_method='voxel', dist_algorithm='brute'):
"""Initializes attributes shared accross all graph constructors.
@@ -65,7 +68,7 @@ def __init__(self, directed=False, max_length=None, classes=None,
# Store whether the inter-cluster distance matrix must be evaluated
self.compute_dist = (max_length is not None or
- self.name in ['mst', 'knn'])
+ self.name in ('mst', 'knn'))
# If this is a loop graph, simply set as undirected
assert self.name != 'loop' or self.directed, (
diff --git a/spine/model/layer/gnn/graph/bipartite.py b/spine/model/layer/gnn/graph/bipartite.py
index c4ad199f..24c38218 100644
--- a/spine/model/layer/gnn/graph/bipartite.py
+++ b/spine/model/layer/gnn/graph/bipartite.py
@@ -17,6 +17,8 @@ class BipartiteGraph(GraphBase):
See :class:`GraphBase` for attributes/methods shared
across all graph constructors.
"""
+
+ # Name of the graph constructor (as specified in the configuration)
name = 'bipartite'
def __init__(self, directed_to, **kwargs):
diff --git a/spine/model/layer/gnn/graph/complete.py b/spine/model/layer/gnn/graph/complete.py
index c79aa586..77afc052 100644
--- a/spine/model/layer/gnn/graph/complete.py
+++ b/spine/model/layer/gnn/graph/complete.py
@@ -16,6 +16,8 @@ class CompleteGraph(GraphBase):
See :class:`GraphBase` for attributes/methods shared
across all graph constructors.
"""
+
+ # Name of the graph constructor (as specified in the configuration)
name = 'complete'
def generate(self, clusts, **kwargs):
diff --git a/spine/model/layer/gnn/graph/delaunay.py b/spine/model/layer/gnn/graph/delaunay.py
index f5f37274..f9cfc37d 100644
--- a/spine/model/layer/gnn/graph/delaunay.py
+++ b/spine/model/layer/gnn/graph/delaunay.py
@@ -21,6 +21,8 @@ class DelaunayGraph(GraphBase):
See :class:`GraphBase` for attributes/methods shared
across all graph constructors.
"""
+
+ # Name of the graph constructor (as specified in the configuration)
name = 'delaunay'
def generate(self, data, clusts, **kwargs):
diff --git a/spine/model/layer/gnn/graph/knn.py b/spine/model/layer/gnn/graph/knn.py
index 1dee2e38..25dcc5cc 100644
--- a/spine/model/layer/gnn/graph/knn.py
+++ b/spine/model/layer/gnn/graph/knn.py
@@ -19,6 +19,8 @@ class KNNGraph(GraphBase):
See :class:`GraphBase` for attributes/methods shared
across all graph constructors.
"""
+
+ # Name of the graph constructor (as specified in the configuration)
name = 'knn'
def __init__(self, k, **kwargs):
diff --git a/spine/model/layer/gnn/graph/loop.py b/spine/model/layer/gnn/graph/loop.py
index 02ef26e1..e3865ff3 100644
--- a/spine/model/layer/gnn/graph/loop.py
+++ b/spine/model/layer/gnn/graph/loop.py
@@ -15,6 +15,8 @@ class LoopGraph(GraphBase):
See :class:`GraphBase` for attributes/methods shared
across all graph constructors.
"""
+
+ # Name of the graph constructor (as specified in the configuration)
name = 'loop'
def generate(self, clusts, **kwargs):
diff --git a/spine/model/layer/gnn/graph/mst.py b/spine/model/layer/gnn/graph/mst.py
index a39879a3..65d8479f 100644
--- a/spine/model/layer/gnn/graph/mst.py
+++ b/spine/model/layer/gnn/graph/mst.py
@@ -22,6 +22,8 @@ class MSTGraph(GraphBase):
See :class:`GraphBase` for attributes/methods shared
across all graph constructors.
"""
+
+ # Name of the graph constructor (as specified in the configuration)
name = 'mst'
def generate(self, clusts, dist_mat, **kwargs):
diff --git a/spine/model/layer/gnn/loss/edge_channel.py b/spine/model/layer/gnn/loss/edge_channel.py
index 6c603c10..86e4ae8b 100644
--- a/spine/model/layer/gnn/loss/edge_channel.py
+++ b/spine/model/layer/gnn/loss/edge_channel.py
@@ -36,6 +36,8 @@ class EdgeChannelLoss(torch.nn.Module):
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
+
+ # Name of the GNN loss (as specified in the configuration)
name = 'channel'
def __init__(self, target, mode='group', loss='ce', balance_loss=False,
diff --git a/spine/model/layer/gnn/loss/node_class.py b/spine/model/layer/gnn/loss/node_class.py
index 0f26140c..71b47460 100644
--- a/spine/model/layer/gnn/loss/node_class.py
+++ b/spine/model/layer/gnn/loss/node_class.py
@@ -34,8 +34,12 @@ class NodeClassLoss(torch.nn.Module):
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
+
+ # Name of the loss (as specified in the configuration)
name = 'class'
- aliases = ['classification']
+
+ # Alternative allowed names of the loss
+ aliases = ('classification',)
def __init__(self, target, loss='ce', balance_loss=False, weights=None):
"""Initialize the node classifcation loss function.
diff --git a/spine/model/layer/gnn/loss/node_orient.py b/spine/model/layer/gnn/loss/node_orient.py
index ed8daa79..698c9aca 100644
--- a/spine/model/layer/gnn/loss/node_orient.py
+++ b/spine/model/layer/gnn/loss/node_orient.py
@@ -35,8 +35,12 @@ class NodeOrientLoss(torch.nn.Module):
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
+
+ # Name of the loss (as specified in the configuration)
name = 'orient'
- aliases = ['orientation']
+
+ # Alternative allowed names of the loss
+ aliases = ('orientation',)
def __init__(self, loss='ce'):
"""Initialize the node orientation loss function.
diff --git a/spine/model/layer/gnn/loss/node_reg.py b/spine/model/layer/gnn/loss/node_reg.py
index d3821a7b..dda9a4cb 100644
--- a/spine/model/layer/gnn/loss/node_reg.py
+++ b/spine/model/layer/gnn/loss/node_reg.py
@@ -33,8 +33,12 @@ class NodeRegressionLoss(torch.nn.Module):
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
+
+ # Name of the loss (as specified in the configuration)
name = 'reg'
- aliases = ['regression']
+
+ # Alternative allowed names of the loss
+ aliases = ('regression',)
def __init__(self, target, loss='mse'):
"""Initialize the node regression loss function.
diff --git a/spine/model/layer/gnn/loss/node_shower_primary.py b/spine/model/layer/gnn/loss/node_shower_primary.py
index e9368c4e..9da47777 100644
--- a/spine/model/layer/gnn/loss/node_shower_primary.py
+++ b/spine/model/layer/gnn/loss/node_shower_primary.py
@@ -35,6 +35,8 @@ class NodeShowerPrimaryLoss(torch.nn.Module):
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
+
+ # Name of the loss (as specified in the configuration)
name = 'shower_primary'
def __init__(self, balance_loss=False, high_purity=False,
diff --git a/spine/model/layer/gnn/loss/node_vertex.py b/spine/model/layer/gnn/loss/node_vertex.py
index e58df7ba..0f510bc7 100644
--- a/spine/model/layer/gnn/loss/node_vertex.py
+++ b/spine/model/layer/gnn/loss/node_vertex.py
@@ -45,12 +45,14 @@ class NodeVertexLoss(torch.nn.Module):
See configuration files prefixed with `grappa_` under the `config`
directory for detailed examples of working configurations.
"""
+
+ # Name of the loss (as specified in the configuration)
name = 'vertex'
def __init__(self, balance_primary_loss=False, primary_loss='ce',
regression_loss='mse', only_contained=True,
normalize_positions=False, use_anchor_points=False,
- return_vertex_labels=False, detector=None, boundaries=None):
+ return_vertex_labels=False, detector=None, geometry_file=None):
"""Initialize the vertex regression loss function.
Parameters
@@ -71,8 +73,8 @@ def __init__(self, balance_primary_loss=False, primary_loss='ce',
If `True`, return the list vertex labels (one per particle)
detector : str, optional
Name of a recognized detector to the geometry from
- boundaries : str, optional
- Path to a `.npy` boundary file to load the boundaries from
+ geometry_file : str, optional
+ Path to a `.yaml` geometry file to load the geometry from
"""
# Initialize the parent class
super().__init__()
@@ -95,7 +97,7 @@ def __init__(self, balance_primary_loss=False, primary_loss='ce',
# If containment is requested, intialize geometry
if self.only_contained:
- self.geo = Geometry(detector, boundaries)
+ self.geo = Geometry(detector, geometry_file)
self.geo.define_containment_volumes(margin=0., mode='module')
def forward(self, clust_label, clusts, node_pred, meta=None,
diff --git a/spine/model/layer/gnn/model/layer/agnnconv.py b/spine/model/layer/gnn/model/layer/agnnconv.py
index 406e1670..0968c2a0 100644
--- a/spine/model/layer/gnn/model/layer/agnnconv.py
+++ b/spine/model/layer/gnn/model/layer/agnnconv.py
@@ -18,6 +18,8 @@ class AGNNConvNodeLayer(nn.Module):
Source: https://arxiv.org/abs/1803.03735
"""
+
+ # Name of the node layer (as specified in the configuration)
name = 'agnnconv'
def __init__(self, node_in, edge_in, glob_in,
diff --git a/spine/model/layer/gnn/model/layer/econv.py b/spine/model/layer/gnn/model/layer/econv.py
index 3c24fbae..02d455c2 100644
--- a/spine/model/layer/gnn/model/layer/econv.py
+++ b/spine/model/layer/gnn/model/layer/econv.py
@@ -21,6 +21,8 @@ class EConvNodeLayer(nn.Module):
Source: https://arxiv.org/abs/1801.07829
"""
+
+ # Name of the node layer (as specified in the configuration)
name = 'econv'
def __init__(self, node_in, edge_in, glob_in, mlp, aggr='max', **kwargs):
diff --git a/spine/model/layer/gnn/model/layer/gatconv.py b/spine/model/layer/gnn/model/layer/gatconv.py
index cef25da8..b9f9366b 100644
--- a/spine/model/layer/gnn/model/layer/gatconv.py
+++ b/spine/model/layer/gnn/model/layer/gatconv.py
@@ -18,6 +18,8 @@ class GATConvNodeLayer(nn.Module):
Source: https://arxiv.org/abs/1710.10903
"""
+
+ # Name of the node layer (as specified in the configuration)
name = 'gatconv'
def __init__(self, node_in, edge_in, glob_in, out_channels,
diff --git a/spine/model/layer/gnn/model/layer/mlp.py b/spine/model/layer/gnn/model/layer/mlp.py
index 8c801ff9..57fa6139 100644
--- a/spine/model/layer/gnn/model/layer/mlp.py
+++ b/spine/model/layer/gnn/model/layer/mlp.py
@@ -21,6 +21,8 @@ class MLPEdgeLayer(nn.Module):
multi-layer perceptron (MLP) and outputs an (E, N_o) vector, with N_o the
width of the MLP (feature size of the hidden representation).
"""
+
+ # Name of the edge layer (as specified in the configuration)
name = 'mlp'
def __init__(self, node_in, edge_in, glob_in, mlp):
@@ -96,6 +98,8 @@ class MLPNodeLayer(nn.Module):
to form a (N_o + N_c + N_g) feature vector. This new vector is passed
through a second MLP to update the node features to (N_o').
"""
+
+ # Name of the node layer (as specified in the configuration)
name = 'mlp'
def __init__(self, node_in, edge_in, glob_in, message_mlp,
@@ -199,6 +203,8 @@ class MLPGlobalLayer(nn.Module):
perceptron (MLP) and outputs a (B, N_o) vector, with N_o the width of the
MLP (feature size of the hidden representation).
"""
+
+ # Name of the global layer (as specified in the configuration)
name = 'mlp'
def __init__(self, node_in, glob_in, mlp, reduction='mean'):
diff --git a/spine/model/layer/gnn/model/layer/nnconv.py b/spine/model/layer/gnn/model/layer/nnconv.py
index 06491240..83f692ff 100644
--- a/spine/model/layer/gnn/model/layer/nnconv.py
+++ b/spine/model/layer/gnn/model/layer/nnconv.py
@@ -22,6 +22,8 @@ class NNConvNodeLayer(nn.Module):
Source: https://arxiv.org/abs/1704.02901
"""
+
+ # Name of the node layer (as specified in the configuration)
name = 'nnconv'
def __init__(self, node_in, edge_in, glob_in, out_channels,
diff --git a/spine/model/layer/gnn/model/meta.py b/spine/model/layer/gnn/model/meta.py
index 2a0ced1a..85774f60 100644
--- a/spine/model/layer/gnn/model/meta.py
+++ b/spine/model/layer/gnn/model/meta.py
@@ -15,6 +15,8 @@
class MetaLayerGNN(nn.Module):
"""Completely generic message-passing GNN."""
+
+ # Name of the model (as specified in the configuration)
name = 'meta'
def __init__(self, node_feats=0, node_layer=None, edge_feats=0,
diff --git a/spine/model/uresnet.py b/spine/model/uresnet.py
index 4c6d69c9..466dbb9f 100644
--- a/spine/model/uresnet.py
+++ b/spine/model/uresnet.py
@@ -8,8 +8,11 @@
import MinkowskiEngine as ME
from spine.data import TensorBatch
-from spine.utils.globals import BATCH_COL, VALUE_COL, GHOST_SHP
+from spine.utils.globals import BATCH_COL, COORD_COLS, VALUE_COL, GHOST_SHP
from spine.utils.logger import logger
+from spine.utils.torch_local import local_cdist
+
+from .layer.factories import loss_fn_factory
from .layer.cnn.act_norm import act_factory, norm_factory
from .layer.cnn.uresnet_layers import UResNet
@@ -19,7 +22,7 @@
class UResNetSegmentation(nn.Module):
"""UResNet for semantic segmentation.
-
+
Typical configuration should look like:
.. code-block:: yaml
@@ -103,7 +106,7 @@ def forward(self, data):
- 1 is the batch ID
- D is the number of dimensions in the input image
- N_f is the number of features per voxel
-
+
Returns
-------
dict
@@ -195,10 +198,6 @@ def __init__(self, uresnet, uresnet_loss):
# Initialize the loss configuration
self.process_loss_config(**uresnet_loss)
- # Initialize the cross-entropy loss
- # TODO: Make it configurable
- self.xe = torch.nn.functional.cross_entropy
-
def process_model_config(self, num_classes, ghost=False, **kwargs):
"""Process the parameters of the upstream model needed for in the loss.
@@ -215,12 +214,15 @@ def process_model_config(self, num_classes, ghost=False, **kwargs):
self.num_classes = num_classes
self.ghost = ghost
- def process_loss_config(self, ghost_label=-1, alpha=1.0, beta=1.0,
- balance_loss=False):
+ def process_loss_config(self, loss='ce', ghost_label=-1, alpha=1.0,
+ beta=1.0, balance_loss=False,
+ upweight_points=False, upweight_radius=20):
"""Process the loss function parameters.
Parameters
----------
+ loss : str, default 'ce'
+ Loss function used for semantic segmentation
ghost_label : int, default -1
ID of ghost points. If specified (> -1), classify ghosts only
alpha : float, default 1.0
@@ -229,12 +231,22 @@ def process_loss_config(self, ghost_label=-1, alpha=1.0, beta=1.0,
Ghost mask loss prefactor
balance_loss : bool, default False
Whether to weight the loss to account for class imbalance
+ upweight_points : bool, default False
+ Whether to weight the loss higher near specific points (to be
+ provided as `point_label` as a loss input)
+ upweight_radius: bool, default False
+ Radius around the points of interest for which to upweight the loss
"""
+ # Set the loss function
+ self.loss_fn = loss_fn_factory(loss, reduction='none')
+
# Store the loss configuration
- self.ghost_label = ghost_label
- self.alpha = alpha
- self.beta = beta
- self.balance_loss = balance_loss
+ self.ghost_label = ghost_label
+ self.alpha = alpha
+ self.beta = beta
+ self.balance_loss = balance_loss
+ self.upweight_points = upweight_points
+ self.upweight_radius = upweight_radius
# If a ghost label is provided, it cannot be in conjecture with
# having a dedicated ghost masking layer
@@ -242,10 +254,10 @@ def process_loss_config(self, ghost_label=-1, alpha=1.0, beta=1.0,
"Cannot classify ghost exclusively (ghost_label) and "
"have a dedicated ghost masking layer at the same time.")
- def forward(self, seg_label, segmentation, ghost=None,
+ def forward(self, seg_label, segmentation, point_label=None, ghost=None,
weights=None, **kwargs):
"""Computes the cross-entropy loss of the semantic segmentation
- predictions.
+ predictions.
Parameters
----------
@@ -253,9 +265,12 @@ def forward(self, seg_label, segmentation, ghost=None,
(N, 1 + D + 1) Tensor of segmentation labels for the batch
segmentation : TensorBatch
(N, N_c) Tensor of logits from the segmentation model
- ghost : TensorBatch
+ point_label : TensorBatch, optional
+ (P, 1 + D + 1) Tensor of points of interests for the batch. This
+ is used to upweight the loss near specific points.
+ ghost : TensorBatch, optional
(N, 2) Tensor of ghost logits from the segmentation model
- weights : torch.Tensor, optional
+ weights : TensorBatch, optional
(N) Tensor of weights for each pixel in the batch
**kwargs : dict, optional
Other outputs of the upstream model which are not relevant here
@@ -266,27 +281,40 @@ def forward(self, seg_label, segmentation, ghost=None,
Dictionary of accuracies and losses
"""
# Get the underlying tensor in each TensorBatch
- seg_label = seg_label.tensor
- segmentation = segmentation.tensor
+ seg_label_t = seg_label.tensor
+ segmentation_t = segmentation.tensor
+ ghost_t = ghost.tensor if ghost is not None else ghost
+ weights_t = weights.tensor if weights is not None else weights
# Make sure that the segmentation output and labels have the same length
- assert len(seg_label) == len(segmentation), (
- f"The `segmentation` output length ({len(segmentation)}) and "
- f"its labels ({len(seg_label)}) do not match.")
- assert not self.ghost or len(seg_label) == len(ghost), (
- f"The `ghost` output length ({len(ghost)}) and "
- f"its labels ({len(seg_label)}) do not match.")
-
- # If the loss is to be class-weighted, cannot also provide weights
- assert not self.balance_loss or weights is None, (
- "If weights are provided, cannot also class-weight loss.")
+ assert len(seg_label_t) == len(segmentation_t), (
+ f"The `segmentation` output length ({len(segmentation_t)}) "
+ f"and its labels ({len(seg_label_t)}) do not match.")
+ assert not self.ghost or len(seg_label_t) == len(ghost_t), (
+ f"The `ghost` output length ({len(ghost_t)}) and "
+ f"its labels ({len(seg_label_t)}) do not match.")
+ assert not self.ghost or weights is None, (
+ "Providing explicit weights is not compatible when peforming "
+ "deghosting in tandem with semantic segmentation.")
+
+ # If requested, produce weights based on point-proximity
+ if self.upweight_points:
+ assert point_label is not None, (
+ "If upweighting the loss nearby points of interests, must "
+ "provide a list of such points in `point_label`.")
+ dist_weights = self.get_distance_weights(seg_label, point_label)
+ if weights is not None:
+ weights_t *= dist_weights.tensor
+ else:
+ weights_t = dist_weights
# Check that the labels have sensible values
if self.ghost_label > -1:
- labels = (seg_label[:, VALUE_COL] == self.ghost_label).long()
+ labels_t = (seg_label_t[:, VALUE_COL] == self.ghost_label).long()
+
else:
- labels = seg_label[:, VALUE_COL].long()
- if torch.any(labels > self.num_classes):
+ labels_t = seg_label_t[:, VALUE_COL].long()
+ if torch.any(labels_t > self.num_classes):
raise ValueError(
"The segmentation labels contain labels larger than "
"the number of logits output by the model.")
@@ -294,18 +322,18 @@ def forward(self, seg_label, segmentation, ghost=None,
# If there is a dedicated ghost layer, apply the mask first
if self.ghost:
# Count the number of voxels in each class
- ghost_labels = (labels == GHOST_SHP).long()
- ghost_loss, ghost_acc, ghost_acc_class = self.loss_accuracy(
- ghost, ghost_labels)
+ ghost_labels_t = (labels_t == GHOST_SHP).long()
+ ghost_loss, ghost_acc, ghost_acc_class = self.get_loss_accuracy(
+ ghost_t, ghost_labels_t)
# Restrict the segmentation target to true non-ghosts
- nonghost = torch.nonzero(ghost_labels == 0).flatten()
- segmentation = segmentation[nonghost]
- labels = labels[nonghost]
+ nonghost = torch.nonzero(ghost_labels_t == 0).flatten()
+ segmentation_t = segmentation_t[nonghost]
+ labels_t = labels_t[nonghost]
# Compute the loss/accuracy of the semantic segmentation step
- seg_loss, seg_acc, seg_acc_class = self.loss_accuracy(
- segmentation, labels, weights)
+ seg_loss, seg_acc, seg_acc_class, weights_t = self.get_loss_accuracy(
+ segmentation_t, labels_t, weights_t)
# Get the combined loss and accuracies
result = {}
@@ -331,9 +359,56 @@ def forward(self, seg_label, segmentation, ghost=None,
for c in range(self.num_classes):
result[f'accuracy_class_{c}'] = seg_acc_class[c]
+ if weights_t is not None:
+ result['weights'] = TensorBatch(weights_t, seg_label.counts)
+
return result
- def loss_accuracy(self, logits, labels, weights=None):
+ def get_distance_weights(self, seg_label, point_label):
+ """Define weights for each of the points in the image based on their
+ distance from points of interests (typically vertices, but user defined).
+
+ Parameters
+ ----------
+ seg_label : TensorBatch
+ (N, 1 + D + 1) Tensor of segmentation labels for the batch
+ point_label : TensorBatch
+ (P, 1 + D + 1) Tensor of points of interests for the batch. This
+ is used to upweight the loss of points near a vertex.
+
+ Returns
+ -------
+ torch.Tensor
+ (N) Array of weights associated with each point
+ """
+ # Loop over the entries in the batch, compute proximity for each point
+ dists = torch.full_like(seg_label.tensor[:, 0], float('inf'))
+ for b in range(seg_label.batch_size):
+ # Fetch image voxel and point coordinates for this entry
+ voxels_b = seg_label[b][:, COORD_COLS]
+ points_b = point_label[b][:, COORD_COLS]
+ if not len(points_b) or not len(voxels_b):
+ continue
+
+ # Compute the minimal distance to any point in this entry
+ dist_mat = local_cdist(voxels_b, points_b)
+ dists_b = torch.min(dist_mat, dim=1).values
+
+ # Record information in the batch-wise tensor
+ lower, upper = seg_label.edges[b], seg_label.edges[b+1]
+ dists[lower:upper] = dists_b
+
+ # Upweight the points within some distance of the points of interest
+ proximity = (dists < self.upweight_radius).long()
+ close_count = torch.sum(proximity)
+ counts = torch.tensor(
+ [len(dists) - close_count, close_count],
+ dtype=torch.long, device=dists.device)
+ weights = len(proximity)/2/counts
+
+ return weights[proximity]
+
+ def get_loss_accuracy(self, logits, labels, weights=None):
"""Computes the loss, global and classwise accuracy.
Parameters
@@ -353,10 +428,12 @@ def loss_accuracy(self, logits, labels, weights=None):
Global accuracy
np.ndarray
(N_c) Vector of class-wise accuracy
+ torch.Tensor
+ (N) Updated set of weights for each pixel in the batch
"""
# If there is no input, nothing to do
if not len(logits):
- return 0., 1., np.ones(num_classes, dtype=np.float32)
+ return 0., 1., np.ones(num_classes, dtype=np.float32), weights
# Count the number of voxels in each class
num_classes = logits.shape[1]
@@ -365,15 +442,22 @@ def loss_accuracy(self, logits, labels, weights=None):
for c in range(num_classes):
counts[c] = torch.sum(labels == c).item()
+ # If requested, create a set of weights based on class prevalance
+ if self.balance_loss:
+ class_weight = torch.ones(
+ len(counts), dtype=logits.dtype, device=logits.device)
+ class_weight[counts > 0] = len(labels)/num_classes/counts[counts > 0]
+ class_weights = class_weight[labels]
+ if weights is not None:
+ weights *= class_weights
+ else:
+ weights = class_weights
+
# Compute the loss
- if self.balance_loss and torch.all(counts):
- class_weight = len(labels)/num_classes/counts
- loss = self.xe(logits, labels, weight=class_weight)
+ if weights is None:
+ loss = self.loss_fn(logits, labels).mean()
else:
- if weights is None:
- loss = self.xe(logits, labels, reduction='mean')
- else:
- loss = (weights*self.xe(logits, labels, reduction='none')).sum()
+ loss = (weights*self.loss_fn(logits, labels)).sum()/weights.sum()
# Compute the accuracies
with torch.no_grad():
@@ -388,4 +472,4 @@ def loss_accuracy(self, logits, labels, weights=None):
# Global prediction accuracy
acc = (preds == labels).sum().item() / torch.sum(counts).item()
- return loss, acc, acc_class
+ return loss, acc, acc_class, weights
diff --git a/spine/post/base.py b/spine/post/base.py
index 5b7c0aa2..97d0fdc9 100644
--- a/spine/post/base.py
+++ b/spine/post/base.py
@@ -17,24 +17,22 @@ class PostBase(ABC):
----------
name : str
Name of the post-processor as defined in the configuration file
- aliases : []
+ aliases : Tuple[str]
Alternative acceptable names for a post-processor
- parent_path : str
- Path to the main configuration file (to access relative configurations)
- keys : Dict[str, bool]
- List of data product keys used to operate the post-processor
- truth_point_mode : str
- Type of `points` attribute to use for the truth particles
- units : str
- Units in which the objects must be expressed (one of 'px' or 'cm')
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = None
+
+ # Alternative allowed names of the post-processor
aliases = ()
- parent_path = ''
- keys = None
- truth_point_mode = 'points'
+
+ # Units in which the post-processor expects objects to be expressed in
units = 'cm'
+ # Set of data keys needed for this post-processor to operate
+ _keys = ()
+
# List of recognized object types
_obj_types = ('fragment', 'particle', 'interaction')
@@ -42,27 +40,27 @@ class PostBase(ABC):
_run_modes = ('reco', 'truth', 'both', 'all')
# List of known point modes for true particles and their corresponding keys
- _point_modes = {
- 'points': 'points_label',
- 'points_adapt': 'points',
- 'points_g4': 'points_g4'
- }
+ _point_modes = (
+ ('points', 'points_label'),
+ ('points_adapt', 'points'),
+ ('points_g4', 'points_g4')
+ )
# List of known source modes for true particles and their corresponding keys
- _source_modes = {
- 'sources': 'sources_label',
- 'sources_adapt': 'sources',
- 'sources_g4': 'sources_g4'
- }
+ _source_modes = (
+ ('sources', 'sources_label'),
+ ('sources_adapt', 'sources'),
+ ('sources_g4', 'sources_g4')
+ )
# List of known deposition modes for true particles and their corresponding keys
- _dep_modes = {
- 'depositions': 'depositions_label',
- 'depositions_q': 'depositions_q_label',
- 'depositions_adapt': 'depositions_label_adapt',
- 'depositions_adapt_q': 'depositions',
- 'depositions_g4': 'depositions_g4'
- }
+ _dep_modes = (
+ ('depositions', 'depositions_label'),
+ ('depositions_q', 'depositions_q_label'),
+ ('depositions_adapt', 'depositions_label_adapt'),
+ ('depositions_adapt_q', 'depositions'),
+ ('depositions_g4', 'depositions_g4')
+ )
def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
truth_dep_mode=None, parent_path=None):
@@ -88,10 +86,6 @@ def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
Path to the parent directory of the main analysis configuration. This
allows for the use of relative paths in the post-processors.
"""
- # Initialize default keys
- if self.keys is None:
- self.keys = {}
-
# If run mode is specified, process it
if run_mode is not None:
# Check that the run mode is recognized
@@ -126,35 +120,97 @@ def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
+ self.particle_keys
+ self.interaction_keys)
- self.keys.update({k:True for k in self.obj_keys})
+ # Update underlying keys, if needed
+ self.update_keys({k: True for k in self.obj_keys})
# If a truth point mode is specified, store it
if truth_point_mode is not None:
- assert truth_point_mode in self._point_modes, (
+ assert truth_point_mode in self.point_modes, (
"The `truth_point_mode` argument must be one of "
- f"{self._point_modes.keys()}. Got `{truth_point_mode}` instead.")
+ f"{self.point_modes.keys()}. Got `{truth_point_mode}` instead.")
self.truth_point_mode = truth_point_mode
- self.truth_point_key = self._point_modes[self.truth_point_mode]
+ self.truth_point_key = self.point_modes[self.truth_point_mode]
self.truth_source_mode = truth_point_mode.replace('points', 'sources')
- self.truth_source_key = self._source_modes[self.truth_source_mode]
+ self.truth_source_key = self.source_modes[self.truth_source_mode]
self.truth_index_mode = truth_point_mode.replace('points', 'index')
# If a truth deposition mode is specified, store it
if truth_dep_mode is not None:
- assert truth_dep_mode in self._dep_modes, (
+ assert truth_dep_mode in self.dep_modes, (
"The `truth_dep_mode` argument must be one of "
- f"{self._dep_modes.keys()}. Got `{truth_dep_mode}` instead.")
+ f"{self.dep_modes.keys()}. Got `{truth_dep_mode}` instead.")
if truth_point_mode is not None:
prefix = truth_point_mode.replace('points', 'depositions')
assert truth_dep_mode.startswith(prefix), (
"Points mode {truth_point_mode} and deposition mode "
"{truth_dep_mode} are incompatible.")
self.truth_dep_mode = truth_dep_mode
- self.truth_dep_key = self._dep_modes[truth_dep_mode]
+ self.truth_dep_key = self.dep_modes[truth_dep_mode]
# Store the parent path
self.parent_path = parent_path
+ @property
+ def keys(self):
+ """Dictionary of (key, necessity) pairs which determine which data keys
+ are needed/optional for the post-processor to run.
+
+ Returns
+ -------
+ Dict[str, bool]
+ Dictionary of (key, necessity) pairs to be used
+ """
+ return dict(self._keys)
+
+ @property
+ def point_modes(self):
+ """Dictionary which makes the correspondance between the name of a true
+ object point attribute with the underlying point tensor it points to.
+
+ Returns
+ -------
+ Dict[str, str]
+ Dictionary of (attribute, key) mapping for point coordinates
+ """
+ return dict(self._point_modes)
+
+ @property
+ def source_modes(self):
+ """Dictionary which makes the correspondance between the name of a true
+ object source attribute with the underlying source tensor it points to.
+
+ Returns
+ -------
+ Dict[str, str]
+ Dictionary of (attribute, key) mapping for point sources
+ """
+ return dict(self._source_modes)
+
+ @property
+ def dep_modes(self):
+ """Dictionary which makes the correspondance between the name of a true
+ object deposition attribute with the underlying deposition array it points to.
+
+ Returns
+ -------
+ Dict[str, str]
+ Dictionary of (attribute, key) mapping for point depositions
+ """
+ return dict(self._dep_modes)
+
+ def update_keys(self, update_dict):
+ """Update the underlying set of keys and their necessity in place.
+
+ Parameters
+ ----------
+ update_dict : Dict[str, bool]
+ Dictionary of (key, necessity) pairs to update the keys with
+ """
+ if len(update_dict) > 0:
+ keys = self.keys
+ keys.update(update_dict)
+ self._keys = tuple(keys.items())
+
def __call__(self, data, entry=None):
"""Calls the post processor on one entry.
@@ -172,7 +228,7 @@ def __call__(self, data, entry=None):
"""
# Fetch the input dictionary
data_filter = {}
- for key, req in self.keys.items():
+ for key, req in self._keys:
# If this key is needed, check that it exists
assert not req or key in data, (
f"Post-processor `{self.name}` is missing an essential "
@@ -290,8 +346,8 @@ def check_units(self, obj):
"""
if obj.units != self.units:
raise ValueError(
- f"Coordinates must be expressed in "
- f"{self.units}; currently in {obj.units} instead.")
+ f"Coordinates must be expressed in {self.units} but are "
+ f"currently in {obj.units} instead.")
@abstractmethod
def process(self, data):
diff --git a/spine/post/crt/crt_matching.py b/spine/post/crt/crt_matching.py
index d2eb8bc4..f9192d1f 100644
--- a/spine/post/crt/crt_matching.py
+++ b/spine/post/crt/crt_matching.py
@@ -4,49 +4,48 @@
class CRTMatchProcessor(PostBase):
- '''
- Associates TPC interactions with optical flashes.
- '''
- name = 'run_crt_tpc_matching'
- data_cap = ['crthits']
- result_cap = ['interactions'] # TODO: Should be done at particle level
-
- def __init__(self,
- crthit_keys,
- **kwargs):
- '''
- Post processor for running CRT-TPC matching using matcha.
+ """Associates TPC particles with CRT hits.
+ """
+
+ # Name of the post-processor (as specified in the configuration)
+ name = 'crt_match'
+
+ # Alternative allowed names of the post-processor
+ aliases = ('run_crt_matching',)
+
+ def __init__(self, crthit_key, obj_type='particle', run_mode='reco'):
+ """Initialize the CRT/TPC matching post-processor.
Parameters
----------
- crthit_keys : List[str]
- List of keys that provide the CRT information in the data dictionary
+ crthit_key : str
+ Data product key which provides the CRT information
**kwargs : dict
Keyword arguments to pass to the CRT-TPC matching algorithm
- '''
+ """
+ # Initialize the parent class
+ super().__init__(obj_type, run_mode)
+
# Store the relevant attributes
- self.crthit_keys = crthit_keys
+ self.crthit_key = crthit_key
def process(self, data_dict, result_dict):
- '''
- Find [interaction, flash] pairs
-
+ """Find particle/CRT matches for one entry.
Parameters
+
----------
- data_dict : dict
- Input data dictionary
- result_dict : dict
- Chain output dictionary
+ data : dict
+ Dictionary of data products
Notes
-----
This post-processor also modifies the list of Interactions
in-place by adding the following attributes:
- interaction.crthit_matched: (bool)
- Indicator for whether the given interaction has a CRT-TPC match
- interaction.crthit_id: (list of ints)
- List of IDs for CRT hits that were matched to one or more tracks
- '''
+ particle.is_crthit_matched: bool
+ Indicator for whether the given particle has a CRT-TPC match
+ particle.crthit_ids: List[int]
+ List of IDs for CRT hits that were matched to that particle
+ """
crthits = {}
assert len(self.crthit_keys) > 0
for key in self.crthit_keys:
diff --git a/spine/post/manager.py b/spine/post/manager.py
index c5ca196f..d28cbf8d 100644
--- a/spine/post/manager.py
+++ b/spine/post/manager.py
@@ -1,6 +1,7 @@
"""Manages the operation of post-processors."""
from warnings import warn
+from copy import deepcopy
from collections import defaultdict, OrderedDict
import numpy as np
@@ -27,6 +28,7 @@ def __init__(self, cfg, parent_path=None):
Path to the analysis tools configuration file
"""
# Loop over the post-processor modules and get their priorities
+ cfg = deepcopy(cfg)
keys = np.array(list(cfg.keys()))
priorities = -np.ones(len(keys), dtype=np.int32)
for i, k in enumerate(keys):
diff --git a/spine/post/metric/match.py b/spine/post/metric/match.py
index c18043b8..ea0c1e1c 100644
--- a/spine/post/metric/match.py
+++ b/spine/post/metric/match.py
@@ -14,6 +14,8 @@
class MatchProcessor(PostBase):
"""Does the matching between reconstructed and true objects."""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'match'
def __init__(self, fragment=None, particle=None, interaction=None,
@@ -36,8 +38,8 @@ def __init__(self, fragment=None, particle=None, interaction=None,
# Initialize the necessary matchers
configs = {'fragment': fragment, 'particle': particle,
'interaction': interaction}
+ keys = {}
self.matchers = {}
- self.keys = {}
for key, cfg in configs.items():
if cfg is not None and cfg != False:
# Initialize the matcher
@@ -47,11 +49,14 @@ def __init__(self, fragment=None, particle=None, interaction=None,
# If any matcher includes ghost points, must load meta
if self.matchers[key].ghost:
- self.keys['meta'] = True
+ keys['meta'] = True
assert len(self.matchers), (
"Must specify one of 'fragment', 'particle' or 'interaction'.")
+ # Update the set of keys necessary for this post-processor
+ self.update_keys(keys)
+
# Initialize the parent class
super().__init__(list(self.matchers.keys()), 'both', truth_point_mode)
diff --git a/spine/post/optical/barycenter.py b/spine/post/optical/barycenter.py
index 09c3317f..cab47a34 100644
--- a/spine/post/optical/barycenter.py
+++ b/spine/post/optical/barycenter.py
@@ -1,18 +1,21 @@
+"""Module that supports barycenter-based flash matching."""
+
import numpy as np
-from spine.utils.geo import Geometry
from spine.utils.numba_local import cdist
+
class BarycenterFlashMatcher:
"""Matches interactions and flashes by matching the charge barycenter of
TPC interactions with the light barycenter of optical flashes.
"""
- axes = ['x', 'y', 'z']
- def __init__(self, match_method='threshold', dimensions=[1,2],
- charge_weighted=False, time_window=None, first_flash_only=False,
- min_inter_size=None, min_flash_pe=None, match_distance=None,
- detector=None, boundary_file=None, source_file=None):
+ # List of valid matching methods
+ _match_methods = ('threshold', 'best')
+
+ def __init__(self, match_method='threshold', dimensions=[1, 2],
+ charge_weighted=False, time_window=None, first_flash_only=False,
+ min_inter_size=None, min_flash_pe=None, match_distance=None):
"""Initalize the barycenter flash matcher.
Parameters
@@ -21,7 +24,7 @@ def __init__(self, match_method='threshold', dimensions=[1,2],
Matching method (one of 'threshold' or 'best')
- 'threshold': If the two barycenters are within some distance, match
- 'best': For each flash, pick the best matched interaction
- dimensions: list, default [0,1,2]
+ dimensions: list, default [1, 2]
Dimensions involved in the distance computation
charge_weighted : bool, default False
Use interaction pixel charge information to weight the centroid
@@ -35,14 +38,17 @@ def __init__(self, match_method='threshold', dimensions=[1,2],
Minimum number of total PE in a flash to consider it
match_distance : float, optional
If a threshold is used, specifies the acceptable distance
- detector : str, optional
- Detector to get the geometry from
- boundary_file : str, optional
- Path to a detector boundary file. Supersedes `detector` if set
"""
- self.geo = None
- if detector is not None or boundary_file is not None:
- self.geo = Geometry(detector, boundary_file, source_file)
+ # Check validity of key parameters
+ if match_method not in self._match_methods:
+ raise ValueError(
+ "Barycenter flash matching method not recognized: "
+ f"{match_method}. Must be one of {self._match_methods}.")
+
+ if match_method == 'threshold':
+ assert match_distance is not None, (
+ "When using the `threshold` method, must specify the "
+ "`match_distance` parameter.")
# Store the flash matching parameters
self.match_method = match_method
@@ -54,22 +60,15 @@ def __init__(self, match_method='threshold', dimensions=[1,2],
self.min_flash_pe = min_flash_pe
self.match_distance = match_distance
- # Check validity of certain parameters
- if self.match_method not in ['threshold', 'best']:
- msg = f'Barycenter flash matching method not recognized: {match_method}'
- raise ValueError(msg)
- if self.match_method == 'threshold':
- assert self.match_distance is not None, \
- 'When using the `threshold` method, must specify `match_distance`'
def get_matches(self, interactions, flashes):
"""Makes [interaction, flash] pairs that have compatible barycenters.
Parameters
----------
- interactions : List[Interaction]
+ interactions : List[Union[RecoInteraction, TruthInteraction]]
List of interactions
- flashes : List[larcv.Flash]
+ flashes : List[Flash]
List of optical flashes
Returns
@@ -77,24 +76,30 @@ def get_matches(self, interactions, flashes):
List[[Interaction, larcv.Flash, float]]
List of [interaction, flash, distance] triplets
"""
- # Restrict the flashes to those that fit the selection criteria
- flashes = np.asarray(flashes, dtype=object)
+ # Restrict the flashes to those that fit the selection criteria.
+ # Skip if there are no valid flashes
if self.time_window is not None:
t1, t2 = self.time_window
- mask = [(f.time() > t1 and f.time() < t2) for f in flashes]
- flashes = flashes[mask]
+ flahses = [f for f in flashes if (f.time > t1 and f.time < t2)]
+
if self.min_flash_pe is not None:
- mask = [f.flash_total_pE > self.min_flash_pe for f in flashes]
- flashes = flashes[mask]
+ flashes = [f for f in flashes if f.total_pe > self.min_flash_pe]
+
if not len(flashes):
return []
+
+ # If requested, restrict the list of flashes to match to the first one
if self.first_flash_only:
flashes = [flashes[0]]
- # Restrict the interactions to those that fit the selection criterion
+ # Restrict the interactions to those that fit the selection criterion.
+ # Skip if there are no valid interactions
if self.min_inter_size is not None:
- interactions = [ia for ia in interactions \
- if ia.size > self.min_inter_size]
+ interactions = [
+ inter for inter in interactions
+ if inter.size > self.min_inter_size
+ ]
+
if not len(interactions):
return []
@@ -102,19 +107,19 @@ def get_matches(self, interactions, flashes):
op_centroids = np.empty((len(flashes), len(self.dims)))
op_widths = np.empty((len(flashes), len(self.dims)))
for i, f in enumerate(flashes):
- for j, d in enumerate(self.dims):
- op_centroids[i, j] = getattr(f, f'{self.axes[d]}Center')()
- op_widths[i, j] = getattr(f, f'{self.axes[d]}Width')()
+ op_centroids[i] = f.center[self.dims]
+ op_widhts[i] = f.width[self.dims]
# Get interactions centroids
int_centroids = np.empty((len(interactions), len(self.dims)))
- for i, ia in enumerate(interactions):
+ for i, inter in enumerate(interactions):
if not self.charge_weighted:
- int_centroids[i] = np.mean(ia.points[:, self.dims], axis=0)
+ int_centroids[i] = np.mean(inter.points[:, self.dims], axis=0)
+
else:
- int_centroids[i] = np.sum(ia.depositions * \
- ia.points[:, self.dims], axis=0) \
- / np.sum(ia.depositions)
+ int_centroids[i] = np.sum(
+ inter.depositions * inter.points[:, self.dims], axis=0)
+ int_centroids[i] /= np.sum(inter.depositions)
# Compute the flash to interaction distance matrix
dist_mat = cdist(op_centroids, int_centroids)
@@ -126,14 +131,17 @@ def get_matches(self, interactions, flashes):
for i, f in enumerate(flashes):
best_match = np.argmin(dist_mat[i])
dist = dist_mat[i, best_match]
- if self.match_distance is not None \
- and dist > self.match_distance:
- continue
+ if (self.match_distance is not None and
+ dist > self.match_distance):
+ continue
match.append((interactions[best_match], f, dist))
+
elif self.match_method == 'threshold':
# Find all compatible pairs
valid_pairs = np.vstack(np.where(dist_mat <= self.match_distance)).T
- matches = [(interactions[j], flashes[i], dist_mat[i,j]) \
- for i, j in valid_pairs]
+ matches = [
+ (interactions[j], flashes[i], dist_mat[i,j])
+ for i, j in valid_pairs
+ ]
return matches
diff --git a/spine/post/optical/flash_matching.py b/spine/post/optical/flash_matching.py
index cacc5c37..9cc8161f 100644
--- a/spine/post/optical/flash_matching.py
+++ b/spine/post/optical/flash_matching.py
@@ -1,8 +1,15 @@
+"""Post-processor in charge of finding matches between charge and light.
+"""
+
import numpy as np
from warnings import warn
from spine.post.base import PostBase
+from spine.data.out.base import OutBase
+
+from spine.utils.geo import Geometry
+
from .barycenter import BarycenterFlashMatcher
from .likelihood import LikelihoodFlashMatcher
@@ -11,20 +18,35 @@
class FlashMatchProcessor(PostBase):
"""Associates TPC interactions with optical flashes."""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'flash_match'
- aliases = ['run_flash_matching']
- def __init__(self, flash_map, method='likelihood', run_mode='reco',
- truth_point_mode='points', parent_path=None, **kwargs):
+ # Alternative allowed names of the post-processor
+ aliases = ('run_flash_matching',)
+
+ def __init__(self, flash_key, volume, ref_volume_id=None,
+ method='likelihood', detector=None, geometry_file=None,
+ run_mode='reco', truth_point_mode='points',
+ truth_dep_mode='depositions', parent_path=None, **kwargs):
"""Initialize the flash matching algorithm.
Parameters
----------
+ flash_key : str
+ Flash data product name. In most cases, this is unambiguous, unless
+ there are multiple types of segregated optical detectors
+ volume : str
+ Physical volume corresponding to each flash ('module' or 'tpc')
+ ref_volume_id : str, optional
+ If specified, the flash matching expects all interactions/flashes
+ to live into a specific optical volume. Must shift everything.
method : str, default 'likelihood'
Flash matching method (one of 'likelihood' or 'barycenter')
- flash_map : dict
- Maps a flash data product key in the data ditctionary to an
- optical volume in the detector
+ detector : str, optional
+ Detector to get the geometry from
+ geometry_file : str, optional
+ Path to a `.yaml` geometry file to load the geometry from
parent_path : str, optional
Path to the parent directory of the main analysis configuration.
This allows for the use of relative paths in the post-processors.
@@ -33,13 +55,21 @@ def __init__(self, flash_map, method='likelihood', run_mode='reco',
"""
# Initialize the parent class
super().__init__(
- 'interaction', run_mode, truth_point_mode,
+ 'interaction', run_mode, truth_point_mode, truth_dep_mode,
parent_path=parent_path)
- # If there is no map from flash data product to volume ID, throw
- self.flash_map = flash_map
- for key in self.flash_map:
- self.keys[key] = True
+ # Make sure the flash data product is available, store
+ self.flash_key = flash_key
+ self.update_keys({flash_key: True})
+
+ # Initialize the detector geometry
+ self.geo = Geometry(detector, geometry_file)
+
+ # Get the volume within which each flash is confined
+ assert volume in ['tpc', 'module'], (
+ "The `volume` must be one of 'tpc' or 'module'.")
+ self.volume = volume
+ self.ref_volume_id = ref_volume_id
# Initialize the flash matching algorithm
if method == 'barycenter':
@@ -47,7 +77,7 @@ def __init__(self, flash_map, method='likelihood', run_mode='reco',
elif method == 'likelihood':
self.matcher = LikelihoodFlashMatcher(
- **kwargs, parent_path=self.parent_path)
+ detector=detector, parent_path=self.parent_path, **kwargs)
else:
raise ValueError(f'Flash matching method not recognized: {method}')
@@ -63,15 +93,25 @@ def process(self, data):
Notes
-----
This post-processor modifies the list of `interaction` objectss
- in-place by adding the following attributes:
+ in-place by filling the following attributes
- interaction.is_flash_matched: (bool)
Indicator for whether the given interaction has a flash match
- - interaction.flash_time: float
- The flash time in microseconds
+ - interaction.flash_ids: np.ndarray
+ The flash IDs in the flash list
+ - interaction.flash_volume_ids: np.ndarray
+ The flash optical volume IDs in the flash list
+ - interaction.flash_times: np.ndarray
+ The flash time(s) in microseconds
- interaction.flash_total_pe: float
- - interaction.flash_hypo_pe: float
+ Total number of PEs associated with the matched flash(es)
+ - interaction.flash_hypo_pe: float, optional
+ Total number of PEss associated with the hypothesis flash
"""
- # Loop over the keys to match
+ # Fetch the optical volume each flash belongs to
+ flashes = data[self.flash_key]
+ volume_ids = np.asarray([f.volume_id for f in flashes])
+
+ # Loop over the optical volumes, run flash matching
for k in self.interaction_keys:
# Fetch interactions, nothing to do if there are not any
interactions = data[k]
@@ -83,31 +123,95 @@ def process(self, data):
# Clear previous flash matching information
for inter in interactions:
+ inter.flash_ids = []
+ inter.flash_volume_ids = []
+ inter.flash_times = []
if inter.is_flash_matched:
inter.is_flash_matched = False
- inter.flash_id = -1
- inter.flash_time = -np.inf
- inter.flash_total_pe = -1.0
- inter.flash_hypo_pe = -1.0
-
- # Loop over flash keys
- for key, module_id in self.flash_map.items():
- # Get the list of flashes associated with that key
- flashes = data[key]
-
- # Get list of interactions that originate from the same module
- # TODO: this only works for interactions coming from a single
- # TODO: module. Must fix this.
- ints = [inter for inter in interactions if inter.module_ids[0] == module_id]
+ inter.flash_total_pe = -1.
+ inter.flash_hypo_pe = -1.
+
+ # Loop over the optical volumes
+ for volume_id in np.unique(volume_ids):
+ # Get the list of flashes associated with this optical volume
+ flashes_v = []
+ for flash in flashes:
+ # Skip if the flash is not associated with the right volume
+ if flash.volume_id != volume_id:
+ continue
+
+ # Reshape the flash based on geometry
+ pe_per_ch = np.zeros(
+ self.geo.optical.num_detectors_per_volume,
+ dtype=flash.pe_per_ch.dtype)
+ if self.ref_volume_id is not None:
+ lower = flash.volume_id*len(pe_per_ch)
+ upper = (flash.volume_id + 1)*len(pe_per_ch)
+ pe_per_ch = flash.pe_per_ch[lower:upper]
+ else:
+ pe_per_ch[:len(flash.pe_per_ch)] = flash.pe_per_ch
+
+ flash.pe_per_ch = pe_per_ch
+ flashes_v.append(flash)
+
+ # Crop interactions to only include depositions in the optical volume
+ interactions_v = []
+ for inter in interactions:
+ # Fetch the points in the current optical volume
+ sources = self.get_sources(inter)
+ if self.volume == 'module':
+ index = self.geo.get_volume_index(sources, volume_id)
+
+ elif self.volume == 'tpc':
+ num_cpm = self.geo.tpc.num_chambers_per_module
+ module_id, tpc_id = volume_id//num_cpm, volume_id%num_cpm
+ index = self.geo.get_volume_index(sources, module_id, tpc_id)
+
+ # If there are no points in this volume, proceed
+ if len(index) == 0:
+ continue
+
+ # Fetch points and depositions
+ points = self.get_points(inter)[index]
+ depositions = self.get_depositions(inter)[index]
+ if self.ref_volume_id is not None:
+ # If the reference volume is specified, shift positions
+ points = self.geo.translate(
+ points, volume_id, self.ref_volume_id)
+
+ # Create an interaction which holds positions/depositions
+ inter_v = OutBase(
+ id=inter.id, points=points, depositions=depositions)
+ interactions_v.append(inter_v)
# Run flash matching
- matches = self.matcher.get_matches(ints, flashes)
+ matches = self.matcher.get_matches(interactions_v, flashes_v)
# Store flash information
- for inter, flash, match in matches:
- inter.is_flash_matched = True
- inter.flash_id = int(flash.id)
- inter.flash_time = float(flash.time)
- inter.flash_total_pe = float(flash.total_pe)
+ for inter_v, flash, match in matches:
+ # Get the interaction that matches the cropped version
+ inter = interactions[inter_v.id]
+
+ # Get the flash hypothesis (if the matcher produces one)
+ hypo_pe = -1.
if hasattr(match, 'hypothesis'):
- inter.flash_hypo_pe = float(np.sum(match.hypothesis))
+ hypo_pe = float(np.sum(list(match.hypothesis)))
+
+ # Append
+ inter.flash_ids.append(int(flash.id))
+ inter.flash_volume_ids.append(int(flash.volume_id))
+ inter.flash_times.append(float(flash.time))
+ if inter.is_flash_matched:
+ inter.flash_total_pe += float(flash.total_pe)
+ inter.flash_hypo_pe += hypo_pe
+
+ else:
+ inter.is_flash_matched = True
+ inter.flash_total_pe = float(flash.total_pe)
+ inter.flash_hypo_pe = hypo_pe
+
+ # Cast list attributes to numpy arrays
+ for inter in interactions:
+ inter.flash_ids = np.asarray(inter.flash_ids, dtype=np.int32)
+ inter.flash_volume_ids = np.asarray(inter.flash_volume_ids, dtype=np.int32)
+ inter.flash_times = np.asarray(inter.flash_times, dtype=np.float32)
diff --git a/spine/post/optical/likelihood.py b/spine/post/optical/likelihood.py
index 1d387d5c..f93ceda8 100644
--- a/spine/post/optical/likelihood.py
+++ b/spine/post/optical/likelihood.py
@@ -1,56 +1,62 @@
+"""Module which supports likelihood-based flash matchin (OpT0Finder)."""
+
import os, sys
import numpy as np
import time
-from spine.utils.geo import Geometry
-
class LikelihoodFlashMatcher:
"""Interface class between full chain outputs and OpT0Finder
See https://github.com/drinkingkazu/OpT0Finder for more details about it.
"""
- def __init__(self, cfg, parent_path=None, reflash_merging_window=None,
- detector=None, boundary_file=None, scaling=1.,
- truth_dep_mode='depositions'):
+
+ def __init__(self, cfg, detector, parent_path=None,
+ reflash_merging_window=None, scaling=1., alpha=0.21,
+ recombination_mip=0.65, legacy=False):
"""Initialize the likelihood-based flash matching algorithm.
Parameters
----------
cfg : str
Flash matching configuration file path
+ detector : str, optional
+ Detector to get the geometry from
parent_path : str, optional
Path to the parent configuration file (allows for relative paths)
reflash_merging_window : float, optional
Maximum time between successive flashes to be considered a reflash
- detector : str, optional
- Detector to get the geometry from
- boundary_file : str, optional
- Path to a detector boundary file. Supersedes `detector` if set
scaling : Union[float, str], default 1.
Global scaling factor for the depositions (can be an expression)
- truth_dep_mode : str, default 'depositions'
- Attribute used to fetch deposition values for truth interactions
+ alpha : float, default 0.21
+ Number of excitons (Ar*) divided by number of electron-ion pairs (e-,Ar+)
+ recombination_mip : float, default 0.65
+ Recombination factor for MIP-like particles in LAr
+ legacy : bool, default False
+ Use the legacy OpT0Finder function(s). TODO: remove when dropping legacy
"""
# Initialize the flash manager (OpT0Finder wrapper)
- self.initialize_backend(cfg, parent_path)
-
- # Initialize the geometry
- self.geo = Geometry(detector, boundary_file)
+ self.initialize_backend(cfg, detector, parent_path)
# Get the external parameters
- self.truth_dep_mode = truth_dep_mode
self.reflash_merging_window = reflash_merging_window
self.scaling = scaling
if isinstance(self.scaling, str):
self.scaling = eval(self.scaling)
+ self.alpha = alpha
+ if isinstance(self.alpha, str):
+ self.alpha = eval(self.alpha)
+ self.recombination_mip = recombination_mip
+ if isinstance(self.recombination_mip, str):
+ self.recombination_mip = eval(self.recombination_mip)
+ self.legacy = legacy
# Initialize flash matching attributes
self.matches = None
self.qcluster_v = None
self.flash_v = None
- def initialize_backend(self, cfg, parent_path):
+ def initialize_backend(self, cfg, detector, parent_path):
"""Initialize OpT0Finder (backend).
Expects that the environment variable `FMATCH_BASEDIR` is set.
@@ -62,6 +68,8 @@ def initialize_backend(self, cfg, parent_path):
----------
cfg: str
Path to config for OpT0Finder
+ detector : str, optional
+ Detector to get the geometry from
parent_path : str, optional
Path to the parent configuration file (allows for relative paths)
"""
@@ -83,9 +91,17 @@ def initialize_backend(self, cfg, parent_path):
os.environ['FMATCH_DATADIR'] = os.path.join(basedir, 'dat')
# Load up the detector specifications
+ if detector is None:
+ det_cfg = os.path.join(basedir, 'dat/detector_specs.cfg')
+ else:
+ det_cfg = os.path.join(basedir, f'dat/detector_specs_{detector}.cfg')
+
+ if not os.path.isfile(det_cfg):
+ raise FileNotFoundError(
+ f"Cannot file detector specification file: {det_cfg}.")
+
from flashmatch import flashmatch
- flashmatch.DetectorSpecs.GetME(
- os.path.join(basedir, 'dat/detector_specs.cfg'))
+ flashmatch.DetectorSpecs.GetME(det_cfg)
# Fetch and initialize the OpT0Finder configuration
if parent_path is not None and not os.path.isfile(cfg):
@@ -124,23 +140,6 @@ def get_matches(self, interactions, flashes):
if not len(interactions) or not len(flashes):
return []
- # Get the module ID in which all the interactions live
- module_ids = np.empty(len(interactions), dtype=np.int64)
- for i, inter in enumerate(interactions):
- ids = inter.module_ids
- assert len(ids) > 0, (
- "The interaction object does not contain any information "
- "about which optical module produced it; must be provided.")
- assert len(ids) == 1, (
- "Cannot match interactions that are composed of points "
- "originating for more than one optical module.")
- module_ids[i] = ids[0]
-
- # Check that all interactions live in one module, store it
- assert len(np.unique(module_ids)) == 1, (
- "Should only provide interactions from a single optical module.")
- self.module_id = module_ids[0]
-
# Build a list of QCluster_t (OpT0Finder interaction representation)
self.qcluster_v = self.make_qcluster_list(interactions)
@@ -188,19 +187,19 @@ def make_qcluster_list(self, interactions):
qcluster.time = 0
# Get the point coordinates
- points = self.geo.translate(inter.points[valid_mask],
- self.module_id, 0)
+ points = inter.points[valid_mask]
# Get the depositions
- if not inter.is_truth:
- depositions = inter.depositions[valid_mask]
- else:
- depositions = getattr(inter, self.truth_dep_mode)
+ depositions = inter.depositions[valid_mask]
# Fill the trajectory
pytraj = np.hstack([points, depositions[:, None]])
traj = flashmatch.as_geoalgo_trajectory(pytraj)
- qcluster += self.light_path.MakeQCluster(traj, self.scaling)
+ if self.legacy:
+ qcluster += self.light_path.MakeQCluster(traj, self.scaling)
+ else:
+ qcluster += self.light_path.MakeQCluster(
+ traj, self.scaling, self.alpha, self.recombination_mip)
# Append
qcluster_v.append(qcluster)
@@ -243,7 +242,7 @@ def make_flash_list(self, flashes):
for idx, f in enumerate(flashes):
# Initialize the Flash_t object
flash = flashmatch.Flash_t()
- flash.idx = f.id # Assign a unique index
+ flash.idx = int(f.id) # Assign a unique index
flash.time = f.time # Flash timing, a candidate T0
# Assign the flash position and error on this position
@@ -251,9 +250,8 @@ def make_flash_list(self, flashes):
flash.x_err, flash.y_err, flash.z_err = 0, 0, 0
# Assign the individual PMT optical hit PEs
- offset = 0 if len(f.pe_per_ch) == 180 else 180
- for i in range(180):
- flash.pe_v.push_back(f.pe_per_ch[i + offset])
+ for i in range(len(f.pe_per_ch)):
+ flash.pe_v.push_back(f.pe_per_ch[i])
flash.pe_err_v.push_back(0.)
# Append
@@ -286,7 +284,6 @@ def run_flash_matching(self):
# Adjust the output position to account for the module shift
for m in all_matches:
pos = np.array([m.tpc_point.x, m.tpc_point.y, m.tpc_point.z])
- pos = self.geo.translate(pos, 0, self.module_id)
m.tpc_point.x = pos[0]
m.tpc_point.y = pos[1]
m.tpc_point.z = pos[2]
diff --git a/spine/post/reco/calo.py b/spine/post/reco/calo.py
index 477bb733..63bae3ec 100644
--- a/spine/post/reco/calo.py
+++ b/spine/post/reco/calo.py
@@ -14,8 +14,12 @@ class CalorimetricEnergyProcessor(PostBase):
"""Compute calorimetric energy by summing the charge depositions and
scaling by the ADC to MeV conversion factor, if needed.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'calo_ke'
- aliases = ['reconstruct_calo_energy']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('reconstruct_calo_energy',)
def __init__(self, scaling=1., shower_fudge=1., obj_type='particle',
run_mode='reco', truth_dep_mode='depositions'):
@@ -65,12 +69,19 @@ def process(self, data):
class CalibrationProcessor(PostBase):
"""Apply calibrations to the reconstructed objects."""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'calibration'
- aliases = ['apply_calibrations']
- keys = {'run_info': False}
- def __init__(self, dedx=2.2, do_tracking=False, obj_type='particle',
- run_mode='reco', truth_point_mode='points', **cfg):
+ # Alternative allowed names of the post-processor
+ aliases = ('apply_calibrations',)
+
+ # Set of data keys needed for this post-processor to operate
+ _keys = (('run_info', False),)
+
+ def __init__(self, dedx=2.2, do_tracking=False,
+ obj_type=('particle', 'interaction'), run_mode='reco',
+ truth_point_mode='points', **cfg):
"""Initialize the calibration manager.
Parameters
@@ -94,12 +105,22 @@ def __init__(self, dedx=2.2, do_tracking=False, obj_type='particle',
self.do_tracking = do_tracking
# Add necessary keys
- self.keys['points'] = run_mode != 'truth'
- self.keys[self.truth_point_key] = run_mode != 'reco'
- self.keys['depositions'] = run_mode != 'truth'
- self.keys[self.truth_dep_key] = run_mode != 'reco'
- self.keys['sources'] = run_mode != 'truth'
- self.keys[self.truth_source_key] = run_mode != 'reco'
+ keys = {}
+ if run_mode != 'truth':
+ keys.update({
+ 'points': True,
+ 'depositions': True,
+ 'sources': True
+ })
+
+ if run_mode != 'reco':
+ keys.update({
+ self.truth_point_key: True,
+ self.truth_dep_key: True,
+ self.truth_source_key: True
+ })
+
+ self.update_keys(keys)
def process(self, data):
"""Apply calibrations to each particle in one entry.
@@ -116,7 +137,7 @@ def process(self, data):
run_id = run_info.run
# Loop over particle objects
- for k in self.obj_keys:
+ for k in self.particle_keys:
points_key = 'points' if not 'truth' in k else self.truth_point_key
source_key = 'sources' if not 'truth' in k else self.truth_source_key
dep_key = 'depositions' if not 'truth' in k else self.truth_dep_key
@@ -155,3 +176,14 @@ def process(self, data):
data[dep_key][unass_index] = self.calibrator(
data[points_key][unass_index], data[dep_key][unass_index],
data[source_key][unass_index], run_id, self.dedx)
+
+ # If requested, updated the depositions attribute of interactions
+ for k in self.interaction_keys:
+ dep_key = 'depositions' if not 'truth' in k else self.truth_dep_key
+ for inter in data[k]:
+ # Update depositions for the interaction
+ depositions = data[dep_key][inter.index]
+ if not part.is_truth:
+ inter.depositions = depositions
+ else:
+ setattr(inter, self.truth_dep_mode, depositions)
diff --git a/spine/post/reco/cathode_cross.py b/spine/post/reco/cathode_cross.py
index ceb2bbb8..1e342b01 100644
--- a/spine/post/reco/cathode_cross.py
+++ b/spine/post/reco/cathode_cross.py
@@ -23,13 +23,17 @@ class CathodeCrosserProcessor(PostBase):
- If the particle is sigificantly out-of-time, a cathode crosser will
be composed of two distinct reconstructed particle objects
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'cathode_crosser'
- aliases = ['find_cathode_crossers']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('find_cathode_crossers',)
def __init__(self, crossing_point_tolerance, offset_tolerance,
angle_tolerance, adjust_crossers=True, merge_crossers=True,
- detector=None, boundary_file=None, source_file=None,
- run_mode='reco', truth_point_mode='points'):
+ detector=None, geometry_file=None, run_mode='reco',
+ truth_point_mode='points'):
"""Initialize the cathode crosser finder algorithm.
Parameters
@@ -51,17 +55,15 @@ def __init__(self, crossing_point_tolerance, offset_tolerance,
and merge them into one particle
detector : str, optional
Detector to get the geometry from
- boundary_file : str, optional
- Path to a detector boundary file. Supersedes `detector` if set
- source_file : str, optional
- Path to a detector source file. Supersedes `detector` if set
+ geometry_file : str, optional
+ Path to a `.yaml` geometry file to load the geometry from
"""
# Initialize the parent class
super().__init__(
- ['particle', 'interaction'], run_mode, truth_point_mode)
+ ('particle', 'interaction'), run_mode, truth_point_mode)
# Initialize the geometry
- self.geo = Geometry(detector, boundary_file, source_file)
+ self.geo = Geometry(detector, geometry_file)
# Store the matching parameters
self.crossing_point_tolerance = crossing_point_tolerance
@@ -71,10 +73,13 @@ def __init__(self, crossing_point_tolerance, offset_tolerance,
self.merge_crossers = merge_crossers
# Add the points to the list of keys to load
+ keys = {}
if run_mode != 'truth':
- self.keys['points'] = True
+ keys['points'] = True
if run_mode != 'reco':
- self.keys[truth_point_mode] = True
+ keys[truth_point_mode] = True
+
+ self.update_keys(keys)
def process(self, data):
"""Find cathode crossing particles in one entry.
@@ -142,13 +147,14 @@ def process(self, data):
continue
# Get the cathode position, drift axis and cathode plane axes
- daxis, cpos = self.geo.cathodes[modules_i[0]]
+ daxis = self.geo.tpc[modules_i[0]].drift_axis
+ cpos = self.geo.tpc[modules_i[0]].cathode_pos
caxes = np.array([i for i in range(3) if i != daxis])
# Store the distance of the particle to the cathode
- tpc_offset = self.geo.get_min_tpc_offset(
+ tpc_offset = self.geo.get_min_volume_offset(
end_points_i, modules_i[0], tpcs_i[0])[daxis]
- cdists = end_points_i[:,daxis] - tpc_offset - cpos
+ cdists = end_points_i[:, daxis] - tpc_offset - cpos
# Loop over other tracks
j = i + 1
@@ -254,24 +260,12 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
points_attr = 'points' if not truth else self.truth_point_mode
points_key = 'points' if not truth else self.truth_point_key
particles = data[part_key]
- closest_attr = [None, None]
if idx_j is not None:
# Merge particles
int_id_i = particles[idx_i].interaction_id
int_id_j = particles[idx_j].interaction_id
particles[idx_i].merge(particles.pop(idx_j))
- # Assign start and end point to a specific TPC
- for attr in ('start_point', 'end_point'):
- key_point = getattr(particles[idx_i], attr)
- points = self.get_points(particles[idx_i])
- argmin = np.argmin(cdist(key_point[None, :], points))
- sources = self.get_sources(particles[idx_i])
- tpc_id = self.geo.get_contributors(sources[argmin][None, :])[1]
- closest_attr[tpc_id[0]] = attr
-
- assert np.all([val is not None for val in closest_attr])
-
# Update the particle IDs and interaction IDs
assert idx_j > idx_i
for i, p in enumerate(particles):
@@ -279,6 +273,18 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
if p.interaction_id == int_id_j:
p.interaction_id = int_id_i
+ # Assign start and end point to a specific TPC
+ closest_attr = [None, None]
+ for attr in ('start_point', 'end_point'):
+ key_point = getattr(particles[idx_i], attr)
+ points = self.get_points(particles[idx_i])
+ argmin = np.argmin(cdist(key_point[None, :], points))
+ sources = self.get_sources(particles[idx_i])
+ tpc_id = self.geo.get_contributors(sources[argmin][None, :])[1]
+ closest_attr[tpc_id[0]] = attr
+
+ assert np.all([val is not None for val in closest_attr])
+
# Get TPCs that contributed to this particle
particle = particles[idx_i]
modules, tpcs = self.geo.get_contributors(self.get_sources(particle))
@@ -291,7 +297,8 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
# Get the cathode position
m = modules[0]
- daxis, cpos = self.geo.cathodes[m]
+ daxis = self.geo.tpc[m].drift_axis
+ cpos = self.geo.tpc[m].cathode_pos
# Loop over contributing TPCs, shift the points in each independently
offsets, global_offset = self.get_cathode_offsets(
@@ -300,7 +307,7 @@ def adjust_positions(self, data, idx_i, idx_j=None, truth=False):
# Move each of the sister particles by the same amount
for sister in sisters:
# Find the index corresponding to the sister particle
- tpc_index = self.geo.get_tpc_index(
+ tpc_index = self.geo.get_volume_index(
self.get_sources(sister), m, t)
index = self.get_index(sister)[tpc_index]
if not len(index):
@@ -377,28 +384,29 @@ def get_cathode_offsets(self, particle, module, tpcs):
General offset for this particle (proxy of out-of-time displacement)
"""
# Get the cathode position
- daxis, cpos = self.geo.cathodes[module]
+ daxis = self.geo.tpc[module].drift_axis
+ cpos = self.geo.tpc[module].cathode_pos
dvector = (np.arange(3) == daxis).astype(float)
# Check which side of the cathode each TPC lives
flip = (-1) ** (
- self.geo.boundaries[module, tpcs[0], daxis].mean()
- > self.geo.boundaries[module, tpcs[1], daxis].mean())
+ self.geo.tpc[module, tpcs[0]].boundaries[daxis].mean()
+ > self.geo.tpc[module, tpcs[1]].boundaries[daxis].mean())
# Loop over the contributing TPCs
closest_points = np.empty((2, 3))
offsets = np.empty(2)
for i, t in enumerate(tpcs):
# Get the end points of the track segment
- index = self.geo.get_tpc_index(
+ index = self.geo.get_volume_index(
self.get_sources(particle), module, t)
points = self.get_points(particle)[index]
idx0, idx1, _ = farthest_pair(points, 'recursive')
end_points = points[[idx0, idx1]]
# Find the point closest to the cathode
- tpc_offset = self.geo.get_min_tpc_offset(end_points,
- module, t)[daxis]
+ tpc_offset = self.geo.get_min_volume_offset(
+ end_points, module, t)[daxis]
cdists = end_points[:, daxis] - tpc_offset - cpos
argmin = np.argmin(np.abs(cdists))
closest_points[i] = end_points[argmin]
diff --git a/spine/post/reco/direction.py b/spine/post/reco/direction.py
index 720c2f58..87de5e1f 100644
--- a/spine/post/reco/direction.py
+++ b/spine/post/reco/direction.py
@@ -14,8 +14,12 @@ class DirectionProcessor(PostBase):
This modules assign the `start_dir` and `end_dir` attributes.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'direction'
- aliases = ['reconstruct_directions']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('reconstruct_directions',)
def __init__(self, neighborhood_radius=-1, optimize=True,
obj_type='particle', truth_point_mode='points',
diff --git a/spine/post/reco/geometry.py b/spine/post/reco/geometry.py
index 45f92a66..89b3fca4 100644
--- a/spine/post/reco/geometry.py
+++ b/spine/post/reco/geometry.py
@@ -17,13 +17,17 @@ class ContainmentProcessor(PostBase):
boundaries of the detector and assign the `is_contained` attribute
accordingly.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'containment'
- aliases = ['check_containment']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('check_containment',)
def __init__(self, margin, cathode_margin=None, detector=None,
- boundary_file=None, source_file=None, mode='module',
+ geometry_file=None, mode='module',
allow_multi_module=False, min_particle_sizes=0,
- obj_type=['particle', 'interaction'],
+ obj_type=('particle', 'interaction'),
truth_point_mode='points', run_mode='both'):
"""Initialize the containment conditions.
@@ -45,10 +49,8 @@ def __init__(self, margin, cathode_margin=None, detector=None,
If specified, sets a different margin for the cathode boundaries
detector : str, optional
Detector to get the geometry from
- boundary_file : str, optional
- Path to a detector boundary file. Supersedes `detector` if set
- source_file : str, optional
- Path to a detector source file. Supersedes `detector` if set
+ geometry_file : str, optional
+ Path to a `.yaml` geometry file to load the geometry from
mode : str, default 'module'
Containement criterion (one of 'global', 'module', 'tpc'):
- If 'tpc', makes sure it is contained within a single tpc
@@ -72,14 +74,14 @@ def __init__(self, margin, cathode_margin=None, detector=None,
# Initialize the geometry, if needed
if mode != 'meta':
self.use_meta = False
- self.geo = Geometry(detector, boundary_file, source_file)
+ self.geo = Geometry(detector, geometry_file)
self.geo.define_containment_volumes(margin, cathode_margin, mode)
else:
- assert detector is None and boundary_file is None, (
+ assert detector is None and geometry_file is None, (
"When using `meta` to check containment, must not "
"provide geometry information.")
- self.keys['meta'] = True
+ self.update_keys({'meta': True})
self.use_meta = True
self.margin = margin
@@ -154,11 +156,15 @@ class FiducialProcessor(PostBase):
The fiducial volume is defined as a margin distances from each of the
detector walls.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'fiducial'
- aliases = ['check_fiducial']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('check_fiducial',)
def __init__(self, margin, cathode_margin=None, detector=None,
- boundary_file=None, mode='module', run_mode='both',
+ geometry_file=None, mode='module', run_mode='both',
truth_vertex_mode='vertex'):
"""Initialize the fiducial conditions.
@@ -175,8 +181,8 @@ def __init__(self, margin, cathode_margin=None, detector=None,
If specified, sets a different margin for the cathode boundaries
detector : str, default 'icarus'
Detector to get the geometry from
- boundary_file : str, optional
- Path to a detector boundary file. Supersedes `detector` if set
+ geometry_file : str, optional
+ Path to a `.yaml` geometry file to load the geometry from
mode : str, default 'module'
Containement criterion (one of 'global', 'module', 'tpc'):
- If 'tpc', makes sure it is contained within a single tpc
@@ -194,14 +200,14 @@ def __init__(self, margin, cathode_margin=None, detector=None,
# Initialize the geometry
if mode != 'meta':
self.use_meta = False
- self.geo = Geometry(detector, boundary_file)
+ self.geo = Geometry(detector, geometry_file)
self.geo.define_containment_volumes(margin, cathode_margin, mode)
else:
- assert detector is None and boundary_file is None, (
+ assert detector is None and geometry_file is None, (
"When using `meta` to check containment, must not "
"provide geometry information.")
- self.keys['meta'] = True
+ self.update_keys({'meta': True})
self.use_meta = True
self.margin = margin
diff --git a/spine/post/reco/kinematics.py b/spine/post/reco/kinematics.py
index d79a36ad..33ea323e 100644
--- a/spine/post/reco/kinematics.py
+++ b/spine/post/reco/kinematics.py
@@ -18,8 +18,12 @@ class ParticleShapeLogicProcessor(PostBase):
- If a particle has track shape, it can only have a track PID
- If a particle has delta/michel shape, it can only be a secondary electron
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'shape_logic'
- aliases = ['enforce_particle_semantics']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('enforce_particle_semantics',)
def __init__(self, enforce_pid=True, enforce_primary=True):
"""Store information about which particle properties should
@@ -78,8 +82,12 @@ class ParticleThresholdProcessor(PostBase):
"""Adjust the particle PID and primary properties according to customizable
thresholds and priority orderings.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'particle_threshold'
- aliases = ['adjust_particle_properties']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('adjust_particle_properties',)
def __init__(self, shower_pid_thresholds=None, track_pid_thresholds=None,
primary_threshold=None):
@@ -156,8 +164,12 @@ class InteractionTopologyProcessor(PostBase):
"""Adjust the topology of interactions by applying thresholds on the
minimum kinetic energy of particles.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'topology_threshold'
- aliases = ['adjust_interaction_topology']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('adjust_interaction_topology',)
def __init__(self, ke_thresholds, reco_ke_mode='ke',
truth_ke_mode='energy_deposit', run_mode='both'):
diff --git a/spine/post/reco/label.py b/spine/post/reco/label.py
index bccc83ee..f7a61a44 100644
--- a/spine/post/reco/label.py
+++ b/spine/post/reco/label.py
@@ -16,8 +16,12 @@ class ChildrenProcessor(PostBase):
"""Count the number of children of a given particle, using the particle
hierarchy information from :class:`ParticleGraphParser`.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'children_count'
- aliases = ['count_children']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('count_children',)
def __init__(self, mode='shape', obj_type='particle'):
"""Initialize the children counting parameters.
diff --git a/spine/post/reco/mcs.py b/spine/post/reco/mcs.py
index def28a95..75ec3079 100644
--- a/spine/post/reco/mcs.py
+++ b/spine/post/reco/mcs.py
@@ -16,13 +16,18 @@ class MCSEnergyProcessor(PostBase):
"""Reconstruct the kinetic energy of tracks based on their Multiple-Coulomb
scattering (MCS) angles while passing through liquid argon.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'mcs_ke'
- aliases = ['reconstruct_mcs_energy']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('reconstruct_mcs_energy',)
def __init__(self, tracking_mode='bin_pca', segment_length=5.0,
split_angle=False, res_a=0.25, res_b=1.25,
- include_pids=[MUON_PID, PION_PID, PROT_PID, KAON_PID],
- only_uncontained=False, obj_type='particle', run_mode='both',
+ include_pids=(MUON_PID, PION_PID, PROT_PID, KAON_PID),
+ fill_per_pid=False, only_uncontained=False,
+ obj_type='particle', run_mode='both',
truth_point_mode='points', **kwargs):
"""Store the necessary attributes to do MCS-based estimations.
@@ -41,6 +46,8 @@ def __init__(self, tracking_mode='bin_pca', segment_length=5.0,
Parameter b in the a/dx^b which models the angular uncertainty
include_pids : list, default [2, 3, 4, 5]
Particle species to compute the kinetic energy for
+ fill_per_pid : bool, default False
+ If `True`, compute the MCS KE estimate under all PID assumptions
only_uncontained : bool, default False
Only run the algorithm on particles that are not contained
**kwargs : dict, optiona
@@ -51,6 +58,7 @@ def __init__(self, tracking_mode='bin_pca', segment_length=5.0,
# Store the general parameters
self.include_pids = include_pids
+ self.fill_per_pid = fill_per_pid
self.only_uncontained = only_uncontained
# Store the tracking parameters
@@ -108,3 +116,14 @@ def process(self, data):
obj.mcs_ke = mcs_fit(
theta, mass, self.segment_length, 1,
self.split_angle, self.res_a, self.res_b)
+
+ # If requested, convert the KE to other PID hypotheses
+ if self.fill_per_pid:
+ # Compute the momentum (what MCS is truly sensitive to)
+ mom = np.sqrt(obj.mcs_ke**2 + 2*mass*obj.mcs_ke)
+
+ # For each PID, convert back to KE
+ for pid in self.include_pids:
+ mass = PID_MASSES[pid]
+ ke = np.sqrt(mom**2 + mass**2) - mass
+ obj.mcs_ke_per_pid[pid] = ke
diff --git a/spine/post/reco/points.py b/spine/post/reco/points.py
index 90aaae23..f6e9bf23 100644
--- a/spine/post/reco/points.py
+++ b/spine/post/reco/points.py
@@ -11,9 +11,15 @@
class TrackExtremaProcessor(PostBase):
"""Assigns track start point and end point."""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'track_extrema'
- aliases = ['assign_track_extrema']
- keys = {'ppn_candidates': False}
+
+ # Alternative allowed names of the post-processor
+ aliases = ('assign_track_extrema',)
+
+ # Set of data keys needed for this post-processor to operate
+ _keys = (('ppn_candidates', False),)
def __init__(self, method='local', obj_type='particle', **kwargs):
"""Initialize the track end point assignment parameters.
@@ -60,7 +66,7 @@ def process(self, data):
elif self.method == 'ppn':
assert 'ppn_candidates' in data, (
- "Must run the `ppn_points` post-processor "
+ "Must run the `ppn` post-processor "
"before using PPN predictions to assign extrema.")
flip = not check_track_orientation_ppn(
part.start_point, part.end_point,
diff --git a/spine/post/reco/ppn.py b/spine/post/reco/ppn.py
index dba40041..bced0250 100644
--- a/spine/post/reco/ppn.py
+++ b/spine/post/reco/ppn.py
@@ -23,10 +23,18 @@ class PPNProcessor(PostBase):
If `restrict_shape` is `True`, points will be matched to particles with
the same predicted semantic type only.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'ppn'
- aliases = ['get_ppn_candidates']
- keys = {'segmentation': True, 'ppn_points': True, 'ppn_coords': True,
- 'ppn_masks': True, 'ppn_classify_endpoints': False}
+
+ # Alternative allowed names of the post-processor
+ aliases = ('get_ppn_candidates',)
+
+ # Set of data keys needed for this post-processor to operate
+ _keys = (
+ ('segmentation', True), ('ppn_points', True), ('ppn_coords', True),
+ ('ppn_masks', True), ('ppn_classify_endpoints', False)
+ )
def __init__(self, assign_to_particles=False, restrict_shape=False,
match_threshold=2., **ppn_pred_cfg):
diff --git a/spine/post/reco/shower.py b/spine/post/reco/shower.py
index a56cfcb7..7cf8b045 100644
--- a/spine/post/reco/shower.py
+++ b/spine/post/reco/shower.py
@@ -1,13 +1,15 @@
import numpy as np
-from spine.utils.globals import (PHOT_PID, PROT_PID, PION_PID, ELEC_PID)
+from spine.utils.globals import (PHOT_PID, PROT_PID, PION_PID, ELEC_PID,
+ SHOWR_SHP, TRACK_SHP)
from spine.post.base import PostBase
from scipy.spatial.distance import cdist
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
-__all__ = ['ConversionDistanceProcessor', 'ShowerMultiArmCheck']
+__all__ = ['ConversionDistanceProcessor', 'ShowerMultiArmCheck',
+ 'ShowerStartpointCorrectionProcessor']
class ConversionDistanceProcessor(PostBase):
@@ -17,8 +19,12 @@ class ConversionDistanceProcessor(PostBase):
NOTE: This processor can only change reco electron shower pid to
photon pid depending on the distance threshold.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'shower_conversion_distance'
- aliases = ['shower_separation_processor']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('shower_separation_processor',)
def __init__(self, threshold=-1.0, vertex_mode='vertex'):
"""Specify the EM shower conversion distance threshold and
@@ -69,6 +75,7 @@ def process(self, data):
criterion = self.convdist_vertex_startpoint(ia, p)
else:
raise ValueError('Invalid point mode')
+ p.vertex_distance = criterion
if criterion >= self.threshold:
p.pid = PHOT_PID
@@ -165,15 +172,19 @@ class ShowerMultiArmCheck(PostBase):
NOTE: This processor can only change reco electron shower pid to
photon pid depending on the angle threshold.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'shower_multi_arm_check'
- aliases = ['shower_multi_arm']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('shower_multi_arm',)
- def __init__(self, threshold=0.25, min_samples=20, eps=0.02):
+ def __init__(self, threshold=70, min_samples=20, eps=0.02):
"""Specify the threshold for the number of arms of showers.
Parameters
----------
- threshold : float, default 0.25
+ threshold : float, default 70 (deg)
If the electron shower's leading and subleading angle are
separated by more than this, the shower is considered to be
invalid and its PID will be changed to PHOT_PID.
@@ -206,6 +217,7 @@ def process(self, data):
angle = self.compute_angular_criterion(p, ia.vertex,
eps=self.eps,
min_samples=self.min_samples)
+ p.shower_split_angle = angle
if angle > self.threshold:
p.pid = PHOT_PID
@@ -230,7 +242,7 @@ def compute_angular_criterion(p, vertex, eps, min_samples):
-------
max_angle : float
Maximum angle between the mean cluster direction vectors
- of the shower points.
+ of the shower points (degrees)
"""
points = p.points
depositions = p.depositions
@@ -272,6 +284,79 @@ def compute_angular_criterion(p, vertex, eps, min_samples):
vecs = np.vstack(vecs)
cos_dist = cosine_similarity(vecs)
# max_angle ranges from 0 (parallel) to 2 (antiparallel)
- max_angle = (np.abs(1.0 - cos_dist)).max()
+ max_angle = np.clip((1.0 - cos_dist).max(), a_min=0, a_max=2)
+ max_angle_deg = np.rad2deg(np.arccos(1 - max_angle))
# counts = counts[1:]
- return max_angle
\ No newline at end of file
+ return max_angle_deg
+
+
+class ShowerStartpointCorrectionProcessor(PostBase):
+ """Correct the startpoint of the primary EM shower by
+ finding the closest point to the vertex.
+ """
+
+ # Name of the post-processor (as specified in the configuration)
+ name = 'showerstart_correction_processor'
+
+ # Alternative allowed names of the post-processor
+ aliases = ('reco_shower_startpoint_correction',)
+
+ def __init__(self, threshold=1.0):
+ """Specify the EM shower conversion distance threshold and
+ the type of vertex to use for the distance calculation.
+
+ Parameters
+ ----------
+ threshold : float, default -1.0
+ If EM shower has a conversion distance greater than this,
+ its PID will be changed to PHOT_PID.
+ """
+ super().__init__('interaction', 'reco')
+ self.threshold = threshold
+
+ def process(self, data):
+ """Update the shower startpoint using the closest point to the vertex.
+
+ Parameters
+ ----------
+ data : dict
+ Dictionaries of data products
+ """
+ # Loop over the reco interactions
+ for ia in data['reco_interactions']:
+ vertex = ia.vertex
+ for p in ia.particles:
+ if (p.shape == SHOWR_SHP) and (p.is_primary):
+ new_point = self.correct_shower_startpoint(p, ia)
+ p.start_point = new_point
+
+
+ @staticmethod
+ def correct_shower_startpoint(shower_p, ia):
+ """Function to correct the shower startpoint by finding the closest
+ point to the vertex.
+
+ Parameters
+ ----------
+ shower_p : RecoParticle
+ Primary EM shower to correct the startpoint.
+ ia : RecoInteraction
+ Reco interaction to use as the vertex estimate.
+
+ Returns
+ -------
+ guess : np.ndarray
+ (3, ) array of the corrected shower startpoint.
+ """
+ track_points = [p.points for p in ia.particles if p.shape == TRACK_SHP and p.is_primary]
+ if track_points == []:
+ return shower_p.start_point
+
+ track_points = np.vstack(track_points)
+ dist = cdist(shower_p.points.reshape(-1, 3), track_points.reshape(-1, 3))
+ min_dist = dist.min()
+ closest_idx, _ = np.where(dist == min_dist)
+ if len(closest_idx) == 0:
+ return shower_p.start_point
+ guess = shower_p.points[closest_idx[0]]
+ return guess
diff --git a/spine/post/reco/tracking.py b/spine/post/reco/tracking.py
index 55b5c426..db97d44c 100644
--- a/spine/post/reco/tracking.py
+++ b/spine/post/reco/tracking.py
@@ -19,12 +19,16 @@ class CSDAEnergyProcessor(PostBase):
"""Reconstruct the kinetic energy of tracks based on their range in liquid
argon using the continuous slowing down approximation (CSDA).
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'csda_ke'
- aliases = ['reconstruct_csda_energy']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('reconstruct_csda_energy',)
def __init__(self, tracking_mode='step_next',
- include_pids=[MUON_PID, PION_PID, PROT_PID, KAON_PID],
- obj_type='particle', run_mode='both',
+ include_pids=(MUON_PID, PION_PID, PROT_PID, KAON_PID),
+ fill_per_pid=False, obj_type='particle', run_mode='both',
truth_point_mode='points', **kwargs):
"""Store the necessary attributes to do CSDA range-based estimation.
@@ -35,6 +39,8 @@ def __init__(self, tracking_mode='step_next',
'step', 'step_next', 'bin_pca' or 'spline')
include_pids : list, default [2, 3, 4, 5]
Particle species to compute the kinetic energy for
+ fill_per_pid : bool, default False
+ If `True`, compute the CSDA KE estimate under all PID assumptions
**kwargs : dict, optional
Additional arguments to pass to the tracking algorithm
"""
@@ -43,6 +49,7 @@ def __init__(self, tracking_mode='step_next',
# Fetch the functions that map the range to a KE
self.include_pids = include_pids
+ self.fill_per_pid = fill_per_pid
self.splines = {
ptype: csda_table_spline(ptype) for ptype in include_pids}
@@ -88,16 +95,26 @@ def process(self, data):
# Compute the CSDA kinetic energy
if length > 0.:
obj.csda_ke = self.splines[obj.pid](length).item()
+ if self.fill_per_pid:
+ for pid in self.include_pids:
+ obj.csda_ke_per_pid[pid] = self.splines[pid](length).item()
else:
obj.csda_ke = 0.
+ if self.fill_per_pid:
+ for pid in self.include_pids:
+ obj.csda_ke_per_pid[pid] = 0.
class TrackValidityProcessor(PostBase):
"""Check if track is valid based on the proximity to reconstructed vertex.
Can also reject small tracks that are close to the vertex (optional).
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'track_validity'
- aliases = ['track_validity_processor']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('track_validity_processor',)
def __init__(self, threshold=3., ke_threshold=50.,
check_small_track=False, **kwargs):
@@ -135,7 +152,8 @@ def process(self, data):
if p.shape == TRACK_SHP and p.is_primary:
# Check vertex attachment
dist = np.linalg.norm(p.points - ia.vertex, axis=1)
- if dist.min() > self.threshold:
+ p.vertex_distance = dist.min()
+ if p.vertex_distance > self.threshold:
p.is_primary = False
# Check if track is small and within radius from vertex
if self.check_small_track:
@@ -148,8 +166,12 @@ def process(self, data):
class TrackShowerMergerProcessor(PostBase):
"""Merge tracks into showers based on a set of selection criteria.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'merge_track_to_shower'
- aliases = ['track_shower_merger']
+
+ # Alternative allowed names of the post-processor
+ aliases = ('track_shower_merger',)
def __init__(self, angle_threshold=10, adjacency_threshold=0.5,
dedx_threshold=-1, track_length_limit=50, **kwargs):
diff --git a/spine/post/reco/vertex.py b/spine/post/reco/vertex.py
index 322ccd9d..cfb4a89e 100644
--- a/spine/post/reco/vertex.py
+++ b/spine/post/reco/vertex.py
@@ -12,10 +12,14 @@
class VertexProcessor(PostBase):
"""Reconstruct one vertex for each interaction in the provided list."""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'vertex'
- aliases = ['reconstruct_vertex']
- def __init__(self, include_shapes=[SHOWR_SHP, TRACK_SHP],
+ # Alternative allowed names of the post-processor
+ aliases = ('reconstruct_vertex',)
+
+ def __init__(self, include_shapes=(SHOWR_SHP, TRACK_SHP),
use_primaries=True, update_primaries=False,
anchor_vertex=True, touching_threshold=2.0,
angle_threshold=0.3, run_mode='both',
diff --git a/spine/post/trigger/trigger.py b/spine/post/trigger/trigger.py
index c4135ec6..60474fb3 100644
--- a/spine/post/trigger/trigger.py
+++ b/spine/post/trigger/trigger.py
@@ -13,9 +13,15 @@ class TriggerProcessor(PostBase):
"""Parses trigger information from a CSV file and adds a new trigger_info
data product to the data dictionary.
"""
+
+ # Name of the post-processor (as specified in the configuration)
name = 'trigger'
- aliases = ['parse_trigger']
- keys = {'run_info': True}
+
+ # Alternative allowed names of the post-processor
+ aliases = ('parse_trigger',)
+
+ # Set of data keys needed for this post-processor to operate
+ _keys = (('run_info', True),)
def __init__(self, file_path, correct_flash_times=True, flash_keys=None,
flash_time_corr_us=4):
@@ -53,7 +59,7 @@ def __init__(self, file_path, correct_flash_times=True, flash_keys=None,
self.flash_time_corr_us = flash_time_corr_us
# Add flash keys to the required data products
- self.keys.update({k: True for k in self.flash_keys})
+ self.update_keys({k: True for k in self.flash_keys})
def process(self, data):
"""Parse the trigger information of one entry.
diff --git a/spine/utils/calib/lifetime.py b/spine/utils/calib/lifetime.py
index 59d3f5c4..a5b67149 100644
--- a/spine/utils/calib/lifetime.py
+++ b/spine/utils/calib/lifetime.py
@@ -13,12 +13,8 @@ class LifetimeCalibrator:
"""
name = 'lifetime'
- def __init__(self,
- num_tpcs,
- lifetime = None,
- driftv = None,
- lifetime_db = None,
- driftv_db = None):
+ def __init__(self, num_tpcs, lifetime=None, driftv=None, lifetime_db=None,
+ driftv_db=None):
"""Load the information needed to make a lifetime correction.
Parameters
@@ -108,12 +104,14 @@ def process(self, points, values, geo, tpc_id, run_id=None):
driftv = self.driftv[run_id]
# Compute the distance to the anode plane
- m, t = tpc_id // geo.num_tpcs_per_module, tpc_id % geo.num_tpcs_per_module
- daxis, position = geo.anodes[m, t]
+ m = tpc_id // geo.tpc.num_chambers_per_module
+ t = tpc_id % geo.tpc.num_chambers_per_module
+ daxis = geo.tpc[m, t].drift_axis
+ position = geo.tpc[m, t].anode_pos
drifts = np.abs(points[:, daxis] - position)
# Clip down to the physical range of possible drift distances
- max_drift = geo.ranges[m, t][daxis]
+ max_drift = geo.tpc[m, t].dimensions[daxis]
drifts = np.clip(drifts, 0., max_drift)
# Convert the drift distances to correction factors
diff --git a/spine/utils/calib/manager.py b/spine/utils/calib/manager.py
index 85c49eee..abf69659 100644
--- a/spine/utils/calib/manager.py
+++ b/spine/utils/calib/manager.py
@@ -40,9 +40,9 @@ def __init__(self, geometry, **cfg):
# Add necessary geometry information
if key != 'recombination':
- value['num_tpcs'] = self.geo.num_tpcs
+ value['num_tpcs'] = self.geo.tpc.num_chambers
else:
- value['drift_dir'] = self.geo.drift_dirs[0, 0]
+ value['drift_dir'] = self.geo.tpc[0, 0].drift_dir
# Append
self.modules[key] = calibrator_factory(key, value)
@@ -78,12 +78,13 @@ def __call__(self, points, values, sources=None, run_id=None,
# Create a mask for each of the TPC volume in the detector
if sources is not None:
tpc_indexes = []
- for t in range(self.geo.num_tpcs):
- # Get the set of points associated with this TPC
- module_id = t//self.geo.num_tpcs_per_module
- tpc_id = t%self.geo.num_tpcs_per_module
- tpc_index = self.geo.get_tpc_index(sources, module_id, tpc_id)
- tpc_indexes.append(tpc_index)
+ for module_id in range(self.geo.tpc.num_modules):
+ for tpc_id in range(self.geo.tpc.num_chambers_per_module):
+ # Get the set of points associated with this TPC
+ tpc_index = self.geo.get_volume_index(
+ sources, module_id, tpc_id)
+ tpc_indexes.append(tpc_index)
+
else:
assert points is not None, (
"If sources are not given, must provide points instead.")
@@ -91,7 +92,7 @@ def __call__(self, points, values, sources=None, run_id=None,
# Loop over the TPCs, apply the relevant calibration corrections
new_values = np.copy(values)
- for t in range(self.geo.num_tpcs):
+ for t in range(self.geo.tpc.num_chambers):
# Restrict to the TPC of interest
if len(tpc_indexes[t]) == 0:
continue
diff --git a/spine/utils/calib/recombination.py b/spine/utils/calib/recombination.py
index 0098ec71..045e4c05 100644
--- a/spine/utils/calib/recombination.py
+++ b/spine/utils/calib/recombination.py
@@ -56,6 +56,7 @@ def __init__(self, efield, drift_dir, model='mbox', birks_a=0.800,
self.model = 'birks'
self.a = birks_a
self.k = birks_k/efield/LAR_DENSITY # cm/MeV
+
elif model in ['mbox', 'mbox_ell']:
self.model = 'mbox'
self.alpha = mbox_alpha
@@ -64,6 +65,7 @@ def __init__(self, efield, drift_dir, model='mbox', birks_a=0.800,
if model == 'mbox_ell':
self.use_angles = True
self.r = mbox_ell_r
+
else:
raise ValueError(
f"Recombination model not recognized: {model}. "
diff --git a/spine/utils/geo/__init__.py b/spine/utils/geo/__init__.py
index 36fb9df2..421fad46 100644
--- a/spine/utils/geo/__init__.py
+++ b/spine/utils/geo/__init__.py
@@ -1,3 +1,3 @@
-"""Geometry module."""
+"""Detector geometry module."""
-from .base import Geometry
+from .manager import Geometry
diff --git a/spine/utils/geo/base.py b/spine/utils/geo/base.py
deleted file mode 100644
index 4d3e7bc9..00000000
--- a/spine/utils/geo/base.py
+++ /dev/null
@@ -1,746 +0,0 @@
-"""Module with a general-purpose geometry class."""
-
-import os
-import pathlib
-from dataclasses import dataclass
-
-import numpy as np
-
-
-@dataclass
-class Geometry:
- """Handles all geometry functions for a collection of box-shaped TPCs.
-
- Attributes
- ----------
- boundaries : np.ndarray
- (N_m, N_t, D, 2) Array of TPC boundaries
- - N_m is the number of modules (or cryostat) in the detector
- - N_t is the number of TPCs per module (or cryostat)
- - D is the number of dimension (always 3)
- - 2 corresponds to the lower/upper boundaries along that axis
- sources : np.ndarray
- (N, m, N_t, N_s, 2) Array of contributing logical TPCs to each TPC
- - N_s is the number of contributing logical TPCs to a geometry TPC
- - 2 corresponds to the [module ID, tpc ID] of a contributing pair
- If this is not specified, the assumption is that there is an exact
- match between logical and physical TPCs (as specified by boundaries)
- opdets : np.ndarray
- (N_m[, N_t], N_p, 3) Array of optical detector locations
- - N_p is the number of optical detectors per module or TPC
- - 3 corresponds to the [x, y, z] optical detector coordinates
- ranges : np.ndarray
- (N_m, N_t, D) Array of TPC ranges
- tpcs : np.ndarray
- (N_m*N_t, D, 2) Array of individual TPC boundaries
- num_tpcs : int
- Number of TPC volumes in the detector, N_m*N_t
- modules : np.ndarray
- (N_m, D, 2) Array of detector module boundaries
- num_modules : int
- Number of modules in the detector, N_m
- detector : np.ndarray
- (D, 2) Boundaries of the detector as a whole
- centers : np.ndarray
- (N_m, 3) Centers of the detector modules
- anodes : np.ndarray
- (N_m, N_t, 2) List of (axis, position) pairs of each anode
- cathodes : np.ndarray
- (N_m, 2) List of (axis, position) pairs of each cathode
- anode_wall_ids : np.ndarray
- (N_m, N_t, 2) Maps each (module, tpc) pair onto a specific anode
- cathode_wall_ids : np.ndarray
- (N_m, N_t, 2) Maps each (module, tpc) pair onto a specific cathode
- drift_dirs : np.ndarray
- (N_m, N_t, D) Drift direction in each TPC
- """
- boundaries: np.ndarray
- modules: np.ndarray
- detector: np.ndarray
- sources: np.ndarray
- opdets: np.ndarray
- centers: np.ndarray
- anodes: np.ndarray
- cathodes: np.ndarray
- anode_wall_ids: np.ndarray
- cathode_wall_ids: np.ndarray
- drift_dirs: np.ndarray
-
- def __init__(self, detector=None, boundaries=None,
- sources=None, opdets=None):
- """Initializes a detector geometry object.
-
- The boundary file is a (N_m, N_t, D, 2) np.ndarray where:
- - N_m is the number of modules (or cryostat) in the detector
- - N_t is the number of TPCs per module (or cryostat)
- - D is the number of dimension (always 3)
- - 2 corresponds to the lower/upper boundaries along that axis
-
- The sources file is a (N_m, N_t, N_s, 2) np.ndarray where:
- - N_s is the number of contributing logical TPCs to a geometry TPC
- - 2 corresponds to the [module ID, tpc ID] of a contributing pair
-
- The opdets file is a (N_m[, N_t], N_p, 3) np.ndarray where:
- - N_p is the number of optical detectors per module or TPC
- - 3 corresponds to the [x, y, z] optical detector coordinates
-
- Parameters
- ----------
- detector : str, optional
- Name of a recognized detector to the geometry from
- boundaries : str, optional
- Path to a `.npy` boundary file to load the boundaries from
- sources : str, optional
- Path to a `.npy` source file to load the sources from
- opdets : str, optional
- Path to a `.npy` opdet file to load the opdet coordinates from
- """
- # If the boundary file is not provided, fetch a default boundary file
- assert detector is not None or boundaries is not None, (
- "Must minimally provide a detector boundary file source")
- if boundaries is None:
- path = pathlib.Path(__file__).parent
- boundaries = os.path.join(path, 'source',
- f'{detector.lower()}_boundaries.npy')
-
- # If the source file is not a file, fetch the default source file
- if sources is None and detector is not None:
- path = pathlib.Path(__file__).parent
- file_path = os.path.join(path, 'source',
- f'{detector.lower()}_sources.npy')
- if os.path.isfile(file_path):
- sources = file_path
-
- # If the opdets file is not a file, fetch the default opdets file
- if opdets is None and detector is not None:
- path = pathlib.Path(__file__).parent
- file_path = os.path.join(path, 'source',
- f'{detector.lower()}_opdets.npy')
- if os.path.isfile(file_path):
- opdets = file_path
-
- # Check that the boundary file exists, load it
- if not os.path.isfile(boundaries):
- raise FileNotFoundError("Could not find boundary "
- f"file: {boundaries}")
- self.boundaries = np.load(boundaries)
-
- # Check that the sources file exists, load it
- self.sources = None
- if sources is not None:
- if not os.path.isfile(sources):
- raise FileNotFoundError("Could not find sources "
- f"file: {sources}")
- self.sources = np.load(sources)
- assert self.sources.shape[:2] == self.boundaries.shape[:2], (
- "There should be one list of sources per TPC")
- else:
- # Match the source of each TPC in order of (module ID, tpc ID)
- shape = (*self.boundaries.shape[:2], 1, 2)
- num_tpcs = shape[0]*shape[1]
- module_ids = np.arange(num_tpcs)//self.num_tpcs_per_module
- tpc_ids = np.arange(num_tpcs)%self.num_tpcs_per_module
- self.sources = np.vstack((module_ids, tpc_ids)).T.reshape(shape)
-
- # Check that the optical detector file exists, load it
- self.opdets = None
- if opdets is not None:
- if not os.path.isfile(opdets):
- raise FileNotFoundError("Could not find opdets "
- f"file: {opdets}")
- self.opdets = np.load(opdets)
- assert (self.opdets.shape[:2] == self.boundaries.shape[:2] or
- (self.opdets.shape[0] == self.boundaries.shape[0] and
- len(self.opdets.shape) == 3)), (
- "There should be one list of opdets per module or TPC")
-
- # Build a list of modules
- self.build_modules()
-
- # Build an all-encompassing detector object
- self.build_detector()
-
- # Build cathodes/anodes if the modules share a central cathode
- if self.boundaries.shape[1] == 2:
- self.build_planes()
-
- # Containment volumes to be defined by the user
- self._cont_volumes = None
- self._cont_use_source = False
-
- @property
- def tpcs(self):
- """Single list of all TPCs.
-
- Returns
- -------
- np.ndarray
- (N_m*N_t, D, 2) Array of TPC boundaries
- """
- return self.boundaries.reshape(-1, 3, 2)
-
- @property
- def ranges(self):
- """Range of each TPC.
-
- Returns
- -------
- np.ndarray
- (N_m, N_t, D) Array of TPC ranges
- """
- return np.abs(self.boundaries[..., 1] - self.boundaries[...,0])
-
- @property
- def num_tpcs(self):
- """Number of TPC volumes.
-
- Returns
- -------
- int
- Number of TPC volumes, N_m*N_t
- """
- return len(self.tpcs)
-
- @property
- def num_tpcs_per_module(self):
- """Number of TPC volumes per module.
-
- Returns
- -------
- int
- Number of TPC volumes per module, N_t
- """
- return self.boundaries.shape[1]
-
- @property
- def num_modules(self):
- """Number of detector modules.
-
- Returns
- -------
- int
- Number of detector modules, N_m
- """
- return len(self.modules)
-
- def build_modules(self):
- """Converts the list of boundaries of TPCs that make up the modules into
- a list of boundaries that encompass each module. Also store the center
- of each module and the total number of moudules.
- """
- self.modules = np.empty((len(self.boundaries), 3, 2))
- self.centers = np.empty((len(self.boundaries), 3))
- for m, module in enumerate(self.boundaries):
- self.modules[m] = self.merge_volumes(module)
- self.centers[m] = np.mean(self.modules[m], axis=1)
-
- def build_detector(self):
- """Converts the list of boundaries of TPCs that make up the detector
- into a single set of overall detector boundaries.
- """
- self.detector = self.merge_volumes(self.tpcs)
-
- def build_planes(self):
- """Converts the list of boundaries of TPCs that make up the modules and
- tpcs into a list of cathode plane positions for each module and anode
- plane positions for each TPC. The cathode/anode positions are expressed
- as a simple number pair [axis, position] with axis the drift axis and
- position the cathode position along that axis.
-
- Also stores a [axis, side] pair for each TPC which tells which of the
- walls of the TPCs is the cathode wall
- """
- tpc_shape = self.boundaries.shape[:2]
- self.anodes = np.empty(tpc_shape, dtype = object)
- self.cathodes = np.empty(tpc_shape[0], dtype = object)
- self.drift_dirs = np.empty((*tpc_shape, 3))
- self.cathode_wall_ids = np.empty((*tpc_shape, 2), dtype = np.int32)
- self.anode_wall_ids = np.empty((*tpc_shape, 2), dtype = np.int32)
- for m, module in enumerate(self.boundaries):
- # Check that the module is central-cathode style
- assert len(module) == 2, (
- "A module with < 2 TPCs has no central cathode.")
-
- # Identify the drift axis
- centers = np.mean(module, axis=-1)
- drift_dir = centers[1] - centers[0]
- drift_dir /= np.linalg.norm(drift_dir)
- axis = np.where(drift_dir)[0]
- assert len(axis) == 1, (
- "The drift direction is not aligned with an axis, abort.")
- axis = axis[0]
-
- # Store the cathode position
- midpoint = np.sum(centers, axis=0)/2
- self.cathodes[m] = [axis, midpoint[axis]]
-
- # Store the wall ID of each TPC that makes up the module
- for t, tpc in enumerate(module):
- # Store which side the anode/cathode are on
- side = int(centers[t][axis] - midpoint[axis] < 0.)
- self.cathode_wall_ids[m, t] = [axis, side]
- self.anode_wall_ids[m, t] = [axis, 1-side]
-
- # Store the position of the anode for each TPC
- anode_pos = self.boundaries[m, t, axis, 1-side]
- self.anodes[m, t] = [axis, anode_pos]
-
- # Store the drift direction for each TPC
- self.drift_dirs[m, t] = (-1)**side * drift_dir
-
- def get_contributors(self, sources):
- """Gets the list of [module ID, tpc ID] pairs that contributed to a
- particle or interaction object, as defined in this geometry.
-
- Parameters
- ----------
- sources : np.ndarray
- (N, 2) Array of [module ID, tpc ID] pairs, one per voxel
-
- Returns
- -------
- List[np.ndarray]
- (2, N_t) Pair of arrays: the first contains the list of
- contributing modules, the second of contributing tpcs.
- """
- sources = np.unique(sources, axis=0)
- contributor_mask = np.zeros(self.boundaries.shape[:2], dtype=bool)
- for m, module_source in enumerate(self.sources):
- for t, tpc_source in enumerate(module_source):
- for source in sources:
- if (tpc_source == source).all(axis=-1).any(axis=-1):
- contributor_mask[m, t] = True
- break
-
- return np.where(contributor_mask)
-
- def get_tpc_index(self, sources, module_id, tpc_id):
- """Gets the list of indices of points that belong to a specify
- [module ID, tpc ID] pair.
-
- Parameters
- ----------
- sources : np.ndarray
- (S, 2) : List of [module ID, tpc ID] pairs that created
- the point cloud (as defined upstream)
- module_id : int
- ID of the module
- tpc_id : int
- ID of the TPC within the module
-
- Returns
- -------
- np.ndarray
- (N) Index of points that belong to that TPC
- """
- mask = np.zeros(len(sources), dtype=bool)
- for source in self.sources[module_id, tpc_id]:
- mask |= (sources == source).all(axis=-1)
-
- return np.where(mask)[0]
-
- def get_closest_tpc_indexes(self, points):
- """For each TPC, get the list of points that live closer to it
- than any other TPC in the detector.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) Set of point coordinates
-
- Returns
- -------
- List[np.ndarray]
- List of index of points that belong to each TPC
- """
- # Compute the distance from the points to each TPC
- distances = np.empty((self.num_tpcs, len(points)))
- for t in range(self.num_tpcs):
- module_id = t//self.num_tpcs_per_module
- tpc_id = t%self.num_tpcs_per_module
- offsets = self.get_tpc_offsets(points, module_id, tpc_id)
- distances[t] = np.linalg.norm(offsets, axis=1)
-
- # For each TPC, append the list of point indices associated with it
- tpc_indexes = []
- argmins = np.argmin(distances, axis=0)
- for t in range(self.num_tpcs):
- tpc_indexes.append(np.where(argmins == t)[0])
-
- return tpc_indexes
-
- def get_closest_module(self, points):
- """For each point, find the ID of the closest module.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) Set of point coordinates
-
- Returns
- -------
- np.ndarray
- (N) List of module indexes, one per input point
- """
- module_ids = np.empty(len(points), dtype = np.int32)
- for module_id, c in enumerate(self.centers):
- # Find out the boundaries of the volume closest to this module
- dists = self.centers - c
- lower_pad = np.zeros(dists.shape)
- upper_pad = np.zeros(dists.shape)
- lower_pad[dists >= 0], upper_pad[dists <= 0] = np.inf, np.inf
- lower = c + np.max(dists - lower_pad, axis=0) / 2
- upper = c + np.min(dists + upper_pad, axis=0) / 2
-
- # Assign all points within those boundaries to this module
- mask = np.all(points > lower, axis = 1) \
- & np.all(points < upper, axis = 1)
- module_ids[mask] = module_id
-
- return module_ids
-
- def get_closest_module_indexes(self, points):
- """For each module, get the list of points that live closer to it
- than any other module in the detector.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) Set of point coordinates
-
- Returns
- -------
- List[np.ndarray]
- List of index of points that belong to each module
- """
- # For each module, append the list of point indices associated with it
- module_ids = self.get_closest_module(points)
- module_indexes = []
- for m in range(self.num_modules):
- module_indexes.append(np.where(module_ids == m)[0])
-
- return module_indexes
-
- def get_tpc_offsets(self, points, module_id, tpc_id):
- """Compute how far each point is from a TPC volume.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) : Point coordinates
- module_id : int
- ID of the module
- tpc_id : int
- ID of the TPC within the module
-
- Returns
- -------
- np.ndarray
- (N, 3) Offsets w.r.t. to the TPC location
- """
- # Compute the axis-wise distances of each point to each boundary
- tpc = self.boundaries[module_id, tpc_id]
- dists = points[..., None] - tpc
-
- # If a point is between two boundaries, the distance is 0. If it is
- # outside, the distance is that of the closest boundary
- signs = (np.sign(dists[..., 0]) + np.sign(dists[..., 1]))/2
- offsets = signs * np.min(np.abs(dists), axis=-1)
-
- return offsets
-
- def get_min_tpc_offset(self, points, module_id, tpc_id):
- """Get the minimum offset to apply to a point cloud to bring it
- within the boundaries of a TPC.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) : Point coordinates
- module_id : int
- ID of the module
- tpc_id : int
- ID of the TPC within the module
-
- Returns
- -------
- np.ndarray
- (3) Offsets w.r.t. to the TPC location
- """
- # Compute the distance for each point, get the maximum necessary offset
- offsets = self.get_tpc_offsets(points, module_id, tpc_id)
- offsets = offsets[np.argmax(np.abs(offsets), axis=0), np.arange(3)]
-
- return offsets
-
- def translate(self, points, source_id, target_id, factor=None):
- """Moves a point cloud from one module to another one
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) Set of point coordinates
- source_id: int
- Module ID from which to move the point cloud
- target_id : int
- Module ID to which to move the point cloud
- factor : Union[float, np.ndarray], optional
- Multiplicative factor to apply to the offset. This is necessary if
- the points are not expressed in detector coordinates
-
- Returns
- -------
- np.ndarray
- (N, 3) Set of translated point coordinates
- """
- # If the source and target are the same, nothing to do here
- if target_id == source_id:
- return np.copy(points)
-
- # Fetch the inter-module shift
- offset = self.centers[target_id] - self.centers[source_id]
- if factor is not None:
- offset *= factor
-
- # Translate
- return points + offset
-
- def split(self, points, target_id, sources=None, meta=None):
- """Migrate all points to a target module, organize them by module ID.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) Set of point coordinates
- target_id : int
- Module ID to which to move the point cloud
- sources : np.ndarray, optional
- (N, 2) Array of [module ID, tpc ID] pairs, one per voxel
- meta : Meta, optional
- Meta information about the voxelized image. If provided, the
- points are assumed to be provided in voxel coordinates.
-
- Returns
- -------
- np.ndarray
- (N, 3) Shifted set of points
- List[np.ndarray]
- List of index of points that belong to each module
- """
-
- # Check that the target ID exists
- assert target_id > -1 and target_id < len(self.modules), (
- "Target ID should be in [0, N_modules[")
-
- # Get the module ID of each of the input points
- convert = False
- if sources is not None:
- # If provided, simply use that
- module_indexes = []
- for m in range(self.num_modules):
- module_indexes.append(np.where(sources[:, 0] == m)[0])
-
- else:
- # If the points are expressed in pixel coordinates, translate
- convert = meta is not None
- if convert:
- points = meta.to_cm(points, center=True)
-
- # If not provided, find module each point belongs to by proximity
- module_indexes = self.get_closest_module_indexes(points)
-
- # Now shifts all points that are not in the target
- for module_id, module_index in enumerate(module_indexes):
- # If this is the target module, nothing to do here
- if module_id == target_id:
- continue
-
- # Shift the coordinates
- points[module_index] = self.translate(
- points[module_index], module_id, target_id)
-
- # Bring the coordinates back to pixels, if they were shifted
- if convert:
- points = meta.to_px(points, floor=True)
-
- return points, module_indexes
-
- def check_containment(self, points, sources=None,
- allow_multi_module=False, summarize=True):
- """Check whether a point cloud comes within some distance of the
- boundaries of a certain subset of detector volumes, depending on the
- mode.
-
- Parameters
- ----------
- points : np.ndarray
- (N, 3) Set of point coordinates
- sources : np.ndarray, optional
- (S, 2) : List of [module ID, tpc ID] pairs that created the
- point cloud
- allow_multi_module : bool, default `False`
- Whether to allow particles/interactions to span multiple modules
- summarize : bool, default `True`
- If `True`, only returns a single flag for the whole cloud.
- Otherwise, returns a boolean array corresponding to each point.
-
- Returns
- -------
- Union[bool, np.ndarray]
- `True` if the particle is contained, `False` if not
- """
- # If the containment volumes are not defined, throw
- if self._cont_volumes is None:
- raise ValueError("Must call `define_containment_volumes` first.")
-
- # If sources are provided, only consider source volumes
- if self._cont_use_source:
- # Get the contributing TPCs
- assert len(points) == len(sources), (
- "Need to provide sources to make a source-based check.")
- contributors = self.get_contributors(sources)
- if not allow_multi_module and len(np.unique(contributors[0])) > 1:
- return False
-
- # Define the smallest box containing all contributing TPCs
- index = contributors[0] * self.boundaries.shape[1] + contributors[1]
- volume = self.merge_volumes(self._cont_volumes[index])
- volumes = [volume]
- else:
- volumes = self._cont_volumes
-
- # Loop over volumes, make sure the cloud is contained in at least one
- if summarize:
- contained = False
- for v in volumes:
- if (points > v[:, 0]).all() and (points < v[:, 1]).all():
- contained = True
- break
- else:
- contained = np.zeros(len(points), dtype=bool)
- for v in volumes:
- contained |= ((points > v[:, 0]).all(axis=1) &
- (points < v[:, 1]).all(axis=1))
-
- return contained
-
- def define_containment_volumes(self, margin, cathode_margin=None,
- mode ='module'):
- """This function defines a list of volumes to check containment against.
-
- If the containment is checked against a constant volume, it is more
- efficient to call this function once and call `check_containment`
- reapitedly after.
-
- Parameters
- ----------
- margin : Union[float, List[float], np.array]
- Minimum distance from a detector wall to be considered contained:
- - If float: distance buffer is shared between all 6 walls
- - If [x,y,z]: distance is shared between pairs of walls facing
- each other and perpendicular to a shared axis
- - If [[x_low,x_up], [y_low,y_up], [z_low,z_up]]: distance is
- specified individually of each wall.
- cathode_margin : float, optional
- If specified, sets a different margin for the cathode boundaries
- mode : str, default 'module'
- Containement criterion (one of 'global', 'module', 'tpc'):
- - If 'tpc', makes sure it is contained within a single TPC
- - If 'module', makes sure it is contained within a single module
- - If 'detector', makes sure it is contained within in the detector
- - If 'source', use the origin of voxels to determine which TPC(s)
- contributed to them, and define volumes accordingly
- """
- # Translate the margin parameter to a (3,2) matrix
- if np.isscalar(margin):
- margin = np.full((3,2), margin)
- elif len(np.array(margin).shape) == 1:
- assert len(margin) == 3, (
- "Must provide one value per axis.")
- margin = np.repeat([margin], 2, axis=0).T
- else:
- assert np.array(margin).shape == (3,2), (
- "Must provide two values per axis.")
- margin = np.copy(margin)
-
- # Establish the volumes to check against
- self._cont_volumes = []
- if mode in ['tpc', 'source']:
- for m, module in enumerate(self.boundaries):
- for t, tpc in enumerate(module):
- vol = self.adapt_volume(tpc, margin, \
- cathode_margin, m, t)
- self._cont_volumes.append(vol)
- self._cont_use_source = mode == 'source'
- elif mode == 'module':
- for m in self.modules:
- vol = self.adapt_volume(m, margin)
- self._cont_volumes.append(vol)
- self._cont_use_source = False
- elif mode == 'detector':
- vol = self.adapt_volume(self.detector, margin)
- self._cont_volumes.append(vol)
- self._cont_use_source = False
- else:
- raise ValueError(f"Containement check mode not recognized: {mode}.")
-
- self._cont_volumes = np.array(self._cont_volumes)
-
- def adapt_volume(self, ref_volume, margin, cathode_margin=None,
- module_id=None, tpc_id=None):
- """Apply margins from a given volume. Takes care of subtleties
- associated with the cathode, if requested.
-
- Parameters
- ----------
- ref_volume : np.ndarray
- (3, 2) Array of volume boundaries
- margin : np.ndarray
- Minimum distance from a detector wall to be considered contained as
- [[x_low,x_up], [y_low,y_up], [z_low,z_up]], i.e. distance is
- specified individually of each wall.
- cathode_margin : float, optional
- If specified, sets a different margin for the cathode boundaries
- module_id : int, optional
- ID of the module
- tpc_id : int, optional
- ID of the TPC within the module
-
- Returns
- -------
- np.ndarray
- (3, 2) Updated array of volume boundaries
- """
- # Reduce the volume according to the margin
- volume = np.copy(ref_volume)
- volume[:,0] += margin[:,0]
- volume[:,1] -= margin[:,1]
-
- # If a cathode margin is provided, adapt the cathode wall differently
- if cathode_margin is not None:
- axis, side = self.cathode_wall_ids[module_id, tpc_id]
- flip = (-1) ** side
- volume[axis, side] += flip * (cathode_margin - margin[axis, side])
-
- return volume
-
- @staticmethod
- def merge_volumes(volumes):
- """Given a list of volumes and their boundaries, find the smallest box
- that encompass all volumes combined.
-
- Parameters
- ----------
- volumes : np.ndarray
- (N, 3, 2) List of volume boundaries
-
- Returns
- -------
- np.ndarray
- (3, 2) Boundaries of the combined volume
- """
- volume = np.empty((3, 2))
- volume[:,0] = np.min(volumes, axis=0)[:,0]
- volume[:,1] = np.max(volumes, axis=0)[:,1]
-
- return volume
diff --git a/spine/utils/geo/detector/__init__.py b/spine/utils/geo/detector/__init__.py
new file mode 100644
index 00000000..42976bcb
--- /dev/null
+++ b/spine/utils/geo/detector/__init__.py
@@ -0,0 +1,11 @@
+"""Module which holds all detector component geometries.
+
+This includes:
+- :class:`TPCDetector` for a set of organized TPCs
+- :class:`OptDetector` for a set of organized light collection detectors
+- :class:`CRTDetector` for a set of organized CRT planes
+"""
+
+from .tpc import TPCDetector
+from .optical import OptDetector
+from .crt import CRTDetector
diff --git a/spine/utils/geo/detector/base.py b/spine/utils/geo/detector/base.py
new file mode 100644
index 00000000..02f6462f
--- /dev/null
+++ b/spine/utils/geo/detector/base.py
@@ -0,0 +1,82 @@
+"""Basic detector components shared across multiple subsystems.
+
+This currently handles:
+- :class:`Box` which corresponds to box-shaped detector modules.
+"""
+
+from dataclasses import dataclass
+
+import numpy as np
+
+__all__ = ['Box']
+
+
+@dataclass
+class Box:
+ """Class which holds all methods associated with a box-shapes component.
+
+ Attributes
+ ----------
+ boundaries : np.ndarray
+ (3, 2) Box boundaries
+ - 3 is the number of dimensions
+ - 2 corresponds to the lower/upper boundaries along each axis
+ """
+ boundaries: np.ndarray
+
+ def __init__(self, lower, upper):
+ """Initialize the box object.
+
+ Parameters
+ ----------
+ lower : np.ndarray
+ (3) Lower bounds of the box
+ upper : np.ndarray
+ (3) Upper bounds of the box
+ """
+ # Store lower and upper boundaries in one array
+ self.boundaries = np.vstack((lower, upper)).T
+
+ @property
+ def center(self):
+ """Center of the box.
+
+ Returns
+ -------
+ np.ndarray
+ Center of the box
+ """
+ return np.mean(self.boundaries, axis=1)
+
+ @property
+ def lower(self):
+ """Lower bounds of the box.
+
+ Returns
+ -------
+ np.ndarray
+ Lower bounds of the box
+ """
+ return self.boundaries[:, 0]
+
+ @property
+ def upper(self):
+ """Upper bounds of the box.
+
+ Returns
+ -------
+ np.ndarray
+ Upper bounds of the box
+ """
+ return self.boundaries[:, 1]
+
+ @property
+ def dimensions(self):
+ """Dimensions of the box.
+
+ Returns
+ -------
+ np.ndarray
+ Box dimensions
+ """
+ return self.boundaries[:, 1] - self.boundaries[:, 0]
diff --git a/spine/utils/geo/detector/crt.py b/spine/utils/geo/detector/crt.py
new file mode 100644
index 00000000..98828ff3
--- /dev/null
+++ b/spine/utils/geo/detector/crt.py
@@ -0,0 +1,77 @@
+"""CRT detector geometry classes."""
+
+from typing import List
+from dataclasses import dataclass
+
+import numpy as np
+
+__all__ = ['CRTDetector']
+
+
+@dataclass
+class CRTDetector:
+ """Handles all geometry queries for a set of cosmic-ray taggers.
+
+ Attributes
+ ----------
+ positions : np.ndarray
+ (N_c, 3) Location of the center of each of the CRT planes
+ - N_c is the number of CRT planes
+ dimensions : np.ndarray
+ (N_c, 3) Dimensions of each of the CRT planes
+ - N_c is the number of CRT planes
+ norms : np.ndarray
+ (N_c) Axis aligned with the norm of each of the CRT planes
+ - N_c is the number of CRT planes
+ det_ids : Dict[int, int], optional
+ Mapping between the CRT channel and its corresponding detector
+ """
+ positions : np.ndarray
+ dimensions : np.ndarray
+ norms : np.ndarray
+ det_ids : dict = None
+
+ def __init__(self, dimensions, positions, norms, logical_ids=None):
+ """Parse the CRT detector configuration.
+
+ The assumption here is that the CRT detectors collectively cover the
+ entire detector, regardless of TPC/optical modularization.
+
+ Parameters
+ ----------
+ dimensions : List[List[float]]
+ Dimensions of each of the CRT plane
+ positions : List[List[float]]
+ Positions of each of the CRT plane
+ norms : List[int]
+ Axis along which each of the CRT plane norm is pointing
+ logical_ids : List[int], optional
+ Logical index corresponding to each CRT channel. If not specified,
+ it is assumed that there is a one-to-one correspondance between
+ physical and logical CRT planes.
+ """
+ # Check the sanity of the configuration
+ assert len(positions) == len(dimensions), (
+ "Must provide the dimensions of each of the CRT element. "
+ f"Got {len(dimensions)}, but expected {len(positions)}.")
+ assert logical_ids is None or len(logical_ids) == len(positions), (
+ "Must provide the logical ID of each of the CRT element. "
+ f"Got {len(logical_ids)}, but expected {len(positions)}.")
+
+ # Store CRT detector parameters
+ self.positions = np.asarray(positions)
+ self.dimensions = np.asarray(dimensions)
+ self.norms = np.asarray(norms, dtype=int)
+ if logical_ids is not None:
+ self.det_ids = {idx: i for i, idx in enumerate(logical_ids)}
+
+ @property
+ def num_detectors(self):
+ """Returns the number of CRT planes around the detector.
+
+ Returns
+ -------
+ int
+ Number of CRT planes, N_c
+ """
+ return self.positions.shape[0]
diff --git a/spine/utils/geo/detector/optical.py b/spine/utils/geo/detector/optical.py
new file mode 100644
index 00000000..4d3ca614
--- /dev/null
+++ b/spine/utils/geo/detector/optical.py
@@ -0,0 +1,163 @@
+"""Optical detector geometry classes."""
+
+from typing import List
+from dataclasses import dataclass
+
+import numpy as np
+
+__all__ = ['OptDetector']
+
+
+@dataclass
+class OptDetector:
+ """Handles all geometry queries for a set of optical detectors.
+
+ Attributes
+ ----------
+ volume : str
+ The boundaries of each optical volume ('tpc' or 'module'), as defined
+ by the the set of PMTs in each volume
+ positions : np.ndarray
+ (N_v, N_o, 3) Location of the center of each of the optical detectors
+ - N_v is the number of optical volumes
+ - N_o is the number of optical detectors in each volume
+ shape : List[str]
+ (N_d) Optical detector shape(s), combination of 'ellipsoid' and 'box'
+ - N_d is the number of detector types
+ dimensions : np.ndarray
+ (N_d, 3) Dimensions of each of the optical detector types
+ - N_d is the number of detector types
+ shape_ids : np.ndarray, optional
+ (N_o) Type of each of the optical detectors
+ - N_o is the number of optical detectors
+ det_ids : np.ndarray, optional
+ (N_c) Mapping between the optical channel and its corresponding detector
+ - N_c is the number of optical channels (this number can be larger
+ than the number of detectors if e.g. multiple SiPMs are used per
+ optical detector)
+ """
+ volume: str
+ positions: np.ndarray
+ shape: list
+ dimensions: np.ndarray
+ shape_ids: np.ndarray = None
+ det_ids: np.ndarray = None
+
+ def __init__(self, volume, volume_offsets, shape, dimensions, positions,
+ shape_ids=None, det_ids=None, global_index=False):
+ """Parse the optical detector configuration.
+
+ Parameters
+ ----------
+ volume : str
+ Optical decteor segmentation (per 'tpc' or per 'module')
+ volume_offsets : np.ndarray
+ Offsets of the optical volumes w.r.t. to the origin
+ shape : Union[str, List[str]]
+ Optical detector geomtry (combination of 'ellipsoid' and/or 'box')
+ dimensions : Union[List[float], List[List[float]]]]
+ (N_o, 3) List of optical detector dimensions along each axis
+ (either the ellipsoid axis lenghts or the box edge lengths)
+ positions : List[List[float]]
+ (N_o, 3) Positions of each of the optical detectors
+ shape_ids : List[int], optional
+ (N_o) If there is different types of optical detectors, specify
+ which type each of the index corresponds to
+ det_ids : List[int], optional
+ (N_c) If there are multiple readout channels which contribute to each
+ physical optical detector, map each channel onto a physical detector
+ global_index : bool, default False
+ If `True`, the flash objects have a `pe_per_ch` attribute which refers
+ to the entire index of optical detectors, rather than one volume
+ """
+ # Parse the detector shape(s) and its mapping, store is as a list
+ assert (shape in ['ellipsoid', 'box'] or
+ np.all([s in ['ellipsoid', 'box'] for s in shape])), (
+ "The shape of optical detectors must be represented as either "
+ "an 'ellipsoid' or a 'box', or a list of them.")
+ assert isinstance(shape, str) or shape_ids is not None, (
+ "If mutiple shapes are provided, must provide a shape map.")
+ assert shape_ids is None or len(shape_ids) == len(positions), (
+ "Must provide a shape index for each optical channel.")
+
+ self.shape = shape
+ if isinstance(shape, str):
+ self.shape = [shape]
+ self.shape_ids = shape_ids
+ if shape_ids is not None:
+ self.shape_ids = np.asarray(shape_ids, dtype=int)
+
+ # Parse the detector dimensions, store as a 2D array
+ self.dimensions = np.asarray(dimensions).reshape(-1, 3)
+ assert len(self.dimensions) == len(self.shape), (
+ "Must provide optical detector dimensions for each shape.")
+
+ # Store remaining optical detector parameters
+ self.volume = volume
+ self.det_ids = det_ids
+ if det_ids is not None:
+ self.det_ids = np.asarray(det_ids, dtype=int)
+
+ # Store optical detector positions
+ count = len(positions)
+ offsets = np.asarray(volume_offsets)
+ relative_positions = np.asarray(positions)
+ self.positions = np.empty((len(offsets), count, 3))
+ for v in range(len(offsets)):
+ self.positions[v] = relative_positions + offsets[v]
+
+ # Store if the flash points to the entire index of optical detectors
+ self.global_index = global_index
+
+ @property
+ def num_volumes(self):
+ """Returns the number of optical volumes.
+
+ Returns
+ -------
+ int
+ Number of optical volumes, N_v
+ """
+ return self.positions.shape[0]
+
+ @property
+ def num_detectors_per_volume(self):
+ """Returns the number of optical detectors in each optical volume.
+
+ Returns
+ -------
+ int
+ Number of optical detectors in each volume, N_o
+ """
+ return self.positions.shape[1]
+
+ @property
+ def num_detectors(self):
+ """Number of optical detectors.
+
+ Returns
+ -------
+ int
+ Total number of optical detector, N_v*N_o
+ """
+ return self.num_volumes*self.num_detectors_per_volume
+
+ def volume_index(self, volume_id):
+ """Returns an index which corresponds to detectors in a certain volume.
+
+ Parameters
+ ----------
+ volume_id : int
+ ID of the volume to return the index for
+
+ Returns
+ -------
+ np.ndarray
+ Index of the detectors which belong to the requested volume ID
+ """
+ # If using a global index, all volumes point to the same index
+ if self.global_index:
+ return np.arange(self.num_detectors)
+
+ return (volume_id*self.num_detectors_per_volume +
+ np.arange(self.num_detectors_per_volume))
diff --git a/spine/utils/geo/detector/tpc.py b/spine/utils/geo/detector/tpc.py
new file mode 100644
index 00000000..2e2c6c65
--- /dev/null
+++ b/spine/utils/geo/detector/tpc.py
@@ -0,0 +1,425 @@
+"""TPC detector geometry classes."""
+
+from typing import List
+from dataclasses import dataclass
+
+import numpy as np
+
+from .base import Box
+
+__all__ = ['TPCDetector']
+
+
+@dataclass
+class Chamber(Box):
+ """Class which holds all properties of an individual time-projection
+ chamber (TPC).
+
+ Attributes
+ ----------
+ drift_dir : np.ndarray
+ (3) TPC drift direction vector (normalized)
+ drift_axis : int
+ Axis along which the electrons drift (0, 1 or 2)
+ """
+ drift_dir: np.ndarray
+ drift_axis: int
+
+ def __init__(self, position, dimensions, drift_dir):
+ """Initialize the TPC object.
+
+ Parameters
+ ----------
+ position : np.ndarray
+ (3) Position of the center of the TPC
+ dimensions : np.ndarray
+ (3) Dimension of the TPC
+ drift_dir : np.ndarray
+ (3) Drift direction vector
+ """
+ # Initialize the underlying box object
+ lower = position - dimensions/2
+ upper = position + dimensions/2
+ super().__init__(lower, upper)
+
+ # Make sure that the drift axis only points in one direction
+ nonzero_axes = np.where(drift_dir)[0]
+ assert len(nonzero_axes) == 1, (
+ "The drift direction must be aligned with a base axis.")
+
+ # Store drift information
+ self.drift_dir = drift_dir
+ self.drift_axis = nonzero_axes[0]
+
+ @property
+ def drift_sign(self):
+ """Sign of drift w.r.t. to the drift axis orientation.
+
+ Returns
+ -------
+ int
+ Returns the sign of the drift vector w.r.t. to the drift axis
+ """
+ return int(self.drift_dir[self.drift_axis])
+
+ @property
+ def anode_side(self):
+ """Returns whether the anode is on the lower or upper boundary of
+ the TPC along the drift axis (0 for lower, 1 for upper).
+
+ Returns
+ -------
+ int
+ Anode side of the TPC
+ """
+ return (self.drift_sign + 1)//2
+
+ @property
+ def cathode_side(self):
+ """Returns whether the cathode is on the lower or upper boundary of
+ the TPC along the drift axis (0 for lower, 1 for upper).
+
+ Returns
+ -------
+ int
+ Cathode side of the TPC
+ """
+ return 1 - self.anode_side
+
+ @property
+ def anode_pos(self):
+ """Position of the anode along the drift direction.
+
+ Returns
+ -------
+ float
+ Anode position along the drift direction
+ """
+ return self.boundaries[self.drift_axis, self.anode_side]
+
+ @property
+ def cathode_pos(self):
+ """Position of the cathode along the drift direction.
+
+ Returns
+ -------
+ float
+ Cathode position along the drift direction
+ """
+ return self.boundaries[self.drift_axis, self.cathode_side]
+
+
+@dataclass
+class Module(Box):
+ """Class which holds all properties of a TPC module.
+
+ A module can hold either one chamber or two chambers with a shared cathode.
+
+ Attributes
+ ----------
+ chambers : List[Chamber]
+ List of individual TPCs that make up the module
+ """
+ chambers: List[Chamber]
+
+ def __init__(self, positions, dimensions, drift_dirs=None):
+ """Intialize the TPC module.
+
+ Parameters
+ ----------
+ positions : np.ndarray
+ (N_t) List of TPC center positions, one per TPC
+ dimensions : np.ndarray
+ (3) Dimensions of one TPC
+ drift_dirs : np.ndarray, optional
+ (N_t, 3) List of drift directions. If this is not provided, it is
+ inferred from the module configuration, provided that the module
+ is composed of two TPCs with a shared cathode.
+ """
+ # Sanity checks
+ assert len(positions) in [1, 2], (
+ "A TPC module must contain exactly one or two TPCs.")
+ assert (drift_dirs is not None) ^ (len(positions) == 2), (
+ "For TPC modules with one TPC, the drift direction cannot be "
+ "inferred and must be provided explicitely. For modules with "
+ "two TPCs, must not set the drift direction arbitrarily.")
+
+ # Build TPCs
+ self.chambers = []
+ for t in range(len(positions)):
+ # Fetch the drift axis. If not provided, join the two TPC centers
+ if drift_dirs is not None:
+ drift_dir = drift_dirs[t]
+ else:
+ drift_dir = positions[t] - positions[1 - t]
+ drift_dir /= np.linalg.norm(drift_dir)
+
+ # Instantiate TPC
+ self.chambers.append(Chamber(positions[t], dimensions, drift_dir))
+
+ # Initialize the underlying box object
+ lower = np.min(np.vstack([c.lower for c in self.chambers]), axis=0)
+ upper = np.max(np.vstack([c.upper for c in self.chambers]), axis=0)
+ super().__init__(lower, upper)
+
+ @property
+ def num_chambers(self):
+ """Number of individual TPCs that make up this module.
+
+ Returns
+ -------
+ int
+ Number of TPCs in the module
+ """
+ return len(self.chambers)
+
+ @property
+ def drift_axis(self):
+ """Drift axis for the module (shared between chambers).
+
+ Returns
+ -------
+ int
+ Axis along which electrons drift in this module (0, 1 or 2)
+ """
+ return self.chambers[0].drift_axis
+
+ @property
+ def cathode_pos(self):
+ """Location of the cathode plane along the drift axis.
+
+ Returns
+ -------
+ float
+ Location of the cathode plane along the drift axis
+ """
+ return np.mean([c.cathode_pos for c in self.chambers])
+
+ def __len__(self):
+ """Returns the number of TPCs in the module.
+
+ Returns
+ -------
+ int
+ Number of TPCs in the module
+ """
+ return self.num_chambers
+
+ def __getitem__(self, idx):
+ """Returns an underlying TPC of index idx.
+
+ Parameters
+ ----------
+ idx : int
+ Index of the TPC within the module
+
+ Returns
+ -------
+ Chamber
+ Chamber object
+ """
+ return self.chambers[idx]
+
+ def __iter__(self):
+ """Resets an iterator counter, return self.
+
+ Returns
+ -------
+ Module
+ The module itself
+ """
+ self._counter = 0
+ return self
+
+ def __next__(self):
+ """Defines how to process the next TPC in the module.
+
+ Returns
+ -------
+ Chamber
+ Next Chamber instance in the list
+ """
+ # If there are more TPCs to go through, return it
+ if self._counter < len(self):
+ tpc = self.chambers[self._counter]
+ self._counter += 1
+
+ return tpc
+
+ raise StopIteration
+
+
+@dataclass
+class TPCDetector(Box):
+ """Handles all geometry queries for a set of time-projection chambers.
+
+ Attributes
+ ----------
+ modules : List[Module]
+ (N_m) List of TPC modules associated with this detector
+ chambers : List[Chamber]
+ (N_t) List of individual TPC associated with this detector
+ det_ids : np.ndarray, optional
+ (N_c) Map between logical and physical TPC index
+ """
+ modules : List[Module]
+ chambers: List[Chamber]
+ det_ids : np.ndarray = None
+
+ def __init__(self, dimensions, positions, module_ids, det_ids=None,
+ drift_dirs=None):
+ """Parse the detector boundary configuration.
+
+ Parameters
+ ----------
+ dimensions : List[float]
+ (3) Dimensions of one TPC
+ positions : List[List[float]]
+ (N_t) List of TPC center positions, one per TPC
+ module_ids : List[int]
+ (N_t) List of the module IDs each TPC belongs to
+ det_ids : List[int], optional
+ (N_c) Index of the physical detector which corresponds to each
+ logical ID. This is needed if a TPC is divided into multiple logical
+ IDs. If this is not specified, it assumed that there is a one-to-one
+ correspondance between logical and physical.
+ drift_dirs : List[List[float]], optional
+ (N_t) List of drift direction vectors. If this is not provided, it
+ is inferred from the module configuration, provided that modules
+ are composed of two TPCs (with a shared cathode)
+ """
+ # Check the sanity of the configuration
+ assert len(dimensions) == 3, (
+ "Should provide the TPC dimension along 3 dimensions.")
+ assert np.all([len(pos) == 3 for pos in positions]), (
+ "Must provide the TPC position along 3 dimensions.")
+ assert len(module_ids) == len(positions), (
+ "Must provide one module ID for each TPC.")
+
+ # Cast the dimensions, positions, ids to arrays
+ dimensions = np.asarray(dimensions)
+ positions = np.asarray(positions)
+ module_ids = np.asarray(module_ids, dtype=int)
+
+ # Construct TPC chambers, organized by module
+ self.modules = []
+ self.chambers = []
+ for m in np.unique(module_ids):
+ # Narrow down the set of TPCs to those in this module
+ module_index = np.where(module_ids == m)[0]
+ module_positions = positions[module_index]
+ module_drift_dirs = None
+ if drift_dirs is not None:
+ module_drift_dirs = drift_dirs[module_index]
+
+ # Initialize the module, store
+ module = Module(module_positions, dimensions, module_drift_dirs)
+ self.modules.append(module)
+ self.chambers.extend(module.chambers)
+
+ # Check that if detector IDs are provided, they are comprehensive
+ if det_ids is not None:
+ self.det_ids = np.asarray(det_ids, dtype=int)
+ assert len(np.unique(det_ids)) == self.num_chambers_per_module, (
+ "All physical TPCs must be associated with at least one "
+ "logical TPC.")
+
+ # Initialize the underlying all-encompasing box object
+ lower = np.min(np.vstack([m.lower for m in self.modules]), axis=0)
+ upper = np.max(np.vstack([m.upper for m in self.modules]), axis=0)
+ super().__init__(lower, upper)
+
+ @property
+ def num_chambers(self):
+ """Number of individual TPC voulmes.
+
+ Returns
+ -------
+ int
+ Number of TPC volumes, N_t
+ """
+ return len(self.chambers)
+
+ @property
+ def num_modules(self):
+ """Number of detector modules.
+
+ Returns
+ -------
+ int
+ Number of detector modules, N_m
+ """
+ return len(self.modules)
+
+ @property
+ def num_chambers_per_module(self):
+ """Number of TPC volumes per module.
+
+ Returns
+ -------
+ int
+ Number of TPC volumes per module, N_t
+ """
+ return len(self.modules[0])
+
+ def __len__(self):
+ """Returns the number of modules in the detector.
+
+ Returns
+ -------
+ int
+ Number of modules in the detector
+ """
+ return self.num_modules
+
+ def __getitem__(self, idx):
+ """Returns an underlying module or TPC, depending on the index type.
+
+ If the index is specified as a simple integer, a module is returned. If
+ the index is specified with two integers, a specific chamber within a
+ module is returned instead.
+
+ Parameters
+ ----------
+ idx : Uniont[int, List[int]]
+ Module index or pair of [module ID, chamber ID]
+
+ Returns
+ -------
+ Union[Module, Chamber]
+ Module or Chamber object
+ """
+ if np.isscalar(idx):
+ return self.modules[idx]
+
+ else:
+ module_id, chamber_id = idx
+ return self.modules[module_id].chambers[chamber_id]
+
+ def __iter__(self):
+ """Resets an iterator counter, return self.
+
+ Returns
+ -------
+ TPCDetector
+ The module itself
+ """
+ self._counter = 0
+ return self
+
+ def __next__(self):
+ """Defines how to process the next Module in the detector.
+
+ Returns
+ -------
+ Module
+ Next Module instance in the list
+ """
+ # If there are more TPCs to go through, return it
+ if self._counter < len(self):
+ module = self.modules[self._counter]
+ self._counter += 1
+
+ return module
+
+ raise StopIteration
diff --git a/spine/utils/geo/manager.py b/spine/utils/geo/manager.py
new file mode 100644
index 00000000..7eea1402
--- /dev/null
+++ b/spine/utils/geo/manager.py
@@ -0,0 +1,639 @@
+"""Module with a general-purpose geometry class.
+
+This class supports the storage of:
+- TPC boundaries
+- Optical detector shape/locations
+- CRT detector shape/locations
+
+It also provides a plethora of useful functions to query the geometry.
+"""
+
+import os
+import pathlib
+from dataclasses import dataclass
+
+import yaml
+import numpy as np
+from scipy.spatial.distance import cdist
+
+from .detector import *
+
+
+@dataclass
+class Geometry:
+ """Handles all geometry functions for a collection of box-shaped TPCs with
+ a arbitrary set of optical detectors organized in optical volumes and CRT
+ planes.
+
+ Attributes
+ ----------
+ tpc : TPCDetector
+ TPC detector properties
+ optical : OptDetector, optional
+ Optical detector properties
+ crt : CRTDetector, optional
+ CRT detector properties
+ """
+ tpc: TPCDetector
+ optical: OptDetector = None
+ crt: CRTDetector = None
+
+ def __init__(self, detector=None, file_path=None):
+ """Initializes a detector geometry object.
+
+ The geometry configuration file is a YAML file which contains all the
+ necessary information to construct the physical boundaries of the
+ a detector (TPC size, positions, etc.) and (optionally) the set
+ of optical detectors and CRTs.
+
+ If the detector is already supported, the name is sufficient.
+ Supported: 'icarus', 'sbnd', '2x2', '2x2_single', 'ndlar'
+
+ Parameters
+ ----------
+ detector : str, optional
+ Name of a recognized detector to the geometry from
+ file_path : str, optional
+ Path to a `.yaml` geometry configuration
+ """
+ # Check that we are provided with either a detector name or a file
+ assert (detector is not None) ^ (file_path is not None), (
+ "Must provide either a `detector` name or a geometry "
+ "`file_path`, not neither, not both.")
+
+ # If a detector name is provided, find the corresponding geometry file
+ if detector is not None:
+ path = pathlib.Path(__file__).parent
+ file_path = os.path.join(
+ path, 'source', f'{detector.lower()}_geometry.yaml')
+
+ # Check that the geometry configuration file exists
+ if not os.path.isfile(file_path):
+ raise FileNotFoundError(
+ f"Could not find the geometry file: {file_path}")
+
+ # Load the geometry file, parse it
+ with open(file_path, 'r', encoding='utf-8') as cfg_yaml:
+ cfg = yaml.safe_load(cfg_yaml)
+
+ self.parse_configuration(**cfg)
+
+ # Initialize place-holders for the containment conditions to be defined
+ # by the end-user using :func:`define_containment_volumes`
+ self._cont_volumes = None
+ self._cont_use_source = False
+
+ def parse_configuration(self, tpc, optical=None, crt=None):
+ """Parse the geometry configuration.
+
+ Parameters
+ ----------
+ tpc : dict
+ Detector boundary configuration
+ optical : dict, optional
+ Optical detector configuration
+ crt : dict, optional
+ CRT detector configuration
+ """
+ # Load the charge detector boundaries
+ self.tpc = TPCDetector(**tpc)
+
+ # Load the optical detectors
+ if optical is not None:
+ self.parse_optical(**optical)
+
+ # Load the CRT detectors
+ if crt is not None:
+ self.crt = CRTDetector(**crt)
+
+ def parse_optical(self, volume, **optical):
+ """Parse the optical detector configuration.
+
+ Parameters
+ ----------
+ volume : str
+ Optical volume boundaries (one of 'tpc' or 'module')
+ **optical : dict
+ Reset of the optical detector configuration
+ """
+ # Get the number of optical volumes based on the the volume type
+ assert volume in ['module', 'tpc'], (
+ "Optical detector positions must be provided by TPC or module.")
+
+ if volume == 'module':
+ offsets = [module.center for module in self.tpc.modules]
+ else:
+ offsets = [chamber.center for chamber in self.tpc.chambers]
+
+ # Initialize the optical detector object
+ self.optical = OptDetector(volume, offsets, **optical)
+
+ def get_sources(self, sources):
+ """Converts logical TPC indexes to physical TPC indexes.
+
+ Parameters
+ ----------
+ sources : np.ndarray
+ (N, 2) Array of logical [module ID, tpc ID] pairs, one per point
+
+ Returns
+ ----------
+ np.ndarray
+ (N, 2) Array of physical [module ID, tpc ID] pairs, one per point
+ """
+ # If logical and physical TPCs are aligned, nothing to do
+ if self.tpc.det_ids is None:
+ return sources
+
+ # Otherwise, map logical to physical
+ sources = np.copy(sources)
+ sources[:, 1] = self.tpc.det_ids[sources[:, 1]]
+
+ return sources
+
+ def get_contributors(self, sources):
+ """Gets the list of [module ID, tpc ID] pairs that contributed to a
+ particle or interaction object, as defined in this geometry.
+
+ Parameters
+ ----------
+ sources : np.ndarray
+ (N, 2) Array of [module ID, tpc ID] pairs, one per point
+
+ Returns
+ -------
+ List[np.ndarray]
+ (2, N_t) Pair of arrays: the first contains the list of
+ contributing modules, the second of contributing tpcs.
+ """
+ # Fetch the list of unique logical [module ID, tpc ID] pairs
+ sources = np.unique(sources, axis=0)
+
+ # If the logical TPCs differ from the physical TPCs, convert
+ if self.tpc.det_ids is not None:
+ sources = self.get_sources(sources)
+ sources = np.unique(sources, axis=0)
+
+ # Return as a list of physical [module ID, tpc ID] pairs
+ return list(sources.T)
+
+ def get_volume_index(self, sources, module_id, tpc_id=None):
+ """Gets the list of indices of points that belong to a certain
+ detector volume (module or individual TPC).
+
+ Parameters
+ ----------
+ sources : np.ndarray
+ (N, 2) Array of [module ID, tpc ID] pairs, one per point
+ module_id : int
+ ID of the module
+ tpc_id : int, optional
+ ID of the TPC within the module. If not specified, the volume
+ offsets are estimated w.r.t. the module itself
+
+ Returns
+ -------
+ np.ndarray
+ (N) Index of points that belong to the requested detector volume
+ """
+ # If the logical TPCs differ from the physical TPCs, convert
+ sources = self.get_sources(sources)
+
+ # Compute and return the index
+ if tpc_id is None:
+ return np.where(sources[:, 0] == module_id)[0]
+ else:
+ return np.where((sources == [module_id, tpc_id]).all(axis=-1))[0]
+
+ def get_closest_tpc(self, points):
+ """For each point, find the ID of the closest TPC.
+
+ There is a natural assumption that all TPCs are boxes of identical
+ sizes, so that the relative proximitity of a point to a TPC is
+ equivalent to its proximity to the TPC center.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+
+ Returns
+ -------
+ np.ndarray
+ (N) List of TPC indexes, one per input point
+ """
+ # Get the TPC centers
+ centers = np.asarray([chamber.center for chamber in self.tpc.chambers])
+
+ # Compute the pair-wise distances between points and TPC centers
+ dists = cdist(points, centers)
+
+ # Return the closest center index as the closest centers
+ return np.argmin(dists, axis=1)
+
+ def get_closest_module(self, points):
+ """For each point, find the ID of the closest module.
+
+ There is a natural assumption that all modules are boxes of identical
+ sizes, so that the relative proximitity of a point to a module is
+ equivalent to its proximity to the module center.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+
+ Returns
+ -------
+ np.ndarray
+ (N) List of module indexes, one per input point
+ """
+ # Get the module centers
+ centers = np.asarray([module.center for module in self.tpc.modules])
+
+ # Compute the pair-wise distances between points and module centers
+ dists = cdist(points, centers)
+
+ # Return the closest center index as the closest centers
+ return np.argmin(dists, axis=1)
+
+ def get_closest_tpc_indexes(self, points):
+ """For each TPC, get the list of points that live closer to it than any
+ other TPC in the detector.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+
+ Returns
+ -------
+ List[np.ndarray]
+ List of index of points that belong to each TPC
+ """
+ # Start by finding the closest TPC to each of the points
+ closest_ids = self.get_closest_tpc(points)
+
+ # For each TPC, append the list of point indices associated with it
+ tpc_indexes = []
+ for t in range(self.tpc.num_chambers):
+ tpc_indexes.append(np.where(closest_ids == t)[0])
+
+ return tpc_indexes
+
+ def get_closest_module_indexes(self, points):
+ """For each module, get the list of points that live closer to it
+ than any other module in the detector.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+
+ Returns
+ -------
+ List[np.ndarray]
+ List of index of points that belong to each module
+ """
+ # Start by finding the closest TPC to each of the points
+ closest_ids = self.get_closest_module(points)
+
+ # For each module, append the list of point indices associated with it
+ module_indexes = []
+ for m in range(self.tpc.num_modules):
+ module_indexes.append(np.where(closest_ids == m)[0])
+
+ return module_indexes
+
+ def get_volume_offsets(self, points, module_id, tpc_id=None):
+ """Compute how far each point is from a certain detector volume.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) : Point coordinates
+ module_id : int
+ ID of the module
+ tpc_id : int, optional
+ ID of the TPC within the module. If not specified, the volume
+ offsets are estimated w.r.t. the module itself
+
+ Returns
+ -------
+ np.ndarray
+ (N, 3) Offsets w.r.t. to the volume boundaries
+ """
+ # Compute the axis-wise distances of each point to each boundary
+ idx = module_id if tpc_id is None else (module_id, tpc_id)
+ boundaries = self.tpc[idx].boundaries
+ dists = points[..., None] - boundaries
+
+ # If a point is between two boundaries, the distance is 0. If it is
+ # outside, the distance is that of the closest boundary
+ signs = (np.sign(dists[..., 0]) + np.sign(dists[..., 1]))/2
+ offsets = signs * np.min(np.abs(dists), axis=-1)
+
+ return offsets
+
+ def get_min_volume_offset(self, points, module_id, tpc_id=None):
+ """Get the minimum offset to apply to a point cloud to bring it
+ within the boundaries of a volume.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) : Point coordinates
+ module_id : int
+ ID of the module
+ tpc_id : int, optional
+ ID of the TPC within the module. If not specified, the volume
+ offsets are estimated w.r.t. the module itself
+
+ Returns
+ -------
+ np.ndarray
+ (3) Offsets w.r.t. to the volume location
+ """
+ # Compute the distance for each point, get the maximum necessary offset
+ offsets = self.get_volume_offsets(points, module_id, tpc_id)
+ offsets = offsets[np.argmax(np.abs(offsets), axis=0), np.arange(3)]
+
+ return offsets
+
+ def translate(self, points, source_id, target_id, factor=None):
+ """Moves a point cloud from one module to another one
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+ source_id: int
+ Module ID from which to move the point cloud
+ target_id : int
+ Module ID to which to move the point cloud
+ factor : Union[float, np.ndarray], optional
+ Multiplicative factor to apply to the offset. This is necessary if
+ the points are not expressed in detector coordinates
+
+ Returns
+ -------
+ np.ndarray
+ (N, 3) Set of translated point coordinates
+ """
+ # If the source and target are the same, nothing to do here
+ if target_id == source_id:
+ return np.copy(points)
+
+ # Fetch the inter-module shift
+ offset = self.tpc[target_id].center - self.tpc[source_id].center
+ if factor is not None:
+ offset *= factor
+
+ # Translate
+ return points + offset
+
+ def split(self, points, target_id, sources=None, meta=None):
+ """Migrate all points to a target module, organize them by module ID.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+ target_id : int
+ Module ID to which to move the point cloud
+ sources : np.ndarray, optional
+ (N, 2) Array of [module ID, tpc ID] pairs, one per voxel
+ meta : Meta, optional
+ Meta information about the voxelized image. If provided, the
+ points are assumed to be provided in voxel coordinates.
+
+ Returns
+ -------
+ np.ndarray
+ (N, 3) Shifted set of points
+ List[np.ndarray]
+ List of index of points that belong to each module
+ """
+ # Check that the target ID exists
+ assert target_id > -1 and target_id < self.tpc.num_modules, (
+ "Target ID should be in [0, N_modules[.")
+
+ # Get the module ID of each of the input points
+ convert = False
+ if sources is not None:
+ # If sources are provided, simply use that
+ module_indexes = []
+ for m in range(self.tpc.num_modules):
+ module_indexes.append(np.where(sources[:, 0] == m)[0])
+
+ else:
+ # If the points are expressed in pixel coordinates, translate
+ convert = meta is not None
+ if convert:
+ points = meta.to_cm(points, center=True)
+
+ # If not provided, find module each point belongs to by proximity
+ module_indexes = self.get_closest_module_indexes(points)
+
+ # Now shifts all points that are not in the target
+ for module_id, module_index in enumerate(module_indexes):
+ # If this is the target module, nothing to do here
+ if module_id == target_id:
+ continue
+
+ # Shift the coordinates
+ points[module_index] = self.translate(
+ points[module_index], module_id, target_id)
+
+ # Bring the coordinates back to pixels, if they were shifted
+ if convert:
+ points = meta.to_px(points, floor=True)
+
+ return points, module_indexes
+
+ def check_containment(self, points, sources=None,
+ allow_multi_module=False, summarize=True):
+ """Check whether a point cloud comes within some distance of the
+ boundaries of a certain subset of detector volumes, depending on the
+ mode.
+
+ Parameters
+ ----------
+ points : np.ndarray
+ (N, 3) Set of point coordinates
+ sources : np.ndarray, optional
+ (S, 2) : List of [module ID, tpc ID] pairs that created the
+ point cloud
+ allow_multi_module : bool, default `False`
+ Whether to allow particles/interactions to span multiple modules
+ summarize : bool, default `True`
+ If `True`, only returns a single flag for the whole cloud.
+ Otherwise, returns a boolean array corresponding to each point.
+
+ Returns
+ -------
+ Union[bool, np.ndarray]
+ `True` if the particle is contained, `False` if not
+ """
+ # If the containment volumes are not defined, throw
+ if self._cont_volumes is None:
+ raise ValueError("Must call `define_containment_volumes` first.")
+
+ # If sources are provided, only consider source volumes
+ if self._cont_use_source:
+ # Get the contributing TPCs
+ assert len(points) == len(sources), (
+ "Need to provide sources to make a source-based check.")
+ contributors = self.get_contributors(sources)
+ if not allow_multi_module and len(np.unique(contributors[0])) > 1:
+ return False
+
+ # Define the smallest box containing all contributing TPCs
+ # TODO: this is not ideal
+ index = contributors[0] * self.tpc.num_chambers_per_module + contributors[1]
+ volume = self.merge_volumes(self._cont_volumes[index])
+ volumes = [volume]
+
+ else:
+ volumes = self._cont_volumes
+
+ # Loop over volumes, make sure the cloud is contained in at least one
+ if summarize:
+ contained = False
+ for v in volumes:
+ if (points > v[:, 0]).all() and (points < v[:, 1]).all():
+ contained = True
+ break
+ else:
+ contained = np.zeros(len(points), dtype=bool)
+ for v in volumes:
+ contained |= ((points > v[:, 0]).all(axis=1) &
+ (points < v[:, 1]).all(axis=1))
+
+ return contained
+
+ def define_containment_volumes(self, margin, cathode_margin=None,
+ mode ='module'):
+ """This function defines a list of volumes to check containment against.
+
+ If the containment is checked against a constant volume, it is more
+ efficient to call this function once and call `check_containment`
+ reapitedly after.
+
+ Parameters
+ ----------
+ margin : Union[float, List[float], np.array]
+ Minimum distance from a detector wall to be considered contained:
+ - If float: distance buffer is shared between all 6 walls
+ - If [x,y,z]: distance is shared between pairs of walls facing
+ each other and perpendicular to a shared axis
+ - If [[x_low,x_up], [y_low,y_up], [z_low,z_up]]: distance is
+ specified individually of each wall.
+ cathode_margin : float, optional
+ If specified, sets a different margin for the cathode boundaries
+ mode : str, default 'module'
+ Containement criterion (one of 'global', 'module', 'tpc'):
+ - If 'tpc', makes sure it is contained within a single TPC
+ - If 'module', makes sure it is contained within a single module
+ - If 'detector', makes sure it is contained within in the detector
+ - If 'source', use the origin of voxels to determine which TPC(s)
+ contributed to them, and define volumes accordingly
+ """
+ # Translate the margin parameter to a (3,2) matrix
+ if np.isscalar(margin):
+ margin = np.full((3, 2), margin)
+ elif len(np.array(margin).shape) == 1:
+ assert len(margin) == 3, (
+ "Must provide one value per axis.")
+ margin = np.repeat([margin], 2, axis=0).T
+ else:
+ assert np.array(margin).shape == (3, 2), (
+ "Must provide two values per axis.")
+ margin = np.copy(margin)
+
+ # Establish the volumes to check against
+ self._cont_volumes = []
+ if mode in ['tpc', 'source']:
+ for m, module in enumerate(self.tpc):
+ for t, tpc in enumerate(module):
+ vol = self.adapt_volume(
+ tpc.boundaries, margin, cathode_margin, m, t)
+ self._cont_volumes.append(vol)
+ self._cont_use_source = mode == 'source'
+
+ elif mode == 'module':
+ for module in self.tpc:
+ vol = self.adapt_volume(module.boundaries, margin)
+ self._cont_volumes.append(vol)
+ self._cont_use_source = False
+
+ elif mode == 'detector':
+ vol = self.adapt_volume(self.tpc.boundaries, margin)
+ self._cont_volumes.append(vol)
+ self._cont_use_source = False
+
+ else:
+ raise ValueError(f"Containement check mode not recognized: {mode}.")
+
+ self._cont_volumes = np.array(self._cont_volumes)
+
+ def adapt_volume(self, ref_volume, margin, cathode_margin=None,
+ module_id=None, tpc_id=None):
+ """Apply margins from a given volume. Takes care of subtleties
+ associated with the cathode, if needed.
+
+ Parameters
+ ----------
+ ref_volume : np.ndarray
+ (3, 2) Array of volume boundaries
+ margin : np.ndarray
+ Minimum distance from a detector wall to be considered contained as
+ [[x_low,x_up], [y_low,y_up], [z_low,z_up]], i.e. distance is
+ specified individually of each wall.
+ cathode_margin : float, optional
+ If specified, sets a different margin for the cathode boundaries
+ module_id : int, optional
+ ID of the module
+ tpc_id : int, optional
+ ID of the TPC within the module
+
+ Returns
+ -------
+ np.ndarray
+ (3, 2) Updated array of volume boundaries
+ """
+ # Reduce the volume according to the margin
+ volume = np.copy(ref_volume)
+ volume[:,0] += margin[:,0]
+ volume[:,1] -= margin[:,1]
+
+ # If a cathode margin is provided, adapt the cathode wall differently
+ if cathode_margin is not None:
+ axis = self.tpc[module_id, tpc_id].drift_axis
+ side = self.tpc[module_id, tpc_id].cathode_side
+
+ flip = (-1) ** side
+ volume[axis, side] += flip * (cathode_margin - margin[axis, side])
+
+ return volume
+
+ @staticmethod
+ def merge_volumes(volumes):
+ """Given a list of volumes and their boundaries, find the smallest box
+ that encompass all volumes combined.
+
+ Parameters
+ ----------
+ volumes : np.ndarray
+ (N, 3, 2) List of volume boundaries
+
+ Returns
+ -------
+ np.ndarray
+ (3, 2) Boundaries of the combined volume
+ """
+ volume = np.empty((3, 2))
+ volume[:, 0] = np.min(volumes, axis=0)[:, 0]
+ volume[:, 1] = np.max(volumes, axis=0)[:, 1]
+
+ return volume
diff --git a/spine/utils/geo/source/2x2_boundaries.npy b/spine/utils/geo/source/2x2_boundaries.npy
deleted file mode 120000
index d9edf96e..00000000
--- a/spine/utils/geo/source/2x2_boundaries.npy
+++ /dev/null
@@ -1 +0,0 @@
-2x2_mr5_boundaries.npy
\ No newline at end of file
diff --git a/spine/utils/geo/source/2x2_geometry.yaml b/spine/utils/geo/source/2x2_geometry.yaml
new file mode 120000
index 00000000..295ad9e2
--- /dev/null
+++ b/spine/utils/geo/source/2x2_geometry.yaml
@@ -0,0 +1 @@
+2x2_mr5_geometry.yaml
\ No newline at end of file
diff --git a/spine/utils/geo/source/2x2_mr3_boundaries.npy b/spine/utils/geo/source/2x2_mr3_boundaries.npy
deleted file mode 100644
index 8b876393..00000000
Binary files a/spine/utils/geo/source/2x2_mr3_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/2x2_mr3_geometry.yaml b/spine/utils/geo/source/2x2_mr3_geometry.yaml
new file mode 100644
index 00000000..196227fe
--- /dev/null
+++ b/spine/utils/geo/source/2x2_mr3_geometry.yaml
@@ -0,0 +1,39 @@
+tpc:
+ dimensions: [30.27225, 124.152, 62.076]
+ module_ids: [0, 0, 1, 1, 2, 2, 3, 3]
+ positions:
+ - [18.205125, 43.0, 33.5]
+ - [48.794875, 43.0, 33.5]
+ - [18.205125, 43.0, -33.5]
+ - [48.794875, 43.0, -33.5]
+ - [-48.794875, 43.0, 33.5]
+ - [-18.205125, 43.0, 33.5]
+ - [-48.794875, 43.0, -33.5]
+ - [-18.205125, 43.0, -33.5]
+optical:
+ volume: tpc
+ shape: [box, box]
+ dimensions:
+ - [27.17, 31.02, 1.0]
+ - [27.17, 10.34, 1.0]
+ shape_ids: [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
+ det_ids: [0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7,
+ 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13, 14, 14,
+ 15, 15]
+ positions:
+ - [0, -46.53, -31.49]
+ - [0, -25.85, -31.49]
+ - [0, -15.51, -31.49]
+ - [0, -5.16, -31.49]
+ - [0, 15.51, -31.49]
+ - [0, 36.19, -31.49]
+ - [0, 46.53, -31.49]
+ - [0, 56.87, -31.49]
+ - [0, -46.53, 31.49]
+ - [0, -25.84, 31.49]
+ - [0, -15.51, 31.49]
+ - [0, -5.17, 31.49]
+ - [0, 15.51, 31.49]
+ - [0, 36.19, 31.49]
+ - [0, 46.53, 31.49]
+ - [0, 56.87, 31.49]
diff --git a/spine/utils/geo/source/2x2_mr4_boundaries.npy b/spine/utils/geo/source/2x2_mr4_boundaries.npy
deleted file mode 100644
index 23014f22..00000000
Binary files a/spine/utils/geo/source/2x2_mr4_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/2x2_mr4_geometry.yaml b/spine/utils/geo/source/2x2_mr4_geometry.yaml
new file mode 100644
index 00000000..0406b977
--- /dev/null
+++ b/spine/utils/geo/source/2x2_mr4_geometry.yaml
@@ -0,0 +1,39 @@
+tpc:
+ dimensions: [30.27225, 124.152, 62.076]
+ module_ids: [0, 0, 1, 1, 2, 2, 3, 3]
+ positions:
+ - [18.205125, -267.0, 1333.5]
+ - [48.794875, -267.0, 1333.5]
+ - [18.205125, -267.0, 1266.5]
+ - [48.794875, -267.0, 1266.5]
+ - [-48.794875, -267.0, 1333.5]
+ - [-18.205125, -267.0, 1333.5]
+ - [-48.794875, -267.0, 1266.5]
+ - [-18.205125, -267.0, 1266.5]
+optical:
+ volume: tpc
+ shape: [box, box]
+ dimensions:
+ - [27.17, 31.02, 1.0]
+ - [27.17, 10.34, 1.0]
+ shape_ids: [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
+ det_ids: [0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7,
+ 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13, 14, 14,
+ 15, 15]
+ positions:
+ - [0, -46.53, -31.49]
+ - [0, -25.85, -31.49]
+ - [0, -15.51, -31.49]
+ - [0, -5.16, -31.49]
+ - [0, 15.51, -31.49]
+ - [0, 36.19, -31.49]
+ - [0, 46.53, -31.49]
+ - [0, 56.87, -31.49]
+ - [0, -46.53, 31.49]
+ - [0, -25.84, 31.49]
+ - [0, -15.51, 31.49]
+ - [0, -5.17, 31.49]
+ - [0, 15.51, 31.49]
+ - [0, 36.19, 31.49]
+ - [0, 46.53, 31.49]
+ - [0, 56.87, 31.49]
diff --git a/spine/utils/geo/source/2x2_mr5_boundaries.npy b/spine/utils/geo/source/2x2_mr5_boundaries.npy
deleted file mode 100644
index 6eb844a3..00000000
Binary files a/spine/utils/geo/source/2x2_mr5_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/2x2_mr5_geometry.yaml b/spine/utils/geo/source/2x2_mr5_geometry.yaml
new file mode 100644
index 00000000..ba723123
--- /dev/null
+++ b/spine/utils/geo/source/2x2_mr5_geometry.yaml
@@ -0,0 +1,39 @@
+tpc:
+ dimensions: [30.27225, 124.152, 62.076]
+ module_ids: [0, 0, 1, 1, 2, 2, 3, 3]
+ positions:
+ - [18.205125, -0.0, 33.5]
+ - [48.794875, -0.0, 33.5]
+ - [18.205125, -0.0, -33.5]
+ - [48.794875, -0.0, -33.5]
+ - [-48.794875, -0.0, 33.5]
+ - [-18.205125, -0.0, 33.5]
+ - [-48.794875, -0.0, -33.5]
+ - [-18.205125, -0.0, -33.5]
+optical:
+ volume: tpc
+ shape: [box, box]
+ dimensions:
+ - [27.17, 31.02, 1.0]
+ - [27.17, 10.34, 1.0]
+ shape_ids: [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
+ det_ids: [0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7,
+ 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13, 14, 14,
+ 15, 15]
+ positions:
+ - [0, -46.53, -31.49]
+ - [0, -25.85, -31.49]
+ - [0, -15.51, -31.49]
+ - [0, -5.16, -31.49]
+ - [0, 15.51, -31.49]
+ - [0, 36.19, -31.49]
+ - [0, 46.53, -31.49]
+ - [0, 56.87, -31.49]
+ - [0, -46.53, 31.49]
+ - [0, -25.84, 31.49]
+ - [0, -15.51, 31.49]
+ - [0, -5.17, 31.49]
+ - [0, 15.51, 31.49]
+ - [0, 36.19, 31.49]
+ - [0, 46.53, 31.49]
+ - [0, 56.87, 31.49]
diff --git a/spine/utils/geo/source/2x2_single_boundaries.npy b/spine/utils/geo/source/2x2_single_boundaries.npy
deleted file mode 100644
index 0250b977..00000000
Binary files a/spine/utils/geo/source/2x2_single_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/2x2_single_geometry.yaml b/spine/utils/geo/source/2x2_single_geometry.yaml
new file mode 100644
index 00000000..06cbc933
--- /dev/null
+++ b/spine/utils/geo/source/2x2_single_geometry.yaml
@@ -0,0 +1,33 @@
+tpc:
+ dimensions: [30.1135, 123.708602905, 61.632598877]
+ module_ids: [0, 0]
+ positions:
+ - [-15.2155, 0.0, 0.0]
+ - [15.2155, 0.0, 0.0]
+optical:
+ volume: tpc
+ shape: [box, box]
+ dimensions:
+ - [27.17, 31.02, 1.0]
+ - [27.17, 10.34, 1.0]
+ shape_ids: [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
+ det_ids: [0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 6, 6, 7, 7,
+ 8, 8, 8, 8, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13, 14, 14,
+ 15, 15]
+ positions:
+ - [0, -46.53, -31.49]
+ - [0, -25.85, -31.49]
+ - [0, -15.51, -31.49]
+ - [0, -5.16, -31.49]
+ - [0, 15.51, -31.49]
+ - [0, 36.19, -31.49]
+ - [0, 46.53, -31.49]
+ - [0, 56.87, -31.49]
+ - [0, -46.53, 31.49]
+ - [0, -25.84, 31.49]
+ - [0, -15.51, 31.49]
+ - [0, -5.17, 31.49]
+ - [0, 15.51, 31.49]
+ - [0, 36.19, 31.49]
+ - [0, 46.53, 31.49]
+ - [0, 56.87, 31.49]
diff --git a/spine/utils/geo/source/icarus_boundaries.npy b/spine/utils/geo/source/icarus_boundaries.npy
deleted file mode 100644
index 2a3678a4..00000000
Binary files a/spine/utils/geo/source/icarus_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/icarus_geometry.yaml b/spine/utils/geo/source/icarus_geometry.yaml
new file mode 100644
index 00000000..94fb4a12
--- /dev/null
+++ b/spine/utils/geo/source/icarus_geometry.yaml
@@ -0,0 +1,227 @@
+tpc:
+ dimensions: [148.2, 316.82, 1789.901]
+ module_ids: [0, 0, 1, 1]
+ det_ids: [0, 0, 1, 1]
+ positions:
+ - [-284.39, -23.45, 0.0]
+ - [-136.04, -23.45, 0.0]
+ - [136.04, -23.45, 0.0]
+ - [284.39, -23.45, 0.0]
+optical:
+ volume: module
+ shape: ellipsoid
+ dimensions: [10.8, 20.2, 20.2]
+ positions:
+ - [-160.855, -52.8, -872.9]
+ - [-160.855, 52.8, -872.9]
+ - [-160.855, -105.6, -823.02]
+ - [-160.855, 0.0, -823.02]
+ - [-160.855, 105.6, -823.02]
+ - [-160.855, -105.6, -773.14]
+ - [-160.855, 0.0, -773.14]
+ - [-160.855, 105.6, -773.14]
+ - [-160.855, -52.8, -723.26]
+ - [-160.855, 52.8, -723.26]
+ - [-160.855, -52.8, -673.38]
+ - [-160.855, 52.8, -673.38]
+ - [-160.855, -105.6, -623.5]
+ - [-160.855, 0.0, -623.5]
+ - [-160.855, 105.6, -623.5]
+ - [-160.855, -105.6, -573.62]
+ - [-160.855, 0.0, -573.62]
+ - [-160.855, 105.6, -573.62]
+ - [-160.855, -52.8, -523.74]
+ - [-160.855, 52.8, -523.74]
+ - [-160.855, -52.8, -473.86]
+ - [-160.855, 52.8, -473.86]
+ - [-160.855, -105.6, -423.98]
+ - [-160.855, 0.0, -423.98]
+ - [-160.855, 105.6, -423.98]
+ - [-160.855, -105.6, -374.1]
+ - [-160.855, 0.0, -374.1]
+ - [-160.855, 105.6, -374.1]
+ - [-160.855, -52.8, -324.22]
+ - [-160.855, 52.8, -324.22]
+ - [-160.855, -52.8, -274.34]
+ - [-160.855, 52.8, -274.34]
+ - [-160.855, -105.6, -224.46]
+ - [-160.855, 0.0, -224.46]
+ - [-160.855, 105.6, -224.46]
+ - [-160.855, -105.6, -174.58]
+ - [-160.855, 0.0, -174.58]
+ - [-160.855, 105.6, -174.58]
+ - [-160.855, -52.8, -124.7]
+ - [-160.855, 52.8, -124.7]
+ - [-160.855, -52.8, -74.82]
+ - [-160.855, 52.8, -74.82]
+ - [-160.855, -105.6, -24.94]
+ - [-160.855, 0.0, -24.94]
+ - [-160.855, 105.6, -24.94]
+ - [-160.855, -105.6, 24.94]
+ - [-160.855, 0.0, 24.94]
+ - [-160.855, 105.6, 24.94]
+ - [-160.855, -52.8, 74.82]
+ - [-160.855, 52.8, 74.82]
+ - [-160.855, -52.8, 124.7]
+ - [-160.855, 52.8, 124.7]
+ - [-160.855, -105.6, 174.58]
+ - [-160.855, 0.0, 174.58]
+ - [-160.855, 105.6, 174.58]
+ - [-160.855, -105.6, 224.46]
+ - [-160.855, 0.0, 224.46]
+ - [-160.855, 105.6, 224.46]
+ - [-160.855, -52.8, 274.34]
+ - [-160.855, 52.8, 274.34]
+ - [-160.855, -52.8, 324.22]
+ - [-160.855, 52.8, 324.22]
+ - [-160.855, -105.6, 374.1]
+ - [-160.855, 0.0, 374.1]
+ - [-160.855, 105.6, 374.1]
+ - [-160.855, -105.6, 423.98]
+ - [-160.855, 0.0, 423.98]
+ - [-160.855, 105.6, 423.98]
+ - [-160.855, -52.8, 473.86]
+ - [-160.855, 52.8, 473.86]
+ - [-160.855, -52.8, 523.74]
+ - [-160.855, 52.8, 523.74]
+ - [-160.855, -105.6, 573.62]
+ - [-160.855, 0.0, 573.62]
+ - [-160.855, 105.6, 573.62]
+ - [-160.855, -105.6, 623.5]
+ - [-160.855, 0.0, 623.5]
+ - [-160.855, 105.6, 623.5]
+ - [-160.855, -52.8, 673.38]
+ - [-160.855, 52.8, 673.38]
+ - [-160.855, -52.8, 723.26]
+ - [-160.855, 52.8, 723.26]
+ - [-160.855, -105.6, 773.14]
+ - [-160.855, 0.0, 773.14]
+ - [-160.855, 105.6, 773.14]
+ - [-160.855, -105.6, 823.02]
+ - [-160.855, 0.0, 823.02]
+ - [-160.855, 105.6, 823.02]
+ - [-160.855, -52.8, 872.9]
+ - [-160.855, 52.8, 872.9]
+ - [160.855, -52.8, -872.9]
+ - [160.855, 52.8, -872.9]
+ - [160.855, -105.6, -823.02]
+ - [160.855, 0.0, -823.02]
+ - [160.855, 105.6, -823.02]
+ - [160.855, -105.6, -773.14]
+ - [160.855, 0.0, -773.14]
+ - [160.855, 105.6, -773.14]
+ - [160.855, -52.8, -723.26]
+ - [160.855, 52.8, -723.26]
+ - [160.855, -52.8, -673.38]
+ - [160.855, 52.8, -673.38]
+ - [160.855, -105.6, -623.5]
+ - [160.855, 0.0, -623.5]
+ - [160.855, 105.6, -623.5]
+ - [160.855, -105.6, -573.62]
+ - [160.855, 0.0, -573.62]
+ - [160.855, 105.6, -573.62]
+ - [160.855, -52.8, -523.74]
+ - [160.855, 52.8, -523.74]
+ - [160.855, -52.8, -473.86]
+ - [160.855, 52.8, -473.86]
+ - [160.855, -105.6, -423.98]
+ - [160.855, 0.0, -423.98]
+ - [160.855, 105.6, -423.98]
+ - [160.855, -105.6, -374.1]
+ - [160.855, 0.0, -374.1]
+ - [160.855, 105.6, -374.1]
+ - [160.855, -52.8, -324.22]
+ - [160.855, 52.8, -324.22]
+ - [160.855, -52.8, -274.34]
+ - [160.855, 52.8, -274.34]
+ - [160.855, -105.6, -224.46]
+ - [160.855, 0.0, -224.46]
+ - [160.855, 105.6, -224.46]
+ - [160.855, -105.6, -174.58]
+ - [160.855, 0.0, -174.58]
+ - [160.855, 105.6, -174.58]
+ - [160.855, -52.8, -124.7]
+ - [160.855, 52.8, -124.7]
+ - [160.855, -52.8, -74.82]
+ - [160.855, 52.8, -74.82]
+ - [160.855, -105.6, -24.94]
+ - [160.855, 0.0, -24.94]
+ - [160.855, 105.6, -24.94]
+ - [160.855, -105.6, 24.94]
+ - [160.855, 0.0, 24.94]
+ - [160.855, 105.6, 24.94]
+ - [160.855, -52.8, 74.82]
+ - [160.855, 52.8, 74.82]
+ - [160.855, -52.8, 124.7]
+ - [160.855, 52.8, 124.7]
+ - [160.855, -105.6, 174.58]
+ - [160.855, 0.0, 174.58]
+ - [160.855, 105.6, 174.58]
+ - [160.855, -105.6, 224.46]
+ - [160.855, 0.0, 224.46]
+ - [160.855, 105.6, 224.46]
+ - [160.855, -52.8, 274.34]
+ - [160.855, 52.8, 274.34]
+ - [160.855, -52.8, 324.22]
+ - [160.855, 52.8, 324.22]
+ - [160.855, -105.6, 374.1]
+ - [160.855, 0.0, 374.1]
+ - [160.855, 105.6, 374.1]
+ - [160.855, -105.6, 423.98]
+ - [160.855, 0.0, 423.98]
+ - [160.855, 105.6, 423.98]
+ - [160.855, -52.8, 473.86]
+ - [160.855, 52.8, 473.86]
+ - [160.855, -52.8, 523.74]
+ - [160.855, 52.8, 523.74]
+ - [160.855, -105.6, 573.62]
+ - [160.855, 0.0, 573.62]
+ - [160.855, 105.6, 573.62]
+ - [160.855, -105.6, 623.5]
+ - [160.855, 0.0, 623.5]
+ - [160.855, 105.6, 623.5]
+ - [160.855, -52.8, 673.38]
+ - [160.855, 52.8, 673.38]
+ - [160.855, -52.8, 723.26]
+ - [160.855, 52.8, 723.26]
+ - [160.855, -105.6, 773.14]
+ - [160.855, 0.0, 773.14]
+ - [160.855, 105.6, 773.14]
+ - [160.855, -105.6, 823.02]
+ - [160.855, 0.0, 823.02]
+ - [160.855, 105.6, 823.02]
+ - [160.855, -52.8, 872.9]
+ - [160.855, 52.8, 872.9]
+crt:
+ logical_ids: [30, 31, 32, 33, 34, 40, 41, 42, 43, 44, 45, 46, 47, 50]
+ norms: [1, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 2, 2, 1]
+ dimensions:
+ - [1100.0, 2.5, 2560.0]
+ - [2.5, 162.0, 2560.0]
+ - [2.5, 162.0, 2560.0]
+ - [897.0, 162.0, 2.5]
+ - [1083.0, 160.0, 2.5]
+ - [8.27, 750.0, 1546.68]
+ - [8.27, 670.0, 800.0]
+ - [8.27, 750.0, 800.0]
+ - [8.27, 680.0, 1546.68]
+ - [8.27, 620.0, 800.0]
+ - [8.27, 730.0, 800.0]
+ - [1000.0, 800.0, 14.41]
+ - [1016.55, 480.0, 8.27]
+ - [1100.0, 2.0, 2560.0]
+ positions:
+ - [0.0, 617.389, 150.0]
+ - [555.265, 496.0, 150.0]
+ - [-555.265, 496.0, 150.0]
+ - [-91.5, 496.0, -1143.4]
+ - [0.5, 525.0, 1533.608]
+ - [530.355, 75.0, -400.0]
+ - [560.19, 55.0, -7.0]
+ - [530.355, 75.0, 759.34]
+ - [-530.355, 80.0, -400.0]
+ - [-560.19, 90.0, -7.0]
+ - [-530.355, 85.0, 759.34]
+ - [0.0, 100.0, -1127.535]
+ - [0.0, 80.0, 1173.855]
+ - [0.0, -346.652, 150.0]
diff --git a/spine/utils/geo/source/icarus_opdets.npy b/spine/utils/geo/source/icarus_opdets.npy
deleted file mode 100644
index c96e8b69..00000000
Binary files a/spine/utils/geo/source/icarus_opdets.npy and /dev/null differ
diff --git a/spine/utils/geo/source/icarus_sources.npy b/spine/utils/geo/source/icarus_sources.npy
deleted file mode 100644
index 7ec32031..00000000
Binary files a/spine/utils/geo/source/icarus_sources.npy and /dev/null differ
diff --git a/spine/utils/geo/source/ndlar_boundaries.npy b/spine/utils/geo/source/ndlar_boundaries.npy
deleted file mode 120000
index 21f8f62a..00000000
--- a/spine/utils/geo/source/ndlar_boundaries.npy
+++ /dev/null
@@ -1 +0,0 @@
-ndlar_v1_boundaries.npy
\ No newline at end of file
diff --git a/spine/utils/geo/source/ndlar_geometry.yaml b/spine/utils/geo/source/ndlar_geometry.yaml
new file mode 120000
index 00000000..f5d9c9d7
--- /dev/null
+++ b/spine/utils/geo/source/ndlar_geometry.yaml
@@ -0,0 +1 @@
+ndlar_v1_geometry.yaml
\ No newline at end of file
diff --git a/spine/utils/geo/source/ndlar_v0_boundaries.npy b/spine/utils/geo/source/ndlar_v0_boundaries.npy
deleted file mode 100644
index 06f10d9b..00000000
Binary files a/spine/utils/geo/source/ndlar_v0_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/ndlar_v0_geometry.yaml b/spine/utils/geo/source/ndlar_v0_geometry.yaml
new file mode 100644
index 00000000..49018101
--- /dev/null
+++ b/spine/utils/geo/source/ndlar_v0_geometry.yaml
@@ -0,0 +1,77 @@
+tpc:
+ dimensions: [50.4, 304.0, 97.28]
+ module_ids: [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10,
+ 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20,
+ 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30,
+ 31, 31, 32, 32, 33, 33, 34, 34]
+ positions:
+ - [-331.5, 3.387, 462.36]
+ - [-281.1, 3.387, 462.36]
+ - [-331.5, 3.387, 563.78]
+ - [-281.1, 3.387, 563.78]
+ - [-331.5, 3.387, 665.2]
+ - [-281.1, 3.387, 665.2]
+ - [-331.5, 3.387, 766.62]
+ - [-281.1, 3.387, 766.62]
+ - [-331.5, 3.387, 868.04]
+ - [-281.1, 3.387, 868.04]
+ - [-229.4, 3.387, 462.36]
+ - [-179.0, 3.387, 462.36]
+ - [-229.4, 3.387, 563.78]
+ - [-179.0, 3.387, 563.78]
+ - [-229.4, 3.387, 665.2]
+ - [-179.0, 3.387, 665.2]
+ - [-229.4, 3.387, 766.62]
+ - [-179.0, 3.387, 766.62]
+ - [-229.4, 3.387, 868.04]
+ - [-179.0, 3.387, 868.04]
+ - [-127.3, 3.387, 462.36]
+ - [-76.9, 3.387, 462.36]
+ - [-127.3, 3.387, 563.78]
+ - [-76.9, 3.387, 563.78]
+ - [-127.3, 3.387, 665.2]
+ - [-76.9, 3.387, 665.2]
+ - [-127.3, 3.387, 766.62]
+ - [-76.9, 3.387, 766.62]
+ - [-127.3, 3.387, 868.04]
+ - [-76.9, 3.387, 868.04]
+ - [-25.2, 3.387, 462.36]
+ - [25.2, 3.387, 462.36]
+ - [-25.2, 3.387, 563.78]
+ - [25.2, 3.387, 563.78]
+ - [-25.2, 3.387, 665.2]
+ - [25.2, 3.387, 665.2]
+ - [-25.2, 3.387, 766.62]
+ - [25.2, 3.387, 766.62]
+ - [-25.2, 3.387, 868.04]
+ - [25.2, 3.387, 868.04]
+ - [76.9, 3.387, 462.36]
+ - [127.3, 3.387, 462.36]
+ - [76.9, 3.387, 563.78]
+ - [127.3, 3.387, 563.78]
+ - [76.9, 3.387, 665.2]
+ - [127.3, 3.387, 665.2]
+ - [76.9, 3.387, 766.62]
+ - [127.3, 3.387, 766.62]
+ - [76.9, 3.387, 868.04]
+ - [127.3, 3.387, 868.04]
+ - [179.0, 3.387, 462.36]
+ - [229.4, 3.387, 462.36]
+ - [179.0, 3.387, 563.78]
+ - [229.4, 3.387, 563.78]
+ - [179.0, 3.387, 665.2]
+ - [229.4, 3.387, 665.2]
+ - [179.0, 3.387, 766.62]
+ - [229.4, 3.387, 766.62]
+ - [179.0, 3.387, 868.04]
+ - [229.4, 3.387, 868.04]
+ - [281.1, 3.387, 462.36]
+ - [331.5, 3.387, 462.36]
+ - [281.1, 3.387, 563.78]
+ - [331.5, 3.387, 563.78]
+ - [281.1, 3.387, 665.2]
+ - [331.5, 3.387, 665.2]
+ - [281.1, 3.387, 766.62]
+ - [331.5, 3.387, 766.62]
+ - [281.1, 3.387, 868.04]
+ - [331.5, 3.387, 868.04]
diff --git a/spine/utils/geo/source/ndlar_v1_boundaries.npy b/spine/utils/geo/source/ndlar_v1_boundaries.npy
deleted file mode 100644
index 913f9e90..00000000
Binary files a/spine/utils/geo/source/ndlar_v1_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/ndlar_v1_geometry.yaml b/spine/utils/geo/source/ndlar_v1_geometry.yaml
new file mode 100644
index 00000000..20af08e9
--- /dev/null
+++ b/spine/utils/geo/source/ndlar_v1_geometry.yaml
@@ -0,0 +1,77 @@
+tpc:
+ dimensions: [50.4, 304.0, 97.28]
+ module_ids: [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10,
+ 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20,
+ 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30,
+ 31, 31, 32, 32, 33, 33, 34, 34]
+ positions:
+ - [-325.2, -62.929, 465.7559]
+ - [-274.8, -62.929, 465.7559]
+ - [-325.2, -62.929, 565.7559]
+ - [-274.8, -62.929, 565.7559]
+ - [-325.2, -62.929, 665.7559]
+ - [-274.8, -62.929, 665.7559]
+ - [-325.2, -62.929, 765.7559]
+ - [-274.8, -62.929, 765.7559]
+ - [-325.2, -62.929, 865.7559]
+ - [-274.8, -62.929, 865.7559]
+ - [-225.2, -62.929, 465.7559]
+ - [-174.8, -62.929, 465.7559]
+ - [-225.2, -62.929, 565.7559]
+ - [-174.8, -62.929, 565.7559]
+ - [-225.2, -62.929, 665.7559]
+ - [-174.8, -62.929, 665.7559]
+ - [-225.2, -62.929, 765.7559]
+ - [-174.8, -62.929, 765.7559]
+ - [-225.2, -62.929, 865.7559]
+ - [-174.8, -62.929, 865.7559]
+ - [-125.2, -62.929, 465.7559]
+ - [-74.8, -62.929, 465.7559]
+ - [-125.2, -62.929, 565.7559]
+ - [-74.8, -62.929, 565.7559]
+ - [-125.2, -62.929, 665.7559]
+ - [-74.8, -62.929, 665.7559]
+ - [-125.2, -62.929, 765.7559]
+ - [-74.8, -62.929, 765.7559]
+ - [-125.2, -62.929, 865.7559]
+ - [-74.8, -62.929, 865.7559]
+ - [-25.2, -62.929, 465.7559]
+ - [25.2, -62.929, 465.7559]
+ - [-25.2, -62.929, 565.7559]
+ - [25.2, -62.929, 565.7559]
+ - [-25.2, -62.929, 665.7559]
+ - [25.2, -62.929, 665.7559]
+ - [-25.2, -62.929, 765.7559]
+ - [25.2, -62.929, 765.7559]
+ - [-25.2, -62.929, 865.7559]
+ - [25.2, -62.929, 865.7559]
+ - [74.8, -62.929, 465.7559]
+ - [125.2, -62.929, 465.7559]
+ - [74.8, -62.929, 565.7559]
+ - [125.2, -62.929, 565.7559]
+ - [74.8, -62.929, 665.7559]
+ - [125.2, -62.929, 665.7559]
+ - [74.8, -62.929, 765.7559]
+ - [125.2, -62.929, 765.7559]
+ - [74.8, -62.929, 865.7559]
+ - [125.2, -62.929, 865.7559]
+ - [174.8, -62.929, 465.7559]
+ - [225.2, -62.929, 465.7559]
+ - [174.8, -62.929, 565.7559]
+ - [225.2, -62.929, 565.7559]
+ - [174.8, -62.929, 665.7559]
+ - [225.2, -62.929, 665.7559]
+ - [174.8, -62.929, 765.7559]
+ - [225.2, -62.929, 765.7559]
+ - [174.8, -62.929, 865.7559]
+ - [225.2, -62.929, 865.7559]
+ - [274.8, -62.929, 465.7559]
+ - [325.2, -62.929, 465.7559]
+ - [274.8, -62.929, 565.7559]
+ - [325.2, -62.929, 565.7559]
+ - [274.8, -62.929, 665.7559]
+ - [325.2, -62.929, 665.7559]
+ - [274.8, -62.929, 765.7559]
+ - [325.2, -62.929, 765.7559]
+ - [274.8, -62.929, 865.7559]
+ - [325.2, -62.929, 865.7559]
diff --git a/spine/utils/geo/source/sbnd_boundaries.npy b/spine/utils/geo/source/sbnd_boundaries.npy
deleted file mode 100644
index beb9f15b..00000000
Binary files a/spine/utils/geo/source/sbnd_boundaries.npy and /dev/null differ
diff --git a/spine/utils/geo/source/sbnd_geometry.yaml b/spine/utils/geo/source/sbnd_geometry.yaml
new file mode 100644
index 00000000..5df9531a
--- /dev/null
+++ b/spine/utils/geo/source/sbnd_geometry.yaml
@@ -0,0 +1,339 @@
+tpc:
+ dimensions: [201.3, 400.016, 499.51562]
+ module_ids: [0, 0]
+ positions:
+ - [-100.65, 0.0, 254.70019]
+ - [100.65, 0.0, 254.70019]
+optical:
+ volume: module
+ shape: [box, ellipsoid]
+ global_index: true
+ dimensions:
+ - [0.2, 20.6, 7.6]
+ - [10.8, 20.2, 20.2]
+ shape_ids: [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0]
+ positions:
+ - [-213.75, -135.0, -234.44599]
+ - [213.75, -135.0, -234.44599]
+ - [-213.75, 0.0, -234.44599]
+ - [213.75, 0.0, -234.44599]
+ - [-213.75, 135.0, -234.44599]
+ - [213.75, 135.0, -234.44599]
+ - [-213.4, -175.0, -226.82599]
+ - [213.4, -175.0, -226.82599]
+ - [-213.4, -95.0, -226.82599]
+ - [213.4, -95.0, -226.82599]
+ - [-213.4, -40.0, -226.82599]
+ - [213.4, -40.0, -226.82599]
+ - [-213.4, 40.0, -226.82599]
+ - [213.4, 40.0, -226.82599]
+ - [-213.4, 95.0, -226.82599]
+ - [213.4, 95.0, -226.82599]
+ - [-213.4, 175.0, -226.82599]
+ - [213.4, 175.0, -226.82599]
+ - [-213.75, -135.0, -219.20599]
+ - [213.75, -135.0, -219.20599]
+ - [-213.75, 0.0, -219.20599]
+ - [213.75, 0.0, -219.20599]
+ - [-213.75, 135.0, -219.20599]
+ - [213.75, 135.0, -219.20599]
+ - [-213.75, -175.0, -204.44599]
+ - [213.75, -175.0, -204.44599]
+ - [-213.75, -95.0, -204.44599]
+ - [213.75, -95.0, -204.44599]
+ - [-213.75, -40.0, -204.44599]
+ - [213.75, -40.0, -204.44599]
+ - [-213.75, 40.0, -204.44599]
+ - [213.75, 40.0, -204.44599]
+ - [-213.75, 95.0, -204.44599]
+ - [213.75, 95.0, -204.44599]
+ - [-213.75, 175.0, -204.44599]
+ - [213.75, 175.0, -204.44599]
+ - [-213.4, -135.0, -196.82599]
+ - [213.4, -135.0, -196.82599]
+ - [-213.4, 0.0, -196.82599]
+ - [213.4, 0.0, -196.82599]
+ - [-213.4, 135.0, -196.82599]
+ - [213.4, 135.0, -196.82599]
+ - [-213.75, -175.0, -189.20599]
+ - [213.75, -175.0, -189.20599]
+ - [-213.75, -95.0, -189.20599]
+ - [213.75, -95.0, -189.20599]
+ - [-213.75, -40.0, -189.20599]
+ - [213.75, -40.0, -189.20599]
+ - [-213.75, 40.0, -189.20599]
+ - [213.75, 40.0, -189.20599]
+ - [-213.75, 95.0, -189.20599]
+ - [213.75, 95.0, -189.20599]
+ - [-213.75, 175.0, -189.20599]
+ - [213.75, 175.0, -189.20599]
+ - [-213.75, -135.0, -174.44599]
+ - [213.75, -135.0, -174.44599]
+ - [-213.75, 0.0, -174.44599]
+ - [213.75, 0.0, -174.44599]
+ - [-213.75, 135.0, -174.44599]
+ - [213.75, 135.0, -174.44599]
+ - [-213.4, -175.0, -166.82599]
+ - [213.4, -175.0, -166.82599]
+ - [-213.4, -95.0, -166.82599]
+ - [213.4, -95.0, -166.82599]
+ - [-213.4, -40.0, -166.82599]
+ - [213.4, -40.0, -166.82599]
+ - [-213.4, 40.0, -166.82599]
+ - [213.4, 40.0, -166.82599]
+ - [-213.4, 95.0, -166.82599]
+ - [213.4, 95.0, -166.82599]
+ - [-213.4, 175.0, -166.82599]
+ - [213.4, 175.0, -166.82599]
+ - [-213.75, -135.0, -159.20599]
+ - [213.75, -135.0, -159.20599]
+ - [-213.75, 0.0, -159.20599]
+ - [213.75, 0.0, -159.20599]
+ - [-213.75, 135.0, -159.20599]
+ - [213.75, 135.0, -159.20599]
+ - [-213.75, -135.0, -101.16219]
+ - [213.75, -135.0, -101.16219]
+ - [-213.75, 0.0, -101.16219]
+ - [213.75, 0.0, -101.16219]
+ - [-213.75, 135.0, -101.16219]
+ - [213.75, 135.0, -101.16219]
+ - [-213.4, -175.0, -93.54219]
+ - [213.4, -175.0, -93.54219]
+ - [-213.4, -95.0, -93.54219]
+ - [213.4, -95.0, -93.54219]
+ - [-213.4, -40.0, -93.54219]
+ - [213.4, -40.0, -93.54219]
+ - [-213.4, 40.0, -93.54219]
+ - [213.4, 40.0, -93.54219]
+ - [-213.4, 95.0, -93.54219]
+ - [213.4, 95.0, -93.54219]
+ - [-213.4, 175.0, -93.54219]
+ - [213.4, 175.0, -93.54219]
+ - [-213.75, -135.0, -85.92219]
+ - [213.75, -135.0, -85.92219]
+ - [-213.75, 0.0, -85.92219]
+ - [213.75, 0.0, -85.92219]
+ - [-213.75, 135.0, -85.92219]
+ - [213.75, 135.0, -85.92219]
+ - [-213.75, -175.0, -71.16219]
+ - [213.75, -175.0, -71.16219]
+ - [-213.75, -95.0, -71.16219]
+ - [213.75, -95.0, -71.16219]
+ - [-213.75, -40.0, -71.16219]
+ - [213.75, -40.0, -71.16219]
+ - [-213.75, 40.0, -71.16219]
+ - [213.75, 40.0, -71.16219]
+ - [-213.75, 95.0, -71.16219]
+ - [213.75, 95.0, -71.16219]
+ - [-213.75, 175.0, -71.16219]
+ - [213.75, 175.0, -71.16219]
+ - [-213.4, -135.0, -63.54219]
+ - [213.4, -135.0, -63.54219]
+ - [-213.4, 0.0, -63.54219]
+ - [213.4, 0.0, -63.54219]
+ - [-213.4, 135.0, -63.54219]
+ - [213.4, 135.0, -63.54219]
+ - [-213.75, -175.0, -55.92219]
+ - [213.75, -175.0, -55.92219]
+ - [-213.75, -95.0, -55.92219]
+ - [213.75, -95.0, -55.92219]
+ - [-213.75, -40.0, -55.92219]
+ - [213.75, -40.0, -55.92219]
+ - [-213.75, 40.0, -55.92219]
+ - [213.75, 40.0, -55.92219]
+ - [-213.75, 95.0, -55.92219]
+ - [213.75, 95.0, -55.92219]
+ - [-213.75, 175.0, -55.92219]
+ - [213.75, 175.0, -55.92219]
+ - [-213.75, -135.0, -41.16219]
+ - [213.75, -135.0, -41.16219]
+ - [-213.75, 0.0, -41.16219]
+ - [213.75, 0.0, -41.16219]
+ - [-213.75, 135.0, -41.16219]
+ - [213.75, 135.0, -41.16219]
+ - [-213.4, -175.0, -33.54219]
+ - [213.4, -175.0, -33.54219]
+ - [-213.4, -95.0, -33.54219]
+ - [213.4, -95.0, -33.54219]
+ - [-213.4, -40.0, -33.54219]
+ - [213.4, -40.0, -33.54219]
+ - [-213.4, 40.0, -33.54219]
+ - [213.4, 40.0, -33.54219]
+ - [-213.4, 95.0, -33.54219]
+ - [213.4, 95.0, -33.54219]
+ - [-213.4, 175.0, -33.54219]
+ - [213.4, 175.0, -33.54219]
+ - [-213.75, -135.0, -25.92219]
+ - [213.75, -135.0, -25.92219]
+ - [-213.75, 0.0, -25.92219]
+ - [213.75, 0.0, -25.92219]
+ - [-213.75, 135.0, -25.92219]
+ - [213.75, 135.0, -25.92219]
+ - [-213.75, -135.0, 25.92181]
+ - [213.75, -135.0, 25.92181]
+ - [-213.75, 0.0, 25.92181]
+ - [213.75, 0.0, 25.92181]
+ - [-213.75, 135.0, 25.92181]
+ - [213.75, 135.0, 25.92181]
+ - [-213.4, -175.0, 33.54181]
+ - [213.4, -175.0, 33.54181]
+ - [-213.4, -95.0, 33.54181]
+ - [213.4, -95.0, 33.54181]
+ - [-213.4, -40.0, 33.54181]
+ - [213.4, -40.0, 33.54181]
+ - [-213.4, 40.0, 33.54181]
+ - [213.4, 40.0, 33.54181]
+ - [-213.4, 95.0, 33.54181]
+ - [213.4, 95.0, 33.54181]
+ - [-213.4, 175.0, 33.54181]
+ - [213.4, 175.0, 33.54181]
+ - [-213.75, -135.0, 41.16181]
+ - [213.75, -135.0, 41.16181]
+ - [-213.75, 0.0, 41.16181]
+ - [213.75, 0.0, 41.16181]
+ - [-213.75, 135.0, 41.16181]
+ - [213.75, 135.0, 41.16181]
+ - [-213.75, -175.0, 55.92181]
+ - [213.75, -175.0, 55.92181]
+ - [-213.75, -95.0, 55.92181]
+ - [213.75, -95.0, 55.92181]
+ - [-213.75, -40.0, 55.92181]
+ - [213.75, -40.0, 55.92181]
+ - [-213.75, 40.0, 55.92181]
+ - [213.75, 40.0, 55.92181]
+ - [-213.75, 95.0, 55.92181]
+ - [213.75, 95.0, 55.92181]
+ - [-213.75, 175.0, 55.92181]
+ - [213.75, 175.0, 55.92181]
+ - [-213.4, -135.0, 63.54181]
+ - [213.4, -135.0, 63.54181]
+ - [-213.4, 0.0, 63.54181]
+ - [213.4, 0.0, 63.54181]
+ - [-213.4, 135.0, 63.54181]
+ - [213.4, 135.0, 63.54181]
+ - [-213.75, -175.0, 71.16181]
+ - [213.75, -175.0, 71.16181]
+ - [-213.75, -95.0, 71.16181]
+ - [213.75, -95.0, 71.16181]
+ - [-213.75, -40.0, 71.16181]
+ - [213.75, -40.0, 71.16181]
+ - [-213.75, 40.0, 71.16181]
+ - [213.75, 40.0, 71.16181]
+ - [-213.75, 95.0, 71.16181]
+ - [213.75, 95.0, 71.16181]
+ - [-213.75, 175.0, 71.16181]
+ - [213.75, 175.0, 71.16181]
+ - [-213.75, -135.0, 85.92181]
+ - [213.75, -135.0, 85.92181]
+ - [-213.75, 0.0, 85.92181]
+ - [213.75, 0.0, 85.92181]
+ - [-213.75, 135.0, 85.92181]
+ - [213.75, 135.0, 85.92181]
+ - [-213.4, -175.0, 93.54181]
+ - [213.4, -175.0, 93.54181]
+ - [-213.4, -95.0, 93.54181]
+ - [213.4, -95.0, 93.54181]
+ - [-213.4, -40.0, 93.54181]
+ - [213.4, -40.0, 93.54181]
+ - [-213.4, 40.0, 93.54181]
+ - [213.4, 40.0, 93.54181]
+ - [-213.4, 95.0, 93.54181]
+ - [213.4, 95.0, 93.54181]
+ - [-213.4, 175.0, 93.54181]
+ - [213.4, 175.0, 93.54181]
+ - [-213.75, -135.0, 101.16181]
+ - [213.75, -135.0, 101.16181]
+ - [-213.75, 0.0, 101.16181]
+ - [213.75, 0.0, 101.16181]
+ - [-213.75, 135.0, 101.16181]
+ - [213.75, 135.0, 101.16181]
+ - [-213.75, -135.0, 159.20581]
+ - [213.75, -135.0, 159.20581]
+ - [-213.75, 0.0, 159.20581]
+ - [213.75, 0.0, 159.20581]
+ - [-213.75, 135.0, 159.20581]
+ - [213.75, 135.0, 159.20581]
+ - [-213.4, -175.0, 166.82581]
+ - [213.4, -175.0, 166.82581]
+ - [-213.4, -95.0, 166.82581]
+ - [213.4, -95.0, 166.82581]
+ - [-213.4, -40.0, 166.82581]
+ - [213.4, -40.0, 166.82581]
+ - [-213.4, 40.0, 166.82581]
+ - [213.4, 40.0, 166.82581]
+ - [-213.4, 95.0, 166.82581]
+ - [213.4, 95.0, 166.82581]
+ - [-213.4, 175.0, 166.82581]
+ - [213.4, 175.0, 166.82581]
+ - [-213.75, -135.0, 174.44581]
+ - [213.75, -135.0, 174.44581]
+ - [-213.75, 0.0, 174.44581]
+ - [213.75, 0.0, 174.44581]
+ - [-213.75, 135.0, 174.44581]
+ - [213.75, 135.0, 174.44581]
+ - [-213.75, -175.0, 189.20581]
+ - [213.75, -175.0, 189.20581]
+ - [-213.75, -95.0, 189.20581]
+ - [213.75, -95.0, 189.20581]
+ - [-213.75, -40.0, 189.20581]
+ - [213.75, -40.0, 189.20581]
+ - [-213.75, 40.0, 189.20581]
+ - [213.75, 40.0, 189.20581]
+ - [-213.75, 95.0, 189.20581]
+ - [213.75, 95.0, 189.20581]
+ - [-213.75, 175.0, 189.20581]
+ - [213.75, 175.0, 189.20581]
+ - [-213.4, -135.0, 196.82581]
+ - [213.4, -135.0, 196.82581]
+ - [-213.4, 0.0, 196.82581]
+ - [213.4, 0.0, 196.82581]
+ - [-213.4, 135.0, 196.82581]
+ - [213.4, 135.0, 196.82581]
+ - [-213.75, -175.0, 204.44581]
+ - [213.75, -175.0, 204.44581]
+ - [-213.75, -95.0, 204.44581]
+ - [213.75, -95.0, 204.44581]
+ - [-213.75, -40.0, 204.44581]
+ - [213.75, -40.0, 204.44581]
+ - [-213.75, 40.0, 204.44581]
+ - [213.75, 40.0, 204.44581]
+ - [-213.75, 95.0, 204.44581]
+ - [213.75, 95.0, 204.44581]
+ - [-213.75, 175.0, 204.44581]
+ - [213.75, 175.0, 204.44581]
+ - [-213.75, -135.0, 219.20581]
+ - [213.75, -135.0, 219.20581]
+ - [-213.75, 0.0, 219.20581]
+ - [213.75, 0.0, 219.20581]
+ - [-213.75, 135.0, 219.20581]
+ - [213.75, 135.0, 219.20581]
+ - [-213.4, -175.0, 226.82581]
+ - [213.4, -175.0, 226.82581]
+ - [-213.4, -95.0, 226.82581]
+ - [213.4, -95.0, 226.82581]
+ - [-213.4, -40.0, 226.82581]
+ - [213.4, -40.0, 226.82581]
+ - [-213.4, 40.0, 226.82581]
+ - [213.4, 40.0, 226.82581]
+ - [-213.4, 95.0, 226.82581]
+ - [213.4, 95.0, 226.82581]
+ - [-213.4, 175.0, 226.82581]
+ - [213.4, 175.0, 226.82581]
+ - [-213.75, -135.0, 234.44581]
+ - [213.75, -135.0, 234.44581]
+ - [-213.75, 0.0, 234.44581]
+ - [213.75, 0.0, 234.44581]
+ - [-213.75, 135.0, 234.44581]
+ - [213.75, 135.0, 234.44581]
diff --git a/spine/utils/geo/source/sbnd_sources.npy b/spine/utils/geo/source/sbnd_sources.npy
deleted file mode 100644
index 89152193..00000000
Binary files a/spine/utils/geo/source/sbnd_sources.npy and /dev/null differ
diff --git a/spine/utils/ppn.py b/spine/utils/ppn.py
index 6e915387..9b3c05cc 100644
--- a/spine/utils/ppn.py
+++ b/spine/utils/ppn.py
@@ -546,6 +546,67 @@ def get_ppn_labels(particle_v, meta, dtype, dim=3, min_voxel_count=1,
return np.array(part_info, dtype=dtype)
+def get_vertex_labels(particle_v, neutrino_v, meta, dtype):
+ """Gets particle vertex coordinates.
+
+ It provides the coordinates of points where multiple particles originate:
+ - If the `neutrino_event` is provided, it simply uses the coordinates of
+ the neutrino interaction points.
+ - If the `particle_event` is provided instead, it looks for ancestor point
+ positions shared by at least two **primary** particles.
+
+ Parameters
+ ----------
+ particle_v : List[larcv.Particle]
+ List of LArCV particle objects in the image
+ neutrino_v : List[larcv.Neutrino]
+ List of LArCV neutrino objects in the image
+ meta : larcv::Voxel3DMeta or larcv::ImageMeta
+ Metadata information
+ dtype : str
+ Typing of the output PPN labels
+
+ Returns
+ -------
+ np.array
+ Array of points of shape (N, 4) where 4 = x, y, z, vertex_id
+ """
+ # If the particles are provided, find unique ancestors
+ vertexes = []
+ if particle_v is not None:
+ # Fetch all ancestor positions of primary particles
+ anc_positions = []
+ for i, p in enumerate(particle_v):
+ if p.parent_id() == p.id() or p.ancestor_pdg_code() == 111:
+ if image_contains(meta, p.ancestor_position()):
+ anc_pos = image_coordinates(meta, p.ancestor_position())
+ anc_positions.append(anc_pos)
+
+ # If there is no primary, nothing to do
+ if not len(anc_positions):
+ return np.empty((0, 4), dtype=dtype)
+
+ # Find those that appear > once
+ anc_positions = np.vstack(anc_positions)
+ unique_positions, counts = np.unique(
+ anc_positions, return_counts=True, axis=0)
+ for i, idx in enumerate(np.where(counts > 1)[0]):
+ vertexes.append([*unique_positions[idx], i])
+
+ # If the neutrinos are provided, straightforward
+ if neutrino_v is not None:
+ for i, n in enumerate(neutrino_v):
+ if image_contains(meta, n.position()):
+ nu_pos = image_coordinates(meta, n.position())
+ vertexes.append([*nu_pos, i])
+
+ # If there are no vertex, nothing to do
+ if not len(vertexes):
+ return np.empty((0, 4), dtype=dtype)
+
+ return np.vstack(vertexes).astype(dtype)
+
+
def image_contains(meta, point, dim=3):
"""Checks whether a point is contained in the image box defined by meta.
diff --git a/spine/utils/unwrap.py b/spine/utils/unwrap.py
index 08ad41ab..bf9a46e5 100644
--- a/spine/utils/unwrap.py
+++ b/spine/utils/unwrap.py
@@ -33,7 +33,7 @@ def __init__(self, geometry=None, remove_batch_col=False):
Remove column which specifies batch ID from the unwrapped tensors
"""
self.geo = geometry
- self.num_volumes = self.geo.num_modules if self.geo else 1
+ self.num_volumes = self.geo.tpc.num_modules if self.geo else 1
self.remove_batch_col = remove_batch_col
def __call__(self, data):
diff --git a/spine/vis/__init__.py b/spine/vis/__init__.py
index 51b58dab..17c7baea 100644
--- a/spine/vis/__init__.py
+++ b/spine/vis/__init__.py
@@ -1,12 +1,12 @@
"""Module which centralizes all tools used to visualize data."""
from .out import Drawer
+from .geo import GeoDrawer
from .train import TrainDrawer
from .point import scatter_points
from .cluster import scatter_clusters
from .box import scatter_boxes
from .particle import scatter_particles
from .network import network_topology, network_schematic
-from .detector import detector_traces
from .evaluation import heatmap, annotate_heatmap
from .layout import layout3d, dual_figure3d
diff --git a/spine/vis/box.py b/spine/vis/box.py
index babdbdbb..ef2890f2 100644
--- a/spine/vis/box.py
+++ b/spine/vis/box.py
@@ -130,7 +130,9 @@ def box_trace(lower, upper, draw_faces=False, line=None, linewidth=None,
def box_traces(lowers, uppers, draw_faces=False, color=None, linewidth=None,
- hovertext=None, shared_legend=True, name=None, **kwargs):
+ hovertext=None, cmin=None, cmax=None, shared_legend=True,
+ legendgroup=None, showlegend=True, group_name=None, name=None,
+ **kwargs):
"""Function which produces a list of plotly traces of boxes given a list of
lower bounds and upper bounds in x, y and z.
@@ -148,8 +150,16 @@ def box_traces(lowers, uppers, draw_faces=False, color=None, linewidth=None,
Width of the box edge lines
hovertext : Union[int, str, np.ndarray], optional
Text associated with every box or each box
+ cmin : float, optional
+ Minimum value along the color scale
+ cmax : float, optional
+ Maximum value along the color scale
shared_legend : bool, default True
If True, the plotly legend of all boxes is shared as one
+ legendgroup : str, optional
+ Legend group to be shared between all boxes
+ showlegend : bool, default `True`
+ Whether to show legends on not
name : str, optional
Name of the trace(s)
**kwargs : dict, optional
@@ -172,9 +182,24 @@ def box_traces(lowers, uppers, draw_faces=False, color=None, linewidth=None,
len(hovertext) == len(lowers)), (
"Specify one hovertext for all boxes, or one hovertext per box.")
+ # If one color is provided per box, give an associated hovertext
+ if hovertext is None and color is not None and not np.isscalar(color):
+ hovertext = [f'Value: {v:0.3f}' for v in color]
+
+ # If cmin/cmax are not provided, must build them so that all boxes
+ # share the same colorscale range (not guaranteed otherwise)
+ if color is not None and not np.isscalar(color) and len(color) > 0:
+ if cmin is None:
+ cmin = np.min(color)
+ if cmax is None:
+ cmax = np.max(color)
+
+ # If the legend is to be shared, make sure there is a common legend group
+ if shared_legend and legendgroup is None:
+ legendgroup = 'group_' + str(time.time())
+
# Loop over the list of box boundaries
traces = []
- group_name = 'group_' + str(time.time())
for i, (lower, upper) in enumerate(zip(lowers, uppers)):
# Fetch the right color/hovertext combination
col, hov = color, hovertext
@@ -184,17 +209,16 @@ def box_traces(lowers, uppers, draw_faces=False, color=None, linewidth=None,
hov = hovertext[i]
# If the legend is shared, only draw the legend of the first trace
- legendgroup, showlegend, name_i = None, True, name
if shared_legend:
- legendgroup = group_name
- showlegend = i == 0
+ showlegend = showlegend and i == 0
+ name_i = name
else:
name_i = f'{name} {i}'
# Append list of traces
traces.append(box_trace(
- lower, upper, draw_faces, linewidth=linewidth,
- color=col, hovertext=hov, legendgroup=legendgroup,
+ lower, upper, draw_faces, linewidth=linewidth, color=col,
+ hovertext=hov, cmin=cmin, cmax=cmax, legendgroup=legendgroup,
showlegend=showlegend, name=name_i, **kwargs))
return traces
diff --git a/spine/vis/cluster.py b/spine/vis/cluster.py
index 73c93ba5..8a6da1e4 100644
--- a/spine/vis/cluster.py
+++ b/spine/vis/cluster.py
@@ -145,12 +145,17 @@ def scatter_clusters(points, clusts, color=None, hovertext=None,
# If cmin/cmax are not provided, must build them so that all clusters
# share the same colorscale range (not guaranteed otherwise)
- if color is not None and len(color) and cmin is None or cmax is None:
- if np.isscalar(color[0]):
- cmin, cmax = np.min(color), np.max(color)
- else:
- cmin = np.min(np.concatenate(color))
- cmax = np.max(np.concatenate(color))
+ if color is not None and not np.isscalar(color) and len(color) > 0:
+ if cmin is None:
+ if np.isscalar(color[0]):
+ cmin= np.min(color)
+ else:
+ cmin = np.min(np.concatenate(color))
+ if cmax is None:
+ if np.isscalar(color[0]):
+ cmax = np.min(color)
+ else:
+ cmax = np.min(np.concatenate(color))
# Loop over the list of clusters
traces = []
diff --git a/spine/vis/detector.py b/spine/vis/detector.py
deleted file mode 100644
index 37e524e7..00000000
--- a/spine/vis/detector.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""Draw detectors based on their geometry definition."""
-
-from spine.utils.geo import Geometry
-
-from .box import box_traces
-
-
-def detector_traces(detector=None, boundaries=None, meta=None,
- detector_coords=True, draw_faces=False, shared_legend=True,
- name='Detector', color='rgba(0,0,0,0.150)',
- linewidth=5, **kwargs):
- """Function which takes loads a file with detector boundaries and
- produces a list of traces which represent them in a 3D event display.
-
- The detector boundary file is a `.npy` or `.npz` file which contains
- a single tensor of shape (N, 3, 2), with N the number of detector
- volumes. The first column for each volume represents the lower boundary
- and the second the upper boundary. The boundaries must be ordered.
-
- Parameters
- ----------
- detector : str, optional
- Name of a recognized detector to the geometry from
- boundaries : str, optional
- Name of a recognized detector to get the geometry from or path
- to a `.npy` boundary file to load the boundaries from.
- meta : Meta, optional
- Metadata information (only needed if pixel_coordinates is True)
- detector_coords : bool, default False
- If False, the coordinates are converted to pixel indices
- draw_faces : bool, default False
- Weather or not to draw the box faces, or only the edges
- shared_legend : bool, default True
- If True, the legend entry in plotly is shared between all the
- detector volumes
- name : Union[str, List[str]], default 'Detector'
- Name(s) of the detector volumes
- color : Union[int, str, np.ndarray]
- Color of boxes or list of color of boxes
- linewidth : int, default 2
- Width of the box edge lines
- **kwargs : dict, optional
- List of additional arguments to pass to
- spine.viusalization.boxes.box_traces
-
- Returns
- -------
- List[Union[plotly.graph_objs.Scatter3D, plotly.graph_objs.Mesh3D]]
- List of detector traces (one per TPC)
- """
- # Load the list of boundaries
- boundaries = Geometry(detector, boundaries).tpcs
-
- # If required, convert to pixel coordinates
- if not detector_coords:
- assert meta is not None, (
- "Must provide meta information to convert the detector "
- "boundaries to pixel coordinates.")
- boundaries = meta.to_px(
- boundaries.transpose(0,2,1)).transpose(0,2,1)
-
- # Get a trace per detector volume
- detectors = box_traces(
- boundaries[..., 0], boundaries[..., 1], draw_faces=draw_faces,
- color=color, linewidth=linewidth, shared_legend=shared_legend,
- name=name, **kwargs)
-
- return detectors
diff --git a/spine/vis/ellipsoid.py b/spine/vis/ellipsoid.py
index 25229f3e..1c8c8295 100644
--- a/spine/vis/ellipsoid.py
+++ b/spine/vis/ellipsoid.py
@@ -1,13 +1,16 @@
"""Module to convert a point cloud into an ellipsoidal envelope."""
+import time
+
import numpy as np
from scipy.special import gammaincinv # pylint: disable=E0611
import plotly.graph_objs as go
-def ellipsoid_trace(points, contour=0.5, num_samples=10, color=None,
+def ellipsoid_trace(points=None, centroid=None, covmat=None, contour=0.5,
+ num_samples=10, color=None, intensity=None, hovertext=None,
showscale=False, **kwargs):
- """Converts a cloud of points into a 3D ellipsoid.
+ """Converts a cloud of points or a covariance matrix into a 3D ellipsoid.
This function uses the centroid and the covariance matrix of a cloud of
points to define an ellipsoid which would encompass a user-defined fraction
@@ -16,8 +19,12 @@ def ellipsoid_trace(points, contour=0.5, num_samples=10, color=None,
Parameters
----------
- points : np.ndarray
+ points : np.ndarray, optional
(N, 3) Array of point coordinates
+ centroid : np.ndarray, optional
+ (3) Centroid
+ covmat : np.ndarray, optional
+ (3, 3) Covariance matrix which defines the ellipsoid shape
contour : float, default 0.5
Fraction of the points contained in the ellipsoid, under the
Gaussian distribution assumption
@@ -26,11 +33,29 @@ def ellipsoid_trace(points, contour=0.5, num_samples=10, color=None,
system of the ellipsoid. A larger number increases the resolution.
color : Union[str, float], optional
Color of ellipse
+ intensity : Union[int, float], optional
+ Color intensity of the box along the colorscale axis
+ hovertext : Union[int, str, np.ndarray], optional
+ Text associated with the box
showscale : bool, default False
If True, show the colorscale of the :class:`plotly.graph_objs.Mesh3d`
**kwargs : dict, optional
- Additional parameters to pass to the
+ Additional parameters to pass to the underlying
+ :class:`plotly.graph_objs.Mesh3d` object
"""
+ # Ensure that either a cloud of points or a covariance matrix is provided
+ assert (points is not None) ^ (centroid is not None and covmat is not None), (
+ "Must provide either `points` or both `centroid` and `covmat`.")
+
+ # Update hovertemplate style
+ hovertemplate = 'x: %{x}
y: %{y}
z: %{z}'
+ if hovertext is not None:
+ if not np.isscalar(hovertext):
+ hovertemplate += '
%{text}'
+ else:
+ hovertemplate += f'
{hovertext}'
+ hovertext = None
+
# Compute the points on a unit sphere
phi = np.linspace(0, 2*np.pi, num=num_samples)
theta = np.linspace(-np.pi/2, np.pi/2, num=num_samples)
@@ -40,9 +65,10 @@ def ellipsoid_trace(points, contour=0.5, num_samples=10, color=None,
z = np.sin(theta)
unit_points = np.vstack((x.flatten(), y.flatten(), z.flatten())).T
- # Get the centroid and the covariance matrix
- centroid = np.mean(points, axis=0)
- covmat = np.cov((points - centroid).T)
+ # Get the centroid and the covariance matrix, if needed
+ if points is not None:
+ centroid = np.mean(points, axis=0)
+ covmat = np.cov((points - centroid).T)
# Diagonalize the covariance matrix, get rotation matrix
w, v = np.linalg.eigh(covmat)
@@ -51,19 +77,110 @@ def ellipsoid_trace(points, contour=0.5, num_samples=10, color=None,
# Compute the radius corresponding to the contour probability and rotate
# the points into the basis of the covariance matrix
- assert contour > 0. and contour < 1., (
- "The `contour` parameter should be a probability.")
- radius = np.sqrt(2*gammaincinv(1.5, contour))
+ radius = 1.
+ if contour is not None:
+ assert contour > 0. and contour < 1., (
+ "The `contour` parameter should be a probability.")
+ radius = np.sqrt(2*gammaincinv(1.5, contour))
+
ell_points = centroid + radius*np.dot(unit_points, rotmat)
- # Convert the color provided to a set of intensities
- intensity = None
- if color is not None:
- assert np.isscalar('color'), (
- "Should provide a single color for the ellipsoid.")
- intensity = [color]*len(ell_points)
+ # Convert the color provided to a set of intensities, if needed
+ if color is not None and not isinstance(color, str):
+ assert intensity is None, (
+ "Must not provide both `color` and `intensity`.")
+ intensity = np.full(len(ell_points), color)
+ color = None
# Append Mesh3d object
return go.Mesh3d(
x=ell_points[:, 0], y=ell_points[:, 1], z=ell_points[:, 2],
- intensity=intensity, alphahull=0, showscale=showscale, **kwargs)
+ color=color, intensity=intensity, alphahull=0, showscale=showscale,
+ hovertemplate=hovertemplate, **kwargs)
+
+
+def ellipsoid_traces(centroids, covmat, color=None, hovertext=None, cmin=None,
+ cmax=None, shared_legend=True, legendgroup=None,
+ showlegend=True, name=None, **kwargs):
+ """Function which produces a list of plotly traces of ellipsoids given a
+ list of centroids and one covariance matrix in x, y and z.
+
+ Parameters
+ ----------
+ centroids : np.ndarray
+ (N, 3) Positions of each of the ellipsoid centroids
+ covmat : np.ndarray
+ (3, 3) Covariance matrix which defines any of the base ellipsoid shape
+ color : Union[str, np.ndarray], optional
+ Color of ellipsoids or list of color of ellispoids
+ hovertext : Union[int, str, np.ndarray], optional
+ Text associated with every ellipsoid or each ellipsoid
+ cmin : float, optional
+ Minimum value along the color scale
+ cmax : float, optional
+ Maximum value along the color scale
+ shared_legend : bool, default True
+ If True, the plotly legend of all ellipsoids is shared as one
+ legendgroup : str, optional
+ Legend group to be shared between all boxes
+ showlegend : bool, default `True`
+ Whether to show legends on not
+ name : str, optional
+ Name of the trace(s)
+ **kwargs : dict, optional
+ List of additional arguments to pass to the underlying list of
+ :class:`plotly.graph_objs.Mesh3D`
+
+ Returns
+ -------
+ Union[List[plotly.graph_objs.Mesh3D]]
+ Ellipsoid traces
+ """
+ # Check the parameters
+ assert color is None or np.isscalar(color) or len(color) == len(centroids), (
+ "Specify one color for all ellipsoids, or one color per ellipsoid.")
+ assert (hovertext is None or np.isscalar(hovertext) or
+ len(hovertext) == len(centroids)), (
+ "Specify one hovertext for all ellipsoids, or one hovertext per "
+ "ellipsoid.")
+
+ # If one color is provided per ellipsoid, give an associated hovertext
+ if hovertext is None and color is not None and not np.isscalar(color):
+ hovertext = [f'Value: {v:0.3f}' for v in color]
+
+ # If cmin/cmax are not provided, must build them so that all ellipsoids
+ # share the same colorscale range (not guaranteed otherwise)
+ if color is not None and not np.isscalar(color) and len(color) > 0:
+ if cmin is None:
+ cmin = np.min(color)
+ if cmax is None:
+ cmax = np.max(color)
+
+ # If the legend is to be shared, make sure there is a common legend group
+ if shared_legend and legendgroup is None:
+ legendgroup = 'group_' + str(time.time())
+
+ # Loop over the list of ellipsoid centroids
+ traces = []
+ col, hov = color, hovertext
+ for i, centroid in enumerate(centroids):
+ # Fetch the right color/hovertext combination
+ if color is not None and not np.isscalar(color):
+ col = color[i]
+ if hovertext is not None and not np.isscalar(hovertext):
+ hov = hovertext[i]
+
+ # If the legend is shared, only draw the legend of the first trace
+ if shared_legend:
+ showlegend = showlegend and i == 0
+ name_i = name
+ else:
+ name_i = f'{name} {i}'
+
+ # Append list of traces
+ traces.append(ellipsoid_trace(
+ centroid=centroid, covmat=covmat, contour=None, color=col,
+ hovertext=hov, cmin=cmin, cmax=cmax, legendgroup=legendgroup,
+ showlegend=showlegend, name=name_i, **kwargs))
+
+ return traces
diff --git a/spine/vis/geo.py b/spine/vis/geo.py
new file mode 100644
index 00000000..b1475ed6
--- /dev/null
+++ b/spine/vis/geo.py
@@ -0,0 +1,267 @@
+"""Draw detectors based on their geometry definition."""
+
+import time
+from functools import partial
+
+import numpy as np
+
+from spine.utils.geo import Geometry
+
+from .box import box_traces
+from .ellipsoid import ellipsoid_traces
+
+
+class GeoDrawer:
+ """Handles drawing all things related to the detector geometry.
+
+ This class is loads a :class:`Geometry` object once from a geometry file
+ and uses it to represent all things related to the detector geometry:
+ - TPC boundaries
+ - Optical detectors
+ - CRT detectors
+ """
+
+ def __init__(self, detector=None, file_path=None, detector_coords=True):
+ """Initializes the underlying detector :class:`Geometry` object.
+
+ Parameters
+ ----------
+ detector : str, optional
+ Name of a recognized detector to the geometry from
+ file_path : str, optional
+ Path to a `.yaml` geometry configuration
+ detector_coords : bool, default False
+ If False, the coordinates are converted to pixel indices
+ """
+ # Initialize the detector geometry
+ self.geo = Geometry(detector, file_path)
+
+ # Store whether to use detector cooordinates or not
+ self.detector_coords = detector_coords
+
+ def tpc_traces(self, meta=None, draw_faces=False, shared_legend=True,
+ name='Detector', color='rgba(0,0,0,0.150)', linewidth=5,
+ **kwargs):
+ """Function which produces a list of traces which represent the TPCs in
+ a 3D event display.
+
+ Parameters
+ ----------
+ meta : Meta, optional
+ Metadata information (only needed if pixel_coordinates is True)
+ draw_faces : bool, default False
+ Weather or not to draw the box faces, or only the edges
+ shared_legend : bool, default True
+ If True, the legend entry in plotly is shared between all the
+ detector volumes
+ name : Union[str, List[str]], default 'Detector'
+ Name(s) of the detector volumes
+ color : Union[int, str, np.ndarray]
+ Color of boxes or list of color of boxes
+ linewidth : int, default 2
+ Width of the box edge lines
+ **kwargs : dict, optional
+ List of additional arguments to pass to
+ spine.viusalization.boxes.box_traces
+
+ Returns
+ -------
+ List[Union[plotly.graph_objs.Scatter3D, plotly.graph_objs.Mesh3D]]
+ List of detector traces (one per TPC)
+ """
+ # Load the list of TPC boundaries
+ boundaries = np.stack([c.boundaries for c in self.geo.tpc.chambers])
+
+ # If required, convert to pixel coordinates
+ if not self.detector_coords:
+ assert meta is not None, (
+ "Must provide meta information to convert the TPC "
+ "boundaries to pixel coordinates.")
+ boundaries = meta.to_px(boundaries.transpose(0,2,1)).transpose(0,2,1)
+
+ # Get a trace per detector volume
+ detectors = box_traces(
+ boundaries[..., 0], boundaries[..., 1], draw_faces=draw_faces,
+ color=color, linewidth=linewidth, shared_legend=shared_legend,
+ name=name, **kwargs)
+
+ return detectors
+
+ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None,
+ name='Optical', color='rgba(0,0,255,0.25)', cmin=None,
+ cmax=None, zero_supress=False, volume_id=None, **kwargs):
+ """Function which produces a list of traces which represent the optical
+ detectors in a 3D event display.
+
+ Parameters
+ ----------
+ meta : Meta, optional
+ Metadata information (only needed if pixel_coordinates is True)
+ shared_legend : bool, default True
+ If True, the legend entry in plotly is shared between all the
+ detector volumes
+ legendgroup : str, optional
+ Legend group to be shared between all boxes
+ name : Union[str, List[str]], default 'Detector'
+ Name(s) of the detector volumes
+ color : Union[int, str, np.ndarray]
+ Color of optical detectors or list of color of optical detectors
+ cmin : float, optional
+ Minimum value along the color scale
+ cmax : float, optional
+ Maximum value along the color scale
+ zero_supress : bool, default False
+ If `True`, do not draw optical detectors that are not activated
+ volume_id : int, optional
+ Specifies which optical volume to represent. If not specified, all
+ the optical volumes are drawn
+ **kwargs : dict, optional
+ List of additional arguments to pass to
+ spine.vis.ellipsoid.ellipsoid_traces or spine.vis.box.box_traces
+
+ Returns
+ -------
+ List[plotly.graph_objs.Mesh3D]
+ List of optical detector traces (one per optical detector)
+ """
+ # Check that there is optical detectors to draw
+ assert self.geo.optical is not None, (
+ "This geometry does not have optical detectors to draw.")
+
+ # Fetch the optical element positions and dimensions
+ if volume_id is None:
+ positions = self.geo.optical.positions.reshape(-1, 3)
+ else:
+ positions = self.geo.optical.positions[volume_id]
+ half_dimensions = self.geo.optical.dimensions/2
+
+ # If there is more than one detector shape, fetch shape IDs
+ shape_ids = None
+ if self.geo.optical.shape_ids is not None:
+ shape_ids = self.geo.optical.shape_ids
+ if volume_id is None:
+ shape_ids = np.tile(shape_ids, self.geo.optical.num_volumes)
+
+ # Convert the positions to pixel coordinates, if needed
+ if not self.detector_coords:
+ assert meta is not None, (
+ "Must provide meta information to convert the optical "
+ "element positions/dimensions to pixel coordinates.")
+ positions = meta.to_px(positions)
+ half_dimensions = meta.to_px(half_dimensions)
+
+ # Check that the colors provided fix the appropriate range
+ if color is not None and not np.isscalar(color):
+ assert len(color) == len(positions), (
+ "Must provide one value for each optical detector.")
+
+ # If cmin/cmax are not provided, must build them so that all optical
+ # detectors share the same colorscale range (not guaranteed otherwise)
+ if color is not None and not np.isscalar(color) and len(color) > 0:
+ if cmin is None:
+ cmin = np.min(color)
+ if cmax is None:
+ cmax = np.max(color)
+
+ # If the legend is to be shared, make sure there is a common legend group
+ if shared_legend and legendgroup is None:
+ legendgroup = 'group_' + str(time.time())
+
+ # Draw each of the optical detectors
+ traces = []
+ for i, shape in enumerate(self.geo.optical.shape):
+ # Restrict the positions to those of this shape, if needed
+ if shape_ids is None:
+ pos = positions
+ col = color
+ else:
+ index = np.where(np.asarray(shape_ids) == i)[0]
+ pos = positions[index]
+ if color is not None and not np.isscalar(color):
+ col = color[index]
+ else:
+ col = color
+
+ # If zero-supression is requested, only draw the optical detectors
+ # which record a non-zero signal
+ if zero_supress and color is not None and not np.isscalar(color):
+ index = np.where(np.asarray(col) != 0)[0]
+ pos = pos[index]
+ col = col[index]
+
+ # Determine wheter to show legends or not
+ showlegend = not shared_legend or i == 0
+
+ # Dispatch the drawing based on the type of optical detector
+ hd = half_dimensions[i]
+ if shape == 'box':
+ # Convert the positions/dimensions to box lower/upper bounds
+ lower, upper = pos - hd, pos + hd
+
+ # Build boxes
+ traces += box_traces(
+ lower, upper, shared_legend=shared_legend, name=name,
+ color=col, cmin=cmin, cmax=cmax, draw_faces=True,
+ legendgroup=legendgroup, showlegend=showlegend, **kwargs)
+
+ else:
+ # Convert the optical detector dimensions to a covariance matrix
+ covmat = np.diag(hd**2)
+
+ # Build ellipsoids
+ traces += ellipsoid_traces(
+ pos, covmat, shared_legend=shared_legend, name=name,
+ color=col, cmin=cmin, cmax=cmax,
+ legendgroup=legendgroup, showlegend=showlegend, **kwargs)
+
+ return traces
+
+ def crt_traces(self, meta=None, detector_coords=True, shared_legend=True,
+ name='CRT', color='rgba(0,255,0,0.25)', **kwargs):
+ """Function which produces a list of traces which represent the optical
+ detectors in a 3D event display.
+
+ Parameters
+ ----------
+ meta : Meta, optional
+ Metadata information (only needed if pixel_coordinates is True)
+ detector_coords : bool, default False
+ If False, the coordinates are converted to pixel indices
+ shared_legend : bool, default True
+ If True, the legend entry in plotly is shared between all the
+ detector volumes
+ name : Union[str, List[str]], default 'Detector'
+ Name(s) of the detector volumes
+ color : Union[int, str, np.ndarray]
+ Color of CRT detectors or list of color of CRT detectors
+ **kwargs : dict, optional
+ List of additional arguments to pass to
+ spine.vis.ellipsoid.ellipsoid_traces or spine.vis.box.box_traces
+
+ Returns
+ -------
+ List[plotly.graph_objs.Mesh3D]
+ List of CRT detector traces (one per CRT element)
+ """
+ # Check that there are CRT planes to draw
+ assert self.geo.crt is not None, (
+ "This geometry does not have CRT planes to draw.")
+
+ # Fetch the CRT element positions and dimensions
+ positions = self.geo.crt.positions
+ half_dimensions = self.geo.crt.dimensions/2
+ if not self.detector_coords:
+ assert meta is not None, (
+ "Must provide meta information to convert the CRT "
+ "element positions/dimensions to pixel coordinates.")
+ positions = meta.to_px(positions)
+ half_dimensions = meta.to_px(half_dimensions)
+
+ # Convert the positions/dimensions to box lower/upper bounds
+ lower = positions - half_dimensions
+ upper = positions + half_dimensions
+
+ # Build and return boxes
+ return box_traces(
+ lower, upper, shared_legend=shared_legend, name=name,
+ color=color, draw_faces=True, **kwargs)
diff --git a/spine/vis/hull.py b/spine/vis/hull.py
index 6f153405..7fcc9a69 100644
--- a/spine/vis/hull.py
+++ b/spine/vis/hull.py
@@ -4,7 +4,8 @@
import plotly.graph_objs as go
-def hull_trace(points, color=None, showscale=False, alphahull=0, **kwargs):
+def hull_trace(points, color=None, intensity=None, showscale=False,
+ alphahull=0, **kwargs):
"""Converts a cloud of points into a 3D convex hull.
This function represents a point cloud by forming a mesh with the points
@@ -16,23 +17,23 @@ def hull_trace(points, color=None, showscale=False, alphahull=0, **kwargs):
(N, 3) Array of point coordinates
color : Union[str, float, np.ndarray], optional
Color of hull
+ intensity : Union[int, float], optional
+ Color intensity of the box along the colorscale axis
showscale : bool, default False
If True, show the colorscale of the :class:`plotly.graph_objs.Mesh3d`
alphahull : float, default 0
Parameter that sets how to define the hull. 0 is the convex hull,
larger numbers correspond to alpha-shapes.
**kwargs : dict, optional
- Additional parameters to pass to the
+ Additional parameters to pass to the underlying
+ :class:`plotly.graph_objs.Mesh3d` object
"""
- # Convert the color provided to a set of intensities
- intensity = None
- if color is not None:
- if np.isscalar(color):
- intensity = [color]*len(points)
- else:
- assert len(color) == points, (
- "The color must be a scalar or one value per point")
- intensity = color
+ # Convert the color provided to a set of intensities, if needed
+ if color is not None and not isinstance(color, str):
+ assert intensity is None, (
+ "Must not provide both `color` and `intensity`.")
+ intensity = np.full(len(ell_points), color)
+ color = None
# Append Mesh3d object
return go.Mesh3d(
diff --git a/spine/vis/layout.py b/spine/vis/layout.py
index 06485797..38a5f700 100644
--- a/spine/vis/layout.py
+++ b/spine/vis/layout.py
@@ -11,7 +11,7 @@
from spine.utils.geo import Geometry
-
+# Colorscale definitions
PLOTLY_COLORS = colors.qualitative.Plotly
PLOTLY_COLORS_TUPLE = colors.convert_colors_to_same_type(
deepcopy(PLOTLY_COLORS), 'tuple')[0]
@@ -105,8 +105,8 @@ def layout3d(ranges=None, meta=None, detector=None, titles=None,
assert (ranges is None or None in ranges) and meta is None, (
"Should not specify `detector` along with `ranges` or `meta`.")
geo = Geometry(detector)
- lengths = geo.detector[:,1] - geo.detector[:,0]
- ranges = geo.detector
+ lengths = geo.tpc.dimensions
+ ranges = geo.tpc.boundaries
# Add some padding
ranges[:,0] -= lengths*0.1
diff --git a/spine/vis/out.py b/spine/vis/out.py
index fc39d754..1da4bf92 100644
--- a/spine/vis/out.py
+++ b/spine/vis/out.py
@@ -7,15 +7,15 @@
from spine.utils.globals import COORD_COLS, PID_LABELS, SHAPE_LABELS, TRACK_SHP
+from .geo import GeoDrawer
from .point import scatter_points
from .cluster import scatter_clusters
-from .detector import detector_traces
from .layout import (
layout3d, dual_figure3d, PLOTLY_COLORS_WGRAY, HIGH_CONTRAST_COLORS)
class Drawer:
- """Class dedicated to drawing the true/reconstructed output.
+ """Handles drawing the true/reconstructed output.
This class is given the entire input/output dictionary from one entry and
provides functions to represent the output.
@@ -29,8 +29,8 @@ class Drawer:
# List of known point modes
_point_modes = ('points', 'points_adapt', 'points_g4')
- # Map between attribute and underlying point objects
- _point_map = {'points': 'points_label', 'points_adapt': 'points',
+ # Map between point attributes and underlying point objects
+ _point_map = {'points': 'points_label', 'points_adapt': 'points',
'points_g4': 'points_g4'}
def __init__(self, data, draw_mode='both', truth_point_mode='points',
@@ -85,21 +85,23 @@ def __init__(self, data, draw_mode='both', truth_point_mode='points',
self.truth_point_mode = truth_point_mode
self.truth_index_mode = truth_point_mode.replace('points', 'index')
- # Save the detector properties
+ # If detector information is provided, initialie the geometry drawer
+ self.geo_drawer = None
self.meta = data.get('meta', None)
- self.detector = detector
- self.detector_coords = detector_coords
+ if detector is not None:
+ self.geo_drawer = GeoDrawer(
+ detector=detector, detector_coords=detector_coords)
# Initialize the layout
self.split_scene = split_scene
meta = self.meta if detector is None else None
self.layout = layout3d(
- detector=self.detector, meta=meta,
- detector_coords=self.detector_coords, **kwargs)
+ detector=detector, meta=meta, detector_coords=detector_coords,
+ **kwargs)
def get(self, obj_type, attr=None, draw_end_points=False,
- draw_vertices=False, synchronize=False, titles=None,
- split_traces=False):
+ draw_vertices=False, draw_flashes=False, synchronize=False,
+ titles=None, split_traces=False):
"""Draw the requested object type with the requested mode.
Parameters
@@ -113,6 +115,8 @@ def get(self, obj_type, attr=None, draw_end_points=False,
If `True`, draw the fragment or particle end points
draw_vertices : bool, default False
If `True`, draw the interaction vertices
+ draw_flashes : bool, default False
+ If `True`, draw flashes that have been matched to interactions
synchronize : bool, default False
If `True`, matches the camera position/angle of one plot to the other
titles : List[str], optional
@@ -148,7 +152,7 @@ def get(self, obj_type, attr=None, draw_end_points=False,
traces[prefix] += self._start_point_trace(obj_name)
traces[prefix] += self._end_point_trace(obj_name)
- # Fetch the vertex, if requested
+ # Fetch the vertices, if requested
if draw_vertices:
for prefix in self.prefixes:
obj_name = f'{prefix}_interactions'
@@ -156,17 +160,23 @@ def get(self, obj_type, attr=None, draw_end_points=False,
"Must provide interactions to draw their vertices.")
traces[prefix] += self._vertex_trace(obj_name)
- # Add the detector traces, if available
- if self.detector is not None:
+ # Fetch the flashes, if requested
+ if draw_flashes:
+ assert 'flashes' in self.data, (
+ "Must provide the `flashes` objects to draw them.")
+ for prefix in self.prefixes:
+ obj_name = f'{prefix}_interactions'
+ assert obj_name in self.data, (
+ "Must provide interactions to draw matched flashes.")
+ traces[prefix] += self._flash_trace(obj_name)
+
+ # Add the TPC traces, if available
+ if self.geo_drawer is not None:
if len(self.prefixes) and self.split_scene:
for prefix in self.prefixes:
- traces[prefix] += detector_traces(
- detector=self.detector, meta=self.meta,
- detector_coords=self.detector_coords)
+ traces[prefix] += self.geo_drawer.tpc_traces(meta=self.meta)
else:
- traces[self.prefixes[-1]] += detector_traces(
- detector=self.detector, meta=self.meta,
- detector_coords=self.detector_coords)
+ traces[self.prefixes[-1]] += self.geo_drawer.tpc_traces(meta=self.meta)
# Initialize the figure, return
if len(self.prefixes) > 1 and self.split_scene:
@@ -350,8 +360,8 @@ def _start_point_trace(self, obj_name, color='black', markersize=5,
Returns
-------
- dict
- Dictionary of color parameters (colorscale, cmin, cmax)
+ list
+ List of start point traces
"""
return self._point_trace(
obj_name, 'start_point', color=color, markersize=markersize,
@@ -376,8 +386,8 @@ def _end_point_trace(self, obj_name, color='black', markersize=5,
Returns
-------
- dict
- Dictionary of color parameters (colorscale, cmin, cmax)
+ list
+ List of end point traces
"""
return self._point_trace(
obj_name, 'end_point', color=color, markersize=markersize,
@@ -402,8 +412,8 @@ def _vertex_trace(self, obj_name, vertex_attr='vertex', color='green',
Returns
-------
- dict
- Dictionary of color parameters (colorscale, cmin, cmax)
+ list
+ List of vertex point traces
"""
return self._point_trace(
obj_name, vertex_attr, color=color, markersize=markersize,
@@ -423,8 +433,8 @@ def _point_trace(self, obj_name, point_attr, **kwargs):
Returns
-------
- dict
- Dictionary of color parameters (colorscale, cmin, cmax)
+ list
+ List of point traces
"""
# Define the name of the trace
name = (' '.join(obj_name.split('_')).capitalize()[:-1] + ' ' +
@@ -452,3 +462,44 @@ def _point_trace(self, obj_name, point_attr, **kwargs):
return scatter_points(
points, hovertext=np.array(hovertext), name=name, **kwargs)
+
+ def _flash_trace(self, obj_name, **kwargs):
+ """Draw the cumlative PEs of flashes that have been matched to
+ interactions specified by `obj_name`.
+
+ Parameters
+ ----------
+ obj_name : str
+ Name of the object to draw
+ **kwargs : dict, optional
+ List of additional arguments to pass to :func:`optical_traces`
+
+ Returns
+ -------
+ list
+ List of optical detector traces
+ """
+ # Define the name of the trace
+ name = ' '.join(obj_name.split('_')).capitalize()[:-1] + ' flashes'
+
+ # Find the list of flash IDs to draw
+ flash_ids = []
+ for inter in self.data[obj_name]:
+ if inter.is_flash_matched:
+ flash_ids.extend(inter.flash_ids)
+
+ # Sum values from each flash to build a a global color scale
+ color = np.zeros(self.geo_drawer.geo.optical.num_detectors)
+ opt_det_ids = self.geo_drawer.geo.optical.det_ids
+ for flash_id in flash_ids:
+ flash = self.data['flashes'][flash_id]
+ index = self.geo_drawer.geo.optical.volume_index(flash.volume_id)
+ pe_per_ch = flash.pe_per_ch
+ if opt_det_ids is not None:
+ pe_per_ch = np.bincount(opt_det_ids, weights=pe_per_ch)
+ color[index] += pe_per_ch
+
+ # Return the set of optical detectors with a color scale
+ return self.geo_drawer.optical_traces(
+ meta=self.meta, color=color, zero_supress=True,
+ colorscale='Inferno', name=name)
diff --git a/spine/vis/train.py b/spine/vis/train.py
index 2e282daf..bc8d14c7 100644
--- a/spine/vis/train.py
+++ b/spine/vis/train.py
@@ -267,7 +267,7 @@ def draw(self, model, metric, limits=None, model_name=None,
epoch_v = [epoch_t[iter_t == it] for it in iter_v]
mask = np.where(
np.array([len(e) for e in epoch_v]) == 1)[0]
- epoch_v = [float(epoch_v[i]) for i in mask]
+ epoch_v = [float(epoch_v[i].iloc[0]) for i in mask]
iter_v = iter_v[mask]
metric_v_mean = metric_v_mean[mask]
metric_v_err = metric_v_err[mask]
@@ -430,6 +430,8 @@ def get_training_df(self, log_dir, keys):
log_dfs = []
for i, f in enumerate(np.array(log_files)[order]):
df = pd.read_csv(f, nrows=end_points[i+1]-end_points[i])
+ if len(df) == 0:
+ continue
for key_list in keys:
key, key_name = self.find_key(df, key_list)
df[key_name] = df[key]