Skip to content

Commit

Permalink
Partially written validate docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
dwest77a committed Mar 28, 2024
1 parent 416197a commit 0c049ea
Showing 1 changed file with 31 additions and 79 deletions.
110 changes: 31 additions & 79 deletions pipeline/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from pipeline.errors import *
from pipeline.logs import init_logger, SUFFIXES, SUFFIX_LIST
from pipeline.utils import BypassSwitch
from pipeline.utils import BypassSwitch, open_kerchunk
from ujson import JSONDecodeError
from dask.distributed import LocalCluster

## 1. Array Selection Tools

def find_dimensions(dimlen: int, divisions: int):
def find_dimensions(dimlen: int, divisions: int) -> int:
"""Determine index of slice end position given length of dimension and fraction to assess"""
# Round down then add 1
slicemax = int(dimlen/divisions)+1
Expand All @@ -41,9 +42,9 @@ def get_vslice(shape: list, dtypes: list, lengths: list, divisions: list, logger
logger.debug(f'Slice {vslice}')
return vslice

def get_concat_dims(xfiles, detailfile=None):
# Not usable with virtual dimensions
concat_dims = {'time':0}
def get_concat_dims(xobjs, detailfile=None):
"""Retrieve the sizes of the concatenation dims"""
concat_dims={}
if os.path.isfile(detailfile):
with open(detailfile) as f:
details = json.load(f)
Expand All @@ -52,9 +53,7 @@ def get_concat_dims(xfiles, detailfile=None):
for dim in details['concat_dims']:
concat_dims[dim] = 0

for xf in xfiles:
# Open netcdf in lowest memory intensive way possible.
ds = xr.open_dataset(xf)
for ds in xobjs:
for dim in concat_dims.keys():
concat_dims[dim] += ds[dim].shape[0]
return concat_dims
Expand Down Expand Up @@ -99,7 +98,7 @@ def pick_index(nfiles: list, indexes: list):
indexes.append(index)
return indexes

def locate_kerchunk(args, logger, get_str=False):
def locate_kerchunk(args, logger, get_str=False, remote_protocol='https'):
"""Gets the name of the latest kerchunk file for this project code"""
files = os.listdir(args.proj_dir) # Get filename only
kfiles = []
Expand All @@ -122,7 +121,7 @@ def locate_kerchunk(args, logger, get_str=False):
if get_str:
return kfile, False
else:
return open_kerchunk(kfile, logger, remote_protocol='https'), False
return open_kerchunk(kfile, logger, remote_protocol=remote_protocol), False
elif check_complete:
if not args.forceful:
logger.error('File already exists and no override is set')
Expand All @@ -149,44 +148,6 @@ def locate_kerchunk(args, logger, get_str=False):
logger.error(f'No Kerchunk file located at {args.proj_dir} and no in-place validation indicated - exiting')
raise MissingKerchunkError

def open_kerchunk(kfile: str, logger, isparq=False, remote_protocol='file'):
"""Open kerchunk file from JSON/parquet formats"""
if isparq:
logger.debug('Opening Kerchunk Parquet store')
from fsspec.implementations.reference import ReferenceFileSystem
fs = ReferenceFileSystem(
kfile,
remote_protocol='file',
target_protocol="file",
lazy=True)
return xr.open_dataset(
fs.get_mapper(),
engine="zarr",
backend_kwargs={"consolidated": False, "decode_times": False}
)
else:
logger.debug('Opening Kerchunk JSON file')
try:
mapper = fsspec.get_mapper('reference://',fo=kfile, target_options={"compression":None}, remote_protocol=remote_protocol)
except JSONDecodeError as err:
logger.error(f"Kerchunk file {kfile} appears to be empty")
raise MissingKerchunkError
# Need a safe repeat here
ds = None
attempts = 0
while attempts < 3 and not ds:
attempts += 1
try:
ds = xr.open_zarr(mapper, consolidated=False, decode_times=True)
except OverflowError:
ds = None
except Exception as err:
raise MissingKerchunkError(message=f'Failed to open kerchunk file {kfile}')
if not ds:
raise ChunkDataError
logger.debug('Successfully opened Kerchunk with virtual xarray ds')
return ds

def mem_to_value(mem):
"""Convert a memory value i.e 2G into a value"""
suffix = mem[-1]
Expand All @@ -199,17 +160,6 @@ def value_to_mem(value):
suffix_index += 1
return f'{value:.0f}{SUFFIX_LIST[suffix_index]}'

def check_memory(nfiles, indexes, mem, logger):
logger.info(f'Performing Memory Allowance check for {len(indexes)} files')
memcap = mem_to_value(mem)
nftotal = 0
for index in indexes:
nftotal += os.path.getsize(nfiles[index])

logger.debug(f'Determined memory requirement is {nftotal} - allocated {memcap}')
if nftotal > memcap:
raise ExpectMemoryError(required=value_to_mem(nftotal), current=mem)

def open_netcdfs(args, logger, thorough=False, concat_dims='time'):
"""Returns a single xarray object with one timestep:
- Select a single file and a single timestep from that file
Expand All @@ -223,22 +173,13 @@ def open_netcdfs(args, logger, thorough=False, concat_dims='time'):
thorough = True
xobjs = []
if not thorough:
if not args.bypass.skip_memcheck:
check_memory(xfiles, indexes, args.memory, logger)
else:
logger.warning('Memory checks bypassed')
for one, i in enumerate(indexes):
for i in indexes:
xobjs.append(xr.open_dataset(xfiles[i]))

if len(xobjs) == 0:
logger.error('No valid timestep objects identified')
raise NoValidTimeSlicesError(message='Kerchunk', verbose=args.verbose)
return xobjs, indexes, xfiles
else:
if not args.bypass.skip_memcheck:
check_memory(xfiles, [i for i in range(len(xfiles))], args.memory, logger)
else:
logger.warning('Memory checks bypassed')
xobj = xr.concat([xr.open_dataset(fx) for fx in xfiles], dim=concat_dims, data_vars='minimal')
return xobj, None, xfiles

Expand Down Expand Up @@ -286,7 +227,11 @@ def compare_data(vname: str, xbox, kerchunk_box, logger, bypass=False):
tolerance = None

testpass = True
if not np.array_equal(xbox, kerchunk_box, equal_nan=True):
try:
equality = np.array_equal(xbox, kerchunk_box, equal_nan=True)
except TypeError as err:
equality = np.array_equal(xbox, kerchunk_box)
if not equality:
logger.warning(f'Failed equality check for {vname}')
raise ValidationError
try:
Expand Down Expand Up @@ -324,6 +269,10 @@ def compare_data(vname: str, xbox, kerchunk_box, logger, bypass=False):
raise ValidationError

def validate_shape_to_tolerance(nfiles: int, xv, dims, xshape, kshape, logger, detailfile=None):
"""Special case function for validating a shaped array to some tolerance
- Alternative to opening N files, only works if each file has roughly the same total shape.
- Tolerance is based on the number of files supplied, more files means the tolerance is lower?
"""
tolerance = 1/(nfiles*5)
logger.info(f'Attempting shape bypass using concat-dim tolerance {tolerance*100}%')
try:
Expand Down Expand Up @@ -621,6 +570,9 @@ def validate_dataset(args, fh=None, logid=None, **kwargs):
logger = init_logger(args.verbose, args.mode,'validate', fh=fh, logid=logid)
logger.info(f'Starting tests for {args.proj_code}')

# Experimenting with a local dask cluster for memory limit
cluster = LocalCluster(n_workers=1, threads_per_worker=1, memory_target_fraction=0.95, memory_limit=str(args.memory + 'B'))

if hasattr(args, 'backtrack'):
if args.backtrack:
run_backtrack(args, logger)
Expand Down Expand Up @@ -663,15 +615,15 @@ def validate_dataset(args, fh=None, logid=None, **kwargs):
else:
## Set up loop variables
fullset = bool(args.quality)
concat_dims = get_concat_dims(xfiles, detailfile=f'{args.proj_dir}/detail-cfg.json')

logger.info(f"Attempting file subset validation: {len(indexes)}/{nfiles}")
for step, index in enumerate(indexes):
xobj = xobjs[step]
logger.info(f'Running tests for selected file: {index} ({step+1}/{len(indexes)})')
fullset = attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims)
if fullset:
break
concat_dims = get_concat_dims(xobjs, detailfile=f'{args.proj_dir}/detail-cfg.json')
if not fullset:
logger.info(f"Attempting file subset validation: {len(indexes)}/{nfiles}")
for step, index in enumerate(indexes):
xobj = xobjs[step]
logger.info(f'Running tests for selected file: {index} ({step+1}/{len(indexes)})')
fullset = attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims)
if fullset:
break

if fullset:
print()
Expand Down

0 comments on commit 0c049ea

Please sign in to comment.