Skip to content

Commit

Permalink
Refactor code for better modularity and add new features
Browse files Browse the repository at this point in the history
Deleted the primitives.py and postprocessing.text.py files along with their corresponding tests as their functionalities have been refactored. parse_mark_id method from the deleted files have been moved to sam_segmentation.py which is a simplified version of the extract_marks_in_brackets method and is more suitable for the task at hand. extract_relevant_masks has been replaced by a process method in SamResponseProcessor which extracts the relevant mark ids from the text and uses them to index the detections object.

The mark generating classes have been moved to the wrappers directory as they are effectively serving as wrappers around existing models. Made these changes to make the classes more modular, easy to understand and use. They can be extended in the future to include more models other than SAM.

Also, created two new abstract base classes BasePromptCreator and BaseResponseProcessor in base.py file under pipelines directory. These two classes would be used to ensure a consistent structure across different pipelines.

In maestro/__init__.py, SegmentAnythingMarkGenerator was renamed to SegmentAnything to adhere to the new structure. Bumped the version of Maestro to 0.2.0rc1 from 0.1.1rc1 reflecting the amount of changes made. Finally, updated README.md to reflect the changes made to the API but the new API usage example still needs to be added.
  • Loading branch information
SkalskiP committed Dec 4, 2023
1 parent 71943ee commit 1157fc5
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 183 deletions.
59 changes: 1 addition & 58 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ pip install maestro

## 🔌 API

🚧 The project is still under construction. The redesigned API is coming soon.

