forked from luigifusco/compression-filter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfilter_wrapper.py
115 lines (95 loc) · 4.44 KB
/
filter_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import struct
import os
from typing import Tuple
from collections.abc import Mapping
def double_to_uint32(f):
packed = struct.pack('d', f)
return struct.unpack('II', packed)
def float_to_uint32(f):
packed = struct.pack('f', f)
return struct.unpack('I', packed)[0]
class JP2SPWV_Filter(Mapping):
FILTER_ID = 308
def __init__(self, base_cr: float, height: int, width: int, residual_opt: Tuple[str, float], data_dim: int = 2, filter_path: str = None):
assert height > 0 and width > 0
base_cr = float(base_cr)
hdf_filter_opts = [int(height), int(width), float_to_uint32(base_cr)]
self.base_cr = base_cr
self.height = height
self.width = width
self.residual_opt = residual_opt
residual_type_str, residual_opt_val = residual_opt
self.data_dim = data_dim
residual_opt_val = float(residual_opt_val)
if residual_type_str == "quantile_target":
residual_type = 1
hdf_filter_opts.extend([residual_type, float_to_uint32(residual_opt_val)])
elif residual_type_str == "max_error_target":
residual_type = 2
hdf_filter_opts.extend([residual_type, float_to_uint32(residual_opt_val)])
elif residual_type_str == "relative_error_target":
residual_type = 3
hdf_filter_opts.extend([residual_type, float_to_uint32(residual_opt_val)])
elif residual_type_str == "fixed_sparsification":
residual_type = 4
q_a, q_b = double_to_uint32(residual_opt_val)
hdf_filter_opts.extend([residual_type, q_a, q_b])
else:
print(f""""Unknown residual_type {residual_type_str}, has to be one of 'quantile_target',
'max_error_target', 'relative_error_target' or 'fixed_sparsification""")
self.hdf_filter_opts = tuple(hdf_filter_opts)
self.chunks = (*[1 for _ in range(self.data_dim - 2)], height, width)
#if filter_path is None:
# filter_path = os.path.join(os.path.dirname(__file__), 'src')
#os.environ["HDF5_PLUGIN_PATH"] = filter_path
# https://github.com/silx-kit/hdf5plugin/blob/main/src/hdf5plugin/_filters.py
@property
def _kwargs(self):
return {
'dtype': 'float32',
'chunks': self.chunks,
'compression': self.FILTER_ID,
'compression_opts': self.hdf_filter_opts
}
def __hash__(self):
return hash((self.FILTER_ID, self.hdf_filter_opts))
def __len__(self):
return len(self._kwargs)
def __iter__(self):
return iter(self._kwargs)
def __getitem__(self, item):
return self._kwargs[item]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--base_cr', type=str, default=200, help='base compression ratio')
parser.add_argument('-H', '--height', type=int, default=721, help='height of the data slice or size of latitude dim')
parser.add_argument('-w', '--width', type=int, default=1440, help='width of the data slice or size of longitude dim')
parser.add_argument('-m', '--max_error_target', default=None, type=float)
parser.add_argument('-r', '--relative_error_target', default=None, type=float)
parser.add_argument('-q', '--quantile_target', default=None, type=float)
parser.add_argument('-s', '--fixed_sparsification', default=None, type=float)
args = parser.parse_args()
residual_type = 0
base_cr = float(args.base_cr)
if args.quantile_target:
residual_opt_val = float(args.quantile_target)
residual_type = "quantile_target"
elif args.max_error_target:
residual_opt_val = float(args.max_error_target)
residual_type = "max_error_target"
elif args.relative_error_target:
residual_opt_val = float(args.relative_error_target)
residual_type = "relative_error_target"
elif args.fixed_sparsification:
residual_opt_val = float(args.fixed_sparsification)
residual_type = "fixed_sparsification"
else:
print('Using default settings: relative error target of 0.01')
residual_opt_val = 0.01
residual_type = "relative_error_target"
jp2spwv_filter = JP2SPWV_Filter(base_cr=args.base_cr,
height=args.height,
width=args.width,
residual_opt=(residual_type, residual_opt_val))
print(jp2spwv_filter.hdf_filter_opts)