From ab93f2280d76b9379eac5f31775faf9fa68e9e26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zo=C3=AB=20Bilodeau?= <70441641+zbilodea@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:32:56 +0100 Subject: [PATCH] feat: add branch skimming option for reading ROOT files (#71) * first commit * saving * added branch filtering * formatting? * got rid of eval(), now pylint is complaining about exec()... * need noqa for 'unused variable' * Switched to using uproot's 'cut' argument --- src/hepconvert/_utils.py | 4 +- src/hepconvert/copy_root.py | 30 ++++++---- src/hepconvert/merge.py | 30 ++++++---- src/hepconvert/root_to_parquet.py | 4 ++ tests/test_copy_root.py | 2 +- tests/test_skim.py | 96 +++++++++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 24 deletions(-) create mode 100644 tests/test_skim.py diff --git a/src/hepconvert/_utils.py b/src/hepconvert/_utils.py index 2c3ea0c..d7f7928 100644 --- a/src/hepconvert/_utils.py +++ b/src/hepconvert/_utils.py @@ -6,7 +6,7 @@ def group_branches(tree, keep_branches): """ Creates groups for ak.zip to avoid duplicate counters being created. - Groups created if branches have the same .member("fLeafCount") + Groups created if branches have the same branch.member("fLeafCount") """ groups = [] count_branches = [] @@ -78,7 +78,7 @@ def filter_branches(tree, keep_branches, drop_branches, count_branches): or tree.name == next(iter(keep_branches.keys())) ): keep_branches = keep_branches.get(tree.name) - if isinstance(keep_branches, str): + if isinstance(keep_branches, str) or len(keep_branches) == 1: keep_branches = tree.keys(filter_name=keep_branches) return [ b.name diff --git a/src/hepconvert/copy_root.py b/src/hepconvert/copy_root.py index 985dad6..321adc7 100644 --- a/src/hepconvert/copy_root.py +++ b/src/hepconvert/copy_root.py @@ -21,6 +21,8 @@ def copy_root( # add_branches=None, #TO-DO: add functionality for this, just specify about the counter issue keep_trees=None, drop_trees=None, + cut=None, + expressions=None, progress_bar=None, force=False, fieldname_separator="_", @@ -45,14 +47,23 @@ def copy_root( :type keep_branches: list of str, str, or dict, optional :param drop_branches: To remove branches from all trees, pass a list of names of branches to remove. Wildcarding supported ("Jet_*"). If removing branches from one of multiple trees, - pass a dict of structure: {tree: [branch1, branch2]} to remove branch1 and branch2 from ttree "tree". Defaults to None. Command line option: ``--drop-branches``. + pass a dict of structure: {tree: [branch1, branch2]} to remove branch1 and branch2 from TTree "tree". Defaults to None. Command line option: ``--drop-branches``. :type drop_branches: list of str, str, or dict, optional - :param drop_trees: To keep only certain a ttrees in a file, pass a list of names of ttrees to keep. All others will be removed. - Defaults to None. Command line option: ``--keep-trees``. - :type keep_trees: str or list of str, optional - :param drop_trees: To remove a ttree from a file, pass a list of names of ttrees to remove. + :param keep_branches: To keep only specified branches from all trees, pass a list of names of branches to + remove. If removing branches from one of multiple trees, pass a dict of structure: {tree: [branch1, branch2]} + to remove branch1 and branch2 from TTree "tree". Defaults to None. Command line option: ``--keep-branches``. + :type keep_branches: list of str, str, or dict, optional + :param drop_trees: To remove a TTree from a file, pass a list of names of trees to remove. Defaults to None. Command line option: ``--drop-trees``. :type drop_trees: str or list of str, optional + :param keep_trees: To keep only certain a TTrees in a file, pass a list of names of trees to keep. All others will be removed. + Defaults to None. Command line option: ``--keep-trees``. + :type keep_trees: str or list of str, optional + :param cut: If not None, this expression filters all of the ``expressions``. + :type cut: None or str + :param expressions: Names of ``TBranches`` or aliases to convert to arrays or mathematical expressions of them. + Uses the ``language`` to evaluate. If None, all ``TBranches`` selected by the filters are included. + :type expressions: None, str, or list of str :param progress_bar: Displays a progress bar. Can input a custom tqdm progress bar object, or set ``True`` for a default tqdm progress bar. Must have tqdm installed. :type progress_bar: Bool, tqdm.std.tqdm object @@ -220,6 +231,8 @@ def copy_root( step_size=step_size, how=dict, filter_name=lambda b: b in kb, + expressions=expressions, + cut=cut, ): for group in groups: if (len(group)) > 1: @@ -242,6 +255,7 @@ def copy_root( if key in kb: del chunk[key] if first: + first = False if drop_branches: branch_types = { name: array.type @@ -250,7 +264,6 @@ def copy_root( } else: branch_types = {name: array.type for name, array in chunk.items()} - out_file.mktree( tree.name, branch_types, @@ -260,11 +273,6 @@ def copy_root( initial_basket_capacity=initial_basket_capacity, resize_factor=resize_factor, ) - try: - out_file[tree.name].extend(chunk) - except AssertionError: - msg = "Are the branch_names correct?" - first = False else: try: diff --git a/src/hepconvert/merge.py b/src/hepconvert/merge.py index 6a4059b..5ed4c92 100644 --- a/src/hepconvert/merge.py +++ b/src/hepconvert/merge.py @@ -6,7 +6,11 @@ import uproot from hepconvert import _utils -from hepconvert._utils import filter_branches, get_counter_branches, group_branches +from hepconvert._utils import ( + filter_branches, + get_counter_branches, + group_branches, +) from hepconvert.histogram_adding import _hadd_1d, _hadd_2d, _hadd_3d @@ -18,6 +22,8 @@ def merge_root( drop_branches=None, keep_trees=None, drop_trees=None, + cut=None, + expressions=None, progress_bar=None, fieldname_separator="_", title="", @@ -45,14 +51,19 @@ def merge_root( :type keep_branches: list of str, str, or dict, optional :param drop_branches: To remove branches from all trees, pass a list of names of branches to remove. Wildcarding supported ("Jet_*"). If removing branches from one of multiple trees, - pass a dict of structure: {tree: [branch1, branch2]} to remove branch1 and branch2 from ttree "tree". Defaults to None. Command line option: ``--drop-branches``. + pass a dict of structure: {tree: [branch1, branch2]} to remove branch1 and branch2 from TTree "tree". Defaults to None. Command line option: ``--drop-branches``. :type drop_branches: list of str, str, or dict, optional - :param drop_trees: To keep only certain a ttrees in a file, pass a list of names of ttrees to keep. All others will be removed. + :param keep_trees: To keep only certain a TTrees in a file, pass a list of names of trees to keep. All others will be removed. Defaults to None. Command line option: ``--keep-trees``. :type keep_trees: str or list of str, optional - :param drop_trees: To remove a ttree from a file, pass a list of names of ttrees to remove. + :param drop_trees: To remove a TTree from a file, pass a list of names of trees to remove. Defaults to None. Command line option: ``--drop-trees``. :type drop_trees: str or list of str, optional + :param cut: If not None, this expression filters all of the ``expressions``. + :type cut: None or str + :param expressions: Names of ``TBranches`` or aliases to convert to arrays or mathematical expressions of them. + Uses the ``language`` to evaluate. If None, all ``TBranches`` selected by the filters are included. + :type expressions: None, str, or list of str :param progress_bar: Displays a progress bar. Can input a custom tqdm progress bar object, or set ``True`` for a default tqdm progress bar. Must have tqdm installed. :type progress_bar: Bool, tqdm.std.tqdm object @@ -138,9 +149,8 @@ def merge_root( first = True else: if append: - raise FileNotFoundError( - "File %s" + destination + " not found. File must exist to append." - ) + msg = f"File {destination} not found. Can only append to existing files." + raise FileNotFoundError(msg) out_file = uproot.recreate( destination, compression=uproot.compression.Compression.from_code_pair( @@ -251,9 +261,9 @@ def merge_root( step_size=step_size, how=dict, filter_name=lambda b: b in kb, # noqa: B023 + cut=cut, + expressions=expressions, ): - for key in count_branches: - del chunk[key] for group in groups: if (len(group)) > 1: chunk.update( @@ -276,6 +286,7 @@ def merge_root( if branch_types is None: branch_types = {name: array.type for name, array in chunk.items()} if first: + first = False out_file.mktree( tree.name, branch_types, @@ -289,7 +300,6 @@ def merge_root( out_file[tree.name].extend(chunk) except AssertionError: msg = "TTrees must have the same structure to be merged. Are the branch_names correct?" - first = False else: try: diff --git a/src/hepconvert/root_to_parquet.py b/src/hepconvert/root_to_parquet.py index 28dc1ad..a550ede 100644 --- a/src/hepconvert/root_to_parquet.py +++ b/src/hepconvert/root_to_parquet.py @@ -13,6 +13,8 @@ def root_to_parquet( tree=None, drop_branches=None, keep_branches=None, + cut=None, + expressions=None, force=False, step_size="100 MB", list_to32=False, @@ -217,6 +219,8 @@ def root_to_parquet( for i in f[tree].iterate( step_size=step_size, filter_name=filter_b, + cut=cut, + expressions=expressions, ) ), out_file, diff --git a/tests/test_copy_root.py b/tests/test_copy_root.py index 3c3842f..5ebf19d 100644 --- a/tests/test_copy_root.py +++ b/tests/test_copy_root.py @@ -149,7 +149,7 @@ def test_keep_tree(tmp_path): with pytest.raises( ValueError, - match="Key 'tree5' does not match any TTree in ROOT file/Users/zobil/Desktop/directory/two_trees.root", + match=f"Key 'tree5' does not match any TTree in ROOT file{tmp_path}two_trees.root", ): hepconvert.copy_root( Path(tmp_path) / "copied.root", diff --git a/tests/test_skim.py b/tests/test_skim.py new file mode 100644 index 0000000..9e6c6bd --- /dev/null +++ b/tests/test_skim.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from pathlib import Path + +import awkward as ak +import pytest +import uproot + +import hepconvert + +skhep_testdata = pytest.importorskip("skhep_testdata") + + +def test_copy(tmp_path): + file = uproot.open(skhep_testdata.data_path("uproot-HZZ.root")) + cut_exp = "Jet_Px >= 10" + hepconvert.copy_root( + Path(tmp_path) / "copy.root", + skhep_testdata.data_path("uproot-HZZ.root"), + keep_branches="Jet_", + counter_name=lambda counted: "N" + counted, + force=True, + expressions="Jet_Px", + cut=cut_exp, + ) + hepconvert_file = uproot.open(Path(tmp_path) / "copy.root") + for key in hepconvert_file["events"].keys(): + if key.startswith("Jet_Px"): + assert key in file["events"].keys() + assert len(hepconvert_file["events"][key].arrays()) == 2321 + + +def test_trigger(tmp_path): + file = uproot.open( + skhep_testdata.data_path("nanoAOD_2015_CMS_Open_Data_ttbar.root") + ) + + hepconvert.copy_root( + Path(tmp_path) / "copy.root", + skhep_testdata.data_path("nanoAOD_2015_CMS_Open_Data_ttbar.root"), + keep_branches=["Jet_*"], + force=True, + cut="HLT_QuadPFJet_DoubleBTagCSV_VBF_Mqq200", + ) + hepconvert_file = uproot.open(Path(tmp_path) / "copy.root") + correct_len = 0 + for i in file["Events"]["HLT_QuadPFJet_DoubleBTagCSV_VBF_Mqq200"].array(): + if i is True: + correct_len += 1 + for key in hepconvert_file["Events"].keys(): + if key.startswith("Jet_"): + assert key in file["Events"].keys() + assert len(hepconvert_file["Events"][key].array()) == correct_len + assert ak.all( + hepconvert_file["Events"][key].array() + == file["Events"][key].array()[ + file["Events"]["HLT_QuadPFJet_DoubleBTagCSV_VBF_Mqq200"].array() + ] + ) + + +def test_incorrect_shape(tmp_path): + cut_exp = "Jet_Px >= 10" + with pytest.raises(IndexError): + hepconvert.copy_root( + Path(tmp_path) / "copy.root", + skhep_testdata.data_path("uproot-HZZ.root"), + counter_name=lambda counted: "N" + counted, + force=True, + cut=cut_exp, + ) + + +def test_merge_cut(tmp_path): + file = uproot.open(skhep_testdata.data_path("uproot-HZZ.root"))["events"] + + hepconvert.merge_root( + Path(tmp_path) / "test_simple.root", + [ + skhep_testdata.data_path("uproot-HZZ.root"), + skhep_testdata.data_path("uproot-HZZ.root"), + ], + cut="Jet_Px >= 10", + keep_branches="Jet_*", + progress_bar=True, + force=True, + counter_name=lambda counted: "N" + counted, + ) + hepconvert_file = uproot.open(Path(tmp_path) / "test_simple.root") + for key in hepconvert_file["events"].keys(): + if key.startswith("Jet_"): + assert key in file.keys() + assert ak.all( + hepconvert_file["events"][key].array()[:2421] + == file[key].array()[file["Jet_Px"].array() >= 10] + )