Skip to content

Commit

Permalink
wip: decorator for functional embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Nov 30, 2023
1 parent 4b07dca commit 555c3ef
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 55 deletions.
1 change: 1 addition & 0 deletions renumics/spotlight/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder]
for column in columns:
for embedder_class in registered_embedders:
try:
# embedder = FunctionalEmbedder(func, preprocessor, data_store, column)
embedder = embedder_class(data_store, column)
except CannotEmbed:
continue
Expand Down
92 changes: 91 additions & 1 deletion renumics/spotlight/embeddings/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@
A decorator for data analysis functions
"""

from typing import Type
from typing import Callable, Iterable, List, Optional, Type, Any
import itertools
import io
import av
import numpy as np
import PIL.Image
from numpy._typing import DTypeLike
from numpy.lib import math
from renumics.spotlight import dtypes

from renumics.spotlight.dtypes import create_dtype

from renumics.spotlight.embeddings.exceptions import CannotEmbed
from .typing import Embedder
from .registry import register_embedder

Expand All @@ -13,3 +25,81 @@ def embedder(klass: Type[Embedder]) -> Type[Embedder]:
"""
register_embedder(klass)
return klass


def embed(accepts: DTypeLike, *, sampling_rate: Optional[int] = None):
dtype = create_dtype(accepts)

if dtypes.is_image_dtype(dtype):

def _preprocess_batch(raw_values: List[bytes]):
return [PIL.Image.open(io.BytesIO(value)) for value in raw_values]

elif dtypes.is_audio_dtype(dtype):
if sampling_rate is None:
raise ValueError(
"No sampling rate specified, but required for `audio` embedding."
)

def _preprocess_batch(raw_values: Any):
resampled_batch = []
for raw_data in raw_values:
with av.open(io.BytesIO(raw_data), "r") as container:
resampler = av.AudioResampler(
format="dbl", layout="mono", rate=16000
)
data = []
for frame in container.decode(audio=0):
resampled_frames = resampler.resample(frame)
for resampled_frame in resampled_frames:
frame_array = resampled_frame.to_ndarray()[0]
data.append(frame_array)
resampled_batch.append(np.concatenate(data, axis=0))
return resampled_batch

else:

def _preprocess_batch(raw_values: Any):
return raw_values

def decorate(
func: Callable[[Iterable[list]], Iterable[List[Optional[np.ndarray]]]]
):
class EmbedderImpl(Embedder):
def __init__(self, data_store: Any, column: str):
self.dtype = dtype
if data_store.dtypes[column].name != self.dtype.name:
raise CannotEmbed()

self.data_store = data_store
self.column = column
self.batch_size = 16

self._occupied_indices = []

def _iter_batches(self):
self._occupied_indices = []
batch = []
for i in range(len(self.data_store)):
value = self.data_store.get_converted_value(
self.column, i, simple=False, check=False
)

if value is None:
continue

self._occupied_indices.append(i)
batch.append(value)
if len(batch) == self.batch_size:
yield _preprocess_batch(batch)
batch = []

def __call__(self) -> np.ndarray:
embeddings = itertools.chain(*func(self._iter_batches()))
res = np.empty(len(self.data_store), dtype=np.object_)
res[self._occupied_indices] = list(embeddings)
return res

register_embedder(EmbedderImpl)

return decorate
69 changes: 15 additions & 54 deletions renumics/spotlight/embeddings/embedders/vit.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,29 @@
import io
from typing import Any, List
from typing import Iterable, List

from PIL import Image
import numpy as np
import PIL.Image
import transformers

from renumics.spotlight import dtypes
from renumics.spotlight.embeddings.decorator import embedder
from renumics.spotlight.embeddings.exceptions import CannotEmbed
from renumics.spotlight.embeddings.registry import unregister_embedder
from renumics.spotlight.embeddings.typing import Embedder
from renumics.spotlight.embeddings.decorator import embed
from renumics.spotlight.logging import logger

try:
import torch
except ImportError:
logger.warning("`ViTEmbedder` requires `pytorch` to be installed.")
_torch_available = False
else:
_torch_available = True


@embedder
class ViTEmbedder(Embedder):
def __init__(self, data_store: Any, column: str) -> None:
if not dtypes.is_image_dtype(data_store.dtypes[column]):
raise CannotEmbed
self._data_store = data_store
self._column = column

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

@embed("image")
def vit(batches: Iterable[List[PIL.Image.Image]]):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "google/vit-base-patch16-224"
self.processor = transformers.AutoImageProcessor.from_pretrained(model_name)
self.model = transformers.ViTModel.from_pretrained(model_name).to(self.device)

def __call__(self) -> np.ndarray:
values = self._data_store.get_converted_values(
self._column, indices=slice(None), simple=False, check=False
)
none_mask = [sample is None for sample in values]
if all(none_mask):
return np.array([None] * len(values), dtype=np.object_)

embeddings = self.embed_images(
[Image.open(io.BytesIO(value)) for value in values if value is not None]
)

if any(none_mask):
res = np.empty(len(values), dtype=np.object_)
res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings)
return res

return embeddings

def embed_images(self, batch: List[Image.Image]) -> np.ndarray:
images = [image.convert("RGB") for image in batch]
inputs = self.processor(images=images, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs.to(self.device))
embeddings = outputs.last_hidden_state[:, 0].cpu().numpy()

return embeddings
processor = transformers.AutoImageProcessor.from_pretrained(model_name)
model = transformers.ViTModel.from_pretrained(model_name).to(device)

for batch in batches:
images = [image.convert("RGB") for image in batch]
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs.to(device))
embeddings = outputs.last_hidden_state[:, 0].cpu().numpy()

if not _torch_available:
unregister_embedder(ViTEmbedder)
yield embeddings

0 comments on commit 555c3ef

Please sign in to comment.