Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests for the caching mechanism #27

Merged
merged 24 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c668ccd
add tests for `caching.path`
keewis Aug 8, 2023
133cdb5
add tests for the basic encoders
keewis Aug 8, 2023
42aa856
add tests for the variable encoder
keewis Aug 8, 2023
f841f54
add tests for the `Group` encoder
keewis Aug 8, 2023
6ddebc5
add tests for the generic encoder
keewis Aug 8, 2023
378a361
add tests for the json preprocessor
keewis Aug 8, 2023
098cadf
fix a test bug
keewis Aug 8, 2023
b6450af
add tests for the datetime decoder
keewis Aug 8, 2023
edd23f9
add tests for the json postprocessor
keewis Aug 8, 2023
4745736
replace the callable `parse_bytes` parameter with `type_code`
keewis Aug 8, 2023
7e98481
use the new `postprocess` for decoding json objects
keewis Aug 8, 2023
2c85246
integrate the decoding of `numpy` arrays into `decode_array`
keewis Aug 8, 2023
6a6b13b
check `arr.fs.fs.protocol` since the outer filesystem is always `dir`
keewis Aug 8, 2023
b09650d
add tests for `decode_array`
keewis Aug 8, 2023
5983bb6
fix the array tests
keewis Aug 8, 2023
9d3f5c6
add tests for the `Variable` decoder
keewis Aug 9, 2023
8976683
add tests for the `Group` decoder
keewis Aug 9, 2023
3f7cd80
add tests for the hierarchy decoder
keewis Aug 9, 2023
7862a0a
check that records-per-chunk is passed through decode_group and decod…
keewis Aug 9, 2023
c390147
add tests for the high-level encoder
keewis Aug 9, 2023
58a7d2a
add tests for the high-level decoder
keewis Aug 9, 2023
2e39562
add a tests for the cache creation function
keewis Aug 9, 2023
71d84e7
add tests for the cache reader
keewis Aug 9, 2023
ed44ba5
specifically cast timedeltas and datetimes to int64
keewis Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions ceos_alos2/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@

from ceos_alos2.utils import parse_bytes

raw_dtypes = {
"C*8": np.dtype([("real", ">f4"), ("imag", ">f4")]),
"IU2": np.dtype(">u2"),
}


def parse_data(content, type_code):
dtype = raw_dtypes.get(type_code)
if dtype is None:
raise ValueError(f"unknown type code: {type_code}")

raw = np.frombuffer(content, dtype)
if type_code == "C*8":
return raw["real"] + 1j * raw["imag"]
return raw


def normalize_chunksize(chunksize, dim_size):
if chunksize in (None, -1) or chunksize > dim_size:
Expand Down Expand Up @@ -96,7 +112,7 @@ class Array:
dtype: str | np.dtype = field(repr=True)

# convert raw bytes to data
parse_bytes: callable = field(repr=False)
type_code: str = field(repr=False)

# chunk sizes: chunks in (rows, cols)
records_per_chunk: int | None = field(repr=True, default=None)
Expand Down Expand Up @@ -128,7 +144,7 @@ def __eq__(self, other):
and self.shape == other.shape
and self.dtype == other.dtype
and self.records_per_chunk == other.records_per_chunk
and self.parse_bytes == other.parse_bytes
and self.type_code == other.type_code
)

def __getitem__(self, indexers):
Expand All @@ -142,7 +158,7 @@ def __getitem__(self, indexers):
for chunk_info, ranges in tasks:
chunk = read_chunk(f, **chunk_info)
raw_bytes = extract_ranges(chunk, ranges)
chunk_data = [self.parse_bytes(part) for part in raw_bytes]
chunk_data = [parse_data(part, type_code=self.type_code) for part in raw_bytes]
data_.extend(chunk_data)

