diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 56e9a49b..4c8dba2a 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -26,6 +26,7 @@ def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder] for column in columns: for embedder_class in registered_embedders: try: + # embedder = FunctionalEmbedder(func, preprocessor, data_store, column) embedder = embedder_class(data_store, column) except CannotEmbed: continue diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py index 7127d1a2..00b775ae 100644 --- a/renumics/spotlight/embeddings/decorator.py +++ b/renumics/spotlight/embeddings/decorator.py @@ -2,7 +2,19 @@ A decorator for data analysis functions """ -from typing import Type +from typing import Callable, Iterable, List, Optional, Type, Any +import itertools +import io +import av +import numpy as np +import PIL.Image +from numpy._typing import DTypeLike +from numpy.lib import math +from renumics.spotlight import dtypes + +from renumics.spotlight.dtypes import create_dtype + +from renumics.spotlight.embeddings.exceptions import CannotEmbed from .typing import Embedder from .registry import register_embedder @@ -13,3 +25,81 @@ def embedder(klass: Type[Embedder]) -> Type[Embedder]: """ register_embedder(klass) return klass + + +def embed(accepts: DTypeLike, *, sampling_rate: Optional[int] = None): + dtype = create_dtype(accepts) + + if dtypes.is_image_dtype(dtype): + + def _preprocess_batch(raw_values: List[bytes]): + return [PIL.Image.open(io.BytesIO(value)) for value in raw_values] + + elif dtypes.is_audio_dtype(dtype): + if sampling_rate is None: + raise ValueError( + "No sampling rate specified, but required for `audio` embedding." + ) + + def _preprocess_batch(raw_values: Any): + resampled_batch = [] + for raw_data in raw_values: + with av.open(io.BytesIO(raw_data), "r") as container: + resampler = av.AudioResampler( + format="dbl", layout="mono", rate=16000 + ) + data = [] + for frame in container.decode(audio=0): + resampled_frames = resampler.resample(frame) + for resampled_frame in resampled_frames: + frame_array = resampled_frame.to_ndarray()[0] + data.append(frame_array) + resampled_batch.append(np.concatenate(data, axis=0)) + return resampled_batch + + else: + + def _preprocess_batch(raw_values: Any): + return raw_values + + def decorate( + func: Callable[[Iterable[list]], Iterable[List[Optional[np.ndarray]]]] + ): + class EmbedderImpl(Embedder): + def __init__(self, data_store: Any, column: str): + self.dtype = dtype + if data_store.dtypes[column].name != self.dtype.name: + raise CannotEmbed() + + self.data_store = data_store + self.column = column + self.batch_size = 16 + + self._occupied_indices = [] + + def _iter_batches(self): + self._occupied_indices = [] + batch = [] + for i in range(len(self.data_store)): + value = self.data_store.get_converted_value( + self.column, i, simple=False, check=False + ) + + if value is None: + continue + + self._occupied_indices.append(i) + batch.append(value) + if len(batch) == self.batch_size: + yield _preprocess_batch(batch) + batch = [] + + def __call__(self) -> np.ndarray: + embeddings = itertools.chain(*func(self._iter_batches())) + res = np.empty(len(self.data_store), dtype=np.object_) + res[self._occupied_indices] = list(embeddings) + return res + + register_embedder(EmbedderImpl) + + return decorate diff --git a/renumics/spotlight/embeddings/embedders/vit.py b/renumics/spotlight/embeddings/embedders/vit.py index 4fe17187..8221ed83 100644 --- a/renumics/spotlight/embeddings/embedders/vit.py +++ b/renumics/spotlight/embeddings/embedders/vit.py @@ -1,68 +1,29 @@ -import io -from typing import Any, List +from typing import Iterable, List -from PIL import Image -import numpy as np +import PIL.Image import transformers -from renumics.spotlight import dtypes -from renumics.spotlight.embeddings.decorator import embedder -from renumics.spotlight.embeddings.exceptions import CannotEmbed -from renumics.spotlight.embeddings.registry import unregister_embedder -from renumics.spotlight.embeddings.typing import Embedder +from renumics.spotlight.embeddings.decorator import embed from renumics.spotlight.logging import logger try: import torch except ImportError: logger.warning("`ViTEmbedder` requires `pytorch` to be installed.") - _torch_available = False else: - _torch_available = True - - -@embedder -class ViTEmbedder(Embedder): - def __init__(self, data_store: Any, column: str) -> None: - if not dtypes.is_image_dtype(data_store.dtypes[column]): - raise CannotEmbed - self._data_store = data_store - self._column = column - - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + @embed("image") + def vit(batches: Iterable[List[PIL.Image.Image]]): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_name = "google/vit-base-patch16-224" - self.processor = transformers.AutoImageProcessor.from_pretrained(model_name) - self.model = transformers.ViTModel.from_pretrained(model_name).to(self.device) - - def __call__(self) -> np.ndarray: - values = self._data_store.get_converted_values( - self._column, indices=slice(None), simple=False, check=False - ) - none_mask = [sample is None for sample in values] - if all(none_mask): - return np.array([None] * len(values), dtype=np.object_) - - embeddings = self.embed_images( - [Image.open(io.BytesIO(value)) for value in values if value is not None] - ) - - if any(none_mask): - res = np.empty(len(values), dtype=np.object_) - res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) - return res - - return embeddings - - def embed_images(self, batch: List[Image.Image]) -> np.ndarray: - images = [image.convert("RGB") for image in batch] - inputs = self.processor(images=images, return_tensors="pt") - with torch.no_grad(): - outputs = self.model(**inputs.to(self.device)) - embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() - - return embeddings + processor = transformers.AutoImageProcessor.from_pretrained(model_name) + model = transformers.ViTModel.from_pretrained(model_name).to(device) + for batch in batches: + images = [image.convert("RGB") for image in batch] + inputs = processor(images=images, return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs.to(device)) + embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() -if not _torch_available: - unregister_embedder(ViTEmbedder) + yield embeddings