Skip to content

Commit

Permalink
use load image, upadte readme eaxmple
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Dec 6, 2023
1 parent 6f19cbd commit b0a7659
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 19 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Publish WorkFlow

on:
release:
types: [created]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.head_ref }}
- name: 🐍 Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: 🦾 Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
- name: 🚀 Publish to PyPi
env:
PYPI_USERNAME: ${{ secrets.PYPI_USERNAME }}
PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
PYPI_TEST_PASSWORD: ${{ secrets.PYPI_TEST_PASSWORD }}
run: |
make publish -e PYPI_USERNAME=$PYPI_USERNAME -e PYPI_PASSWORD=$PYPI_PASSWORD -e PYPI_TEST_PASSWORD=$PYPI_TEST_PASSWORD
30 changes: 30 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Test WorkFlow

on:
pull_request:
branches: [main]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.head_ref }}
- name: 🐍 Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: 🦾 Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
- name: 🧹 Lint with flake8
run: |
make check_code_quality
- name: 🧪 Test
run: "python -m pytest ./test"
16 changes: 16 additions & 0 deletions .github/workflows/welcome.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
on:
issues:
types: [opened]
pull_request_target:
types: [opened]

jobs:
build:
name: 👋 Welcome
runs-on: ubuntu-latest
steps:
- uses: actions/[email protected]
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
issue-message: "Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you soon."
pr-message: "Hello there, thank you for opening an PR ! 🙏🏻 The team was notified and they will get back to you soon."
26 changes: 21 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,32 @@ You can find a full list of `detection` Target Models on [the main autodistill r

```python
from autodistill_grounded_sam import GroundedSAM
from autodistill_yolov8 import YOLOv8
from autodistill.detection import CaptionOntology
from autodistill.utils import plot
import cv2


# define an ontology to map class names to our GroundingDINO prompt
# define an ontology to map class names to our GroundedSAM prompt
# the ontology dictionary has the format {caption: class}
# where caption is the prompt sent to the base model, and class is the label that will
# be saved for that caption in the generated annotations
# then, load the model
base_model = GroundedSAM(ontology=CaptionOntology({"shipping container": "container"}))

base_model = GroundedSAM(
ontology=CaptionOntology(
{
"person": "person",
"shipping container": "shipping container",
}
)
)

# run inference on a single image
results = base_model.predict("logistics.jpeg")

plot(
image=cv2.imread("logistics.jpeg"),
classes=base_model.ontology.classes(),
detections=results
)
# label all images in a folder called `context_images`
base_model.label("./context_images", extension=".jpeg")
```
Expand Down
26 changes: 14 additions & 12 deletions autodistill_grounded_sam/grounded_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import cv2

torch.use_deterministic_algorithms(False)

from typing import Any

import numpy as np
import supervision as sv
from autodistill_grounded_sam.helpers import (combine_detections,
load_grounding_dino,
load_SAM)
from autodistill.helpers import load_image
from groundingdino.util.inference import Model
from segment_anything import SamPredictor

import numpy as np
from autodistill.detection import CaptionOntology, DetectionBaseModel

from autodistill_grounded_sam.helpers import (
combine_detections,
load_grounding_dino,
load_SAM
)

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@dataclass
class GroundedSAM(DetectionBaseModel):
ontology: CaptionOntology
Expand All @@ -33,15 +33,17 @@ class GroundedSAM(DetectionBaseModel):
box_threshold: float
text_threshold: float

def __init__(self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25):
def __init__(
self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25
):
self.ontology = ontology
self.grounding_dino_model = load_grounding_dino()
self.sam_predictor = load_SAM()
self.box_threshold = box_threshold
self.text_threshold = text_threshold

def predict(self, input: str) -> sv.Detections:
image = cv2.imread(input)
def predict(self, input: Any) -> sv.Detections:
image = load_image(input, return_format="cv2")

# GroundingDINO predictions
detections_list = []
Expand Down Expand Up @@ -76,4 +78,4 @@ def predict(self, input: str) -> sv.Detections:
detections.mask = np.array(result_masks)

# separate in supervision to combine detections and override class_ids
return detections
return detections
3 changes: 1 addition & 2 deletions autodistill_grounded_sam/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import urllib.request

import numpy as np
import supervision as sv
import torch
from groundingdino.util.inference import Model
from segment_anything import SamPredictor, sam_model_registry

import supervision as sv

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not torch.cuda.is_available():
Expand Down

0 comments on commit b0a7659

Please sign in to comment.