diff --git a/u2net_train.py b/u2net_train.py index 8c5bf76b..ab37ff65 100644 --- a/u2net_train.py +++ b/u2net_train.py @@ -39,8 +39,8 @@ "plain_resized": { "name": "Plain Images", "message": "Learning the dataset itself...\n", - "transform": [Resize(512), ToTensorLab()], - "batch_factor": 1, + "transform": [Resize(1024), ToTensorLab()], + "batch_factor": 0.3, }, "flipped_v": { "name": "Vertical Flips", @@ -367,13 +367,11 @@ def get_dataloader(tra_img_name_list, tra_lbl_name_list, transform, batch_size): transform=transform, ) - cores = 8 - if batch_size == 10: - cores = 2 # freeing up memory a bit + cores = 2 # freeing up memory a bit # DataLoader for the dataset dataloader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, num_workers=cores + dataset, batch_size=int(batch_size), shuffle=True, num_workers=cores ) return dataloader