Skip to content

Commit

Permalink
Merge pull request #21 from equinor/memory_use_2
Browse files Browse the repository at this point in the history
Memory use 2
  • Loading branch information
adamchengtkc authored Jul 5, 2024
2 parents 2095ab7 + e762043 commit 6a464d7
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 40 deletions.
12 changes: 8 additions & 4 deletions tests/warmth3d/test_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@pytest.mark.mpi
def test_3d_compare():
comm = MPI.COMM_WORLD
inc = 2100
inc = 1000
model_pickled = f"model-out-inc_{inc}.p"
if comm.rank == 0 and not os.path.isfile(model_pickled):
global runtime_1D_sim
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_3d_compare():
i.adiab = 0.3e-3


model.simulator.simulate_every = 1
model.simulator.simulate_every = 2

#
# set 1D simulation parameters to be most similar to those in the (later) 3D simulation, for better comparison
Expand All @@ -62,7 +62,6 @@ def test_3d_compare():
print("Total time 1D simulations:", runtime_1D_sim)

pickle.dump( model, open( model_pickled, "wb" ) )
# model = pickle.load( open( model_pickled, "rb" ) )
try:
os.mkdir('mesh')
except FileExistsError:
Expand All @@ -84,7 +83,12 @@ def test_3d_compare():
hx = nnx // 2
hy = nny // 2

nn = model.builder.nodes[hy-mm2.padX][hx-mm2.padX]
nn0 = model.builder.nodes[hy-mm2.padX][hx-mm2.padX]

node_result_path = str(nn0.node_path).replace(".pickle", "_results")
assert os.path.exists(node_result_path), f"ERROR: Node result file {node_result_path} is missing."

nn = pickle.load(open(node_result_path,"rb"))
dd = nn._depth_out[:,0]

mm2_pos, mm2_temp = mm2.get_node_pos_and_temp(-1)
Expand Down
55 changes: 40 additions & 15 deletions warmth/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ def __init__(self):
self._crust_ls:np.ndarray[np.float64]|None=None
self._lith_ls:np.ndarray[np.float64]|None=None
self._subsidence:np.ndarray[np.float64]|None=None

self.seabed_arr:np.ndarray[np.float64]|None=None
self.top_crust_arr:np.ndarray[np.float64]|None=None
self.top_lith_arr:np.ndarray[np.float64]|None=None
self.top_aest_arr:np.ndarray[np.float64]|None=None

