Skip to content

Commit

Permalink
Fill containter (#3)
Browse files Browse the repository at this point in the history
Add option to directly fill DNN data container from within c++ via pybindings.
  • Loading branch information
mhuen authored Oct 18, 2020
1 parent 907a773 commit c372fec
Show file tree
Hide file tree
Showing 4 changed files with 654 additions and 95 deletions.
35 changes: 30 additions & 5 deletions ic3_data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,20 @@ def Configure(self):
# or if it is one that computes the values for one DOM at a time
if self._config['data_format'] in [
'total_dom_charge',
'reduced_summary_statistics_data',
'cascade_classification_data',
'mc_tree_input_data',
]:
self._calculate_per_dom = False
self._calculation_method = 'calculate_for_detector'
class_string = 'ic3_data.data_formats_detector.{}'.format(
self._config['data_format'])
elif self._config['data_format'] in [
'reduced_summary_statistics_data',
]:
self._calculation_method = 'fill_container'
class_string = 'ic3_data.data_formats_fill.{}'.format(
self._config['data_format'])
else:
self._calculate_per_dom = True
self._calculation_method = 'calculate_per_dom'
class_string = 'ic3_data.data_formats.{}'.format(
self._config['data_format'])
self._data_format_func = misc.load_class(class_string)
Expand Down Expand Up @@ -154,7 +159,7 @@ def Physics(self, frame):
# ------------------------------------------------
# Calculate DNN input data seperately for each DOM
# ------------------------------------------------
if self._calculate_per_dom:
if self._calculation_method == 'calculate_per_dom':
# restructure pulses
# charges, times, dom_times_dict, dom_charges_dict = \
# self.restructure_pulses(pulses)
Expand Down Expand Up @@ -218,7 +223,7 @@ def Physics(self, frame):
# ---------------------------------------------------
# Calculate DNN input data for whole detector at once
# ---------------------------------------------------
else:
elif self._calculation_method == 'calculate_for_detector':

global_time_offset, data_dict = self._data_format_func(
frame=frame,
Expand All @@ -240,6 +245,26 @@ def Physics(self, frame):
self._container.global_time_offset.value = global_time_offset
self._container.global_time_offset_batch[self._batch_index] = \
global_time_offset

# --------------------------------
# Directly fill DNN Data Container
# --------------------------------
elif self._calculation_method == 'fill_container':

# fill the data container
self._data_format_func(
container=self._container,
batch_index=self._batch_index,
write_to_frame=self._write_to_frame,
frame=frame,
pulses=pulses,
config=self._config,
dom_exclusions=self._dom_exclusions,
partial_exclusion=self._partial_exclusion,
)

else:
raise ValueError('Unknown method:', self._calculation_method)
# ---------------------------------------------------

# measure time
Expand Down
137 changes: 137 additions & 0 deletions ic3_data/data_formats_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import print_function, division
import numpy as np

from ic3_data.ext_boost import fill_reduced_summary_statistics_data

"""All data format functions must have the following signature:
Parameters
----------
container : DNNDataContainer
The data container that will be filled.
batch_index : int
The batch index.
write_to_frame : bool
Whether or not the DNN data should be written to the frame.
frame : I3Frame
The current frame.
pulses : I3RecoPulseSeriesMap
The pulse series map from which to calculate the DNN input data.
config : dict
A dictionary that contains all configuration settings.
dom_exclusions : list of str, None
List of frame keys that define DOMs or TimeWindows that should be
excluded. Typical values for this are:
['BrightDOMs','SaturationWindows', 'BadDomsList', 'CalibrationErrata']
partial_exclusion : bool, None
If True, partially exclude DOMS, e.g. only omit pulses from
excluded TimeWindows defined in 'dom_exclusions'.
If False, all pulses from a DOM will be excluded if the omkey
exists in the dom_exclusions.
*args
Variable length argument list.
**kwargs
Arbitrary keyword arguments.
"""


def total_dom_charge(container, batch_index, write_to_frame, frame, pulses,
config, dom_exclusions, partial_exclusion,
*args, **kwargs):
"""Get the total DOM charge per DOM
Parameters
----------
container : DNNDataContainer
The data container that will be filled.
batch_index : int
The batch index.
write_to_frame : bool
Whether or not the DNN data should be written to the frame.
frame : I3Frame
The current frame.
pulses : I3RecoPulseSeriesMap
The pulse series map from which to calculate the DNN input data.
config : dict
A dictionary that contains all configuration settings.
dom_exclusions : list of str, None
List of frame keys that define DOMs or TimeWindows that should be
excluded. Typical values for this are:
['BrightDOMs','SaturationWindows', 'BadDomsList', 'CalibrationErrata']
partial_exclusion : bool, None
If True, partially exclude DOMS, e.g. only omit pulses from
excluded TimeWindows defined in 'dom_exclusions'.
If False, all pulses from a DOM will be excluded if the omkey
exists in the dom_exclusions.
*args
Variable length argument list.
**kwargs
Arbitrary keyword arguments.
"""

add_total_charge = True
add_t_first = False
add_t_std = False

fill_reduced_summary_statistics_data(
container=container,
pulse_key=pulses,
add_total_charge=add_total_charge,
add_t_first=add_t_first,
add_t_std=add_t_std,
write_to_frame=write_to_frame,
batch_index=batch_index,
)


def reduced_summary_statistics_data(container, batch_index, write_to_frame,
frame, pulses, config, dom_exclusions,
partial_exclusion, *args, **kwargs):
"""Get a reduced set of summary statistics per DOM
These include: total dom charge, time of first pulse, std. dev of pulse
times. The pulse times are calculated relative to the charge weighted
mean time of all pulses.
Parameters
----------
container : DNNDataContainer
The data container that will be filled.
batch_index : int
The batch index.
write_to_frame : bool
Whether or not the DNN data should be written to the frame.
frame : I3Frame
The current frame.
pulses : I3RecoPulseSeriesMap
The pulse series map from which to calculate the DNN input data.
config : dict
A dictionary that contains all configuration settings.
dom_exclusions : list of str, None
List of frame keys that define DOMs or TimeWindows that should be
excluded. Typical values for this are:
['BrightDOMs','SaturationWindows', 'BadDomsList', 'CalibrationErrata']
partial_exclusion : bool, None
If True, partially exclude DOMS, e.g. only omit pulses from
excluded TimeWindows defined in 'dom_exclusions'.
If False, all pulses from a DOM will be excluded if the omkey
exists in the dom_exclusions.
*args
Variable length argument list.
**kwargs
Arbitrary keyword arguments.
"""

add_total_charge = True
add_t_first = True
add_t_std = True

fill_reduced_summary_statistics_data(
container=container,
pulse_key=pulses,
add_total_charge=add_total_charge,
add_t_first=add_t_first,
add_t_std=add_t_std,
write_to_frame=write_to_frame,
batch_index=batch_index,
)
100 changes: 10 additions & 90 deletions ic3_data_ext/ext_boost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ wrapped with boost python.
#include <boost/python.hpp>

#include "utils.cpp"
#include "reduced_summary_statistics.cpp"

/*
Depending on the boost version, we need to use numpy differently.
Expand All @@ -41,6 +42,7 @@ See answers and discussion provided here:
namespace bn = boost::python::numpy;
#endif

namespace bp = boost::python;


/******************************************************
Expand Down Expand Up @@ -151,96 +153,6 @@ See answers and discussion provided here:
/******************************************************
Functions with pybinding for python-based usage
******************************************************/
template <typename T>
inline boost::python::tuple get_reduced_summary_statistics_data(
const boost::python::object& pulse_map_obj,
const bool add_total_charge,
const bool add_t_first,
const bool add_t_std
) {

// Get pulse map
I3RecoPulseSeriesMap& pulse_map = boost::python::extract<I3RecoPulseSeriesMap&>(pulse_map_obj);

// create a dict for the output data
boost::python::dict data_dict;
T global_offset_time = 0.;

// Iterate over pulses once to obtain global time offset
if (add_t_first){
MeanVarianceAccumulator<T> acc_total;
for (auto const& dom_pulses : pulse_map){
for (auto const& pulse : dom_pulses.second){
acc_total.add_element(pulse.GetTime(), pulse.GetCharge());
}
}
global_offset_time = acc_total.mean();
}

// now iterate over DOMs and pulses to fill data_dict
for (auto const& dom_pulses : pulse_map){

// check if pulses are present
unsigned int n_pulses = dom_pulses.second.size();
if (n_pulses == 0){
continue;
}

// create and initialize variables
T dom_charge_sum = 0.0;
MeanVarianceAccumulator<T> acc;

// loop through pulses
for (auto const& pulse : dom_pulses.second){

// total DOM charge
dom_charge_sum += pulse.GetCharge();

// weighted mean and std
if (add_t_std){
acc.add_element(pulse.GetTime(), pulse.GetCharge());
}
}

// add data
int counter = 0;
boost::python::list bin_exclusions_list; // empty dummy exclusions
boost::python::list bin_indices_list;
boost::python::list bin_values_list;

// Total DOM charge
if (add_total_charge){
bin_indices_list.append(counter);
bin_values_list.append(dom_charge_sum);
counter += 1;
}

// time of first pulse
if (add_t_first){
bin_indices_list.append(counter);
bin_values_list.append(
dom_pulses.second[0].GetTime() - global_offset_time);
counter += 1;
}

// time std deviation of pulses at DOM
if (add_t_std){
bin_indices_list.append(counter);
if (n_pulses == 1){
bin_values_list.append(0.);
} else{
bin_values_list.append(acc.std());
}
counter += 1;
}

// add to data_dict
data_dict[dom_pulses.first] = boost::python::make_tuple(
bin_values_list, bin_indices_list, bin_exclusions_list);
}

return boost::python::make_tuple(global_offset_time, data_dict);
}

template <typename T>
inline boost::python::dict get_cascade_classification_data(
Expand Down Expand Up @@ -989,6 +901,14 @@ BOOST_PYTHON_MODULE(ext_boost)
boost::python::def("get_reduced_summary_statistics_data",
&get_reduced_summary_statistics_data<double>);

boost::python::def("fill_reduced_summary_statistics_data",
&fill_reduced_summary_statistics_data<double>,
(bp::arg("container"), bp::arg("pulse_key"),
bp::arg("add_total_charge"), bp::arg("add_t_first"),
bp::arg("add_t_std"), bp::arg("write_to_frame"),
bp::arg("batch_index"))
);

boost::python::def("get_cascade_classification_data",
&get_cascade_classification_data<double>);

Expand Down
Loading

0 comments on commit c372fec

Please sign in to comment.