data = np.stack(data_, axis=0)
Expand Down
4 changes: 2 additions & 2 deletions ceos_alos2/sar_image/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from ceos_alos2.sar_image.caching.decoders import decode_hierarchy, decode_objects
from ceos_alos2.sar_image.caching.decoders import decode_hierarchy, postprocess
from ceos_alos2.sar_image.caching.encoders import encode_hierarchy, preprocess
from ceos_alos2.sar_image.caching.path import (
local_cache_location,
Expand All @@ -19,7 +19,7 @@ def encode(obj):


def decode(cache, records_per_chunk):
partially_decoded = json.loads(cache, object_hook=decode_objects)
partially_decoded = json.loads(cache, object_hook=postprocess)

return decode_hierarchy(partially_decoded, records_per_chunk=records_per_chunk)

Expand Down
67 changes: 27 additions & 40 deletions ceos_alos2/sar_image/caching/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from ceos_alos2.array import Array
from ceos_alos2.hierarchy import Group, Variable
from ceos_alos2.sar_image.io import parse_data


def postprocess(obj):
if obj.get("__type__") == "tuple":
return tuple(obj["data"])

return obj


def decode_datetime(obj):
Expand All @@ -16,29 +22,16 @@ def decode_datetime(obj):
return reference + offsets


def decode_objects(obj):
def decode_array(encoded, records_per_chunk):
def default_decode(obj):
return np.array(obj["data"], dtype=obj["dtype"])

if obj.get("__type__") == "tuple":
return tuple(obj["data"])
elif obj.get("__type__") != "array":
return obj

dtype = np.dtype(obj["dtype"])
decoders = {
"M": decode_datetime,
}
decoder = decoders.get(dtype.kind, default_decode)
if encoded.get("__type__") == "array":
dtype = np.dtype(encoded["dtype"])
decoders = {"M": decode_datetime}
decoder = decoders.get(dtype.kind, default_decode)

return decoder(obj)


def decode_array(encoded, records_per_chunk):
if not isinstance(encoded, dict):
return encoded
elif encoded.get("__type__") != "record_array":
raise ValueError(f"unknown type: {encoded['__type__']}")
return decoder(encoded)

mapper = fsspec.get_mapper(encoded["root"])
try:
Expand All @@ -48,7 +41,7 @@ def decode_array(encoded, records_per_chunk):

fs = DirFileSystem(fs=mapper.fs, path=mapper.root)

parser = curry(parse_data, type_code=encoded["type_code"])
type_code = encoded["type_code"]
url = encoded["url"]
shape = encoded["shape"]
dtype = encoded["dtype"]
Expand All @@ -59,11 +52,23 @@ def decode_array(encoded, records_per_chunk):
byte_ranges=byte_ranges,
shape=shape,
dtype=dtype,
parse_bytes=parser,
type_code=type_code,
records_per_chunk=records_per_chunk,
)


def decode_variable(encoded, records_per_chunk):
data = decode_array(encoded["data"], records_per_chunk=records_per_chunk)

return Variable(dims=encoded["dims"], data=data, attrs=encoded["attrs"])


def decode_group(encoded, records_per_chunk):
data = valmap(curry(decode_hierarchy, records_per_chunk=records_per_chunk), encoded["data"])

return Group(path=encoded["path"], url=encoded["url"], data=data, attrs=encoded["attrs"])


def decode_hierarchy(encoded, records_per_chunk):
type_ = encoded.get("__type__")

Expand All @@ -76,21 +81,3 @@ def decode_hierarchy(encoded, records_per_chunk):
return encoded

return decoder(encoded, records_per_chunk=records_per_chunk)


def decode_group(encoded, records_per_chunk):
if encoded.get("__type__") != "group":
raise ValueError("not a group")

data = valmap(curry(decode_hierarchy, records_per_chunk=records_per_chunk), encoded["data"])

return Group(path=encoded["path"], url=encoded["url"], data=data, attrs=encoded["attrs"])


def decode_variable(encoded, records_per_chunk):
if encoded.get("__type__") != "variable":
raise ValueError("not a variable")

data = decode_array(encoded["data"], records_per_chunk=records_per_chunk)

return Variable(dims=encoded["dims"], data=data, attrs=encoded["attrs"])
36 changes: 15 additions & 21 deletions ceos_alos2/sar_image/caching/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,31 @@
def encode_timedelta(obj):
units, _ = np.datetime_data(obj.dtype)

return obj.astype(int).tolist(), {"units": units}
return obj.astype("int64").tolist(), {"units": units}


def encode_datetime(obj):
units, _ = np.datetime_data(obj.dtype)
reference = obj[0]

encoding = {"reference": str(reference), "units": units}
encoded = (obj - reference).astype(int).tolist()
encoded = (obj - reference).astype("int64").tolist()

return encoded, encoding


def encode_arraylike(obj):
def encode_array(obj):
if isinstance(obj, Array):
return {
"__type__": "backend_array",
"root": obj.fs.path,
"url": obj.url,
"shape": obj.shape,
"dtype": str(obj.dtype),
"byte_ranges": obj.byte_ranges,
"type_code": obj.type_code,
}

def default_encode(obj):
return obj.tolist(), {}

Expand All @@ -40,21 +51,6 @@ def default_encode(obj):
}


def encode_array(obj):
if not isinstance(obj, Array):
return encode_arraylike(obj)

return {
"__type__": "record_array",
"root": obj.fs.path,
"url": obj.url,
"shape": obj.shape,
"dtype": str(obj.dtype),
"byte_ranges": obj.byte_ranges,
"type_code": obj.parse_bytes.keywords["type_code"],
}


