Skip to content

Commit

Permalink
Drop LBFGS
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Apr 26, 2024
1 parent f82bf99 commit 9f831f1
Showing 1 changed file with 7 additions and 23 deletions.
30 changes: 7 additions & 23 deletions descreen/training/proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def train[

optimizer = torch.optim.RAdam(model.parameters(), lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.002)
# optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1)

input_size = model.input_size(patch_size)
padding = model.reduced_padding(input_size)
Expand Down Expand Up @@ -81,28 +80,13 @@ def interrupt(signum, frame):
test_step()

def train_step(x, y):
if False:
loss = None

def clos():
global loss
pred = model(x)
loss = loss_fn(pred, y)
print(f"loss: {loss}")
# Backpropagation
optimizer.zero_grad()
loss.backward()
return loss

optimizer.step(clos)
else:
pred = model(x)
loss = descreen_loss(pred, y, tv=0.01)
optimizer.zero_grad()
loss.backward()
print(f"loss: {loss}")
optimizer.step()
ema_model.update_parameters(model)
pred = model(x)
loss = descreen_loss(pred, y, tv=0.01)
optimizer.zero_grad()
loss.backward()
print(f"loss: {loss}")
optimizer.step()
ema_model.update_parameters(model)

def valid_step(dataloader):
# Set the model to evaluation mode - important for batch normalization and dropout layers
Expand Down

0 comments on commit 9f831f1

Please sign in to comment.