Skip to content

Commit

Permalink
Updated with comments and rearranged sections
Browse files Browse the repository at this point in the history
  • Loading branch information
dwest77a committed Mar 6, 2024
1 parent 6355ed8 commit 32f7c3f
Showing 1 changed file with 130 additions and 57 deletions.
187 changes: 130 additions & 57 deletions pipeline/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import math

from pipeline.errors import *
from pipeline.logs import init_logger, SUFFIXES, SUFFIX_LIST, BypassSwitch
from pipeline.logs import init_logger, SUFFIXES, SUFFIX_LIST
from pipeline.utils import BypassSwitch
from ujson import JSONDecodeError

## 1. Array Selection Tools

Expand All @@ -40,13 +42,15 @@ def get_vslice(shape: list, dtypes: list, lengths: list, divisions: list, logger
return vslice

def get_concat_dims(xfiles, detailfile=None):
# Not usable with virtual dimensions
concat_dims = {'time':0}
if os.path.isfile(detailfile):
with open(detailfile) as f:
details = json.load(f)
# Initialise concat dims
if 'concat_dims' in details:
concat_dims[details['concat_dims']] = 0
for dim in details['concat_dims']:
concat_dims[dim] = 0

for xf in xfiles:
# Open netcdf in lowest memory intensive way possible.
Expand Down Expand Up @@ -162,7 +166,11 @@ def open_kerchunk(kfile: str, logger, isparq=False, remote_protocol='file'):
)
else:
logger.debug('Opening Kerchunk JSON file')
mapper = fsspec.get_mapper('reference://',fo=kfile, target_options={"compression":None}, remote_protocol=remote_protocol)
try:
mapper = fsspec.get_mapper('reference://',fo=kfile, target_options={"compression":None}, remote_protocol=remote_protocol)
except JSONDecodeError as err:
logger.error(f"Kerchunk file {kfile} appears to be empty")
raise MissingKerchunkError
# Need a safe repeat here
ds = None
attempts = 0
Expand Down Expand Up @@ -200,7 +208,7 @@ def check_memory(nfiles, indexes, mem, logger):
if nftotal > memcap:
raise ExpectMemoryError(required=value_to_mem(nftotal), current=mem)

