-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbuild_mask.py
74 lines (59 loc) · 2.83 KB
/
build_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
""" Regrid once on each reference, find gridpoints with nans in any reference
to build the mask. """
import os
from pathlib import Path
if 'ESMFMKFILE' not in os.environ:
os.environ['ESMFMKFILE'] = str(Path(os.__file__).parent.parent / 'esmf.mk')
import xscen as xs
from dask.distributed import Client
import xarray as xr
import xesmf
import atexit
from xscen import CONFIG
path = 'config/path_part.yml'
config = 'config/config_part.yml'
xs.load_config(path, config, verbose=(__name__ == '__main__'), reset=True)
if __name__ == '__main__':
atexit.register(xs.send_mail_on_exit, subject='Build mask')
daskkws = CONFIG['dask'].get('client', {})
tdd = CONFIG['tdd']
# create project catalog
pcat = xs.ProjectCatalog(CONFIG['project_catalog']['path'],)
with Client(n_workers=3, threads_per_worker=5, memory_limit="8GB", **daskkws):
# create regular grid
ds_grid = xesmf.util.cf_grid_2d(-83, -55, 0.1, 42, 63, 0.1)
# load input, on per ref
dict_input = pcat.search(processing_level="indicators",
adjustment='IC6',
source='CanESM5',
experiment='ssp245',
variable='tg_mean',
domain='QC').to_dataset_dict(**tdd)
for did, ds in dict_input.items():
var = list(ds.data_vars)[0]
if not pcat.exists_in_cat(id=did.split('.')[0], domain='QC-reg1c-mask',
variable=var, processing_level="indicators",):
out = xs.regrid_dataset(ds=ds[[var]], ds_grid=ds_grid,
to_level=ds.attrs['cat:processing_level'])
out.attrs['cat:domain'] = 'QC-reg1c-mask'
# drop vestigial coords
out = out.drop_vars('rotated_pole', errors='ignore')
# for cf
out.lat.attrs['axis'] = 'Y'
out.lon.attrs['axis'] = 'X'
xs.save_and_update(ds=out,
pcat=pcat,
path=CONFIG['paths']['indicators'],
save_kwargs=dict(
rechunk={'time': -1, 'X': 50, 'Y': 50})
)
# build mask
# if any nan in any ref, mask
ds = pcat.search(
processing_level="indicators",
domain='QC-reg1c-mask').to_dataset(concat_on='bias_adjust_project')
count_nans = xr.where(ds.tg_mean.isel(time=0).isnull(), 0, 1
).mean(dim='bias_adjust_project').drop_vars('time')
mask = xr.where(count_nans == 1, 1, 0).to_dataset(name='mask')
xs.save_to_zarr(mask, f"{CONFIG['paths']['base']}mask.zarr",
mode='o')