diff --git a/autodistill_grounded_sam/grounded_sam.py b/autodistill_grounded_sam/grounded_sam.py index de101e0..4810ffe 100644 --- a/autodistill_grounded_sam/grounded_sam.py +++ b/autodistill_grounded_sam/grounded_sam.py @@ -21,6 +21,8 @@ from autodistill.detection import CaptionOntology, DetectionBaseModel +from supervision.detection.utils import mask_to_xyxy + HOME = os.path.expanduser("~") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -77,5 +79,8 @@ def predict(self, input: Any) -> sv.Detections: detections.mask = np.array(result_masks) + # override GroundingDINO bboxes with calculated masks bboxes + detections.xyxy = mask_to_xyxy(detections.mask) + # separate in supervision to combine detections and override class_ids return detections