-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathZonal_RFR_iterate.py
197 lines (161 loc) · 10.6 KB
/
Zonal_RFR_iterate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# import netCDF4 as nc
import os
import datetime
import pandas as pd
# import seaborn as sns
import numpy as np
import scipy.signal as sps
# import xarray as xr
# %matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.colors
# from cartopy import config
# import cartopy.crs as ccrs
from numpy import ma
# from tqdm.notebook import trange, tqdm
from time import time
from glob import glob
# from cartopy.util import add_cyclic_point
# Meredith G. L. Brown
# Jan 2023
#
# # set conditions
# # LAGS = [1,3,7,14,21,28,45,60,90,120] #
# # LAGS = [1,3,7,14,21,28,45,60] #
# LAGS = list(range(1,60,5))
# topnodes = 30 # in the new code this is the total top number of nodes, whereas in the old code it is top nodes per input feature
# normdat = 0 # 1 = yes, 0 = no
# min_thresh = 0.001
# sel_vars = 7
# windat = 1 #windowing 1= yes, 0=no
# winstart = 0
# winend = 601
members = ['ens01','ens02','ens03','ens04','ens05','ens06','ens07','ens08','ens09','ens10']
# sel_vars options: 1 = T1000+T050 (all zonal)
# 2 = AOD+T050 (all zonal)
# 3 = AOD+T050+T1000 (all zonal)
# 4 = T1000 (all zonal)
# 5 = T050 (all zonal)
# 6 = Global AOD, T050, T1000
# 7 = Tropics AOD, T050, T1000
# 8 = Tropics+SubtropN AOD, T050, T1000
# 9 = Tropics+SubtropN+TempN AOD, T050, T1000
# 10 = Tropics+SubtropN+TempN+PolN AOD, T050, T1000
# 11 = Tropics+SubtropN+TempN+PolN+SubtropS AOD, T050, T1000
# 12 = Tropics+SubtropS AOD, T050, T1000
# 13 = Tropics+SubtropS+TempS AOD, T050, T1000
# 14 = Tropics+SubtropS+TempS+PolS AOD, T050, T1000
# 15 = Tropics+SubtropS+SubtropN AOD, T050, T1000
# 16 = Tropics+SubtropS+SubtropN+TempN+TempS AOD, T050, T1000
# 17 = Tropics+SubtropS+SubtropN+TempN+TempS+PolN+PolS AOD, T050, T1000
fname = '/Users/merbrow/Documents/CLDERA/RFR-CLDERA/ZonalHSW_LV_AOD_T050_T1000.csv'
for member in members:
from sklearn.preprocessing import normalize
from ZonalFunctions import readRFRfeatures
RAW_DATA1 = readRFRfeatures(fname)
for key in RAW_DATA1.keys():
# RAW_DATA[key] = list(normalize([RAW_DATA[key]], norm="max")[0])
RAW_DATA1[key] = list(([RAW_DATA1[key]])[0])
if sel_vars == 1:
SUBSET = ['PolN'+member+'_T1000','TempN'+member+'_T1000','SubtropN'+member+'_T1000','Tropics'+member+'_T1000','SubtropS'+member+'_T1000','TempS'+member+'_T1000','PolS'+member+'_T1000',
'PolN'+member+'_T050','TempN'+member+'_T050','SubtropN'+member+'_T050','Tropics'+member+'_T050','SubtropS'+member+'_T050','TempS'+member+'_T050','PolS'+member+'_T050']
if sel_vars == 2:
SUBSET = ['PolN'+member+'_AOD','TempN'+member+'_AOD','SubtropN'+member+'_AOD','Tropics'+member+'_AOD','SubtropS'+member+'_AOD','TempS'+member+'_AOD','PolS'+member+'_AOD',
'PolN'+member+'_T050','TempN'+member+'_T050','SubtropN'+member+'_T050','Tropics'+member+'_T050','SubtropS'+member+'_T050','TempS'+member+'_T050','PolS'+member+'_T050']
if sel_vars == 3:
SUBSET = ['PolN'+member+'_AOD','TempN'+member+'_AOD','SubtropN'+member+'_AOD','Tropics'+member+'_AOD','SubtropS'+member+'_AOD','TempS'+member+'_AOD','PolS'+member+'_AOD',
'PolN'+member+'_T050','TempN'+member+'_T050','SubtropN'+member+'_T050','Tropics'+member+'_T050','SubtropS'+member+'_T050','TempS'+member+'_T050','PolS'+member+'_T050',
'PolN'+member+'_T1000','TempN'+member+'_T1000','SubtropN'+member+'_T1000','Tropics'+member+'_T1000','SubtropS'+member+'_T1000','TempS'+member+'_T1000','PolS'+member+'_T1000']
if sel_vars == 4:
SUBSET = ['PolN'+member+'_T1000','TempN'+member+'_T1000','SubtropN'+member+'_T1000','Tropics'+member+'_T1000','SubtropS'+member+'_T1000','TempS'+member+'_T1000','PolS'+member+'_T1000']
if sel_vars == 5:
SUBSET = ['PolN'+member+'_T050','TempN'+member+'_T050','SubtropN'+member+'_T050','Tropics'+member+'_T050','SubtropS'+member+'_T050','TempS'+member+'_T050','PolS'+member+'_T050']
if sel_vars == 6:
SUBSET = ['Globe'+member+'_AOD','Globe'+member+'_T050','Globe'+member+'_T1000']
if sel_vars == 7:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000']
if sel_vars == 8:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000']
if sel_vars == 9:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000','TempN'+member+'_AOD','TempN'+member+'_T050','TempN'+member+'_T1000']
if sel_vars == 10:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000','TempN'+member+'_AOD','TempN'+member+'_T050','TempN'+member+'_T1000','PolN'+member+'_AOD','PolN'+member+'_T050','PolN'+member+'_T1000']
if sel_vars == 11:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000','TempN'+member+'_AOD','TempN'+member+'_T050','TempN'+member+'_T1000','PolN'+member+'_AOD','PolN'+member+'_T050','PolN'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000']
if sel_vars == 12:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000']
if sel_vars == 13:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000','TempS'+member+'_AOD','TempS'+member+'_T050','TempS'+member+'_T1000']
if sel_vars == 14:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000','TempS'+member+'_AOD','TempS'+member+'_T050','TempS'+member+'_T1000','PolS'+member+'_AOD','PolS'+member+'_T050','PolS'+member+'_T1000']
if sel_vars == 15:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000']
if sel_vars == 16:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000','TempS'+member+'_AOD','TempS'+member+'_T050','TempS'+member+'_T1000','TempN'+member+'_AOD','TempN'+member+'_T050','TempN'+member+'_T1000']
if sel_vars == 17:
SUBSET = ['Tropics'+member+'_AOD','Tropics'+member+'_T050','Tropics'+member+'_T1000','SubtropS'+member+'_AOD','SubtropS'+member+'_T050','SubtropS'+member+'_T1000','SubtropN'+member+'_AOD','SubtropN'+member+'_T050','SubtropN'+member+'_T1000','TempS'+member+'_AOD','TempS'+member+'_T050','TempS'+member+'_T1000','TempN'+member+'_AOD','TempN'+member+'_T050','TempN'+member+'_T1000','PolS'+member+'_AOD','PolS'+member+'_T050','PolS'+member+'_T1000','PolN'+member+'_AOD','PolN'+member+'_T050','PolN'+member+'_T1000']
DATA = {key: RAW_DATA1[key] for key in SUBSET}
if windat == 1:
DATA_WIN = {}
for m in SUBSET:
DATA_WIN[m] = DATA[m][winstart:winend]
DATA = DATA_WIN
#normalize
SUBSET_G = ['Globe'+member+'_AOD','Globe'+member+'_T050','Globe'+member+'_T1000']
GLOBDATA = {key: RAW_DATA1[key] for key in SUBSET_G}
# GLOBDATA,DATA
if normdat == 1:
# normalize data
SUBSET_G = ['Globe'+member+'_AOD','Globe'+member+'_T050','Globe'+member+'_T1000']
GLOBDATA = {key: RAW_DATA1[key] for key in SUBSET_G}
# GLOBDATA,DATA
DATA_NORM = DATA
latbandrange = np.empty([len(SUBSET),2])
import re
for n in list(range(len(SUBSET_G))):
ind_val = SUBSET_G[n].split('_')[1]
for m in list(range(len(SUBSET))):
if re.search(ind_val,SUBSET[m]): # if the global key variable is in the zonal key
key = SUBSET[m]
latbandrange[m,0]=np.min(DATA[key])
latbandrange[m,1]=np.max(DATA[key])
for n in list(range(len(SUBSET_G))):
ind_val = SUBSET_G[n].split('_')[1]
for m in list(range(len(SUBSET))):
if m < 7:
gmin = np.min(latbandrange[:6,0],axis=0)
grange = np.max(latbandrange[:6,1],axis=0) - gmin
if re.search(ind_val,SUBSET[m]): # if the global key variable is in the zonal key
key = SUBSET[m]
DATA_NORM[SUBSET[m]] = (DATA[SUBSET[m]] - gmin)/grange
if m > 6:
gmin = np.min(latbandrange[7:,0],axis=0)
grange = np.max(latbandrange[7:,1],axis=0) - gmin
if re.search(ind_val,SUBSET[m]): # if the global key variable is in the zonal key
key = SUBSET[m]
DATA_NORM[SUBSET[m]] = (DATA[SUBSET[m]] - gmin)/grange
from RFR import run
import pandas as pd
# df,model,training_data = run(DATA,LAGS,randVar=False,model=True)
if normdat == 1:
df,model,training_data = run(DATA_NORM,LAGS,randVar=False,model=True)
if normdat == 0:
df,model,training_data = run(DATA,LAGS,randVar=False,model=True)
data_INPUT,data_TARGET = training_data
# write out all of the weights
if windat == 0:
pd.DataFrame(df).to_csv('SHAPoutputs/ZonalHSW_LV_'+member+"_"+str(sel_vars)+'weights.csv')
if windat == 1:
pd.DataFrame(df).to_csv('SHAPoutputs/ZonalHSW_LV_'+member+"_"+str(sel_vars)+"_WIND"+str(winend)+'weights.csv')
from RFR import pruneTopEdges
from RFR import plotGraphNetwork
df_new = pruneTopEdges(df,n=topnodes,top_perc=None,minVal=min_thresh)
# df_new
if normdat == 0:
if windat == 0:
gname = "SHAPoutputs/Pruned_"+member+"_"+str(topnodes)+"_"+str(sel_vars)+".png"
if windat == 1:
gname = "SHAPoutputs/Pruned_"+member+"_"+str(topnodes)+"_"+str(sel_vars)+"_WIND"+str(winend)+".png"
if normdat == 1:
gname = "SHAPoutputs/Pruned_"+member+"_"+str(topnodes)+"_"+str(sel_vars)+"_NORM.png"
plotGraphNetwork(df_new,gname)