Skip to content

Commit

Permalink
refactor, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
chanshing committed Jun 6, 2023
1 parent 0127824 commit 6b77ae6
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/stepcount/sslmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

verbose = False
torch_cache_path = Path(__file__).parent / 'torch_hub_cache'
Expand Down Expand Up @@ -122,12 +121,12 @@ class EarlyStopping:
doesn't improve after a given patience."""

def __init__(
self,
patience=5,
verbose=False,
delta=0,
path="checkpoint.pt",
trace_func=print,
self,
patience=5,
verbose=False,
delta=0,
path="checkpoint.pt",
trace_func=print,
):
"""
Args:
Expand Down Expand Up @@ -313,7 +312,7 @@ def train(model, train_loader, val_loader, device, class_weights=None, weights_p
model.train()
train_losses = []
train_acces = []
for i, (x, y, _) in enumerate(tqdm(train_loader, disable=not verbose)):
for x, y, _ in tqdm(train_loader, disable=not verbose):
x.requires_grad_(True)
x = x.to(device, dtype=torch.float)
true_y = y.to(device, dtype=torch.long)
Expand All @@ -336,11 +335,11 @@ def train(model, train_loader, val_loader, device, class_weights=None, weights_p

epoch_len = len(str(num_epoch))
print_msg = (
f"[{epoch:>{epoch_len}}/{num_epoch:>{epoch_len}}] | "
+ f"train_loss: {np.mean(train_losses):.3f} | "
+ f"train_acc: {np.mean(train_acces):.3f} | "
+ f"val_loss: {val_loss:.3f} | "
+ f"val_acc: {val_acc:.2f}"
f"[{epoch:>{epoch_len}}/{num_epoch:>{epoch_len}}] | "
+ f"train_loss: {np.mean(train_losses):.3f} | "
+ f"train_acc: {np.mean(train_acces):.3f} | "
+ f"val_loss: {val_loss:.3f} | "
+ f"val_acc: {val_acc:.2f}"
)

early_stopping(val_loss, model)
Expand Down

0 comments on commit 6b77ae6

Please sign in to comment.