From 71d72735e554d3b4f48b464f8e073db6ac446348 Mon Sep 17 00:00:00 2001 From: James Gallagher Date: Wed, 22 Nov 2023 10:38:39 +0000 Subject: [PATCH] add embed_text and embed_image functions --- autodistill_clip/clip_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/autodistill_clip/clip_model.py b/autodistill_clip/clip_model.py index 939903d..aed3912 100644 --- a/autodistill_clip/clip_model.py +++ b/autodistill_clip/clip_model.py @@ -36,12 +36,23 @@ def __init__(self, ontology: CaptionOntology): self.clip_preprocess = preprocess self.tokenize = clip.tokenize + def embed_image(self, input: str) -> np.ndarray: + image = self.clip_preprocess(Image.open(input)).unsqueeze(0).to(DEVICE) + + with torch.no_grad(): + image_features = self.clip_model.encode_image(image) + + return image_features.cpu().numpy() + + def embed_text(self, input: str) -> np.ndarray: + return self.clip_model.encode_text(self.tokenize([input]).to(DEVICE)).cpu().numpy() + def predict(self, input: str) -> sv.Classifications: labels = self.ontology.prompts() image = self.clip_preprocess(Image.open(input)).unsqueeze(0).to(DEVICE) - if isinstance(labels, str): + if isinstance(labels[0], str): text = self.tokenize(labels).to(DEVICE) with torch.no_grad():