Skip to content

Commit

Permalink
perf: better handling of cuda devices by id across both detection and…
Browse files Browse the repository at this point in the history
… clustering commands with --device cuda:0
  • Loading branch information
danellecline committed Jan 14, 2025
1 parent 0c2e48e commit ae8e395
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 3 additions & 6 deletions sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,12 @@ def run_cluster_det(det_dir, save_dir, device, use_vits, config_ini, alpha, clus
min_similarity = float(config('cluster', 'min_similarity'))
model = config('cluster', 'model')

if device != 'cpu':
if 'cuda' in device:
num_devices = torch.cuda.device_count()
info(f'{num_devices} cuda devices available')
info(f'Using device {device}')
if 'cuda' in device:
device_num = device.split(':')[-1]
info(f'Setting CUDA_VISIBLE_DEVICES to {device_num}')
torch.cuda.set_device(device)
os.environ['CUDA_VISIBLE_DEVICES'] = device_num
device_ = torch.device(device)
torch.cuda.set_device(device_)

save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
Expand Down
6 changes: 6 additions & 0 deletions sdcat/detect/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, model_type
create_logger_file('detect')

if not skip_sahi:
if 'cuda' in device:
num_devices = torch.cuda.device_count()
info(f'{num_devices} cuda devices available')
device_ = torch.device(device)
torch.cuda.set_device(device_)
device = 'cuda'
detection_model = create_model(model, conf, device, model_type)

if Path(model).is_dir():
Expand Down

0 comments on commit ae8e395

Please sign in to comment.