Skip to content

Commit

Permalink
DAOS-16355 pydaos: add dir object cache for Datasets (#15888)
Browse files Browse the repository at this point in the history
Dataset reads a lot of samples that are stored under just a few common
directories, to reduce computation load on lookup for each sample file,
this commit introduces a cache of directory objects.

Signed-off-by: Denis Barakhtanov <[email protected]>
  • Loading branch information
0xE0F authored Feb 24, 2025
1 parent ff91fe2 commit f9ecc96
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 105 deletions.
19 changes: 13 additions & 6 deletions src/client/pydaos/torch/torch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ITER_BATCH_SIZE = 32
READDIR_BATCH_SIZE = 128
PARALLEL_SCAN_WORKERS = 16
DIR_CACHE_SIZE = 64 * 1024


def transform_fn_default(data):
Expand Down Expand Up @@ -56,6 +57,8 @@ class Dataset(TorchDataset):
Function to transform samples from storage to in-memory representation
readdir_batch_size: int (optional)
Number of directory entries to read for each readdir call.
dir_cache_size: int (optional)
Number of directory object entries to cache in memory.
Methods
Expand All @@ -78,12 +81,13 @@ class Dataset(TorchDataset):
# pylint: disable=too-many-arguments
def __init__(self, pool=None, cont=None, path=None,
transform_fn=transform_fn_default,
readdir_batch_size=READDIR_BATCH_SIZE):
readdir_batch_size=READDIR_BATCH_SIZE,
dir_cache_size=DIR_CACHE_SIZE):
super().__init__()

self._pool = pool
self._cont = cont
self._dfs = _Dfs(pool=pool, cont=cont)
self._dfs = _Dfs(pool=pool, cont=cont, dir_cache_size=dir_cache_size)
self._transform_fn = transform_fn
self._readdir_batch_size = readdir_batch_size

Expand Down Expand Up @@ -171,6 +175,8 @@ class IterableDataset(TorchIterableDataset):
Number of directory entries to read for each readdir call.
batch_size: int (optional)
Number of samples to fetch per iteration.
dir_cache_size: int (optional)
Number of directory object entries to cache in memory.
Methods
Expand All @@ -187,12 +193,13 @@ class IterableDataset(TorchIterableDataset):
def __init__(self, pool=None, cont=None, path=None,
transform_fn=transform_fn_default,
readdir_batch_size=READDIR_BATCH_SIZE,
batch_size=ITER_BATCH_SIZE):
batch_size=ITER_BATCH_SIZE,
dir_cache_size=DIR_CACHE_SIZE):
super().__init__()

self._pool = pool
self._cont = cont
self._dfs = _Dfs(pool=pool, cont=cont)
self._dfs = _Dfs(pool=pool, cont=cont, dir_cache_size=dir_cache_size)
self._transform_fn = transform_fn
self._readdir_batch_size = readdir_batch_size
self._batch_size = batch_size
Expand Down Expand Up @@ -506,14 +513,14 @@ class _Dfs():
Should not be used directly.
"""

def __init__(self, pool=None, cont=None, rd_only=True):
def __init__(self, pool=None, cont=None, rd_only=True, dir_cache_size=DIR_CACHE_SIZE):
if pool is None:
raise ValueError("pool label or UUID is required")
if cont is None:
raise ValueError("container label or UUID is required")

self._dc = DaosClient()
(ret, dfs) = torch_shim.torch_connect(DAOS_MAGIC, pool, cont, rd_only)
(ret, dfs) = torch_shim.torch_connect(DAOS_MAGIC, pool, cont, rd_only, dir_cache_size)
if ret != 0:
raise OSError(ret, os.strerror(ret), f"could not connect to {pool}:{cont}")

Expand Down
Loading

0 comments on commit f9ecc96

Please sign in to comment.