From 409e62b6e8ea8208bf6724bef05e9bddaf189e96 Mon Sep 17 00:00:00 2001 From: Jonathunky <12529409+Jonathunky@users.noreply.github.com> Date: Wed, 1 Nov 2023 17:08:08 +0100 Subject: [PATCH] Let's work with higher-res images then --- u2net_train.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/u2net_train.py b/u2net_train.py index 8c5bf76b..ab37ff65 100644 --- a/u2net_train.py +++ b/u2net_train.py @@ -39,8 +39,8 @@ "plain_resized": { "name": "Plain Images", "message": "Learning the dataset itself...\n", - "transform": [Resize(512), ToTensorLab()], - "batch_factor": 1, + "transform": [Resize(1024), ToTensorLab()], + "batch_factor": 0.3, }, "flipped_v": { "name": "Vertical Flips", @@ -367,13 +367,11 @@ def get_dataloader(tra_img_name_list, tra_lbl_name_list, transform, batch_size): transform=transform, ) - cores = 8 - if batch_size == 10: - cores = 2 # freeing up memory a bit + cores = 2 # freeing up memory a bit # DataLoader for the dataset dataloader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, num_workers=cores + dataset, batch_size=int(batch_size), shuffle=True, num_workers=cores ) return dataloader