Skip to content

Commit

Permalink
✨Supported raw stuffs in NpSafeSerializer
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 27, 2024
1 parent dd8e173 commit 65f2ac5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
39 changes: 37 additions & 2 deletions core/toolkit/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def save(
folder: TPath,
data: Union["np.ndarray", Callable[[], "np.ndarray"]],
*,
to_raw: bool = False,
verbose: bool = True,
) -> None:
import numpy as np
Expand All @@ -811,7 +812,10 @@ def save(
with timeit(f"save '{folder}'", enabled=verbose):
if not isinstance(data, np.ndarray):
data = data()
np.save(array_path, data)
if to_raw:
data.tofile(array_path)
else:
np.save(array_path, data)
with (folder / cls.size_file).open("w") as f:
f.write(str(get_file_size(array_path)))

Expand All @@ -821,12 +825,31 @@ def load(cls, folder: TPath, *, mmap_mode: Optional[str] = None) -> "np.ndarray"

return np.load(to_path(folder) / cls.array_file, mmap_mode=mmap_mode) # type: ignore

@classmethod
def load_raw(
cls,
folder: TPath,
*,
dtype: "np.dtype",
shape: Tuple[int, ...],
mmap_mode: Optional[str] = None,
) -> "np.ndarray":
import numpy as np

array_path = to_path(folder) / cls.array_file
if mmap_mode is None:
return np.fromfile(array_path, dtype=dtype).reshape(shape)
return np.memmap(array_path, dtype=dtype, mode=mmap_mode, shape=shape) # type: ignore

@classmethod
def try_load(
cls,
folder: TPath,
*,
dtype: Optional["np.dtype"] = None,
shape: Optional[Tuple[int, ...]] = None,
mmap_mode: Optional[str] = None,
from_raw: bool = False,
no_load: bool = False,
**kwargs: Any,
) -> Optional["np.ndarray"]:
Expand All @@ -848,6 +871,12 @@ def try_load(
return None
if no_load:
return np.zeros(0)
if from_raw:
if kwargs:
raise ValueError("`kwargs` are not supported for `from_raw`")
if dtype is None or shape is None:
raise ValueError("`dtype` and `shape` are required for `from_raw`")
return cls.load_raw(folder, dtype=dtype, shape=shape, mmap_mode=mmap_mode)
return np.load(array_path, mmap_mode=mmap_mode, **kwargs) # type: ignore

@classmethod
Expand All @@ -856,7 +885,10 @@ def load_with(
folder: TPath,
init_fn: Callable[[], "np.ndarray"],
*,
dtype: Optional["np.dtype"] = None,
shape: Optional[Tuple[int, ...]] = None,
mmap_mode: Optional[str] = None,
use_raw: bool = False,
no_load: bool = False,
verbose: bool = True,
**kwargs: Any,
Expand All @@ -870,15 +902,18 @@ def load_with(

load_func = lambda: cls.try_load(
folder,
dtype=dtype,
shape=shape,
mmap_mode=mmap_mode,
from_raw=use_raw,
no_load=no_load,
**kwargs,
)
array = load_func()
if array is None:
folder = to_path(folder)
folder.mkdir(parents=True, exist_ok=True)
cls.save(folder, init_fn, verbose=verbose)
cls.save(folder, init_fn, to_raw=use_raw, verbose=verbose)
array = load_func()
if array is None: # pragma: no cover
raise RuntimeError(f"failed to load array from '{folder}'")
Expand Down
28 changes: 23 additions & 5 deletions tests/test_toolkit/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ class TestNpSafeSerializer(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.folder = Path(self.temp_dir.name)
self.data = np.array([1, 2, 3, 4, 5])
self.data = np.array([1, 2, 3, 4, 5], dtype="S7")
self.rawd = dict(dtype=self.data.dtype, shape=self.data.shape)

def tearDown(self):
self.temp_dir.cleanup()
Expand All @@ -429,15 +430,29 @@ def test_save(self):
NpSafeSerializer.save(self.folder, self.data)
self.assertTrue((self.folder / NpSafeSerializer.array_file).exists())
self.assertTrue((self.folder / NpSafeSerializer.size_file).exists())
NpSafeSerializer.cleanup(self.folder)
NpSafeSerializer.save(self.folder, self.data, to_raw=True)
self.assertTrue((self.folder / NpSafeSerializer.array_file).exists())
self.assertTrue((self.folder / NpSafeSerializer.size_file).exists())

def test_load(self):
NpSafeSerializer.save(self.folder, self.data)
loaded_data = NpSafeSerializer.load(self.folder)
np.testing.assert_array_equal(loaded_data, self.data)
NpSafeSerializer.cleanup(self.folder)
NpSafeSerializer.save(self.folder, self.data, to_raw=True)
loaded_data = NpSafeSerializer.load_raw(self.folder, **self.rawd)
np.testing.assert_array_equal(loaded_data, self.data)
loaded_data = NpSafeSerializer.load_raw(self.folder, **self.rawd, mmap_mode="r")
np.testing.assert_array_equal(loaded_data, self.data)

def test_try_load(self):
NpSafeSerializer.save(self.folder, self.data)
np.testing.assert_array_equal(NpSafeSerializer.try_load(self.folder), self.data)
NpSafeSerializer.cleanup(self.folder)
NpSafeSerializer.save(self.folder, self.data, to_raw=True)
loaded_data = NpSafeSerializer.try_load(self.folder, **self.rawd, from_raw=True)
np.testing.assert_array_equal(loaded_data, self.data)
self.assertIsNone(NpSafeSerializer.try_load(self.folder / "invalid"))
np.save(self.folder / NpSafeSerializer.array_file, self.data[..., :2])
self.assertIsNone(NpSafeSerializer.try_load(self.folder))
Expand All @@ -459,12 +474,15 @@ def test_try_load_invalid_size(self):
self.assertIsNone(loaded_data)

def test_load_with(self):
def init_fn():
return np.array([6, 7, 8, 9, 10])
def init():
return np.array([6, 7, 8, 9, 10], dtype=self.data.dtype)

NpSafeSerializer.cleanup(self.folder)
loaded_data = NpSafeSerializer.load_with(self.folder, init_fn)
np.testing.assert_array_equal(loaded_data, init_fn())
arr = NpSafeSerializer.load_with(self.folder, init)
np.testing.assert_array_equal(arr, init())
NpSafeSerializer.cleanup(self.folder)
arr = NpSafeSerializer.load_with(self.folder, init, **self.rawd, use_raw=True)
np.testing.assert_array_equal(arr, init())
self.assertTrue((self.folder / NpSafeSerializer.array_file).exists())
self.assertTrue((self.folder / NpSafeSerializer.size_file).exists())

Expand Down

0 comments on commit 65f2ac5

Please sign in to comment.