forked from shervinea/pytorch-data-generator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_script.py
46 lines (35 loc) · 1.26 KB
/
pytorch_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch.utils import data
from my_classes import Dataset
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
cudnn.benchmark = True
# Parameters
params = {'batch_size': 1,
'shuffle': True,
'num_workers': 6}
max_epochs = 100
# Datasets
partition = {'train': ['id-1', 'id-2', 'id-3'], 'validation': ['id-4']} # IDs
labels = {'id-1': 0, 'id-2': 1, 'id-3': 2, 'id-4': 1} # Labels
# Generators
training_set = Dataset(partition['train'], labels)
training_generator = data.DataLoader(training_set, **params)
validation_set = Dataset(partition['validation'], labels)
validation_generator = data.DataLoader(validation_set, **params)
# Loop over epochs
for epoch in range(max_epochs):
# Training
for local_batch, local_labels in training_generator:
# Transfer to GPU
local_batch, local_labels = local_batch.to(device), local_labels.to(device)
# Model computations
[...]
# Validation
with torch.set_grad_enabled(False):
for local_batch, local_labels in validation_generator:
# Transfer to GPU
local_batch, local_labels = local_batch.to(device), local_labels.to(device)
# Model computations
[...]