diff --git a/descreen/training/proc.py b/descreen/training/proc.py index 7de7e16..de0956e 100644 --- a/descreen/training/proc.py +++ b/descreen/training/proc.py @@ -32,7 +32,6 @@ def train[ optimizer = torch.optim.RAdam(model.parameters(), lr=0.001) # optimizer = torch.optim.SGD(model.parameters(), lr=0.002) - # optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1) input_size = model.input_size(patch_size) padding = model.reduced_padding(input_size) @@ -81,28 +80,13 @@ def interrupt(signum, frame): test_step() def train_step(x, y): - if False: - loss = None - - def clos(): - global loss - pred = model(x) - loss = loss_fn(pred, y) - print(f"loss: {loss}") - # Backpropagation - optimizer.zero_grad() - loss.backward() - return loss - - optimizer.step(clos) - else: - pred = model(x) - loss = descreen_loss(pred, y, tv=0.01) - optimizer.zero_grad() - loss.backward() - print(f"loss: {loss}") - optimizer.step() - ema_model.update_parameters(model) + pred = model(x) + loss = descreen_loss(pred, y, tv=0.01) + optimizer.zero_grad() + loss.backward() + print(f"loss: {loss}") + optimizer.step() + ema_model.update_parameters(model) def valid_step(dataloader): # Set the model to evaluation mode - important for batch normalization and dropout layers