Skip to content

Commit

Permalink
restructured code
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 28, 2025
1 parent b7bb52a commit 41b5b16
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 23 deletions.
11 changes: 7 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
scripts/wandb/*
scripts/images/*
plr.egg-info/*
2025_lecture_examples/scripts/wandb/*
2025_lecture_examples/scripts/images/*
2025_lecture_examples/plr.egg-info/*
2025_lecture_examples/.venv/*
2024_home_work/.data/*
2024_home_work/.venv/*
.venv/*
2024_home_work/plr_exercise.egg-info/*
2024_home_work/data/MNIST/raw/train-images-idx3-ubyte
83 changes: 64 additions & 19 deletions 2024_home_work/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,22 @@ def test(model, device, test_loader, epoch):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)

Expand All @@ -83,25 +90,65 @@ def main():
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)"
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=2,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr",
type=float,
default=1.0,
metavar="LR",
help="learning rate (default: 1.0)",
)
parser.add_argument(
"--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)"
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="quickly check a single pass",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 14)")
parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)")
parser.add_argument("--gamma", type=float, default=0.7, metavar="M", help="Learning rate step gamma (default: 0.7)")
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--dry-run", action="store_true", default=False, help="quickly check a single pass")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model")
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

Expand All @@ -119,9 +166,11 @@ def main():
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("../data", train=False, transform=transform)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST(".data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST(".data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

Expand All @@ -140,7 +189,3 @@ def main():

if __name__ == "__main__":
main()

test = 12

test = 13
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 41b5b16

Please sign in to comment.