![maestro-docs-Snap](https://github.com/roboflow/multimodal-maestro/assets/26109316/a787b7c0-527e-465a-9ca9-d46f4d63ea53)

## 🚀 examples
Expand All @@ -50,69 +48,14 @@ Find dog.

<br>

- **load image**

```python
import cv2

image = cv2.imread("...")
```

- **create and refine marks**

```python
import maestro

generator = maestro.SegmentAnythingMarkGenerator(device='cuda')
marks = generator.generate(image=image)
marks = maestro.refine_marks(marks=marks)
```

- **visualize marks**

```python
mark_visualizer = maestro.MarkVisualizer()
marked_image = mark_visualizer.visualize(image=image, marks=marks)
```
![image-vs-marked-image](https://github.com/roboflow/multimodal-maestro/assets/26109316/92951ed2-65c0-475a-9279-6fd344757092)

- **prompt**

```python
prompt = "Find dog."

response = maestro.prompt_image(api_key=api_key, image=marked_image, prompt=prompt)
```

```
>>> "The dog is prominently featured in the center of the image with the label [9]."
```

- **extract related marks**

```python
masks = maestro.extract_relevant_masks(text=response, detections=refined_marks)
```

```
>>> {'6': array([
... [False, False, False, ..., False, False, False],
... [False, False, False, ..., False, False, False],
... [False, False, False, ..., False, False, False],
... ...,
... [ True, True, True, ..., False, False, False],
... [ True, True, True, ..., False, False, False],
... [ True, True, True, ..., False, False, False]])
... }
```
TODO: Add new API example.

</details>

![multimodal-maestro](https://github.com/roboflow/multimodal-maestro/assets/26109316/c04f2b18-2a1d-4535-9582-e5d3ec0a926e)

## 🚧 roadmap

- [ ] Rewriting the `maestro` API.
- [ ] Update [HF space](https://huggingface.co/spaces/Roboflow/SoM).
- [ ] Documentation page.
- [ ] Add GroundingDINO prompting strategy.
Expand Down
7 changes: 1 addition & 6 deletions maestro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__version__ = "development"

from maestro.lmms.gpt4 import prompt_image
from maestro.markers.sam import SegmentAnythingMarkGenerator
from maestro.wrappers.sam import SegmentAnything
from maestro.postprocessing.mask import (
compute_mask_iou_vectorized,
mask_non_max_suppression,
Expand All @@ -17,9 +17,4 @@
masks_to_marks,
refine_marks
)
from maestro.postprocessing.text import (
extract_marks_in_brackets,
extract_relevant_masks
)
from maestro.visualizers import MarkVisualizer
from maestro.primitives import MarkMode
12 changes: 9 additions & 3 deletions maestro/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from abc import ABC
from abc import ABC, abstractmethod


class BasePromptCreator(ABC):
pass

@abstractmethod
def create(self, *args, **kwargs):
pass


class BaseResponseProcessor(ABC):
pass

@abstractmethod
def process(self, *args, **kwargs):
pass
52 changes: 46 additions & 6 deletions maestro/pipelines/sam_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Callable

import re
import numpy as np
import supervision as sv

from maestro.pipelines.base import BasePromptCreator, BaseResponseProcessor


def extract_mark_ids(text: str) -> List[str]:
"""
Extracts all unique marks enclosed in square brackets from a given string.
Duplicates are removed and the results are sorted in descending order.
Args:
text (str): The string to be searched.
Returns:
List[str]: A list of unique marks found within square brackets, sorted in
descending order.
"""
pattern = r'\[(\d+)\]'
found_marks = re.findall(pattern, text)
unique_marks = set(found_marks)
return sorted(unique_marks, key=int, reverse=False)


def default_annotate(image: np.ndarray, marks: sv.Detections) -> np.ndarray:
h, w, _ = image.shape
line_thickness = sv.calculate_dynamic_line_thickness(resolution_wh=(w, h))
mask_annotator = sv.MaskAnnotator(
color_lookup=sv.ColorLookup.INDEX, opacity=0.4)
polygon_annotator = sv.PolygonAnnotator(
color_lookup=sv.ColorLookup.INDEX, thickness=line_thickness)

annotated_image = image.copy()
annotated_image = mask_annotator.annotate(
scene=annotated_image, detections=marks)
return polygon_annotator.annotate(
scene=annotated_image, detections=marks)


class SamPromptCreator(BasePromptCreator):
def __init__(self, device: str):
self.device = device
Expand All @@ -20,16 +54,22 @@ def create(

class SamResponseProcessor(BaseResponseProcessor):

def process(self, text: str) -> List[str]:
pass
def __init__(
self,
annotate: Callable[[np.ndarray, sv.Detections], np.ndarray] = default_annotate,
) -> None:
self.annotate = annotate

def extract(self, text: str, marks: sv.Detections) -> sv.Detections:
pass
def process(self, text: str, marks: sv.Detections) -> sv.Detections:
mark_ids = extract_mark_ids(text=text)
mark_ids = np.array(mark_ids, dtype=int)
return marks[mark_ids]

def visualize(
self,
text: str,
image: np.ndarray,
marks: sv.Detections
) -> np.ndarray:
pass
marks = self.process(text=text, marks=marks)
return self.annotate(image, marks)
62 changes: 0 additions & 62 deletions maestro/postprocessing/text.py

This file was deleted.

13 changes: 0 additions & 13 deletions maestro/primitives.py

This file was deleted.

File renamed without changes.
4 changes: 2 additions & 2 deletions maestro/markers/sam.py → maestro/wrappers/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from maestro.postprocessing.mask import masks_to_marks


class SegmentAnythingMarkGenerator:
class SegmentAnything:
"""
A class for performing image segmentation using a specified model.
Expand All @@ -28,7 +28,7 @@ def __init__(self, device: str = 'cpu', model_name: str = "facebook/sam-vit-huge
image_processor=self.image_processor,
device=self.device)

def generate(
def predict(
self,
image: np.ndarray,
mask: Optional[np.ndarray] = None
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "maestro"
version = "0.1.1rc1"
version = "0.2.0rc1"
description = "Visual Prompting for Large Multimodal Models (LMMs)"
authors = ["Piotr Skalski <[email protected]>"]
maintainers = ["Piotr Skalski <[email protected]>"]
Expand Down Expand Up @@ -34,7 +34,7 @@ classifiers=[

[tool.poetry.dependencies]
python = ">=3.8,<3.12.0"
supervision = "^0.17.0rc4"
supervision = "^0.17.0rc6"
requests = "^2.31.0"
transformers = "^4.35.2"

Expand Down
Empty file added test/pipelines/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions test/pipelines/test_sam_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List

import pytest

from maestro.pipelines.sam_segmentation import extract_mark_ids


@pytest.mark.parametrize(
"text, expected_result",
[
("[1]", ["1"]),
("lorem ipsum [1] dolor sit amet", ["1"]),
("[1] lorem ipsum [2] dolor sit amet", ["1", "2"]),
("[1] lorem ipsum [1] dolor sit amet", ["1"]),
("[2] lorem ipsum [1] dolor sit amet", ["1", "2"])
]
)
def test_extract_marks_in_brackets(text: str, expected_result: List[str]) -> None:
result = extract_mark_ids(text=text)
assert result == expected_result
31 changes: 0 additions & 31 deletions test/test_postprocess.py

This file was deleted.

0 comments on commit 1157fc5

Please sign in to comment.