def encode_variable(var):
encoded_data = encode_array(var.data)

Expand All @@ -70,10 +66,8 @@ def encode_group(group):
def encode_entry(obj):
if isinstance(obj, Group):
return encode_group(obj)
elif isinstance(obj, Variable):
return encode_variable(obj)
else:
return ValueError(f"unknown type: {type(obj)}")
return encode_variable(obj)

encoded_data = valmap(encode_entry, group.data)

Expand Down
7 changes: 3 additions & 4 deletions ceos_alos2/sar_image/caching/path.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import hashlib
import pathlib

import platformdirs

project_name = "xarray-ceos-alos2"
cache_root = platformdirs.user_cache_path(project_name)


def hashsum(data, algorithm="sha256"):
Expand All @@ -13,11 +13,10 @@ def hashsum(data, algorithm="sha256"):


def local_cache_location(remote_root, path):
subdirs, fname = f"/{path}".rsplit("/", 1)
_, fname = f"/{path}".rsplit("/", 1)
cache_name = f"{fname}.index"

local_root = pathlib.Path(platformdirs.user_cache_dir(project_name))
return local_root / hashsum(remote_root) / cache_name
return cache_root / hashsum(remote_root) / cache_name


def remote_cache_location(remote_root, path):
Expand Down
18 changes: 1 addition & 17 deletions ceos_alos2/sar_image/io.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import itertools
import math

import numpy as np
from tlz.functoolz import curry
from tlz.itertoolz import concat

from ceos_alos2.array import Array
from ceos_alos2.common import record_preamble
from ceos_alos2.sar_image.file_descriptor import file_descriptor_record
from ceos_alos2.sar_image.metadata import raw_dtypes
from ceos_alos2.sar_image.processed_data import processed_data_record
from ceos_alos2.sar_image.signal_data import signal_data_record
from ceos_alos2.utils import to_dict
Expand All @@ -19,27 +16,14 @@
}


def parse_data(content, type_code):
dtype = raw_dtypes.get(type_code)
if dtype is None:
raise ValueError(f"unknown type code: {type_code}")

raw = np.frombuffer(content, dtype)
if type_code == "C*8":
return raw["real"] + 1j * raw["imag"]
return raw


def create_array(fs, path, byte_ranges, shape, dtype, type_code, records_per_chunk):
parser = curry(parse_data, type_code=type_code)

return Array(
fs=fs,
url=path,
byte_ranges=byte_ranges,
shape=shape,
dtype=dtype,
parse_bytes=parser,
type_code=type_code,
records_per_chunk=records_per_chunk,
)

Expand Down
5 changes: 0 additions & 5 deletions ceos_alos2/sar_image/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,6 @@ def metadata_to_groups(metadata):
return Group(path="", url=None, data=processed, attrs=attrs)


raw_dtypes = {
"C*8": np.dtype([("real", ">f4"), ("imag", ">f4")]),
"IU2": np.dtype(">u2"),
}

dtypes = {
"C*8": np.dtype("complex64"),
"IU2": np.dtype("uint16"),
Expand Down
17 changes: 10 additions & 7 deletions ceos_alos2/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,15 @@ def diff_array(a, b):
sections = []
if a.fs != b.fs:
lines = ["Differing filesystem:"]
if a.fs.protocol != b.fs.protocol:
lines.append(f" L protocol {a.fs.protocol}")
lines.append(f" R protocol {b.fs.protocol}")
# fs.protocol is always `dir`, so we have to check the wrapped fs
if a.fs.fs.protocol != b.fs.fs.protocol:
lines.append(f" L protocol {a.fs.fs.protocol}")
lines.append(f" R protocol {b.fs.fs.protocol}")
if a.fs.path != b.fs.path:
lines.append(f" L path {a.fs.path}")
lines.append(f" R path {b.fs.path}")
if len(lines) == 1:
lines.append(" (unknown differences)")
sections.append(newline.join(lines))
if a.url != b.url:
lines = [
Expand Down Expand Up @@ -183,11 +186,11 @@ def diff_array(a, b):
f" {a.dtype} != {b.dtype}",
]
sections.append(newline.join(lines))
if a.parse_bytes != b.parse_bytes:
if a.type_code != b.type_code:
lines = [
"Differing byte parser:",
f" L type_code {a.parse_bytes.parameters['type_code']}",
f" R type_code {b.parse_bytes.parameters['type_code']}",
"Differing type code:",
f" L type_code {a.type_code}",
f" R type_code {b.type_code}",
]
sections.append(newline.join(lines))
if a.records_per_chunk != b.records_per_chunk:
Expand Down
Loading