def open_netcdfs(args, logger, thorough=False):
def open_netcdfs(args, logger, thorough=False, concat_dims='time'):
"""Returns a single xarray object with one timestep:
- Select a single file and a single timestep from that file
- Verify that a single timestep can be selected (Yes: return this xarray object, No: select all files and select a single timestep from that)
Expand Down Expand Up @@ -229,7 +237,7 @@ def open_netcdfs(args, logger, thorough=False):
check_memory(xfiles, [i for i in range(len(xfiles))], args.memory, logger)
else:
logger.warning('Memory checks bypassed')
xobj = xr.concat([xr.open_dataset(fx) for fx in xfiles], dim='time', data_vars='minimal')
xobj = xr.concat([xr.open_dataset(fx) for fx in xfiles], dim=concat_dims, data_vars='minimal')
return xobj, None, xfiles

## 3. Validation Testing
Expand Down Expand Up @@ -276,7 +284,7 @@ def compare_data(vname: str, xbox, kerchunk_box, logger, bypass=False):
tolerance = None

testpass = True
if not np.array_equal(xbox, kerchunk_box):
if not np.array_equal(xbox, kerchunk_box, equal_nan=True):
logger.warning(f'Failed equality check for {vname}')
raise ValidationError
try:
Expand Down Expand Up @@ -370,17 +378,48 @@ def validate_shapes(xobj, kobj, step: int, nfiles: list, xv: str, logger, proj_c
except:
pass
logger.debug(f'{xv} - dimension-adjusted shapes - K: {kshape}, X: {xshape}')

if len(xshape) != len(kshape):
raise ShapeMismatchError(var=xv, first=kshape, second=xshape)
elif xshape != kshape and bypass_shape: # Special bypass-shape testing
if concat_dims == {}:
if xshape != kshape:
# Incorrect dimensions on the shapes of the arrays
if xshape != kshape and bypass_shape: # Special bypass-shape testing
logger.info('Attempting special bypass using tolerance feature')
validate_shape_to_tolerance(nfiles, xv, xobj[xv].dims, xshape, kshape, logger, detailfile=detailfile)
else:
raise TrueShapeValidationError
raise ShapeMismatchError(var=xv, first=xshape, second=kshape)

def check_for_nan(box, bypass, logger, label=None):
"""Special function for assessing if a box selection has non-NaN values within it.
- Needs further testing using different data types"""
logger.debug(f'Checking nan values for {label}')

def handle_boxissue(err):
if type(err) == TypeError:
return False
else:
if bypass.skip_boxfail:
logger.warning(f'{err} - Uncaught error bypassed')
return False
else:
raise err

if box.size == 1:
try:
isnan = np.isnan(box)
except Exception as err:
isnan = handle_boxissue(err)
else:
pass
try:
kb = np.array(box)
isnan = np.all(kb!=kb)
except Exception as err:
isnan = handle_boxissue(err)

if not isnan and box.size >= 1:
try:
isnan = np.all(kb == np.mean(kb))
except Exception as err:
isnan = handle_boxissue(err)
return isnan

def validate_selection(xvariable, kvariable, vname: str, divs: int, currentdiv: int, logger, bypass=BypassSwitch()):
"""Validate this data selection in xvariable/kvariable objects
Expand All @@ -395,8 +434,9 @@ def validate_selection(xvariable, kvariable, vname: str, divs: int, currentdiv:
logger.debug(f'Attempt {repeat} - {currentdiv} divs for {vname}')

vslice = []
shape = {}
shape = []
if divs > 1:
shape = xvariable.shape
logger.debug(f'Detected shape {shape} for {vname}')
dtypes = [xvariable[xvariable.dims[x]].dtype for x in range(len(xvariable.shape))]
lengths = [len(xvariable[xvariable.dims[x]]) for x in range(len(xvariable.shape))]
Expand All @@ -409,28 +449,24 @@ def validate_selection(xvariable, kvariable, vname: str, divs: int, currentdiv:
kbox = kvariable

# Zero shape means no point running divisions - just perform full check
if shape == {} and vslice == []:
if shape == [] and vslice == []:
logger.debug(f'Skipping to full selection (1 division) for {vname}')
currentdiv = 1

try:
kb = np.array(kbox)
if np.all(kb!=kb):
isnan = True
elif np.all(kb == np.mean(kb)):
isnan = True
else:
isnan = False
except Exception as err:
if bypass.skip_boxfail:
logger.warning(f'{err} - check versions')
isnan = True
else:
raise err
try_multiple = 0
knan, xnan = False, True
# Attempt nan checking multiple times due to network issues.
while try_multiple < 3 and knan != xnan:
knan = check_for_nan(kbox, bypass, logger, label='Kerchunk')
xnan = check_for_nan(xbox, bypass, logger, label='Xarray')
try_multiple += 1

if knan != xnan:
raise ValidationError('Kerchunk/NetCDF value mismatch - expected NaN, received values')

if kbox.size >= 1 and not isnan:
if kbox.size >= 1 and not knan:
# Evaluate kerchunk vs xarray and stop here
logger.debug(f'Found non-NaN values with box-size: {int(kbox.size)}')
logger.debug(f'Found comparable box-size: {int(kbox.size)} values')
compare_data(vname, xbox, kbox, logger, bypass=bypass.skip_data_sum)
else:
logger.debug(f'Attempt {repeat} - slice is Null')
Expand All @@ -442,26 +478,39 @@ def validate_selection(xvariable, kvariable, vname: str, divs: int, currentdiv:
if not bypass.skip_softfail:
raise SoftfailBypassError

def validate_data(xobj, kobj, xv: str, step: int, logger, bypass=BypassSwitch(), depth_default=128):
def validate_data(xobj, kobj, xv: str, step: int, logger, bypass=BypassSwitch(), depth_default=128, nfiles=2):
"""Run growing selection test for specified variable from xarray and kerchunk datasets"""
logger.info(f'{xv} : Starting growbox data tests for {step}')
logger.info(f'{xv} : Starting growbox data tests for {step+1} - {depth_default}')

kvariable, xvariable = match_timestamp(xobj[xv], kobj[xv], logger)
if nfiles > 1: # Timestep matching not required if only one file
kvariable, xvariable = match_timestamp(xobj[xv], kobj[xv], logger)
else:
kvariable = kobj[xv]
xvariable = xobj[xv]

# Attempt 128 divisions within selection - 128, 64, 32, 16, 8, 4, 2, 1
return validate_selection(xvariable, kvariable, xv, depth_default, depth_default, logger, bypass=bypass)

def validate_timestep(args, xobj, kobj, step: int, nfiles: int, logger, concat_dims={}):
def validate_timestep(args, xobj, kobj, step: int, nfiles: int, logger, concat_dims={}, index=0):
"""Run all tests for a single file which may or may not equate to 1 timestep"""
# Note: step indexed from 0

# Run Variable and Shape validation

if 'virtual' in concat_dims:
# Assume virtual dimension is first?
logger.info("Filtering out virtual dimension for testing")
virtual = {concat_dims['virtual']:index}
logger.debug(f'Kerchunk index: {index}')
kobj = kobj.isel(**virtual)

xvars = set(xobj.variables)
kvars = set(kobj.variables)
if xvars&kvars != xvars: # Overlap of sets - all xvars should be in kvars
missing = (xvars^kvars)&xvars
raise VariableMismatchError(missing=missing)
else:
logger.info(f'Passed Variable tests')
logger.info(f'Passed Variable tests - all required variables are present')
print()
for xv in xvars:
validate_shapes(xobj, kobj, step, nfiles, xv, logger, args.proj_code,
Expand All @@ -472,7 +521,7 @@ def validate_timestep(args, xobj, kobj, step: int, nfiles: int, logger, concat_d
logger.info(f'Passed all Shape tests')
print()
for xv in xvars:
validate_data(xobj, kobj, xv, step, logger, bypass=args.bypass)
validate_data(xobj, kobj, xv, step, logger, bypass=args.bypass, nfiles=nfiles)
logger.info(f'{xv} : Passed Data test')

def run_successful(args, logger):
Expand All @@ -495,8 +544,18 @@ def run_successful(args, logger):
os.makedirs(complete_dir)

# Open config file to get correct version
version_no = 'kr1.0'
detailfile = f'{args.proj_dir}/detail-cfg.json'
if os.path.isfile(detailfile):
with open(detailfile) as f:
details = json.load(f)
if 'version_no' in details:
version_no = details['version_no']
logger.info(f'Found version {version_no} in detail-cfg')
else:
logger.warning('detail-cfg.json file missing or unreachable - using default version number')

newfile = f'{complete_dir}/{args.proj_code}_kr1.0.json'
newfile = f'{complete_dir}/{args.proj_code}_{version_no}.json'
if args.dryrun:
logger.info(f'DRYRUN: mv {kfile} {newfile}')
else:
Expand Down Expand Up @@ -531,21 +590,16 @@ def run_backtrack(args, logger):

logger.info(f'{args.proj_code} Successfully backtracked to pre-validation')

def attempt_timestep(args, xobj, kobj, step, nfiles, logger, xfiles, depth=0, concat_dims={}):
def attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims={}, fullset=False):
"""Handler for attempting processing on a timestep multiple times.
- Handles error conditions"""
try:
validate_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims)
except ShapeMismatchError as err:
if depth == 2:
if fullset:
raise TrueShapeValidationError
else:
return True
except XKShapeToleranceError as err:
if depth < 1:
# Try new routine to just get the key variable sizes
concat_dims = get_concat_dims(xfiles, detailfile=f'{args.proj_dir}/detail-cfg.json')
attempt_timestep(args, xobj, kobj, step+1, nfiles, logger, xfiles, depth=depth+1, concat_dims=concat_dims)
else:
return True
except Exception as err:
raise err

Expand Down Expand Up @@ -575,26 +629,45 @@ def validate_dataset(args):
if indexes == None:
args.quality = True

detailfile = f'{args.proj_dir}/detail-cfg.json'
with open(detailfile) as f:
details = json.load(f)

## Open kerchunk file
kobj, _v = locate_kerchunk(args, logger)
if not kobj:
raise MissingKerchunkError

## Set up loop variables
fullset = bool(args.quality)
virtual = False
if 'virtual_concat' in details:
virtual = details['virtual_concat']

if virtual:
concat_dims = {'virtual': details['combine_kwargs']['concat_dims'][0]}
# Perform virtual attempt
logger.info(f"Attempting file subset validation: {len(indexes)}/{nfiles} (virtual dimension)")
for step, index in enumerate(indexes):
xobj = xobjs[step]
logger.info(f'Running tests for selected file: {index} ({step+1}/{len(indexes)})')
attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims, index=index)
else:
## Set up loop variables
fullset = bool(args.quality)
concat_dims = get_concat_dims(xfiles, detailfile=f'{args.proj_dir}/detail-cfg.json')

if not fullset:
logger.info(f"Attempting file subset validation: {len(indexes)}/{nfiles}")
for step, index in enumerate(indexes):
xobj = xobjs[step]
logger.info(f'Running tests for selected file: {index} ({step+1}/{len(indexes)})')
fullset = attempt_timestep(args, xobj, kobj, step+1, nfiles, logger, xfiles)

if fullset:
print()
logger.info(f"Attempting total validation")
xobjs, indexes, nfiles = open_netcdfs(args, logger, thorough=True)
fullset = attempt_timestep(args, xobjs, kobj, 0, 1, logger, xfiles)
fullset = attempt_timestep(args, xobj, kobj, step, nfiles, logger, concat_dims=concat_dims)
if fullset:
break

if fullset:
print()
logger.info(f"Attempting total validation")
xobjs, indexes, nfiles = open_netcdfs(args, logger, thorough=True)
attempt_timestep(args, xobjs, kobj, 0, 1, logger, xfiles, concat_dims=concat_dims, fullset=True)

logger.info('All tests passed successfully')
print()
Expand Down

0 comments on commit 32f7c3f

Please sign in to comment.