-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from schencej/add-image-modality
Add interface to handle image matrices directly
- Loading branch information
Showing
7 changed files
with
122 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
68 changes: 68 additions & 0 deletions
68
smqtk_detection/impls/detect_image_objects/random_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from random import randrange | ||
import numpy as np | ||
from typing import Iterable, Tuple, Dict, Hashable, Sequence | ||
|
||
from smqtk_detection.interfaces.detect_image_objects import DetectImageObjects | ||
from smqtk_detection.utils.bbox import AxisAlignedBoundingBox | ||
|
||
|
||
class RandomDetector(DetectImageObjects): | ||
""" | ||
Example implementation of the `DetectImageObjects` interface. An instance | ||
of this class acts as a functor to generate paired bounding boxes and | ||
classification maps for objects detected in a set of images. | ||
""" | ||
|
||
def detect_objects( | ||
self, | ||
img_iter: Iterable[np.ndarray] | ||
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]: | ||
""" | ||
Return random set of detections for each image in the input set. | ||
""" | ||
|
||
dets = [] | ||
|
||
for img in img_iter: | ||
img_h = img.shape[0] | ||
img_w = img.shape[1] | ||
|
||
num_dets = randrange(10) | ||
|
||
dets.append( | ||
[( | ||
self._gen_random_bbox(img_w, img_h), | ||
self._gen_random_class_map([0, 1, 2]) | ||
) for _ in range(num_dets)] | ||
) | ||
|
||
return dets | ||
|
||
def _gen_random_bbox(self, img_w: int, img_h: int) -> AxisAlignedBoundingBox: | ||
""" | ||
Creates `AxisAlignedBoundingBox` object with random vertices within | ||
passed image size. | ||
""" | ||
|
||
min_vertex = [randrange(int(img_w/2)), randrange(int(img_h/2))] | ||
max_vertex = [randrange(int(img_w/2), img_w), randrange(int(img_h/2), img_h)] | ||
|
||
return AxisAlignedBoundingBox(min_vertex, max_vertex) | ||
|
||
def _gen_random_class_map(self, classes: Sequence) -> Dict[Hashable, float]: | ||
""" | ||
Creates dictionary of random classification scores for the list of | ||
input classes. | ||
""" | ||
|
||
scores = np.random.rand(len(classes)) | ||
scores = scores / scores.sum() | ||
|
||
d = {} | ||
for i, c in enumerate(classes): | ||
d[c] = scores[i] | ||
|
||
return d | ||
|
||
def get_config(self) -> dict: | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import abc | ||
from typing import Iterable, Hashable, Dict, Tuple | ||
import numpy as np | ||
|
||
from smqtk_core import Configurable, Pluggable | ||
from smqtk_detection.utils.bbox import AxisAlignedBoundingBox | ||
|
||
|
||
class DetectImageObjects (Configurable, Pluggable): | ||
""" | ||
Algorithm that generates object bounding boxes and classification maps for | ||
a set of input image matricies as ``numpy.ndarray`` type arrays. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def detect_objects( | ||
self, | ||
img_iter: Iterable[np.ndarray] | ||
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]: | ||
""" | ||
Generate paired bounding boxes and classification maps for detected | ||
objects in the given set of images. | ||
:param img_iter: Iterable of input images as numpy arrays. | ||
:return: Iterable of sets of paired bounding boxes and classification | ||
maps. Each set is the collection of detections for the | ||
corresponding input image. | ||
""" | ||
|
||
def __call__( | ||
self, | ||
img_iter: Iterable[np.ndarray] | ||
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]: | ||
""" | ||
Calls `detect_objects() with the given iterable set of images.` | ||
""" | ||
|
||
return self.detect_objects(img_iter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters