diff --git a/tests/test_0017-multi-basket-multi-branch-fetch.py b/tests/test_0017-multi-basket-multi-branch-fetch.py new file mode 100644 index 000000000..6b4737975 --- /dev/null +++ b/tests/test_0017-multi-basket-multi-branch-fetch.py @@ -0,0 +1,328 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE + +from __future__ import absolute_import + +import sys +import json + +try: + from io import StringIO +except ImportError: + from StringIO import StringIO + +import numpy +import pytest +import skhep_testdata + +import uproot4 +import uproot4.interpret.numerical +import uproot4.interpret.library +import uproot4.source.futures + + +def test_any_basket(): + interpretation = uproot4.interpret.numerical.AsDtype(">i4") + + with uproot4.open( + skhep_testdata.data_path("uproot-sample-6.20.04-uncompressed.root") + )["sample/i4"] as branch: + assert branch.basket(0).array(interpretation).tolist() == [ + -15, + -14, + -13, + -12, + -11, + -10, + -9, + ] + assert branch.basket(1).array(interpretation).tolist() == [ + -8, + -7, + -6, + -5, + -4, + -3, + -2, + ] + assert branch.basket(2).array(interpretation).tolist() == [ + -1, + 0, + 1, + 2, + 3, + 4, + 5, + ] + assert branch.basket(3).array(interpretation).tolist() == [ + 6, + 7, + 8, + 9, + 10, + 11, + 12, + ] + assert branch.basket(4).array(interpretation).tolist() == [ + 13, + 14, + ] + + +def test_stitching_arrays(): + interpretation = uproot4.interpret.numerical.AsDtype("i8") + expectation = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + basket_arrays = [[0, 1, 2, 3, 4], [5, 6], [], [7, 8, 9], [10], [11, 12, 13, 14]] + basket_arrays = [numpy.array(x) for x in basket_arrays] + entry_offsets = numpy.array([0, 5, 7, 7, 10, 11, 15]) + library = uproot4.interpret.library._libraries["np"] + + for start in range(16): + for stop in range(15, -1, -1): + actual = interpretation.final_array( + basket_arrays, start, stop, entry_offsets, library, None + ) + assert expectation[start:stop] == actual.tolist() + + +def test_names_entries_to_ranges_or_baskets(): + with uproot4.open( + skhep_testdata.data_path("uproot-sample-6.20.04-uncompressed.root") + )["sample"] as sample: + out = sample._names_entries_to_ranges_or_baskets(["i4"], 0, 30) + assert all(x[0] == "i4" for x in out) + assert [x[2] for x in out] == [0, 1, 2, 3, 4] + assert [x[3] for x in out] == [ + (6992, 7091), + (16085, 16184), + (25939, 26038), + (35042, 35141), + (40396, 40475), + ] + + +def test_ranges_or_baskets_to_arrays(): + with uproot4.open( + skhep_testdata.data_path("uproot-sample-6.20.04-uncompressed.root") + )["sample"] as sample: + branch = sample["i4"] + + ranges_or_baskets = sample._names_entries_to_ranges_or_baskets(["i4"], 0, 30) + branchid_interpretation = { + id(branch): uproot4.interpret.numerical.AsDtype(">i4") + } + entry_start, entry_stop = (0, 30) + decompression_executor = uproot4.source.futures.TrivialExecutor() + interpretation_executor = uproot4.source.futures.TrivialExecutor() + array_cache = None + library = uproot4.interpret.library._libraries["np"] + + output = sample._ranges_or_baskets_to_arrays( + ranges_or_baskets, + branchid_interpretation, + entry_start, + entry_stop, + decompression_executor, + interpretation_executor, + array_cache, + library, + ) + assert output["i4"].tolist() == [ + -15, + -14, + -13, + -12, + -11, + -10, + -9, + -8, + -7, + -6, + -5, + -4, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + ] + + +def test_branch_array(): + with uproot4.open( + skhep_testdata.data_path("uproot-sample-6.20.04-uncompressed.root") + )["sample/i4"] as branch: + assert branch.array( + uproot4.interpret.numerical.AsDtype(">i4"), library="np" + ).tolist() == [ + -15, + -14, + -13, + -12, + -11, + -10, + -9, + -8, + -7, + -6, + -5, + -4, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + ] + + assert branch.array( + uproot4.interpret.numerical.AsDtype(">i4"), + entry_start=3, + entry_stop=-5, + library="np", + ).tolist() == [ + -12, + -11, + -10, + -9, + -8, + -7, + -6, + -5, + -4, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ] + + assert branch.array( + uproot4.interpret.numerical.AsDtype(">i4"), + entry_start=3, + entry_stop=-5, + interpretation_executor=uproot4.decompression_executor, + library="np", + ).tolist() == [ + -12, + -11, + -10, + -9, + -8, + -7, + -6, + -5, + -4, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ] + + with pytest.raises(ValueError): + branch.array(uproot4.interpret.numerical.AsDtype(">i8"), library="np") + + +def test_cache(): + with uproot4.open( + skhep_testdata.data_path("uproot-sample-6.20.04-uncompressed.root") + ) as f: + assert f.cache_key == "db4be408-93ad-11ea-9027-d201a8c0beef:/" + assert f["sample"].cache_key == "db4be408-93ad-11ea-9027-d201a8c0beef:/sample" + assert ( + f["sample/i4"].cache_key + == "db4be408-93ad-11ea-9027-d201a8c0beef:/sample:i4" + ) + i4 = f["sample/i4"] + assert list(f.array_cache) == [] + i4.array(uproot4.interpret.numerical.AsDtype(">i4"), library="np") + assert list(f.array_cache) == [ + "db4be408-93ad-11ea-9027-d201a8c0beef:/sample:i4:AsDtype(Bi4(),Li4()):0-30:np" + ] + + with pytest.raises(OSError): + i4.array( + uproot4.interpret.numerical.AsDtype(">i4"), entry_start=3, library="np" + ) + + i4.array(uproot4.interpret.numerical.AsDtype(">i4"), library="np") + + +def test_pandas(): + pandas = pytest.importorskip("pandas") + with uproot4.open( + skhep_testdata.data_path("uproot-sample-6.20.04-uncompressed.root") + )["sample/i4"] as branch: + series = branch.array( + uproot4.interpret.numerical.AsDtype(">i4"), + entry_start=3, + entry_stop=-5, + interpretation_executor=uproot4.decompression_executor, + library="pd", + ) + assert isinstance(series, pandas.Series) + assert series.values.tolist() == [ + -12, + -11, + -10, + -9, + -8, + -7, + -6, + -5, + -4, + -3, + -2, + -1, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + ] diff --git a/uproot4/__init__.py b/uproot4/__init__.py index 1b1746fd0..f1efa4594 100644 --- a/uproot4/__init__.py +++ b/uproot4/__init__.py @@ -8,13 +8,6 @@ from uproot4.cache import LRUCache from uproot4.cache import LRUArrayCache -object_cache = LRUCache(100) -array_cache = LRUArrayCache("100 MB") - -import uproot4.interpret - -library = "ak" - from uproot4.source.memmap import MemmapSource from uproot4.source.file import FileSource from uproot4.source.http import HTTPSource @@ -22,6 +15,11 @@ from uproot4.source.xrootd import XRootDSource from uproot4.source.xrootd import MultithreadedXRootDSource from uproot4.source.cursor import Cursor +from uproot4.source.futures import TrivialExecutor +from uproot4.source.futures import ThreadPoolExecutor + +decompression_executor = ThreadPoolExecutor() +interpretation_executor = TrivialExecutor() from uproot4.reading import open from uproot4.reading import ReadOnlyFile @@ -33,6 +31,11 @@ from uproot4.model import has_class_named from uproot4.model import class_named +import uproot4.interpret +import uproot4.interpret.library + +default_library = "ak" + import uproot4.models.TObject import uproot4.models.TString import uproot4.models.TArray diff --git a/uproot4/behaviors/TBranch.py b/uproot4/behaviors/TBranch.py index 660427c62..5e2c90560 100644 --- a/uproot4/behaviors/TBranch.py +++ b/uproot4/behaviors/TBranch.py @@ -2,14 +2,22 @@ from __future__ import absolute_import +import sys import threading try: from collections.abc import Mapping + from collections.abc import MutableMapping except ImportError: from collections import Mapping + from collections import MutableMapping +try: + import queue +except ImportError: + import Queue as queue import uproot4.source.cursor +import uproot4.interpret.library import uproot4.reading import uproot4.models.TBasket import uproot4.models.TObjArray @@ -28,6 +36,42 @@ def _get_recursive(hasbranches, where): return None +def _regularize_entries_start_stop(num_entries, entry_start, entry_stop): + if entry_start is None: + entry_start = 0 + elif entry_start < 0: + entry_start += num_entries + entry_start = min(num_entries, max(0, entry_start)) + + if entry_stop is None: + entry_stop = num_entries + elif entry_stop < 0: + entry_stop += num_entries + entry_stop = min(num_entries, max(0, entry_stop)) + + if entry_stop < entry_start: + entry_stop = entry_start + + return int(entry_start), int(entry_stop) + + +def _regularize_executors(decompression_executor, interpretation_executor): + if decompression_executor is None: + decompression_executor = uproot4.decompression_executor + if interpretation_executor is None: + interpretation_executor = uproot4.interpretation_executor + return decompression_executor, interpretation_executor + + +def _regularize_array_cache(array_cache, file): + if isinstance(array_cache, MutableMapping): + return array_cache + elif array_cache is None: + return file._array_cache + else: + raise TypeError("array_cache must be None or a MutableMapping") + + class HasBranches(Mapping): @property def branches(self): @@ -36,6 +80,10 @@ def branches(self): def __getitem__(self, where): original_where = where + got = self._lookup.get(original_where) + if got is not None: + return got + if uproot4._util.isint(where): return self.branches[where] elif uproot4._util.isstr(where): @@ -55,6 +103,7 @@ def __getitem__(self, where): where = "/".join([x for x in where.split("/") if x != ""]) for k, v in self.iteritems(recursive=True): if where == k: + self._lookup[original_where] = v return v else: raise uproot4.KeyInFileError(original_where, self._file.file_path) @@ -62,6 +111,7 @@ def __getitem__(self, where): elif recursive: got = _get_recursive(self, where) if got is not None: + self._lookup[original_where] = got return got else: raise uproot4.KeyInFileError(original_where, self._file.file_path) @@ -69,6 +119,7 @@ def __getitem__(self, where): else: for branch in self.branches: if branch.name == where: + self._lookup[original_where] = branch return branch else: raise uproot4.KeyInFileError(original_where, self._file.file_path) @@ -231,11 +282,134 @@ def __iter__(self): def __len__(self): return len(self.branches) + def _names_entries_to_ranges_or_baskets( + self, branch_names, entry_start, entry_stop + ): + out = [] + for name in branch_names: + branch = self[name] + for basket_num, range_or_basket in branch.entries_to_ranges_or_baskets( + entry_start, entry_stop + ): + out.append((name, branch, basket_num, range_or_basket)) + return out + + def _ranges_or_baskets_to_arrays( + self, + ranges_or_baskets, + branchid_interpretation, + entry_start, + entry_stop, + decompression_executor, + interpretation_executor, + cache, + library, + ): + notifications = queue.Queue() + + branchid_name = {} + branchid_arrays = {} + branchid_num_baskets = {} + ranges = [] + range_args = {} + range_original_index = {} + original_index = 0 + + for name, branch, basket_num, range_or_basket in ranges_or_baskets: + if id(branch) not in branchid_name: + branchid_name[id(branch)] = name + branchid_arrays[id(branch)] = {} + branchid_num_baskets[id(branch)] = 0 + branchid_num_baskets[id(branch)] += 1 + + if isinstance(range_or_basket, tuple) and len(range_or_basket) == 2: + ranges.append(range_or_basket) + range_args[range_or_basket] = (branch, basket_num) + range_original_index[range_or_basket] = original_index + else: + notifications.put(range_or_basket) + + original_index += 1 + + self._file.source.chunks(ranges, notifications=notifications) + + def replace(ranges_or_baskets, original_index, basket): + name, branch, basket_num, range_or_basket = ranges_or_baskets[ + original_index + ] + ranges_or_baskets[original_index] = name, branch, basket_num, basket + + def chunk_to_basket(chunk, branch, basket_num): + try: + cursor = uproot4.source.cursor.Cursor(chunk.start) + basket = uproot4.models.TBasket.Model_TBasket.read( + chunk, cursor, {"basket_num": basket_num}, self._file, branch + ) + original_index = range_original_index[(chunk.start, chunk.stop)] + replace(ranges_or_baskets, original_index, basket) + except Exception: + notifications.put(sys.exc_info()) + else: + notifications.put(basket) + + output = {} + + def basket_to_array(basket): + try: + assert basket.basket_num is not None + branch = basket.parent + interpretation = branchid_interpretation[id(branch)] + basket_arrays = branchid_arrays[id(branch)] + basket_arrays[basket.basket_num] = interpretation.basket_array( + basket, branch + ) + if len(basket_arrays) == branchid_num_baskets[id(branch)]: + name = branchid_name[id(branch)] + output[name] = interpretation.final_array( + basket_arrays, + entry_start, + entry_stop, + branch.entry_offsets, + library, + branch, + ) + except Exception: + notifications.put(sys.exc_info()) + else: + notifications.put(None) + + while len(output) < len(branchid_arrays): + try: + obj = notifications.get(timeout=0.001) + except queue.Empty: + continue + + if isinstance(obj, uproot4.source.chunk.Chunk): + chunk = obj + args = range_args[(chunk.start, chunk.stop)] + decompression_executor.submit(chunk_to_basket, chunk, *args) + + elif isinstance(obj, uproot4.models.TBasket.Model_TBasket): + basket = obj + interpretation_executor.submit(basket_to_array, basket) + + elif obj is None: + pass + + elif isinstance(obj, tuple) and len(obj) == 3: + uproot4.source.futures.delayed_raise(*obj) + + else: + raise AssertionError(obj) + + return dict((name, output[name]) for name, _, _, _ in ranges_or_baskets) + class TBranch(HasBranches): def postprocess(self, chunk, cursor, context): fWriteBasket = self.member("fWriteBasket") + self._lookup = {} self._interpretation = None self._count_branch = None self._count_leaf = None @@ -254,7 +428,13 @@ def postprocess(self, chunk, cursor, context): self._embedded_baskets_lock = None elif self.has_member("fBaskets"): - self._embedded_baskets = self.member("fBaskets") + self._embedded_baskets = [] + for basket in self.member("fBaskets"): + if basket is not None: + basket._basket_num = self._num_normal_baskets + len( + self._embedded_baskets + ) + self._embedded_baskets.append(basket) self._embedded_baskets_lock = None else: @@ -275,6 +455,13 @@ def tree(self): out = out.parent return out + @property + def cache_key(self): + if isinstance(self._parent, uproot4.behaviors.TTree.TTree): + return self.parent.cache_key + ":" + self.name + else: + return self.parent.cache_key + "/" + self.name + @property def entry_offsets(self): if self._num_normal_baskets == 0: @@ -308,7 +495,13 @@ def embedded_baskets(self): self.tree.chunk, cursor, {}, self._file, self ) with self._embedded_baskets_lock: - self._embedded_baskets = baskets + self._embedded_baskets = [] + for basket in baskets: + if basket is not None: + basket._basket_num = self._num_normal_baskets + len( + self._embedded_baskets + ) + self._embedded_baskets.append(basket) return self._embedded_baskets @@ -358,19 +551,13 @@ def __repr__(self): repr(self.name), len(self), id(self) ) - def basket_compressed_bytes(self, basket_num): - raise NotImplementedError - - def basket_uncompressed_bytes(self, basket_num): - raise NotImplementedError - - def basket_cursor(self, basket_num): + def basket_chunk_bytes(self, basket_num): if 0 <= basket_num < self._num_normal_baskets: - return uproot4.source.cursor.Cursor(self.member("fBasketSeek")[basket_num]) + return int(self.member("fBasketBytes")[basket_num]) elif 0 <= basket_num < self.num_baskets: raise IndexError( """branch {0} has {1} normal baskets; cannot get """ - """basket cursor {2} because only normal baskets have cursors + """basket chunk {2} because only normal baskets have chunks in file {3}""".format( repr(self.name), self._num_normal_baskets, @@ -380,19 +567,23 @@ def basket_cursor(self, basket_num): ) else: raise IndexError( - """branch {0} has {1} baskets; cannot get basket cursor {2} + """branch {0} has {1} baskets; cannot get basket chunk {2} in file {3}""".format( repr(self.name), self.num_baskets, basket_num, self._file.file_path ) ) - def basket_chunk_bytes(self, basket_num): + def basket_chunk_cursor(self, basket_num): if 0 <= basket_num < self._num_normal_baskets: - return int(self.member("fBasketBytes")[basket_num]) + start = self.member("fBasketSeek")[basket_num] + stop = start + self.basket_chunk_bytes(basket_num) + cursor = uproot4.source.cursor.Cursor(start) + chunk = self._file.source.chunk(start, stop) + return chunk, cursor elif 0 <= basket_num < self.num_baskets: raise IndexError( - """branch {0} has {1} normal baskets; cannot get """ - """basket chunk {2} because only normal baskets have chunks + """branch {0} has {1} normal baskets; cannot get chunk and """ + """cursor for basket {2} because only normal baskets have cursors in file {3}""".format( repr(self.name), self._num_normal_baskets, @@ -402,16 +593,17 @@ def basket_chunk_bytes(self, basket_num): ) else: raise IndexError( - """branch {0} has {1} baskets; cannot get basket chunk {2} + """branch {0} has {1} baskets; cannot get cursor and chunk """ + """for basket {2} in file {3}""".format( repr(self.name), self.num_baskets, basket_num, self._file.file_path ) ) def basket_key(self, basket_num): - cursor = self.basket_cursor(basket_num) - start = cursor.index + start = self.member("fBasketSeek")[basket_num] stop = start + uproot4.reading.ReadOnlyKey._format_big.size + cursor = uproot4.source.cursor.Cursor(start) chunk = self._file.source.chunk(start, stop) return uproot4.reading.ReadOnlyKey( chunk, cursor, {}, self._file, self, read_strings=False @@ -419,12 +611,9 @@ def basket_key(self, basket_num): def basket(self, basket_num): if 0 <= basket_num < self._num_normal_baskets: - cursor = self.basket_cursor(basket_num) - start = cursor.index - stop = start + self.basket_chunk_bytes(basket_num) - chunk = self._file.source.chunk(start, stop) + chunk, cursor = self.basket_chunk_cursor(basket_num) return uproot4.models.TBasket.Model_TBasket.read( - chunk, cursor, {}, self._file, self + chunk, cursor, {"basket_num": basket_num}, self._file, self ) elif 0 <= basket_num < self.num_baskets: return self.embedded_baskets[basket_num - self._num_normal_baskets] @@ -435,3 +624,77 @@ def basket(self, basket_num): repr(self.name), self.num_baskets, basket_num, self._file.file_path ) ) + + def entries_to_ranges_or_baskets(self, entry_start, entry_stop): + entry_offsets = self.entry_offsets + out = [] + start = entry_offsets[0] + for basket_num, stop in enumerate(entry_offsets[1:]): + if entry_start < stop and start <= entry_stop: + if 0 <= basket_num < self._num_normal_baskets: + byte_start = self.member("fBasketSeek")[basket_num] + byte_stop = byte_start + self.basket_chunk_bytes(basket_num) + out.append((basket_num, (byte_start, byte_stop))) + elif 0 <= basket_num < self.num_baskets: + out.append((basket_num, self.basket(basket_num))) + else: + raise AssertionError((self.name, basket_num)) + start = stop + return out + + def array( + self, + interpretation=None, + entry_start=None, + entry_stop=None, + decompression_executor=None, + interpretation_executor=None, + array_cache=None, + library="ak", + ): + if interpretation is None: + interpretation = self.interpretation + branchid_interpretation = {id(self): interpretation} + + entry_start, entry_stop = _regularize_entries_start_stop( + self.num_entries, entry_start, entry_stop + ) + decompression_executor, interpretation_executor = _regularize_executors( + decompression_executor, interpretation_executor + ) + array_cache = _regularize_array_cache(array_cache, self._file) + library = uproot4.interpret.library._regularize_library(library) + + cache_key = "{0}:{1}:{2}-{3}:{4}".format( + self.cache_key, + interpretation.cache_key, + entry_start, + entry_stop, + library.name, + ) + if array_cache is not None: + got = array_cache.get(cache_key) + if got is not None: + return got + + ranges_or_baskets = [] + for basket_num, range_or_basket in self.entries_to_ranges_or_baskets( + entry_start, entry_stop + ): + ranges_or_baskets.append((None, self, basket_num, range_or_basket)) + + out = self._ranges_or_baskets_to_arrays( + ranges_or_baskets, + branchid_interpretation, + entry_start, + entry_stop, + decompression_executor, + interpretation_executor, + array_cache, + library, + )[None] + + if array_cache is not None: + array_cache[cache_key] = out + + return out diff --git a/uproot4/behaviors/TTree.py b/uproot4/behaviors/TTree.py index ab7ea1a28..9eacb4dc7 100644 --- a/uproot4/behaviors/TTree.py +++ b/uproot4/behaviors/TTree.py @@ -24,8 +24,15 @@ def __repr__(self): def postprocess(self, chunk, cursor, context): self._chunk = chunk + + self._lookup = {} + return self + @property + def cache_key(self): + return self.parent.parent.cache_key + self.name + @property def chunk(self): return self._chunk diff --git a/uproot4/cache.py b/uproot4/cache.py index 6f0e96098..b7da715d8 100644 --- a/uproot4/cache.py +++ b/uproot4/cache.py @@ -39,6 +39,13 @@ def __init__(self, limit): self._data = {} self._lock = threading.Lock() + def __repr__(self): + if self._limit is None: + limit = "(no limit)" + else: + limit = "({0}/{1} full)".format(self._current, self._limit) + return "".format(limit, id(self)) + @property def limit(self): """ @@ -131,3 +138,10 @@ def __init__(self, limit_bytes): else: limit = uproot4._util.memory_size(limit_bytes) super(LRUArrayCache, self).__init__(limit) + + def __repr__(self): + if self._limit is None: + limit = "(no limit)" + else: + limit = "({0}/{1} bytes full)".format(self._current, self._limit) + return "".format(limit, id(self)) diff --git a/uproot4/interpret/__init__.py b/uproot4/interpret/__init__.py index cc82fe920..198b495ef 100644 --- a/uproot4/interpret/__init__.py +++ b/uproot4/interpret/__init__.py @@ -3,147 +3,6 @@ from __future__ import absolute_import -class Library(object): - """ - Indicates the type of array to produce. - - * `imported`: The imported library or raises a helpful "how to" - message if it could not be imported. - * `wrap_numpy(array)`: Wraps a NumPy array into the native type for - this library. - * `wrap_jagged(array)`: Wraps a jagged array into the native type for - this library. - * `wrap_python(array)`: Wraps an array of Python objects into the native - type for this library. - """ - - @property - def imported(self): - raise AssertionError - - def wrap_numpy(self, array): - raise AssertionError - - def wrap_jagged(self, array): - raise AssertionError - - def wrap_python(self, array): - raise AssertionError - - def __repr__(self): - return repr(self.name) - - def __eq__(self, other): - return type(_libraries[self.name]) is type(_libraries[other.name]) # noqa: E721 - - -class NumPy(Library): - name = "np" - - @property - def imported(self): - import numpy - - return numpy - - def wrap_numpy(self, array): - return array - - def wrap_jagged(self, array): - return self.wrap_python(array) - - def wrap_python(self, array): - numpy = self.imported - out = numpy.zeros(len(array), dtype=numpy.object) - for i, x in enumerate(array): - out[i] = x - return out - - -class Awkward(Library): - name = "ak" - - @property - def imported(self): - try: - import awkward1 - except ImportError: - raise ImportError( - """install the 'awkward1' package with: - - pip install awkward1""" - ) - else: - return awkward1 - - -class Pandas(Library): - name = "pd" - - @property - def imported(self): - try: - import pandas - except ImportError: - raise ImportError( - """install the 'pandas' package with: - - pip install pandas - -or - - conda install pandas""" - ) - else: - return pandas - - def wrap_numpy(self, array): - pandas = self.imported - return pandas.Series(pandas) - - def wrap_jagged(self, array): - array = array.compact - pandas = self.imported - index = pandas.MultiIndex.from_arrays( - [array.parents, array.localindex], names=["entry", "subentry"] - ) - return pandas.Series(array.content, index=index) - - def wrap_python(self, array): - pandas = self.imported - return pandas.Series(array) - - -_libraries = { - NumPy.name: NumPy(), - Awkward.name: Awkward(), - Pandas.name: Pandas(), -} - -_libraries["numpy"] = _libraries[NumPy.name] -_libraries["Numpy"] = _libraries[NumPy.name] -_libraries["NumPy"] = _libraries[NumPy.name] -_libraries["NUMPY"] = _libraries[NumPy.name] - -_libraries["awkward1"] = _libraries[Awkward.name] -_libraries["Awkward1"] = _libraries[Awkward.name] -_libraries["AWKWARD1"] = _libraries[Awkward.name] -_libraries["awkward"] = _libraries[Awkward.name] -_libraries["Awkward"] = _libraries[Awkward.name] -_libraries["AWKWARD"] = _libraries[Awkward.name] - -_libraries["pandas"] = _libraries[Pandas.name] -_libraries["Pandas"] = _libraries[Pandas.name] -_libraries["PANDAS"] = _libraries[Pandas.name] - - -def _regularize_library(library): - try: - return _libraries[library] - except KeyError: - raise ValueError("unrecognized library: {0}".format(repr(library))) - - class Interpretation(object): """ Abstract class for interpreting TTree basket data as arrays (NumPy and @@ -155,23 +14,11 @@ class Interpretation(object): dimension) of the NumPy array that would be created. * `awkward_form`: Form of the Awkward Array that would be created (requires `awkward1`); used by the `ak.type` function. - * `empty_array(library)`: An empty, finalized array, as defined by this - Interpretation and Library. - * `num_items(num_bytes, num_entries)`: Predict the number of items. * `basket_array(data, byte_offsets)`: Create a basket_array from a basket's `data` and `byte_offsets`. - * `fillable_array(num_items, num_entries)`: Create the array that is - incrementally filled by baskets as they arrive. This may include - excess at the beginning of the first basket and the end of the - last basket to cover full baskets (trimmed later). - * `fill(basket_array, fillable_array, item_start, item_stop, entry_start, - entry_stop)`: Copy data from the basket_array to fillable_array - with possible transformations (e.g. big-to-native endian, shift - offsets). - * `trim(fillable_array, entry_start, entry_stop)`: Remove any excess - entries in the first and last baskets. - * `finalize(fillable_array, library)`: Return an array in the desired - form for a given Library. + * `final_array(basket_arrays, entry_start, entry_stop, entry_offsets, library)`: + Combine basket_arrays with basket excess trimmed and in the form + required by a given library. """ @property @@ -186,77 +33,25 @@ def numpy_dtype(self): def awkward_form(self): raise AssertionError - def empty_array(self, library): - raise AssertionError - - def num_items(self, num_bytes, num_entries): - raise AssertionError - def basket_array(self, data, byte_offsets): raise AssertionError - def fillable_array(self, num_items, num_entries): - raise AssertionError - - def fill( - self, - basket_array, - fillable_array, - item_start, - item_stop, - entry_start, - entry_stop, + def final_array( + self, basket_arrays, entry_start, entry_stop, entry_offsets, library ): raise AssertionError - def trim(self, fillable_array, entry_start, entry_stop): - raise AssertionError - - def finalize(self, fillable_array, library): - raise AssertionError - - def hook_before_basket_array(self, data, byte_offsets): - pass - - def hook_after_basket_array(self, data, byte_offsets, basket_array): - pass - - def hook_before_fillable_array(self, num_items, num_entries): - pass - - def hook_after_fillable_array(self, num_items, num_entries, fillable_array): - pass - - def hook_before_fill( - self, - basket_array, - fillable_array, - item_start, - item_stop, - entry_start, - entry_stop, - ): - pass - - def hook_after_fill( - self, - basket_array, - fillable_array, - item_start, - item_stop, - entry_start, - entry_stop, - ): + def hook_before_basket_array(self, *args, **kwargs): pass - def hook_before_trim(self, fillable_array, entry_start, entry_stop): + def hook_after_basket_array(self, *args, **kwargs): pass - def hook_after_trim(self, fillable_array, entry_start, entry_stop, trimmed_array): + def hook_before_final_array(self, *args, **kwargs): pass - def hook_before_finalize(self, fillable_array, library): + def hook_before_library_finalize(self, *args, **kwargs): pass - def hook_after_finalize(self, fillable_array, library, final_array): + def hook_after_final_array(self, *args, **kwargs): pass diff --git a/uproot4/interpret/jagged.py b/uproot4/interpret/jagged.py new file mode 100644 index 000000000..ddeaeafbc --- /dev/null +++ b/uproot4/interpret/jagged.py @@ -0,0 +1,48 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE + +from __future__ import absolute_import + + +class CompactJaggedArray(object): + def __init__(self, offsets, content): + self._offsets = offsets + self._content = content + + @property + def offsets(self): + return self._offsets + + @property + def content(self): + return self._content + + @property + def parents(self): + raise NotImplementedError + + @property + def localindex(self): + raise NotImplementedError + + +class JaggedArray(object): + def __init__(self, starts, stops, content): + self._starts = starts + self._stops = stops + self._content = content + + @property + def starts(self): + return self._starts + + @property + def stops(self): + return self._stops + + @property + def content(self): + return self._content + + @property + def compact(self): + raise NotImplementedError diff --git a/uproot4/interpret/library.py b/uproot4/interpret/library.py new file mode 100644 index 000000000..853beb959 --- /dev/null +++ b/uproot4/interpret/library.py @@ -0,0 +1,199 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE + +from __future__ import absolute_import + +import numpy + +import uproot4.interpret.jagged +import uproot4.interpret.objects + + +class Library(object): + """ + Indicates the type of array to produce. + + * `imported`: The imported library or raises a helpful "how to" + message if it could not be imported. + * `wrap_numpy(array)`: Wraps a NumPy array into the native type for + this library. + * `wrap_jagged(array)`: Wraps a jagged array into the native type for + this library. + * `wrap_python(array)`: Wraps an array of Python objects into the native + type for this library. + """ + + @property + def imported(self): + raise AssertionError + + def empty(self, shape, dtype): + raise AssertionError + + def finalize(self, array, branch): + raise AssertionError + + def __repr__(self): + return repr(self.name) + + def __eq__(self, other): + return type(_libraries[self.name]) is type(_libraries[other.name]) # noqa: E721 + + +class NumPy(Library): + name = "np" + + @property + def imported(self): + import numpy + + return numpy + + def empty(self, shape, dtype): + return numpy.empty(shape, dtype) + + def finalize(self, array, branch): + if isinstance(array, uproot4.interpret.jagged.JaggedArray): + out = numpy.zeros(len(array), dtype=numpy.object) + for i, x in enumerate(array): + out[i] = x + return out + + else: + return array + + +class Awkward(Library): + name = "ak" + + @property + def imported(self): + try: + import awkward1 + except ImportError: + raise ImportError( + """install the 'awkward1' package with: + + pip install awkward1""" + ) + else: + return awkward1 + + +class Pandas(Library): + name = "pd" + + @property + def imported(self): + try: + import pandas + except ImportError: + raise ImportError( + """install the 'pandas' package with: + + pip install pandas + +or + + conda install pandas""" + ) + else: + return pandas + + def empty(self, shape, dtype): + return numpy.empty(shape, dtype) + + def finalize(self, array, branch): + pandas = self.imported + + if isinstance(array, uproot4.interpret.jagged.JaggedArray): + compact = array.compact + index = pandas.MultiIndex.from_arrays( + [compact.parents, compact.localindex], names=["entry", "subentry"] + ) + return pandas.Series(compact.content, index=index) + + elif isinstance(array, uproot4.interpret.objects.ObjectArray): + out = numpy.zeros(len(array), dtype=numpy.object) + for i, x in enumerate(array): + out[i] = x + return pandas.Series(out) + + else: + return pandas.Series(array) + + +class CuPy(Library): + name = "cp" + + @property + def imported(self): + try: + import cupy + except ImportError: + raise ImportError( + """install the 'cupy' package with: + + pip install cupy + +or + + conda install cupy""" + ) + else: + return cupy + + def empty(self, shape, dtype): + cupy = self.imported + return cupy.empty(shape, dtype) + + def finalize(self, array, branch): + cupy = self.imported + + if isinstance(array, uproot4.interpret.jagged.JaggedArray): + raise TypeError("jagged arrays and objects are not supported by CuPy") + + else: + return cupy.array(array) + + +_libraries = { + NumPy.name: NumPy(), + Awkward.name: Awkward(), + Pandas.name: Pandas(), + CuPy.name: CuPy(), +} + +_libraries["numpy"] = _libraries[NumPy.name] +_libraries["Numpy"] = _libraries[NumPy.name] +_libraries["NumPy"] = _libraries[NumPy.name] +_libraries["NUMPY"] = _libraries[NumPy.name] + +_libraries["awkward1"] = _libraries[Awkward.name] +_libraries["Awkward1"] = _libraries[Awkward.name] +_libraries["AWKWARD1"] = _libraries[Awkward.name] +_libraries["awkward"] = _libraries[Awkward.name] +_libraries["Awkward"] = _libraries[Awkward.name] +_libraries["AWKWARD"] = _libraries[Awkward.name] + +_libraries["pandas"] = _libraries[Pandas.name] +_libraries["Pandas"] = _libraries[Pandas.name] +_libraries["PANDAS"] = _libraries[Pandas.name] + +_libraries["cupy"] = _libraries[CuPy.name] +_libraries["Cupy"] = _libraries[CuPy.name] +_libraries["CuPy"] = _libraries[CuPy.name] +_libraries["CUPY"] = _libraries[CuPy.name] + + +def _regularize_library(library): + if isinstance(library, Library): + return _libraries[library.name] + + elif isinstance(library, type) and issubclass(library, Library): + return _libraries[library().name] + + else: + try: + return _libraries[library] + except KeyError: + raise ValueError("unrecognized library: {0}".format(repr(library))) diff --git a/uproot4/interpret/numerical.py b/uproot4/interpret/numerical.py index 411eaab52..5289640dd 100644 --- a/uproot4/interpret/numerical.py +++ b/uproot4/interpret/numerical.py @@ -16,36 +16,94 @@ def _dtype_shape(dtype): class Numerical(uproot4.interpret.Interpretation): - def empty_array(self, library): - return library.wrap_numpy(numpy.empty(0, dtype=self.numpy_dtype)) - - def fillable_array(self, num_items, num_entries): - assert num_items == num_entries - dtype, shape = _dtype_shape(self.to_dtype) - quotient, remainder = divmod(num_items, numpy.prod(shape)) - if remainder != 0: - raise ValueError( - "cannot reshape {0} items into dimensions {1}".format(num_items, shape) - ) - return numpy.empty(quotient, dtype=self.to_dtype) - - def fill( - self, - basket_array, - fillable_array, - item_start, - item_stop, - entry_start, - entry_stop, - ): - assert item_start == entry_start and item_stop == entry_stop - fillable_array.reshape(-1)[item_start:item_stop] = basket_array.reshape(-1) + @property + def to_dtype(self): + return self._to_dtype - def trim(self, fillable_array, entry_start, entry_stop): - return fillable_array[entry_start:entry_stop] + @property + def numpy_dtype(self): + return self._to_dtype - def finalize(self, fillable_array, library): - return library.wrap_numpy(fillable_array) + @property + def awkward_form(self): + raise NotImplementedError + + def final_array( + self, basket_arrays, entry_start, entry_stop, entry_offsets, library, branch + ): + self.hook_before_final_array( + basket_arrays=basket_arrays, + entry_start=entry_start, + entry_stop=entry_stop, + entry_offsets=entry_offsets, + library=library, + branch=branch, + ) + + if not entry_start < entry_stop: + output = library.empty((0,), self.to_dtype) + + else: + length = 0 + start = entry_offsets[0] + for basket_num, stop in enumerate(entry_offsets[1:]): + if start <= entry_start and entry_stop <= stop: + length += entry_stop - entry_start + elif start <= entry_start < stop: + length += stop - entry_start + elif start <= entry_stop <= stop: + length += entry_stop - start + elif entry_start < stop and start <= entry_stop: + length += stop - start + start = stop + + output = library.empty((length,), self.to_dtype) + + start = entry_offsets[0] + for basket_num, stop in enumerate(entry_offsets[1:]): + if start <= entry_start and entry_stop <= stop: + local_start = entry_start - start + local_stop = entry_stop - start + basket_array = basket_arrays[basket_num] + output[:] = basket_array[local_start:local_stop] + elif start <= entry_start < stop: + local_start = entry_start - start + local_stop = stop - start + basket_array = basket_arrays[basket_num] + output[: stop - entry_start] = basket_array[local_start:local_stop] + elif start <= entry_stop <= stop: + local_start = 0 + local_stop = entry_stop - start + basket_array = basket_arrays[basket_num] + output[start - entry_start :] = basket_array[local_start:local_stop] + elif entry_start < stop and start <= entry_stop: + basket_array = basket_arrays[basket_num] + output[start - entry_start : stop - entry_start] = basket_array + start = stop + + self.hook_before_library_finalize( + basket_arrays=basket_arrays, + entry_start=entry_start, + entry_stop=entry_stop, + entry_offsets=entry_offsets, + library=library, + branch=branch, + output=output, + ) + + output = library.finalize(output, branch) + + self.hook_after_final_array( + basket_arrays=basket_arrays, + entry_start=entry_start, + entry_stop=entry_stop, + entry_offsets=entry_offsets, + library=library, + branch=branch, + output=output, + ) + + return output class AsDtype(Numerical): @@ -56,14 +114,13 @@ def __init__(self, from_dtype, to_dtype=None): else: self._to_dtype = numpy.dtype(to_dtype) + def __repr__(self): + return "AsDtype({0}, {1})".format(repr(self._from_dtype), repr(self._to_dtype)) + @property def from_dtype(self): return self._from_dtype - @property - def to_dtype(self): - return self._to_dtype - _numpy_byteorder_to_cache_key = { "!": "B", ">": "B", @@ -107,26 +164,31 @@ def form(dtype, name): + "]" ) - return "AsDtype({0},{1})".format(from_dtype, to_dtype) - - @property - def numpy_dtype(self): - return self._to_dtype + return "{0}({1},{2})".format(type(self).__name__, from_dtype, to_dtype) - @property - def awkward_form(self): - raise NotImplementedError + def basket_array(self, basket, branch): + self.hook_before_basket_array(basket=basket, branch=branch) - def num_items(self, num_bytes, num_entries): + assert basket.byte_offsets is None dtype, shape = _dtype_shape(self._from_dtype) - quotient, remainder = divmod(num_bytes, dtype.itemsize) - assert remainder == 0 - return quotient + try: + output = basket.data.view(self._from_dtype).reshape((-1,) + shape) + except ValueError: + raise ValueError( + """basket {0} in branch {1} has the wrong number of bytes ({2}) """ + """for interpretation {3} +in file {4}""".format( + basket.basket_num, + repr(branch.name), + len(basket.data), + self, + branch.file.file_path, + ) + ) - def basket_array(self, data, byte_offsets): - assert byte_offsets is None - dtype, shape = _dtype_shape(self._from_dtype) - return data.view(self._from_dtype).reshape((-1,) + shape) + self.hook_before_basket_array(basket=basket, branch=branch, output=output) + + return output class AsArray(AsDtype): diff --git a/uproot4/interpret/objects.py b/uproot4/interpret/objects.py new file mode 100644 index 000000000..ac55f01e5 --- /dev/null +++ b/uproot4/interpret/objects.py @@ -0,0 +1,9 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/uproot4/blob/master/LICENSE + +from __future__ import absolute_import + +import uproot4.interpret.jagged + + +class ObjectArray(uproot4.interpret.jagged.JaggedArray): + pass diff --git a/uproot4/models/TBasket.py b/uproot4/models/TBasket.py index c11478003..6a03b77b4 100644 --- a/uproot4/models/TBasket.py +++ b/uproot4/models/TBasket.py @@ -9,6 +9,7 @@ import uproot4.model import uproot4.deserialization import uproot4.compression +import uproot4.behaviors.TBranch _tbasket_format1 = struct.Struct(">ihiIhh") @@ -17,10 +18,19 @@ class Model_TBasket(uproot4.model.Model): + def __repr__(self): + basket_num = self._basket_num if self._basket_num is not None else "(unknown)" + return "".format( + basket_num, repr(self._parent.name), id(self) + ) + def read_numbytes_version(self, chunk, cursor, context): pass def read_members(self, chunk, cursor, context): + assert isinstance(self._parent, uproot4.behaviors.TBranch.TBranch) + self._basket_num = context.get("basket_num") + ( self._members["fNbytes"], self._key_version, @@ -64,7 +74,7 @@ def read_members(self, chunk, cursor, context): cursor.skip(self._members["fKeylen"]) self._raw_data = None - self._data = cursor.bytes(chunk, self.border) + self._data = cursor.bytes(chunk, self.border, copy_if_memmap=True) else: if self.compressed_bytes != self.uncompressed_bytes: @@ -73,7 +83,9 @@ def read_members(self, chunk, cursor, context): ) self._raw_data = uncompressed.get(0, self.uncompressed_bytes) else: - self._raw_data = cursor.bytes(chunk, self.uncompressed_bytes) + self._raw_data = cursor.bytes( + chunk, self.uncompressed_bytes, copy_if_memmap=True + ) if self.border != self.uncompressed_bytes: self._data = self._raw_data[: self.border] @@ -90,6 +102,10 @@ def read_members(self, chunk, cursor, context): self._data = self._raw_data self._byte_offsets = None + @property + def basket_num(self): + return self._basket_num + @property def key_version(self): return self._key_version @@ -138,5 +154,10 @@ def data(self): def byte_offsets(self): return self._byte_offsets + def array(self, interpretation=None): + if interpretation is None: + interpretation = self._parent.interpretation + return interpretation.basket_array(self, self.parent) + uproot4.classes["TBasket"] = Model_TBasket diff --git a/uproot4/reading.py b/uproot4/reading.py index bc57a83f8..6da17fb0e 100644 --- a/uproot4/reading.py +++ b/uproot4/reading.py @@ -19,6 +19,7 @@ from collections import MutableMapping import uproot4.compression +import uproot4.cache import uproot4.source.cursor import uproot4.source.chunk import uproot4.source.memmap @@ -33,18 +34,21 @@ def open( file_path, - object_cache=uproot4.object_cache, - array_cache=uproot4.array_cache, + object_cache=100, + array_cache="100 MB", classes=uproot4.classes, **options ): """ Args: file_path (str or Path): File path or URL to open. - object_cache (None or MutableMapping): Cache of objects drawn from - ROOT directories (e.g histograms, TTrees, other directories). - array_cache (None or MutableMapping): Cache of arrays drawn from - TTrees. + object_cache (None, MutableMapping, or int): Cache of objects drawn + from ROOT directories (e.g histograms, TTrees, other directories); + if None, do not use a cache; if an int, create a new cache of this + size. + array_cache (None, MutableMapping, or memory size): Cache of arrays + drawn from TTrees; if None, do not use a cache; if a memory size, + create a new cache of this size. classes (None or MutableMapping): If None, defaults to uproot4.classes; otherwise, a container of class definitions that is both used to fill with new classes and search for dependencies. @@ -108,8 +112,8 @@ class ReadOnlyFile(object): def __init__( self, file_path, - object_cache=uproot4.object_cache, - array_cache=uproot4.array_cache, + object_cache=100, + array_cache="100 MB", classes=uproot4.classes, **options ): @@ -238,8 +242,10 @@ def object_cache(self): def object_cache(self, value): if value is None or isinstance(value, MutableMapping): self._object_cache = value + elif uproot4._util.isint(value): + self._object_cache = uproot4.cache.LRUCache(value) else: - raise TypeError("object_cache must be None or a MutableMapping") + raise TypeError("object_cache must be None, a MutableMapping, or an int") @property def array_cache(self): @@ -249,8 +255,12 @@ def array_cache(self): def array_cache(self, value): if value is None or isinstance(value, MutableMapping): self._array_cache = value + elif uproot4._util.isint(value) or uproot4._util.isstr(value): + self._array_cache = uproot4.cache.LRUArrayCache(value) else: - raise TypeError("array_cache must be None or a MutableMapping") + raise TypeError( + "array_cache must be None, a MutableMapping, or a memory size" + ) @property def classes(self): @@ -1023,6 +1033,38 @@ def show_streamers(self, classname=None, stream=sys.stdout): """ self._file.show_streamers(classname=classname, stream=stream) + @property + def cache_key(self): + return self.file.hex_uuid + ":" + "/".join(self.path) + "/" + + @property + def object_cache(self): + return self._file._object_cache + + @object_cache.setter + def object_cache(self, value): + if value is None or isinstance(value, MutableMapping): + self._file._object_cache = value + elif uproot4._util.isint(value): + self._file._object_cache = uproot4.cache.LRUCache(value) + else: + raise TypeError("object_cache must be None, a MutableMapping, or an int") + + @property + def array_cache(self): + return self._file._array_cache + + @array_cache.setter + def array_cache(self, value): + if value is None or isinstance(value, MutableMapping): + self._file._array_cache = value + elif uproot4._util.isint(value) or uproot4._util.isstr(value): + self._file._array_cache = uproot4.cache.LRUArrayCache(value) + else: + raise TypeError( + "array_cache must be None, a MutableMapping, or a memory size" + ) + def iterclassnames( self, recursive=True, diff --git a/uproot4/source/cursor.py b/uproot4/source/cursor.py index d07ae2b9d..5f592dbfd 100644 --- a/uproot4/source/cursor.py +++ b/uproot4/source/cursor.py @@ -171,18 +171,27 @@ def field(self, chunk, format, move=True): self._index = stop return format.unpack(chunk.get(start, stop))[0] - def bytes(self, chunk, length, move=True): + def bytes(self, chunk, length, move=True, copy_if_memmap=False): """ Interpret data at this index of the Chunk as raw bytes with a given `length`. If `move` is False, only peek: don't update the index. + + If `copy_if_memmap` is True and the chunk is a np.memmap, it is copied. """ start = self._index stop = start + length if move: self._index = stop - return chunk.get(start, stop) + out = chunk.get(start, stop) + if copy_if_memmap: + step = out + while getattr(step, "base", None) is not None: + if isinstance(step, numpy.memmap): + return numpy.array(out, copy=True) + step = step.base + return out def array(self, chunk, length, dtype, move=True): """ diff --git a/uproot4/source/futures.py b/uproot4/source/futures.py index 392937733..0c8897a41 100644 --- a/uproot4/source/futures.py +++ b/uproot4/source/futures.py @@ -10,6 +10,7 @@ from __future__ import absolute_import +import os import sys import time import threading @@ -79,6 +80,56 @@ def add_done_callback(self, fn): return fn(self) +class TrivialExecutor(Executor): + """ + An Executor that doesn't manage any Threads or Resources. + """ + + def __repr__(self): + return "".format(id(self)) + + @property + def num_workers(self): + """ + Always returns 0, which indicates the lack of background workers. + """ + return 0 + + def __enter__(self): + """ + Returns self. + """ + return self + + def __exit__(self, exception_type, exception_value, traceback): + """ + Does nothing. + """ + pass + + def submit(self, fn, *args, **kwargs): + """ + Immediately evaluate the function `fn` with `args` and `kwargs`. + """ + if isinstance(fn, TrivialFuture): + return fn + else: + return TrivialFuture(fn(*args, **kwargs)) + + def map(self, func, *iterables): + """ + Like Python's Executor. + """ + for x in iterables: + yield func(x) + + def shutdown(self, wait=True): + """ + Does nothing. + """ + pass + + class ResourceExecutor(Executor): """ An Executor that doesn't manage any Threads, but does manage Resources, @@ -105,6 +156,7 @@ def __enter__(self): Passes `__enter__` to the Resource. """ self._resource.__enter__() + return self def __exit__(self, exception_type, exception_value, traceback): """ @@ -252,12 +304,106 @@ def run(self): assert isinstance(future, TaskFuture) try: - future._result = future._task(self._resource) + if self._resource is None: + future._result = future._task() + else: + future._result = future._task(self._resource) except Exception: future._excinfo = sys.exc_info() future._set_finished() +class ThreadPoolExecutor(Executor): + """ + An Executor that manages only Threads, not Resources. + + All Threads are shut down when exiting a context block. + """ + + def __init__(self, num_workers=None): + """ + Args: + num_workers (None or int): Number of threads to launch; if None, + use os.cpu_count(). + """ + if num_workers is None: + if hasattr(os, "cpu_count"): + num_workers = os.cpu_count() + + else: + import multiprocessing + + num_workers = multiprocessing.cpu_count() + + self._work_queue = queue.Queue() + self._workers = [] + for x in range(num_workers): + self._workers.append(ThreadResourceWorker(None, self._work_queue)) + for thread in self._workers: + thread.start() + + def __repr__(self): + return "".format( + len(self._workers), id(self) + ) + + @property + def num_workers(self): + """ + The number of Threads in this thread pool. + """ + return len(self._workers) + + @property + def workers(self): + """ + The Threads in this thread pool. + """ + return self._workers + + def __enter__(self): + """ + Returns self. + """ + return self + + def __exit__(self, exception_type, exception_value, traceback): + """ + Shuts down the Threads in the thread pool. + """ + self.shutdown() + + def submit(self, fn, *args, **kwargs): + """ + Submits a function to be evaluated by a Thread in the thread pool. + """ + if len(args) != 0 or len(kwargs) != 0: + task = TaskFuture(lambda: fn(*args, **kwargs)) + else: + task = TaskFuture(fn) + + self._work_queue.put(task) + return task + + def map(self, func, *iterables): + """ + Like Python's Executor. + """ + futures = [self.submit(func, x) for x in iterables] + for future in futures: + yield future.result() + + def shutdown(self, wait=True): + """ + Puts None on the `work_queue` until all Threads get the message and + shut down. + """ + while any(thread.is_alive() for thread in self._workers): + for x in range(len(self._workers)): + self._work_queue.put(None) + time.sleep(0.001) + + class ThreadResourceExecutor(Executor): """ An Executor that manages Threads as well as Resources, such as file handles @@ -298,6 +444,7 @@ def __enter__(self): """ for thread in self._workers: thread.resource.__enter__() + return self def __exit__(self, exception_type, exception_value, traceback): """