diff --git a/datumaro/plugins/lfw_format.py b/datumaro/plugins/lfw_format.py index 6bf3d4a9b6..e78ee15736 100644 --- a/datumaro/plugins/lfw_format.py +++ b/datumaro/plugins/lfw_format.py @@ -65,13 +65,12 @@ def _load_items(self, path): items = {} label_categories = self._categories.get(AnnotationType.label) + images = {} if osp.isdir(self._images_dir): images = { - osp.splitext(osp.relpath(p, self._images_dir))[0].replace("\\", "/"): p + osp.splitext(osp.relpath(p, self._images_dir))[0].replace("\\", "/"): Image(path=p) for p in find_images(self._images_dir, recursive=True) } - else: - images = {} with open(path, encoding="utf-8") as f: @@ -85,51 +84,47 @@ def get_label_id(label_name): for line in f: pair = line.strip().split("\t") + if len(pair) == 1 and pair[0] != "": annotations = [] image = pair[0] item_id = pair[0] objects = item_id.split("/") + if 1 < len(objects): label_name = objects[0] label = get_label_id(label_name) if label is not None: annotations.append(Label(label)) item_id = item_id[len(label_name) + 1 :] - if item_id not in items: - image = images.get(item_id) - if image: - image = Image(path=image) + if item_id not in items: items[item_id] = DatasetItem( - id=item_id, subset=self._subset, media=image, annotations=annotations + id=item_id, + subset=self._subset, + media=images.get(image), + annotations=annotations, ) + elif len(pair) == 3: image1, id1 = self.get_image_name(pair[0], pair[1]) image2, id2 = self.get_image_name(pair[0], pair[2]) label = get_label_id(pair[0]) if id1 not in items: - annotations = [] - annotations.append(Label(label)) - - image = images.get(image1) - if image: - image = Image(path=image) - items[id1] = DatasetItem( - id=id1, subset=self._subset, media=image, annotations=annotations + id=id1, + subset=self._subset, + media=images.get(image1), + annotations=[Label(label)], ) - if id2 not in items: - annotations = [] - annotations.append(Label(label)) - - image = images.get(image2) - if image: - image = Image(path=image) + if id2 not in items: items[id2] = DatasetItem( - id=id2, subset=self._subset, media=image, annotations=annotations + id=id2, + subset=self._subset, + media=images.get(image2), + annotations=[Label(label)], ) # pairs form a directed graph @@ -139,35 +134,32 @@ def get_label_id(label_name): elif len(pair) == 4: image1, id1 = self.get_image_name(pair[0], pair[1]) + if pair[2] == "-": - image2 = pair[3] - id2 = pair[3] + image2, id2 = pair[3], pair[3] else: image2, id2 = self.get_image_name(pair[2], pair[3]) + if id1 not in items: - annotations = [] label = get_label_id(pair[0]) - annotations.append(Label(label)) - - image = images.get(image1) - if image: - image = Image(path=image) - items[id1] = DatasetItem( - id=id1, subset=self._subset, media=image, annotations=annotations + id=id1, + subset=self._subset, + media=images.get(image1), + annotations=[Label(label)], ) + if id2 not in items: annotations = [] if pair[2] != "-": label = get_label_id(pair[2]) annotations.append(Label(label)) - image = images.get(image2) - if image: - image = Image(path=image) - items[id2] = DatasetItem( - id=id2, subset=self._subset, media=image, annotations=annotations + id=id2, + subset=self._subset, + media=images.get(image2), + annotations=annotations, ) # pairs form a directed graph @@ -179,24 +171,35 @@ def get_label_id(label_name): if osp.isfile(landmarks_file): with open(landmarks_file, encoding="utf-8") as f: for line in f: - line = line.split("\t") + line_parts = line.split("\t") + item_id = osp.splitext(line_parts[0])[0] - item_id = osp.splitext(line[0])[0] objects = item_id.split("/") - if 1 < len(objects): - label_name = objects[0] - label = get_label_id(label_name) - if label is not None: - item_id = item_id[len(label_name) + 1 :] + label_name = objects[0] if 1 < len(objects) else "" + + label = get_label_id(label_name) + if label is not None: + item_id = item_id[len(label_name) + 1 :] + if item_id not in items: items[item_id] = DatasetItem( id=item_id, subset=self._subset, - image=osp.join(self._images_dir, line[0]), + image=Image(path=osp.join(self._images_dir, line_parts[0])), ) annotations = items[item_id].annotations - annotations.append(Points([float(p) for p in line[1:]], label=label)) + annotations.append(Points([float(p) for p in line_parts[1:]], label=label)) + + labeled_images = set( + item.media.path + for item in items.values() + if getattr(item.media, "path", None) is not None + ) + for image_name, image in images.items(): + if image.path in labeled_images: + continue + items[image_name] = DatasetItem(id=image_name, subset=self._subset, image=image) return items @@ -241,90 +244,94 @@ def apply(self): label_categories = self._extractor.categories()[AnnotationType.label] labels = {label.name: 0 for label in label_categories} + unlabeled_items = [] + neutral_items = [] + included_items = [] positive_pairs = [] negative_pairs = [] - neutral_items = [] landmarks = [] - included_items = [] for item in subset: - anns = [ann for ann in item.annotations if ann.type == AnnotationType.label] - label, label_name = None, None - if anns: - label = anns[0] - label_name = label_categories[anns[0].label].name - labels[label_name] += 1 + label_annotations = [ + ann for ann in item.annotations if ann.type == AnnotationType.label + ] + + if not label_annotations: + unlabeled_items.append(item) + continue + + label_obj = label_annotations[0] + label_name = label_categories[label_obj.label].name + labels[label_name] += 1 if self._save_media and item.media: - subdir = osp.join(subset_name, LfwPath.IMAGES_DIR) - if label_name: - subdir = osp.join(subdir, label_name) + subdir = osp.join(subset_name, LfwPath.IMAGES_DIR, label_name) self._save_image(item, subdir=subdir) - if label is not None: - person1 = label_name - num1 = item.id - if num1.startswith(person1): - num1 = int(num1.replace(person1, "")[1:]) - curr_item = person1 + "/" + str(num1) + person1 = label_name + num1 = item.id + if num1.startswith(person1): + num1 = int(num1.replace(person1, "")[1:]) + curr_item = person1 + "/" + str(num1) - if "positive_pairs" in label.attributes: + if "positive_pairs" in label_obj.attributes: + if curr_item not in included_items: + included_items.append(curr_item) + for pair in label_obj.attributes["positive_pairs"]: + search = LfwPath.PATTERN.search(pair) + if search: + num2 = search.groups()[1] + num2 = int(num2) + else: + num2 = pair + if num2.startswith(person1): + num2 = num2.replace(person1, "")[1:] + curr_item = person1 + "/" + str(num2) if curr_item not in included_items: included_items.append(curr_item) - for pair in label.attributes["positive_pairs"]: - search = LfwPath.PATTERN.search(pair) - if search: - num2 = search.groups()[1] - num2 = int(num2) - else: - num2 = pair - if num2.startswith(person1): - num2 = num2.replace(person1, "")[1:] - curr_item = person1 + "/" + str(num2) - if curr_item not in included_items: - included_items.append(curr_item) - positive_pairs.append("%s\t%s\t%s" % (person1, num1, num2)) - - if "negative_pairs" in label.attributes: + positive_pairs.append("%s\t%s\t%s" % (person1, num1, num2)) + + if "negative_pairs" in label_obj.attributes: + if curr_item not in included_items: + included_items.append(curr_item) + for pair in label_obj.attributes["negative_pairs"]: + search = LfwPath.PATTERN.search(pair) + curr_item = "" + if search: + person2, num2 = search.groups() + num2 = int(num2) + curr_item += person2 + "/" + else: + person2 = "-" + num2 = pair + objects = pair.split("/") + if 1 < len(objects) and objects[0] in labels: + person2 = objects[0] + num2 = pair.replace(person2, "")[1:] + curr_item += person2 + "/" + curr_item += str(num2) if curr_item not in included_items: included_items.append(curr_item) - for pair in label.attributes["negative_pairs"]: - search = LfwPath.PATTERN.search(pair) - curr_item = "" - if search: - person2, num2 = search.groups() - num2 = int(num2) - curr_item += person2 + "/" - else: - person2 = "-" - num2 = pair - objects = pair.split("/") - if 1 < len(objects) and objects[0] in labels: - person2 = objects[0] - num2 = pair.replace(person2, "")[1:] - curr_item += person2 + "/" - curr_item += str(num2) - if curr_item not in included_items: - included_items.append(curr_item) - negative_pairs.append("%s\t%s\t%s\t%s" % (person1, num1, person2, num2)) - - if ( - "positive_pairs" not in label.attributes - and "negative_pairs" not in label.attributes - and curr_item not in included_items - ): - neutral_items.append("%s/%s" % (person1, item.id)) - included_items.append(curr_item) + negative_pairs.append("%s\t%s\t%s\t%s" % (person1, num1, person2, num2)) - elif item.id not in included_items: - neutral_items.append(item.id) - included_items.append(item.id) + if ( + "positive_pairs" not in label_obj.attributes + and "negative_pairs" not in label_obj.attributes + and curr_item not in included_items + ): + neutral_items.append("%s/%s" % (person1, item.id)) + included_items.append(curr_item) item_landmarks = [p for p in item.annotations if p.type == AnnotationType.points] for landmark in item_landmarks: + label_name = label_categories[landmark.label].name landmarks.append( - "%s\t%s" - % (item.id + LfwPath.IMAGE_EXT, "\t".join(str(p) for p in landmark.points)) + "%s/%s\t%s" + % ( + label_name, + item.id + LfwPath.IMAGE_EXT, + "\t".join(str(p) for p in landmark.points), + ) ) annotations_dir = osp.join(self._save_dir, subset_name, LfwPath.ANNOTATION_DIR) @@ -344,3 +351,7 @@ def apply(self): people_file = osp.join(annotations_dir, LfwPath.PEOPLE_FILE) with open(people_file, "w", encoding="utf-8") as f: f.writelines(["%s\t%d\n" % (label, labels[label]) for label in labels]) + + if unlabeled_items and self._save_media: + for item in unlabeled_items: + self._save_image(item, subdir=osp.join(item.subset, LfwPath.IMAGES_DIR)) diff --git a/tests/test_lfw_format.py b/tests/test_lfw_format.py index c9f492cc71..da180dad86 100644 --- a/tests/test_lfw_format.py +++ b/tests/test_lfw_format.py @@ -143,6 +143,38 @@ def test_can_save_and_load_with_landmarks(self): compare_datasets(self, source_dataset, parsed_dataset) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_save_and_load_with_only_landmarks(self): + source_dataset = Dataset.from_iterable( + [ + DatasetItem( + id="name0_0001", + subset="test", + media=Image(data=np.ones((2, 5, 3))), + annotations=[ + Points([0, 4, 3, 3, 2, 2, 1, 0, 3, 0], label=0), + ], + ), + ], + categories=["name0"], + ) + + target_dataset = Dataset.from_iterable( + [ + DatasetItem( + id="name0_0001", + subset="test", + media=Image(data=np.ones((2, 5, 3))), + ), + ], + categories=["name0"], + ) + + LfwConverter.convert(source_dataset, "./lfw", save_media=True) + parsed_dataset = Dataset.import_from("./lfw", "lfw") + + compare_datasets(self, parsed_dataset, target_dataset) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_can_save_and_load_with_no_subsets(self): source_dataset = Dataset.from_iterable(