Skip to content

Commit

Permalink
don't require chips for every class in ClassificationImageDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Apr 3, 2024
1 parent 03dc037 commit 79d6292
Showing 1 changed file with 12 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Iterable, List, Optional, Tuple
from os import PathLike
from os.path import splitext
from os.path import join, splitext
from pathlib import Path
from itertools import chain

Expand Down Expand Up @@ -64,10 +64,20 @@ def make_image_folder_dataset(data_dir: str,
return DatasetFolder(
data_dir, loader=load_image, extensions=IMG_EXTENSIONS)

from rastervision.pipeline.file_system.utils import (file_exists,
list_paths)

class_dirs = [join(data_dir, c) for c in classes]
classes_present = [
c for c, dir in zip(classes, class_dirs)
if file_exists(dir, include_dir=True) and len(list_paths(dir)) > 0
]
class_to_id = {c: classes.index(c) for c in classes_present}

class ImageFolder(DatasetFolder):
def find_classes(self,
directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Override to force mapping from class name to class index."""
return classes, {c: i for (i, c) in enumerate(classes)}
return classes_present, class_to_id

return ImageFolder(data_dir, loader=load_image, extensions=IMG_EXTENSIONS)

0 comments on commit 79d6292

Please sign in to comment.