Skip to content

Commit

Permalink
add functions for set train and set eval in original traj pred
Browse files Browse the repository at this point in the history
  • Loading branch information
souryadey committed Feb 4, 2025
1 parent dbeb532 commit 689e426
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
22 changes: 14 additions & 8 deletions dlkoopman/traj_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def _evolve(self, Y0) -> torch.Tensor:
return Ypred


def _set_train(self):
self.ae.train()
self.Knet.train()

def _set_eval(self):
self.ae.eval()
self.Knet.eval()


def train_net(self,
numepochs=10, batch_size=250, early_stopping=0, early_stopping_metric='total_loss',
lr=1e-3, weight_decay=0., decoder_loss_weight=1e-2,
Expand Down Expand Up @@ -305,8 +314,7 @@ def train_net(self,
self.dh.Xtr = self.dh.Xtr[torch.randperm(self.dh.Xtr.shape[0])]

## Training ##
self.ae.train()
self.Knet.train()
self._set_train()

# Start batches
for batch in range(numbatches):
Expand Down Expand Up @@ -376,8 +384,7 @@ def train_net(self,

## Validation ##
if do_val:
self.ae.eval()
self.Knet.eval()
self._set_eval()

with torch.no_grad():
Yva, Xrva = self.ae(self.dh.Xva) # shapes: Yva = (num_va_trajectories, num_indexes, encoded_size), Xrva = (num_va_trajectories, num_indexes, input_size)
Expand Down Expand Up @@ -434,8 +441,7 @@ def test_net(self):
print("WARNING: You have called 'test_net()', but there is no test data. Please pass a 'DataHandler' object containing 'Xte'.")

else:
self.ae.eval()
self.Knet.eval()
self._set_eval()

with torch.no_grad():
Yte, Xrte = self.ae(self.dh.Xte) # shapes: Yte = (num_te_trajectories, num_indexes, encoded_size), Xrte = (num_te_trajectories, num_indexes, input_size)
Expand Down Expand Up @@ -471,8 +477,8 @@ def predict_new(self, X0) -> torch.Tensor:
if self.cfg.normalize_Xdata:
X0 = utils.scale(X0, scale=self.dh.Xscale)

self.ae.eval()
self.Knet.eval()
self._set_eval()

with torch.no_grad():
Y0 = self.ae.encoder(X0)
Ypred = self._evolve(Y0)
Expand Down
1 change: 1 addition & 0 deletions dlkoopman/traj_pred_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def predict_new(self, X0, U) -> torch.Tensor:
U = utils.scale(U, scale=self.dh.Uscale)

self._set_eval()

with torch.no_grad():
Y0 = self.data_ae.encoder(X0) # shape = (num_new_trajectories, data_encoded_size)
V = self.control_enc(U) # shape = (num_new_trajectories, num_indexes, control_encoded_size)
Expand Down

0 comments on commit 689e426

Please sign in to comment.