diff --git a/pipeline/validate.py b/pipeline/validate.py index a29b369..e4838d2 100644 --- a/pipeline/validate.py +++ b/pipeline/validate.py @@ -16,12 +16,13 @@ from pipeline.errors import * from pipeline.logs import init_logger, SUFFIXES, SUFFIX_LIST -from pipeline.utils import BypassSwitch +from pipeline.utils import BypassSwitch, open_kerchunk from ujson import JSONDecodeError +from dask.distributed import LocalCluster ## 1. Array Selection Tools -def find_dimensions(dimlen: int, divisions: int): +def find_dimensions(dimlen: int, divisions: int) -> int: """Determine index of slice end position given length of dimension and fraction to assess""" # Round down then add 1 slicemax = int(dimlen/divisions)+1 @@ -41,9 +42,9 @@ def get_vslice(shape: list, dtypes: list, lengths: list, divisions: list, logger logger.debug(f'Slice {vslice}') return vslice -def get_concat_dims(xfiles, detailfile=None): - # Not usable with virtual dimensions - concat_dims = {'time':0} +def get_concat_dims(xobjs, detailfile=None): + """Retrieve the sizes of the concatenation dims""" + concat_dims={} if os.path.isfile(detailfile): with open(detailfile) as f: details = json.load(f) @@ -52,9 +53,7 @@ def get_concat_dims(xfiles, detailfile=None): for dim in details['concat_dims']: concat_dims[dim] = 0 - for xf in xfiles: - # Open netcdf in lowest memory intensive way possible. - ds = xr.open_dataset(xf) + for ds in xobjs: for dim in concat_dims.keys(): concat_dims[dim] += ds[dim].shape[0] return concat_dims @@ -99,7 +98,7 @@ def pick_index(nfiles: list, indexes: list): indexes.append(index) return indexes -def locate_kerchunk(args, logger, get_str=False): +def locate_kerchunk(args, logger, get_str=False, remote_protocol='https'): """Gets the name of the latest kerchunk file for this project code""" files = os.listdir(args.proj_dir) # Get filename only kfiles = [] @@ -122,7 +121,7 @@ def locate_kerchunk(args, logger, get_str=False): if get_str: return kfile, False else: - return open_kerchunk(kfile, logger, remote_protocol='https'), False + return open_kerchunk(kfile, logger, remote_protocol=remote_protocol), False elif check_complete: if not args.forceful: logger.error('File already exists and no override is set') @@ -149,44 +148,6 @@ def locate_kerchunk(args, logger, get_str=False): logger.error(f'No Kerchunk file located at {args.proj_dir} and no in-place validation indicated - exiting') raise MissingKerchunkError -def open_kerchunk(kfile: str, logger, isparq=False, remote_protocol='file'): - """Open kerchunk file from JSON/parquet formats""" - if isparq: - logger.debug('Opening Kerchunk Parquet store') - from fsspec.implementations.reference import ReferenceFileSystem - fs = ReferenceFileSystem( - kfile, - remote_protocol='file', - target_protocol="file", - lazy=True) - return xr.open_dataset( - fs.get_mapper(), - engine="zarr", - backend_kwargs={"consolidated": False, "decode_times": False} - ) - else: - logger.debug('Opening Kerchunk JSON file') - try: - mapper = fsspec.get_mapper('reference://',fo=kfile, target_options={"compression":None}, remote_protocol=remote_protocol) - except JSONDecodeError as err: - logger.error(f"Kerchunk file {kfile} appears to be empty") - raise MissingKerchunkError - # Need a safe repeat here - ds = None - attempts = 0 - while attempts < 3 and not ds: - attempts += 1 - try: - ds = xr.open_zarr(mapper, consolidated=False, decode_times=True) - except OverflowError: - ds = None - except Exception as err: - raise MissingKerchunkError(message=f'Failed to open kerchunk file {kfile}') - if not ds: - raise ChunkDataError - logger.debug('Successfully opened Kerchunk with virtual xarray ds') - return ds - def mem_to_value(mem): """Convert a memory value i.e 2G into a value""" suffix = mem[-1] @@ -199,17 +160,6 @@ def value_to_mem(value): suffix_index += 1 return f'{value:.0f}{SUFFIX_LIST[suffix_index]}' -def check_memory(nfiles, indexes, mem, logger): - logger.info(f'Performing Memory Allowance check for {len(indexes)} files') - memcap = mem_to_value(mem) - nftotal = 0 - for index in indexes: - nftotal += os.path.getsize(nfiles[index]) - - logger.debug(f'Determined memory requirement is {nftotal} - allocated {memcap}') - if nftotal > memcap: - raise ExpectMemoryError(required=value_to_mem(nftotal), current=mem) - def open_netcdfs(args, logger, thorough=False, concat_dims='time'): """Returns a single xarray object with one timestep: - Select a single file and a single timestep from that file @@ -223,22 +173,13 @@ def open_netcdfs(args, logger, thorough=False, concat_dims='time'): thorough = True xobjs = [] if not thorough: - if not args.bypass.skip_memcheck: - check_memory(xfiles, indexes, args.memory, logger) - else: - logger.warning('Memory checks bypassed') - for one, i in enumerate(indexes): + for i in indexes: xobjs.append(xr.open_dataset(xfiles[i])) - if len(xobjs) == 0: logger.error('No valid timestep objects identified') raise NoValidTimeSlicesError(message='Kerchunk', verbose=args.verbose) return xobjs, indexes, xfiles else: - if not args.bypass.skip_memcheck: - check_memory(xfiles, [i for i in range(len(xfiles))], args.memory, logger) - else: - logger.warning('Memory checks bypassed') xobj = xr.concat([xr.open_dataset(fx) for fx in xfiles], dim=concat_dims, data_vars='minimal') return xobj, None, xfiles @@ -286,7 +227,11 @@ def compare_data(vname: str, xbox, kerchunk_box, logger, bypass=False): tolerance = None testpass = True - if not np.array_equal(xbox, kerchunk_box, equal_nan=True): + try: + equality = np.array_equal(xbox, kerchunk_box, equal_nan=True) + except TypeError as err: + equality = np.array_equal(xbox, kerchunk_box) + if not equality: logger.warning(f'Failed equality check for {vname}') raise ValidationError try: @@ -324,6 +269,10 @@ def compare_data(vname: str, xbox, kerchunk_box, logger, bypass=False): raise ValidationError def validate_shape_to_tolerance(nfiles: int, xv, dims, xshape, kshape, logger, detailfile=None): + """Special case function for validating a shaped array to some tolerance + - Alternative to opening N files, only works if each file has roughly the same total shape. + - Tolerance is based on the number of files supplied, more files means the tolerance is lower? + """ tolerance = 1/(nfiles*5) logger.info(f'Attempting shape bypass using concat-dim tolerance {tolerance*100}%') try: @@ -621,6 +570,9 @@ def validate_dataset(args, fh=None, logid=None, **kwargs): logger = init_logger(args.verbose, args.mode,'validate', fh=fh, logid=logid) logger.info(f'Starting tests for {args.proj_code}') + # Experimenting with a local dask cluster for memory limit + cluster = LocalCluster(n_workers=1, threads_per_worker=1, memory_target_fraction=0.95, memory_limit=str(args.memory + 'B')) + if hasattr(args, 'backtrack'): if args.backtrack: run_backtrack(args, logger) @@ -663,15 +615,15 @@ def validate_dataset(args, fh=None, logid=None, **kwargs): else: ## Set up loop variables fullset = bool(args.quality) - concat_dims = get_concat_dims(xfiles, detailfile=f'{args.proj_dir}/detail-cfg.json') - - logger.info(f"Attempting file subset validation: {len(indexes)}/{nfiles}") - for step, index in enumerate(indexes): - xobj = xobjs[step] - logger.info(f'Running tests for selected file: {index} ({step+1}/{len(indexes)})') - fullset = attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims) - if fullset: - break + concat_dims = get_concat_dims(xobjs, detailfile=f'{args.proj_dir}/detail-cfg.json') + if not fullset: + logger.info(f"Attempting file subset validation: {len(indexes)}/{nfiles}") + for step, index in enumerate(indexes): + xobj = xobjs[step] + logger.info(f'Running tests for selected file: {index} ({step+1}/{len(indexes)})') + fullset = attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims) + if fullset: + break if fullset: print()