@property
def shf(self)->float:
Expand All @@ -123,50 +126,72 @@ def result(self)-> Results|None:
Results|None
None if not simulated
"""
items = [self._depth_out,self.temperature_out,self._idsed]
if any(isinstance(i,type(None)) for i in items):
return None
else:
return Results(self._depth_out,self.temperature_out,self._idsed,self.sediments,self.kCrust,self.kLith,self.kAsth)
return Results(self._depth_out,self.temperature_out,self._idsed,self.sediments,self.kCrust,self.kLith,self.kAsth)

def clear_unused_data(self):
"""Removes most arrays of detailed input and output that are not needed by warmth3D, in order to save memory.
"""
# self.max_time = self._depth_out.shape[1]
self._depth_out = None
self.temperature_out =None
self._idsed = None
self.coord_initial = None
self._crust_ls = None
self._lith_ls = None
self._subsidence =None

def compute_derived_arrays(self):
"""Computes depths of seabed, top crust, top lithosphere and top aestenosphere, and stores them with the node.
This allows the depth and temperature arrays to be discarded to save memory.
"""
self.top_crust_arr = [ self._depth_out[ np.where(self._idsed[:,age] == -1)[0][0], age] for age in range(self.max_time)]
self.top_lith_arr = [ self._depth_out[ np.where(self._idsed[:,age] == -2)[0][0], age] for age in range(self.max_time)]
self.top_aest_arr = [ self._depth_out[ np.where(self._idsed[:,age] == -3)[0][0], age] for age in range(self.max_time)]
self.seabed_arr = np.array( [ self._depth_out[np.where(~np.isnan(self.temperature_out[:,age]))[0][0],age] for age in range(self.max_time)])


@property
def crust_ls(self)->np.ndarray[np.float64]:
if isinstance(self.result,Results):
all_age = self.result.ages
all_age = np.arange(len(self.top_lith_arr),dtype=np.int32)
val = np.zeros(all_age.size)
for age in all_age:
val[age] = self.result.crust_thickness(age)
val[age] = self.top_lith_arr[age] - self.top_crust_arr[age]

return val
else:
return self._crust_ls
@property
def lith_ls(self)->np.ndarray[np.float64]:
if isinstance(self.result,Results):
all_age = self.result.ages
all_age = np.arange(len(self.top_lith_arr),dtype=np.int32)
val = np.zeros(all_age.size)
for age in all_age:
val[age] = self.result.lithosphere_thickness(age)
val[age] = self.top_aest_arr[age] - self.top_lith_arr[age]
return val
else:
return self._lith_ls
@property
def subsidence(self)->np.ndarray[np.float64]:
if isinstance(self.result,Results):
all_age = self.result.ages
all_age = np.arange(len(self.top_lith_arr),dtype=np.int32)
val = np.zeros(all_age.size)
for age in all_age:
val[age] = self.result.seabed(age)
val[age] = self.seabed_arr[age]
# val[age] = self.result.seabed(age)
return val
else:
return self._subsidence
@property
def sed_thickness_ls(self)->float:
if isinstance(self.result,Results):
all_age = self.result.ages
all_age = np.arange(len(self.top_lith_arr),dtype=np.int32)
val = np.zeros(all_age.size)
for age in all_age:
seabed = self.result.seabed(age)
top_crust = self.result.top_crust(age)
# seabed = self.result.seabed(age)
# top_crust = self.result.top_crust(age)
seabed = self.seabed_arr[age]
top_crust = self.top_crust_arr[age]
val[age] = top_crust - seabed
return val
else:
Expand Down
2 changes: 1 addition & 1 deletion warmth/forward_modelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _sediment_conductivity_sekiguchi(mean_porosity: np.ndarray[np.float64], cond
temperature_K=273.15+mid_pt_temperautureC
conductivity = 1.84+358*((1.0227*conductivity)-1.882)*((1/temperature_K)-0.00068)
effective_conductivity = conductivity*(1-mean_porosity)
return effective_conductivity-0.5
return effective_conductivity

def _check_beta(self, wd_diff: float, beta_current: float, beta_all: np.ndarray[np.float64], Wd_diff_all: np.ndarray[np.float64]) -> tuple[bool, np.ndarray[np.float64], np.ndarray[np.float64]]:
"""Check if current beta factor matches the observed subsidence
Expand Down
27 changes: 18 additions & 9 deletions warmth/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,32 @@ def interpolateNode(interpolationNodes: List[single_node], interpolationWeights=
node.X = np.sum( np.array( [node.X * w for node,w in zip(interpolationNodes,iWeightNorm)] ) )
node.Y = np.sum( np.array( [node.Y * w for node,w in zip(interpolationNodes,iWeightNorm)] ) )

times = range(node.result._depth.shape[1])
# self.top_crust_arr = [ self._depth_out[ np.where(self._idsed[:,age] == -1)[0][0], age] for age in range(self.max_time)]
# #print ("PING B")
# self.top_lith_arr = [ self._depth_out[ np.where(self._idsed[:,age] == -2)[0][0], age] for age in range(self.max_time)]
# #print ("PING C")
# self.top_aest_arr = [ self._depth_out[ np.where(self._idsed[:,age] == -3)[0][0], age] for age in range(self.max_time)]
# #print ("PING D")

# self.top_lithosphere(age)-self.top_crust(age)

if node.subsidence is None:
node.subsidence = np.sum( np.array( [ [node.result.seabed(t) for t in times] * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
node.subsidence = np.sum( np.array( [ node.seabed_arr[:] * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
if node.crust_ls is None:
node.crust_ls = np.sum( np.array( [ [node.result.crust_thickness(t) for t in times] * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
node.crust_ls = np.sum( np.array( [ (node.top_lith_arr[:]-node.top_crust_arr[:]) * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
if node.lith_ls is None:
node.lith_ls = np.sum( np.array( [ [node.result.lithosphere_thickness(t) for t in times] * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
node.crust_ls = np.sum( np.array( [ (node.top_aest_arr[:]-node.top_lithosphere[:]) * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)

if node.beta is None:
node.beta = np.sum( np.array( [node.beta * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
if node.kAsth is None:
node.kAsth = np.sum( np.array( [node.kAsth * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
if node.kLith is None:
node.kLith = np.sum( np.array( [node.kLith * w for node,w in zip(interpolationNodes,iWeightNorm)] ) , axis = 0)
if node._depth_out is None:
node._depth_out = np.sum([node.result._depth_out*w for n,w in zip(interpolationNodes[0:1], [1] )], axis=0)
if node.temperature_out is None:
node.temperature_out = np.sum([n.result.temperature_out*w for n,w in zip(interpolationNodes[0:1], [1] )], axis=0)
# if node._depth_out is None:
# node._depth_out = np.sum([node.result._depth_out*w for n,w in zip(interpolationNodes[0:1], [1] )], axis=0)
# if node.temperature_out is None:
# node.temperature_out = np.sum([n.result.temperature_out*w for n,w in zip(interpolationNodes[0:1], [1] )], axis=0)

if node.sed is None:
node.sed = np.sum([n.sed*w for n,w in zip(interpolationNodes,iWeightNorm)], axis=0)
Expand All @@ -167,7 +175,7 @@ def interpolateNode(interpolationNodes: List[single_node], interpolationWeights=
def interpolate_all_nodes(builder:Builder)->Builder:
for ni in range(len(builder.nodes)):
for nj in range(len(builder.nodes[ni])):
if builder.nodes[ni][nj] is False:
if (builder.nodes[ni][nj] is False) or (not builder.nodes[ni][nj]._full_simulation):
closest_x_up = []
for j in range(ni,len(builder.nodes[nj])):
matching_x = [ i[0] for i in builder.indexer_full_sim if i[0]==j ]
Expand Down Expand Up @@ -195,6 +203,7 @@ def interpolate_all_nodes(builder:Builder)->Builder:

interpolationNodes = [ builder.nodes[i[0]][i[1]] for i in itertools.product(closest_x_up+closest_x_down, closest_y_up+closest_y_down) ]
interpolationNodes = [nn for nn in interpolationNodes if nn is not False]
interpolationNodes = [nn for nn in interpolationNodes if nn._full_simulation]
node = interpolateNode(interpolationNodes)
node.X, node.Y = builder.grid.location_grid[ni,nj,:]
builder.nodes[ni][nj] = node
Expand Down
15 changes: 8 additions & 7 deletions warmth/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
class Results:
"""Simulation results
"""
def __init__(self,depth:np.ndarray, temperature:np.ndarray,sediments_ids:np.ndarray,sediment_input:pd.DataFrame,k_crust:float,k_lith:float,k_asth:float):
def __init__(self, depth:np.ndarray, temperature:np.ndarray,
sediments_ids:np.ndarray,sediment_input:pd.DataFrame,k_crust:float,k_lith:float,k_asth:float):
self._depth=depth
self._temperature=temperature
self._sediments_ids=sediments_ids
Expand All @@ -22,7 +23,7 @@ class resultValues(TypedDict):
depth: np.ndarray[np.float64]
layerId: np.ndarray[np.int32]
values:np.ndarray[np.float64]

@property
def ages(self)->np.ndarray[np.int32]:
"""Array of all simulated ages
Expand All @@ -32,7 +33,7 @@ def ages(self)->np.ndarray[np.int32]:
np.ndarray
Array of ages
"""
return np.arange(self._depth.shape[1],dtype=np.int32)
return np.arange(self._max_time,dtype=np.int32)

def top_crust(self,age:int)->float:
"""Depth of crust
Expand Down Expand Up @@ -62,7 +63,7 @@ def top_lithosphere(self,age:int)->float:
-------
float
Depth of lithospheric mantle / Moho from sea level (m)
"""
"""
depth_idx= np.where(self.sediment_ids(age) == -2)[0][0]
return self._depth[depth_idx,age]

Expand All @@ -78,7 +79,7 @@ def top_asthenosphere(self,age:int)->float:
-------
float
Depth of Asthenosphere from sea level (m)
"""
"""
depth_idx= np.where(self.sediment_ids(age) == -3)[0][0]
return self._depth[depth_idx,age]

Expand Down Expand Up @@ -572,7 +573,7 @@ def interp_value(self):
for n in self._builder.iter_node():
if n._full_simulation is False:
idx = n.indexer
val =interped[idx[0],idx[1]]
val =interped[idx[1],idx[0]]
setattr(n,prop,val)
return

Expand All @@ -594,7 +595,7 @@ def interp_arr(self):
if isinstance(getattr(node,prop),type(None)):
setattr(node,prop,np.zeros(self.n_age))
idx = node.indexer
interpolated_val =interp_all_this_age[idx[0],idx[1]]
interpolated_val =interp_all_this_age[idx[1],idx[0]]
arr = getattr(node,prop)
arr[age] =interpolated_val
setattr(node,prop,arr)
Expand Down
12 changes: 8 additions & 4 deletions warmth/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def run(self) -> Path:
self.node = fw.current_node
self._pad_sediments()
self.node.simulated_at = time.time()
self.node.max_time = self.node._depth_out.shape[1]
self.node.compute_derived_arrays()
self.node.node_path = self.node_path
filepath = self._save_results()
# Delete input node
self.node_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -244,12 +247,13 @@ def _parellel_run(self, save, purge, use_mpi=False):
if save==False:
from shutil import rmtree
rmtree(self._builder.parameters.output_path)
if filtered >0:
logger.info(f"Interpolating results back to {filtered} partial simulated nodes")
interp_res= Results_interpolator(self._builder)
interp_res.run()
# if filtered >0:
# logger.info(f"Interpolating results back to {filtered} partial simulated nodes")
# interp_res= Results_interpolator(self._builder)
# interp_res.run()
return
def put_node_to_grid(self,node:single_node):
node.clear_unused_data()
self._builder.nodes[node.indexer[0]][node.indexer[1]]=node
return

Expand Down

0 comments on commit 6a464d7

Please sign in to comment.