From 2857c053b146fea2c712612ca44e33aceb175b2d Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Wed, 22 Nov 2023 14:58:00 +0100 Subject: [PATCH 01/24] WIP: dummy embed columns --- renumics/spotlight/analysis/registry.py | 3 +- renumics/spotlight/data_store.py | 29 ++++++++++++--- renumics/spotlight/embeddings/__init__.py | 37 +++++++++++++++++++ renumics/spotlight/embeddings/decorator.py | 14 +++++++ .../embeddings/embedders/__init__.py | 0 .../spotlight/embeddings/embedders/dummy.py | 13 +++++++ renumics/spotlight/embeddings/registry.py | 22 +++++++++++ renumics/spotlight/embeddings/typing.py | 10 +++++ 8 files changed, 122 insertions(+), 6 deletions(-) create mode 100644 renumics/spotlight/embeddings/__init__.py create mode 100644 renumics/spotlight/embeddings/decorator.py create mode 100644 renumics/spotlight/embeddings/embedders/__init__.py create mode 100644 renumics/spotlight/embeddings/embedders/dummy.py create mode 100644 renumics/spotlight/embeddings/registry.py create mode 100644 renumics/spotlight/embeddings/typing.py diff --git a/renumics/spotlight/analysis/registry.py b/renumics/spotlight/analysis/registry.py index 8d7b75b5..d50212ae 100644 --- a/renumics/spotlight/analysis/registry.py +++ b/renumics/spotlight/analysis/registry.py @@ -1,10 +1,11 @@ """ Manage data analyzers available for spotlights automatic dataset analysis. """ +from typing import Set from .typing import DataAnalyzer -registered_analyzers = set() +registered_analyzers: Set[DataAnalyzer] = set() def register_analyzer(analyzer: DataAnalyzer) -> None: diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index aff0f73a..68802321 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -3,7 +3,7 @@ import io import os import statistics -from typing import Any, Iterable, List, Optional, Set, Union, cast +from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast import numpy as np import filetype @@ -17,6 +17,7 @@ from renumics.spotlight.data_source import DataSource from renumics.spotlight.dtypes.conversion import ConvertedValue, convert_to_dtype from renumics.spotlight.data_source.data_source import ColumnMetadata +from renumics.spotlight.embeddings import embed_columns from renumics.spotlight.io import audio from renumics.spotlight.typing import is_iterable, is_pathtype from renumics.spotlight.media.mesh import Mesh @@ -31,13 +32,16 @@ class DataStore: _data_source: DataSource _user_dtypes: spotlight_dtypes.DTypeMap _dtypes: spotlight_dtypes.DTypeMap + _embeddings: Dict[str, np.ndarray] def __init__( self, data_source: DataSource, user_dtypes: spotlight_dtypes.DTypeMap ) -> None: + self._embeddings = {} self._data_source = data_source self._user_dtypes = user_dtypes self._update_dtypes() + self._embeddings = embed_columns(self, self._data_source.column_names) def __len__(self) -> int: return len(self._data_source) @@ -56,7 +60,7 @@ def generation_id(self) -> int: @property def column_names(self) -> List[str]: - return self._data_source.column_names + return self._data_source.column_names + list(self._embeddings) @property def data_source(self) -> DataSource: @@ -64,12 +68,22 @@ def data_source(self) -> DataSource: @property def dtypes(self) -> spotlight_dtypes.DTypeMap: - return self._dtypes + return { + **self._dtypes, + **{ + column: spotlight_dtypes.EmbeddingDType(length=embeddings.shape[1]) + for column, embeddings in self._embeddings.items() + }, + } def check_generation_id(self, generation_id: int) -> None: self._data_source.check_generation_id(generation_id) def get_column_metadata(self, column_name: str) -> ColumnMetadata: + if column_name in self._embeddings: + return ColumnMetadata( + nullable=True, editable=False, hidden=True, description=None, tags=[] + ) return self._data_source.get_column_metadata(column_name) def get_converted_values( @@ -79,8 +93,13 @@ def get_converted_values( simple: bool = False, check: bool = True, ) -> List[ConvertedValue]: - dtype = self._dtypes[column_name] - normalized_values = self._data_source.get_column_values(column_name, indices) + dtype = self.dtypes[column_name] + if column_name in self._embeddings: + normalized_values: Iterable = self._embeddings[column_name] + else: + normalized_values = self._data_source.get_column_values( + column_name, indices + ) converted_values = [ convert_to_dtype(value, dtype, simple=simple, check=check) for value in normalized_values diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py new file mode 100644 index 00000000..2f30db26 --- /dev/null +++ b/renumics/spotlight/embeddings/__init__.py @@ -0,0 +1,37 @@ +""" +Dataset Analysis +""" + +import importlib +import pkgutil +from typing import Any, Dict, List + +import numpy as np + +from renumics.spotlight.logging import logger + +from .registry import registered_embedders +from . import embedders as embedders_namespace + +# import all modules in .embedders +for module_info in pkgutil.iter_modules(embedders_namespace.__path__): + importlib.import_module(embedders_namespace.__name__ + "." + module_info.name) + + +def embed_columns(data_store: Any, columns: List[str]) -> Dict[str, np.ndarray]: + """ + Find dataset issues in the data source + """ + + logger.info("Embedding started.") + + all_embeddings: Dict[str, np.ndarray] = {} + for column in columns: + for embedder in registered_embedders: + if (embeddings := embedder(data_store, column)) is not None: + all_embeddings[f"{column}.embedding"] = embeddings + break + + logger.info("Embedding done.") + + return all_embeddings diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py new file mode 100644 index 00000000..19370dc4 --- /dev/null +++ b/renumics/spotlight/embeddings/decorator.py @@ -0,0 +1,14 @@ +""" +A decorator for data analysis functions +""" + +from .typing import Embedder +from .registry import register_embedder + + +def embedder(func: Embedder) -> Embedder: + """ + register an embedder function + """ + register_embedder(func) + return func diff --git a/renumics/spotlight/embeddings/embedders/__init__.py b/renumics/spotlight/embeddings/embedders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/renumics/spotlight/embeddings/embedders/dummy.py b/renumics/spotlight/embeddings/embedders/dummy.py new file mode 100644 index 00000000..8db53fa0 --- /dev/null +++ b/renumics/spotlight/embeddings/embedders/dummy.py @@ -0,0 +1,13 @@ +from typing import Any, Optional + +import numpy as np + +from renumics.spotlight import dtypes +from renumics.spotlight.embeddings.decorator import embedder + + +@embedder +def dummy(data_store: Any, column: str) -> Optional[np.ndarray]: + if dtypes.is_image_dtype(data_store.dtypes[column]): + return np.random.random((len(data_store), 4)) + return None diff --git a/renumics/spotlight/embeddings/registry.py b/renumics/spotlight/embeddings/registry.py new file mode 100644 index 00000000..800b6307 --- /dev/null +++ b/renumics/spotlight/embeddings/registry.py @@ -0,0 +1,22 @@ +""" +Manage data analyzers available for spotlights automatic dataset analysis. +""" +from typing import Set + +from .typing import Embedder + +registered_embedders: Set[Embedder] = set() + + +def register_embedder(embedder: Embedder) -> None: + """ + Register an embedder + """ + registered_embedders.add(embedder) + + +def unregister_embedder(embedder: Embedder) -> None: + """ + Unregister an embedder + """ + registered_embedders.remove(embedder) diff --git a/renumics/spotlight/embeddings/typing.py b/renumics/spotlight/embeddings/typing.py new file mode 100644 index 00000000..d3517114 --- /dev/null +++ b/renumics/spotlight/embeddings/typing.py @@ -0,0 +1,10 @@ +""" +Shared types for embeddings +""" + +from typing import Any, Callable, Optional + +import numpy as np + + +Embedder = Callable[[Any, str], Optional[np.ndarray]] From 4b6cc6df6325d9d04a9ccb2a00884136e5107345 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 28 Nov 2023 16:25:41 +0100 Subject: [PATCH 02/24] Compute embeddings lazy --- renumics/spotlight/app.py | 39 ++++++++++ renumics/spotlight/backend/exceptions.py | 11 +++ renumics/spotlight/data_source/data_source.py | 1 + renumics/spotlight/data_store.py | 28 +++++-- renumics/spotlight/embeddings/__init__.py | 28 +++++-- renumics/spotlight/embeddings/decorator.py | 9 ++- .../spotlight/embeddings/embedders/dummy.py | 17 +++-- renumics/spotlight/embeddings/exceptions.py | 9 +++ renumics/spotlight/embeddings/registry.py | 8 +- renumics/spotlight/embeddings/typing.py | 16 +++- renumics/spotlight_plugins/core/api/table.py | 23 ++++++ src/client/apis/TableApi.ts | 73 +++++++++++++++++++ src/client/models/Column.ts | 9 +++ 13 files changed, 242 insertions(+), 29 deletions(-) create mode 100644 renumics/spotlight/embeddings/exceptions.py diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index 5cc67f56..645261e4 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -31,6 +31,7 @@ ResetLayoutMessage, WebsocketManager, ) +from renumics.spotlight.embeddings import create_embedders, embed from renumics.spotlight.layout.nodes import Layout from renumics.spotlight.backend.config import Config from renumics.spotlight.typing import PathType @@ -83,6 +84,15 @@ class IssuesUpdatedMessage(Message): data: Any = None +class EmbeddingsUpdatedMessage(Message): + """ + Notify about updated embeddings. + """ + + type: Literal["columnsUpdated"] = "columnsUpdated" + data: Any = None + + class SpotlightApp(FastAPI): """ Spotlight wsgi application @@ -334,6 +344,7 @@ def update(self, config: AppConfig) -> None: self._data_store = DataStore(data_source, self._user_dtypes) self._broadcast(RefreshMessage()) self._update_issues() + self._update_embeddings() if config.layout is not None: if self._data_store is not None: dataset_uid = self._data_store.uid @@ -451,6 +462,34 @@ def _on_issues_ready(future: Future) -> None: task.future.add_done_callback(_on_issues_ready) + def _update_embeddings(self) -> None: + """ + Update embeddings, update them in the data store and notify client about. + """ + self._broadcast(EmbeddingsUpdatedMessage()) + + if self._data_store is None: + return + + embedders = create_embedders(self._data_store, self._data_store.column_names) + + self._data_store.embeddings = {column: None for column in embedders} + + task = self.task_manager.create_task( + embed, (embedders,), name="update_embeddings" + ) + + def _on_embeddings_ready(future: Future) -> None: + if self._data_store is None: + return + try: + self._data_store.embeddings = future.result() + except CancelledError: + return + self._broadcast(EmbeddingsUpdatedMessage()) + + task.future.add_done_callback(_on_embeddings_ready) + def _broadcast(self, message: Message) -> None: """ Broadcast a message to all connected clients via websocket diff --git a/renumics/spotlight/backend/exceptions.py b/renumics/spotlight/backend/exceptions.py index 6bbfa155..cf08cf90 100644 --- a/renumics/spotlight/backend/exceptions.py +++ b/renumics/spotlight/backend/exceptions.py @@ -181,3 +181,14 @@ def __init__(self) -> None: "dataset using `dataset.rebuild()`.", status.HTTP_500_INTERNAL_SERVER_ERROR, ) + + +class ComputedColumnNotReady(Problem): + """A computed column is not yet ready""" + + def __init__(self, name: str) -> None: + super().__init__( + "Computed column not ready", + f"Computed column {name} is not yet ready.", + status.HTTP_404_NOT_FOUND, + ) diff --git a/renumics/spotlight/data_source/data_source.py b/renumics/spotlight/data_source/data_source.py index 2807710c..d2e2e300 100644 --- a/renumics/spotlight/data_source/data_source.py +++ b/renumics/spotlight/data_source/data_source.py @@ -27,6 +27,7 @@ class ColumnMetadata: hidden: bool = False description: Optional[str] = None tags: List[str] = dataclasses.field(default_factory=list) + computed: bool = False class DataSource(ABC): diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index 68802321..a63cb161 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -17,7 +17,6 @@ from renumics.spotlight.data_source import DataSource from renumics.spotlight.dtypes.conversion import ConvertedValue, convert_to_dtype from renumics.spotlight.data_source.data_source import ColumnMetadata -from renumics.spotlight.embeddings import embed_columns from renumics.spotlight.io import audio from renumics.spotlight.typing import is_iterable, is_pathtype from renumics.spotlight.media.mesh import Mesh @@ -32,7 +31,7 @@ class DataStore: _data_source: DataSource _user_dtypes: spotlight_dtypes.DTypeMap _dtypes: spotlight_dtypes.DTypeMap - _embeddings: Dict[str, np.ndarray] + _embeddings: Dict[str, Optional[np.ndarray]] def __init__( self, data_source: DataSource, user_dtypes: spotlight_dtypes.DTypeMap @@ -41,7 +40,6 @@ def __init__( self._data_source = data_source self._user_dtypes = user_dtypes self._update_dtypes() - self._embeddings = embed_columns(self, self._data_source.column_names) def __len__(self) -> int: return len(self._data_source) @@ -71,18 +69,33 @@ def dtypes(self) -> spotlight_dtypes.DTypeMap: return { **self._dtypes, **{ - column: spotlight_dtypes.EmbeddingDType(length=embeddings.shape[1]) + column: spotlight_dtypes.EmbeddingDType( + length=None if embeddings is None else embeddings.shape[1] + ) for column, embeddings in self._embeddings.items() }, } + @property + def embeddings(self) -> Dict[str, Optional[np.ndarray]]: + return self._embeddings + + @embeddings.setter + def embeddings(self, new_embeddings: Dict[str, Optional[np.ndarray]]) -> None: + self._embeddings = new_embeddings + def check_generation_id(self, generation_id: int) -> None: self._data_source.check_generation_id(generation_id) def get_column_metadata(self, column_name: str) -> ColumnMetadata: if column_name in self._embeddings: return ColumnMetadata( - nullable=True, editable=False, hidden=True, description=None, tags=[] + nullable=True, + editable=False, + hidden=True, + description=None, + tags=[], + computed=True, ) return self._data_source.get_column_metadata(column_name) @@ -95,7 +108,10 @@ def get_converted_values( ) -> List[ConvertedValue]: dtype = self.dtypes[column_name] if column_name in self._embeddings: - normalized_values: Iterable = self._embeddings[column_name] + embeddings = self._embeddings[column_name] + if embeddings is None: + return [None] * len(self) + normalized_values: Iterable = embeddings else: normalized_values = self._data_source.get_column_values( column_name, indices diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 2f30db26..3e75cb1c 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -7,6 +7,8 @@ from typing import Any, Dict, List import numpy as np +from renumics.spotlight.embeddings.exceptions import CannotEmbed +from renumics.spotlight.embeddings.typing import Embedder from renumics.spotlight.logging import logger @@ -18,20 +20,30 @@ importlib.import_module(embedders_namespace.__name__ + "." + module_info.name) -def embed_columns(data_store: Any, columns: List[str]) -> Dict[str, np.ndarray]: +def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder]: """ - Find dataset issues in the data source + Create embedding functions for the given data store. """ logger.info("Embedding started.") - all_embeddings: Dict[str, np.ndarray] = {} + embedders: Dict[str, Embedder] = {} for column in columns: - for embedder in registered_embedders: - if (embeddings := embedder(data_store, column)) is not None: - all_embeddings[f"{column}.embedding"] = embeddings - break + for embedder_class in registered_embedders: + try: + embedder = embedder_class(data_store, column) + except CannotEmbed: + continue + embedders[f"{column}.embedding"] = embedder + break logger.info("Embedding done.") - return all_embeddings + return embedders + + +def embed(embedders: Dict[str, Embedder]) -> Dict[str, np.ndarray]: + """ + Run the given functions. + """ + return {column: embedder() for column, embedder in embedders.items()} diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py index 19370dc4..7127d1a2 100644 --- a/renumics/spotlight/embeddings/decorator.py +++ b/renumics/spotlight/embeddings/decorator.py @@ -2,13 +2,14 @@ A decorator for data analysis functions """ +from typing import Type from .typing import Embedder from .registry import register_embedder -def embedder(func: Embedder) -> Embedder: +def embedder(klass: Type[Embedder]) -> Type[Embedder]: """ - register an embedder function + register an embedder class """ - register_embedder(func) - return func + register_embedder(klass) + return klass diff --git a/renumics/spotlight/embeddings/embedders/dummy.py b/renumics/spotlight/embeddings/embedders/dummy.py index 8db53fa0..951337af 100644 --- a/renumics/spotlight/embeddings/embedders/dummy.py +++ b/renumics/spotlight/embeddings/embedders/dummy.py @@ -1,13 +1,20 @@ -from typing import Any, Optional +from typing import Any import numpy as np from renumics.spotlight import dtypes from renumics.spotlight.embeddings.decorator import embedder +from renumics.spotlight.embeddings.exceptions import CannotEmbed +from renumics.spotlight.embeddings.typing import Embedder @embedder -def dummy(data_store: Any, column: str) -> Optional[np.ndarray]: - if dtypes.is_image_dtype(data_store.dtypes[column]): - return np.random.random((len(data_store), 4)) - return None +class Dummy(Embedder): + def __init__(self, data_store: Any, column: str) -> None: + if not dtypes.is_image_dtype(data_store.dtypes[column]): + raise CannotEmbed + self._data_store = data_store + self._column = column + + def __call__(self) -> np.ndarray: + return np.random.random((len(self._data_store), 4)) diff --git a/renumics/spotlight/embeddings/exceptions.py b/renumics/spotlight/embeddings/exceptions.py new file mode 100644 index 00000000..92c57bf7 --- /dev/null +++ b/renumics/spotlight/embeddings/exceptions.py @@ -0,0 +1,9 @@ +""" +Exceptions used by embedders. +""" + + +class CannotEmbed(Exception): + """ + Raised when a column cannot be embed by an embedder. + """ diff --git a/renumics/spotlight/embeddings/registry.py b/renumics/spotlight/embeddings/registry.py index 800b6307..eaab1f9a 100644 --- a/renumics/spotlight/embeddings/registry.py +++ b/renumics/spotlight/embeddings/registry.py @@ -1,21 +1,21 @@ """ Manage data analyzers available for spotlights automatic dataset analysis. """ -from typing import Set +from typing import Set, Type from .typing import Embedder -registered_embedders: Set[Embedder] = set() +registered_embedders: Set[Type[Embedder]] = set() -def register_embedder(embedder: Embedder) -> None: +def register_embedder(embedder: Type[Embedder]) -> None: """ Register an embedder """ registered_embedders.add(embedder) -def unregister_embedder(embedder: Embedder) -> None: +def unregister_embedder(embedder: Type[Embedder]) -> None: """ Unregister an embedder """ diff --git a/renumics/spotlight/embeddings/typing.py b/renumics/spotlight/embeddings/typing.py index d3517114..a5dacc5f 100644 --- a/renumics/spotlight/embeddings/typing.py +++ b/renumics/spotlight/embeddings/typing.py @@ -2,9 +2,21 @@ Shared types for embeddings """ -from typing import Any, Callable, Optional +from abc import ABC, abstractmethod +from typing import Any import numpy as np -Embedder = Callable[[Any, str], Optional[np.ndarray]] +class Embedder(ABC): + @abstractmethod + def __init__(self, data_store: Any, column: str) -> None: + """ + Raise if dtype of the given column is not supported. + """ + + @abstractmethod + def __call__(self) -> np.ndarray: + """ + Embed the given column. + """ diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index 81d265c2..e55d2a8e 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -28,6 +28,7 @@ class Column(BaseModel): values: List[Any] description: Optional[str] tags: Optional[List[str]] + computed: bool class Table(BaseModel): @@ -82,6 +83,7 @@ def get_table(request: Request) -> ORJSONResponse: dtype=dtype.dict(), description=meta.description, tags=meta.tags, + computed=meta.computed, ) columns.append(column) @@ -95,6 +97,27 @@ def get_table(request: Request) -> ORJSONResponse: ) +@router.get( + "/{column}", + tags=["table"], + operation_id="get_column", +) +async def get_table_column( + column: str, generation_id: int, request: Request +) -> Response: + """ + table column api endpoint + """ + app: SpotlightApp = request.app + data_store = app.data_store + if data_store is None: + return ORJSONResponse(None) + data_store.check_generation_id(generation_id) + + values = data_store.get_converted_values(column, simple=False) + return ORJSONResponse(values) + + @router.get( "/{column}/{row}", tags=["table"], diff --git a/src/client/apis/TableApi.ts b/src/client/apis/TableApi.ts index f7ab8807..267e9206 100644 --- a/src/client/apis/TableApi.ts +++ b/src/client/apis/TableApi.ts @@ -27,6 +27,11 @@ export interface GetCellRequest { generationId: number; } +export interface GetColumnRequest { + column: string; + generationId: number; +} + export interface GetWaveformRequest { column: string; row: number; @@ -121,6 +126,74 @@ export class TableApi extends runtime.BaseAPI { return await response.value(); } + /** + * table column api endpoint + * Get Table Column + */ + async getColumnRaw( + requestParameters: GetColumnRequest, + initOverrides?: RequestInit | runtime.InitOverrideFunction + ): Promise> { + if ( + requestParameters.column === null || + requestParameters.column === undefined + ) { + throw new runtime.RequiredError( + 'column', + 'Required parameter requestParameters.column was null or undefined when calling getColumn.' + ); + } + + if ( + requestParameters.generationId === null || + requestParameters.generationId === undefined + ) { + throw new runtime.RequiredError( + 'generationId', + 'Required parameter requestParameters.generationId was null or undefined when calling getColumn.' + ); + } + + const queryParameters: any = {}; + + if (requestParameters.generationId !== undefined) { + queryParameters['generation_id'] = requestParameters.generationId; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + const response = await this.request( + { + path: `/api/table/{column}`.replace( + `{${'column'}}`, + encodeURIComponent(String(requestParameters.column)) + ), + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, + initOverrides + ); + + if (this.isJsonMime(response.headers.get('content-type'))) { + return new runtime.JSONApiResponse(response); + } else { + return new runtime.TextApiResponse(response) as any; + } + } + + /** + * table column api endpoint + * Get Table Column + */ + async getColumn( + requestParameters: GetColumnRequest, + initOverrides?: RequestInit | runtime.InitOverrideFunction + ): Promise { + const response = await this.getColumnRaw(requestParameters, initOverrides); + return await response.value(); + } + /** * table slice api endpoint * Get Table diff --git a/src/client/models/Column.ts b/src/client/models/Column.ts index e80dfad0..ffe02507 100644 --- a/src/client/models/Column.ts +++ b/src/client/models/Column.ts @@ -67,6 +67,12 @@ export interface Column { * @memberof Column */ tags: Array | null; + /** + * + * @type {boolean} + * @memberof Column + */ + computed: boolean; } /** @@ -82,6 +88,7 @@ export function instanceOfColumn(value: object): boolean { isInstance = isInstance && 'values' in value; isInstance = isInstance && 'description' in value; isInstance = isInstance && 'tags' in value; + isInstance = isInstance && 'computed' in value; return isInstance; } @@ -103,6 +110,7 @@ export function ColumnFromJSONTyped(json: any, ignoreDiscriminator: boolean): Co values: json['values'], description: json['description'], tags: json['tags'], + computed: json['computed'], }; } @@ -122,5 +130,6 @@ export function ColumnToJSON(value?: Column | null): any { values: value.values, description: value.description, tags: value.tags, + computed: value.computed, }; } From 2eaa418c88cd36c4b2b4274d91a448d92f19df1a Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 29 Nov 2023 13:23:17 +0100 Subject: [PATCH 03/24] feat: fetch computed column values when ready --- renumics/spotlight/app.py | 8 ++-- renumics/spotlight/data_store.py | 3 +- .../spotlight/embeddings/embedders/dummy.py | 2 + renumics/spotlight_plugins/core/api/table.py | 17 ++++++-- src/client/models/Column.ts | 2 +- src/dataformat/index.ts | 2 +- src/lenses/LensFactory.tsx | 2 +- src/lenses/useCellValues.ts | 35 +++++++++++----- src/stores/dataset/columnFactory.ts | 1 + src/stores/dataset/dataset.ts | 42 ++++++++++++++++++- src/types/dataset.ts | 1 + src/widgets/DataGrid/hooks/useCellValue.ts | 2 +- 12 files changed, 90 insertions(+), 27 deletions(-) diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index 645261e4..02ef2e87 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -84,13 +84,13 @@ class IssuesUpdatedMessage(Message): data: Any = None -class EmbeddingsUpdatedMessage(Message): +class ColumnsUpdatedMessage(Message): """ Notify about updated embeddings. """ type: Literal["columnsUpdated"] = "columnsUpdated" - data: Any = None + data: List[str] class SpotlightApp(FastAPI): @@ -466,8 +466,6 @@ def _update_embeddings(self) -> None: """ Update embeddings, update them in the data store and notify client about. """ - self._broadcast(EmbeddingsUpdatedMessage()) - if self._data_store is None: return @@ -486,7 +484,7 @@ def _on_embeddings_ready(future: Future) -> None: self._data_store.embeddings = future.result() except CancelledError: return - self._broadcast(EmbeddingsUpdatedMessage()) + self._broadcast(ColumnsUpdatedMessage(data=list(embedders.keys()))) task.future.add_done_callback(_on_embeddings_ready) diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index a63cb161..af0f5bb7 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -25,6 +25,7 @@ from renumics.spotlight.media.image import Image from renumics.spotlight.media.sequence_1d import Sequence1D from renumics.spotlight.media.embedding import Embedding +from renumics.spotlight.backend.exceptions import ComputedColumnNotReady class DataStore: @@ -110,7 +111,7 @@ def get_converted_values( if column_name in self._embeddings: embeddings = self._embeddings[column_name] if embeddings is None: - return [None] * len(self) + raise ComputedColumnNotReady(column_name) normalized_values: Iterable = embeddings else: normalized_values = self._data_source.get_column_values( diff --git a/renumics/spotlight/embeddings/embedders/dummy.py b/renumics/spotlight/embeddings/embedders/dummy.py index 951337af..621a699a 100644 --- a/renumics/spotlight/embeddings/embedders/dummy.py +++ b/renumics/spotlight/embeddings/embedders/dummy.py @@ -1,4 +1,5 @@ from typing import Any +import time import numpy as np @@ -17,4 +18,5 @@ def __init__(self, data_store: Any, column: str) -> None: self._column = column def __call__(self) -> np.ndarray: + time.sleep(10) return np.random.random((len(self._data_store), 4)) diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index e55d2a8e..84840315 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -8,7 +8,11 @@ from fastapi.responses import ORJSONResponse, Response from pydantic import BaseModel -from renumics.spotlight.backend.exceptions import FilebrowsingNotAllowed, InvalidPath +from renumics.spotlight.backend.exceptions import ( + ComputedColumnNotReady, + FilebrowsingNotAllowed, + InvalidPath, +) from renumics.spotlight.app import SpotlightApp from renumics.spotlight.app_config import AppConfig from renumics.spotlight.io.path import is_path_relative_to @@ -25,7 +29,7 @@ class Column(BaseModel): optional: bool hidden: bool dtype: Any - values: List[Any] + values: Optional[List[Any]] description: Optional[str] tags: Optional[List[str]] computed: bool @@ -72,7 +76,12 @@ def get_table(request: Request) -> ORJSONResponse: columns = [] for column_name in data_store.column_names: dtype = data_store.dtypes[column_name] - values = data_store.get_converted_values(column_name, simple=True, check=False) + try: + values = data_store.get_converted_values( + column_name, simple=True, check=False + ) + except ComputedColumnNotReady: + values = None meta = data_store.get_column_metadata(column_name) column = Column( name=column_name, @@ -114,7 +123,7 @@ async def get_table_column( return ORJSONResponse(None) data_store.check_generation_id(generation_id) - values = data_store.get_converted_values(column, simple=False) + values = data_store.get_converted_values(column, simple=True, check=False) return ORJSONResponse(values) diff --git a/src/client/models/Column.ts b/src/client/models/Column.ts index ffe02507..7ca50307 100644 --- a/src/client/models/Column.ts +++ b/src/client/models/Column.ts @@ -54,7 +54,7 @@ export interface Column { * @type {Array} * @memberof Column */ - values: Array; + values: Array | null; /** * * @type {string} diff --git a/src/dataformat/index.ts b/src/dataformat/index.ts index 61e99e50..fcbd4706 100644 --- a/src/dataformat/index.ts +++ b/src/dataformat/index.ts @@ -28,7 +28,7 @@ export class Formatter { format(value: any, type: DataType, full = false): string { // format a single value by it's DataType (usually taken from column.type) - if (value === null) { + if (value === null || value === undefined) { return ''; } diff --git a/src/lenses/LensFactory.tsx b/src/lenses/LensFactory.tsx index 5d2e0790..f8117d4f 100644 --- a/src/lenses/LensFactory.tsx +++ b/src/lenses/LensFactory.tsx @@ -97,7 +97,7 @@ const LensFactory: FunctionComponent = ({ const allEditable = columns.every((c) => c.editable); if (problem) { - return Failed to load value!; + return {problem.title}; } if (!values || !urls) return ; if (!LensComponent) return View not found ({view})!; diff --git a/src/lenses/useCellValues.ts b/src/lenses/useCellValues.ts index 7051e986..e7b0221c 100644 --- a/src/lenses/useCellValues.ts +++ b/src/lenses/useCellValues.ts @@ -6,15 +6,29 @@ import { shallow } from 'zustand/shallow'; import { usePrevious } from '../hooks'; async function fetchValue(row: number, column: string, raw: boolean) { - const response = await api.table.getCellRaw({ - row, - column, - generationId: useDataset.getState().generationID, - }); - if (raw) { - return response.raw.arrayBuffer(); - } else { - return response.value(); + try { + const response = await api.table.getCellRaw({ + row, + column, + generationId: useDataset.getState().generationID, + }); + if (raw) { + return response.raw.arrayBuffer(); + } else { + return response.value(); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + if (error.response?.json) { + throw await error.response.json(); + } else { + const problem: Problem = { + type: 'FailedToLoadValue', + title: 'Failed to load value', + detail: error.toString?.(), + }; + throw problem; + } } } @@ -27,7 +41,7 @@ function useCellValues( ): [unknown[] | undefined, Problem | undefined] { const cellsSelector = useCallback( (d: Dataset) => { - return columnKeys.map((key) => d.columnData[key][rowIndex]); + return columnKeys.map((key) => d.columnData[key]?.[rowIndex]); }, [rowIndex, columnKeys] ); @@ -92,7 +106,6 @@ function useCellValues( }) .catch((error) => { if (!cancelledRef.current) { - console.error(error); setProblem(error); } }); diff --git a/src/stores/dataset/columnFactory.ts b/src/stores/dataset/columnFactory.ts index 25097deb..4707dc40 100644 --- a/src/stores/dataset/columnFactory.ts +++ b/src/stores/dataset/columnFactory.ts @@ -77,6 +77,7 @@ export function makeColumn(column: Column, index: number): DataColumn { type: makeDatatype(column.dtype, column.optional), editable: column.editable, optional: column.optional, + computed: column.computed, hidden: column.hidden, description: column.description ?? '', tags: _.uniq(column.tags), diff --git a/src/stores/dataset/dataset.ts b/src/stores/dataset/dataset.ts index 7463da7b..6b639b83 100644 --- a/src/stores/dataset/dataset.ts +++ b/src/stores/dataset/dataset.ts @@ -67,6 +67,7 @@ export interface Dataset { lastFocusedRow?: number; // the last row that has been focused by a view openTable: (path: string) => void; //open the table file at path fetch: () => void; // fetch the dataset from the backend + refetchColumnValues: (columnKey: string) => void; // refetch values for a single column (after update/computation) fetchIssues: () => void; // fetch the ready issues refresh: () => void; // refresh the dataset from the backend addFilter: (filter: Filter) => void; // add a new filter @@ -135,7 +136,7 @@ const fetchTable = async (): Promise<{ const columnData: TableData = {}; table.columns.forEach((rawColumn, i) => { const dsColumn = columns[i]; - if (rawColumn.values === undefined) { + if (rawColumn.values === undefined || rawColumn.values === null) { return; } @@ -243,7 +244,7 @@ export const useDataset = create( filtered: {}, }; - set(() => ({ + set({ uid, generationID, filename, @@ -253,6 +254,37 @@ export const useDataset = create( columnsByKey: _.keyBy(dataframe.columns, 'key'), columnData: dataframe.data, columnStats, + }); + }, + refetchColumnValues: async (columnKey) => { + const column = get().columnsByKey[columnKey]; + + let rawValues = null; + try { + rawValues = await api.table.getColumn({ + column: columnKey, + generationId: get().generationID, + }); + } catch (error) { + notifyAPIError(error); + return; + } + let values = rawValues.map((value: unknown) => + convertValue(value, column.type) + ); + + switch (column.type.kind) { + case 'int': + case 'Category': + values = Int32Array.from(values); + break; + case 'float': + values = Float32Array.from(values); + break; + } + + set(({ columnData }) => ({ + columnData: { ...columnData, [columnKey]: values }, })); }, fetchIssues: async () => { @@ -562,6 +594,12 @@ websocketService.registerMessageHandler('issuesUpdated', () => { useDataset.getState().fetchIssues(); }); +websocketService.registerMessageHandler('columnsUpdated', (columnKeys: string[]) => { + for (const columnKey of columnKeys) { + useDataset.getState().refetchColumnValues(columnKey); + } +}); + useDataset.subscribe( (state) => state.columns, (columns) => { diff --git a/src/types/dataset.ts b/src/types/dataset.ts index 7e4b6801..805dabae 100644 --- a/src/types/dataset.ts +++ b/src/types/dataset.ts @@ -8,6 +8,7 @@ export interface DataColumn { type: datatypes.DataType; editable: boolean; optional: boolean; + computed: boolean; hidden: boolean; description: string; tags: string[]; diff --git a/src/widgets/DataGrid/hooks/useCellValue.ts b/src/widgets/DataGrid/hooks/useCellValue.ts index 33450a45..0786fdfb 100644 --- a/src/widgets/DataGrid/hooks/useCellValue.ts +++ b/src/widgets/DataGrid/hooks/useCellValue.ts @@ -4,7 +4,7 @@ import useSort from './useSort'; // eslint-disable-next-line @typescript-eslint/no-explicit-any function useCellValue(columnKey: string, rowIndex: number): any { const originalIndex = useSort().getOriginalIndex(rowIndex); - return useDataset((d: Dataset) => d.columnData[columnKey][originalIndex]); + return useDataset((d: Dataset) => d.columnData[columnKey]?.[originalIndex]); } export default useCellValue; From 2880e56b0eed06fa0c9373a178291637ad219c36 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 29 Nov 2023 14:19:06 +0100 Subject: [PATCH 04/24] feat: delay fetching of cell values until column is computed --- src/lenses/useCellValues.ts | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/lenses/useCellValues.ts b/src/lenses/useCellValues.ts index e7b0221c..75882413 100644 --- a/src/lenses/useCellValues.ts +++ b/src/lenses/useCellValues.ts @@ -1,3 +1,4 @@ +import _ from 'lodash'; import { useCallback, useEffect, useRef, useState } from 'react'; import { Dataset, useDataset } from '../stores/dataset'; import { Problem } from '../types'; @@ -57,7 +58,17 @@ function useCellValues( const cellEntries = useDataset(cellsSelector, shallow); const columns = useDataset(columnsSelector, shallow); const generationId = useDataset((d) => d.generationID); - const previousGenerationId = usePrevious(generationId); + + const isAnyColumnComputing = useDataset((d) => + _.some( + columnKeys, + (key) => d.columnsByKey[key].computed && d.columnData[key] === undefined + ) + ); + + const previousGenerationId = usePrevious( + isAnyColumnComputing ? undefined : generationId + ); const [values, setValues] = useState(); const [problem, setProblem] = useState(); @@ -66,6 +77,7 @@ function useCellValues( const promisesRef = useRef>>({}); const cancelledRef = useRef(false); + useEffect(() => { // reset cancelled for StrictMode in dev cancelledRef.current = false; @@ -75,6 +87,8 @@ function useCellValues( }, []); useEffect(() => { + if (isAnyColumnComputing) return; + const promises = promisesRef.current; if (generationId !== previousGenerationId) { @@ -111,6 +125,7 @@ function useCellValues( }); } }, [ + isAnyColumnComputing, cellEntries, columnKeys, columns, From a7b53902dd79132a01cd281fad9e13047c2a885b Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 29 Nov 2023 15:09:18 +0100 Subject: [PATCH 05/24] feat: delay umap call until embeddings are ready --- src/widgets/SimilarityMap/SimilarityMap.tsx | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 9bd9c55f..a377f4e4 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -273,6 +273,13 @@ const SimilarityMap: Widget = () => { const widgetId = useMemo(() => uuidv4(), []); + const anyColumnComputing = useDataset((d) => + _.some( + placeByColumnKeys, + (key) => d.columnsByKey[key]?.computed && d.columnData[key] === undefined + ) + ); + useEffect(() => { setVisibleIndices([]); setPositions([]); @@ -282,7 +289,11 @@ const SimilarityMap: Widget = () => { setIsComputing(false); return; } + setIsComputing(true); + if (anyColumnComputing) { + return; + } const reductionPromise = reductionMethod === 'umap' @@ -327,6 +338,7 @@ const SimilarityMap: Widget = () => { umapMetric, umapMinDist, pcaNormalization, + anyColumnComputing, ]); const getOriginalIndices = useCallback( From a277b357530f7eb5dfdf0b593eed4ca3af7319bf Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 29 Nov 2023 15:23:10 +0100 Subject: [PATCH 06/24] feat: show specific loading message while waiting for computed embeddings --- src/widgets/SimilarityMap/SimilarityMap.tsx | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index a377f4e4..275ebf3e 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -36,6 +36,7 @@ import TooltipContent from './TooltipContent'; import { ReductionMethod } from './types'; import Info from '../../components/ui/Info'; import { unknownDataType } from '../../datatypes'; +import { Spinner } from '../../lib'; const MapContainer = styled.div` ${tw`bg-gray-100 border-gray-400 w-full h-full overflow-hidden`} @@ -103,6 +104,7 @@ const SimilarityMap: Widget = () => { >('pcaNormalization', 'none'); const [isComputing, setIsComputing] = useState(false); + const [loadingMessage, setLoadingMessage] = useState(''); const { fullColumns, @@ -290,11 +292,15 @@ const SimilarityMap: Widget = () => { return; } - setIsComputing(true); if (anyColumnComputing) { + setIsComputing(true); + setLoadingMessage(`Computing embeddings`); return; } + setIsComputing(true); + setLoadingMessage(`Computing ${reductionMethod}`); + const reductionPromise = reductionMethod === 'umap' ? dataService.computeUmap( @@ -473,7 +479,14 @@ const SimilarityMap: Widget = () => { ); } else if (isComputing) { - content = ; + content = ( + +
+ +
{loadingMessage}
+
+
+ ); } else if (!areColumnsSelected) { content = ( From 415e687f856bf185f2907cadc34fbcea7c07a78e Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Wed, 29 Nov 2023 16:23:34 +0100 Subject: [PATCH 07/24] chore: remove unnecessary import --- src/widgets/SimilarityMap/SimilarityMap.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/src/widgets/SimilarityMap/SimilarityMap.tsx b/src/widgets/SimilarityMap/SimilarityMap.tsx index 275ebf3e..096115c3 100644 --- a/src/widgets/SimilarityMap/SimilarityMap.tsx +++ b/src/widgets/SimilarityMap/SimilarityMap.tsx @@ -1,5 +1,4 @@ import SimilaritiesIcon from '../../icons/Bubbles'; -import LoadingIndicator from '../../components/LoadingIndicator'; import Plot, { MergeStrategy, Points, From 117371149e56eef8b378d2c02de5dedb0c8d85f2 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Thu, 30 Nov 2023 09:38:27 +0100 Subject: [PATCH 08/24] Implement ViT embedder --- poetry.lock | 100 ++++++++++++------ pyproject.toml | 11 +- renumics/spotlight/app.py | 3 + renumics/spotlight/data_store.py | 33 ++++-- renumics/spotlight/embeddings/__init__.py | 10 +- .../spotlight/embeddings/embedders/dummy.py | 22 ---- .../spotlight/embeddings/embedders/vit.py | 68 ++++++++++++ renumics/spotlight_plugins/core/api/table.py | 1 + 8 files changed, 170 insertions(+), 78 deletions(-) delete mode 100644 renumics/spotlight/embeddings/embedders/dummy.py create mode 100644 renumics/spotlight/embeddings/embedders/vit.py diff --git a/poetry.lock b/poetry.lock index 5d2f2b1b..edc589cb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3404,6 +3404,43 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.9\""} llvmlite = "==0.40.*" numpy = ">=1.21,<1.25" +[[package]] +name = "numpy" +version = "1.24.3" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c1104d3c036fb81ab923f507536daedc718d0ad5a8707c6061cdfd6d184e570"}, + {file = "numpy-1.24.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:202de8f38fc4a45a3eea4b63e2f376e5f2dc64ef0fa692838e31a808520efaf7"}, + {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8535303847b89aa6b0f00aa1dc62867b5a32923e4d1681a35b5eef2d9591a463"}, + {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d926b52ba1367f9acb76b0df6ed21f0b16a1ad87c6720a1121674e5cf63e2b6"}, + {file = "numpy-1.24.3-cp310-cp310-win32.whl", hash = "sha256:f21c442fdd2805e91799fbe044a7b999b8571bb0ab0f7850d0cb9641a687092b"}, + {file = "numpy-1.24.3-cp310-cp310-win_amd64.whl", hash = "sha256:ab5f23af8c16022663a652d3b25dcdc272ac3f83c3af4c02eb8b824e6b3ab9d7"}, + {file = "numpy-1.24.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9a7721ec204d3a237225db3e194c25268faf92e19338a35f3a224469cb6039a3"}, + {file = "numpy-1.24.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d6cc757de514c00b24ae8cf5c876af2a7c3df189028d68c0cb4eaa9cd5afc2bf"}, + {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e3f4e85fc5d4fd311f6e9b794d0c00e7002ec122be271f2019d63376f1d385"}, + {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1d3c026f57ceaad42f8231305d4653d5f05dc6332a730ae5c0bea3513de0950"}, + {file = "numpy-1.24.3-cp311-cp311-win32.whl", hash = "sha256:c91c4afd8abc3908e00a44b2672718905b8611503f7ff87390cc0ac3423fb096"}, + {file = "numpy-1.24.3-cp311-cp311-win_amd64.whl", hash = "sha256:5342cf6aad47943286afa6f1609cad9b4266a05e7f2ec408e2cf7aea7ff69d80"}, + {file = "numpy-1.24.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7776ea65423ca6a15255ba1872d82d207bd1e09f6d0894ee4a64678dd2204078"}, + {file = "numpy-1.24.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ae8d0be48d1b6ed82588934aaaa179875e7dc4f3d84da18d7eae6eb3f06c242c"}, + {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecde0f8adef7dfdec993fd54b0f78183051b6580f606111a6d789cd14c61ea0c"}, + {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4749e053a29364d3452c034827102ee100986903263e89884922ef01a0a6fd2f"}, + {file = "numpy-1.24.3-cp38-cp38-win32.whl", hash = "sha256:d933fabd8f6a319e8530d0de4fcc2e6a61917e0b0c271fded460032db42a0fe4"}, + {file = "numpy-1.24.3-cp38-cp38-win_amd64.whl", hash = "sha256:56e48aec79ae238f6e4395886b5eaed058abb7231fb3361ddd7bfdf4eed54289"}, + {file = "numpy-1.24.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4719d5aefb5189f50887773699eaf94e7d1e02bf36c1a9d353d9f46703758ca4"}, + {file = "numpy-1.24.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0ec87a7084caa559c36e0a2309e4ecb1baa03b687201d0a847c8b0ed476a7187"}, + {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8282b9bcfe2b5e7d491d0bf7f3e2da29700cec05b49e64d6246923329f2b02"}, + {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210461d87fb02a84ef243cac5e814aad2b7f4be953b32cb53327bb49fd77fbb4"}, + {file = "numpy-1.24.3-cp39-cp39-win32.whl", hash = "sha256:784c6da1a07818491b0ffd63c6bbe5a33deaa0e25a20e1b3ea20cf0e43f8046c"}, + {file = "numpy-1.24.3-cp39-cp39-win_amd64.whl", hash = "sha256:d5036197ecae68d7f491fcdb4df90082b0d4960ca6599ba2659957aafced7c17"}, + {file = "numpy-1.24.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:352ee00c7f8387b44d19f4cada524586f07379c0d49270f87233983bc5087ca0"}, + {file = "numpy-1.24.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7d6acc2e7524c9955e5c903160aa4ea083736fde7e91276b0e5d98e6332812"}, + {file = "numpy-1.24.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:35400e6a8d102fd07c71ed7dcadd9eb62ee9a6e84ec159bd48c28235bbb0f8e4"}, + {file = "numpy-1.24.3.tar.gz", hash = "sha256:ab344f1bf21f140adab8e47fdbc7c35a477dc01408791f8ba00d018dd0bc5155"}, +] + [[package]] name = "numpy" version = "1.24.4" @@ -3576,8 +3613,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -3633,6 +3670,7 @@ files = [ ] [package.dependencies] +numpy = {version = "<=1.24.3", markers = "python_full_version <= \"3.8.0\""} types-pytz = ">=2022.1.1" [[package]] @@ -5903,31 +5941,19 @@ files = [ [[package]] name = "torch" -version = "2.1.0" +version = "2.1.1+cpu" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bf57f8184b2c317ef81fb33dc233ce4d850cd98ef3f4a38be59c7c1572d175db"}, - {file = "torch-2.1.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a04a0296d47f28960f51c18c5489a8c3472f624ec3b5bcc8e2096314df8c3342"}, - {file = "torch-2.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0bd691efea319b14ef239ede16d8a45c246916456fa3ed4f217d8af679433cc6"}, - {file = "torch-2.1.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:101c139152959cb20ab370fc192672c50093747906ee4ceace44d8dd703f29af"}, - {file = "torch-2.1.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a6b7438a90a870e4cdeb15301519ae6c043c883fcd224d303c5b118082814767"}, - {file = "torch-2.1.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:2224622407ca52611cbc5b628106fde22ed8e679031f5a99ce286629fc696128"}, - {file = "torch-2.1.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8132efb782cd181cc2dcca5e58effbe4217cdb2581206ac71466d535bf778867"}, - {file = "torch-2.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:5c3bfa91ce25ba10116c224c59d5b64cdcce07161321d978bd5a1f15e1ebce72"}, - {file = "torch-2.1.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:601b0a2a9d9233fb4b81f7d47dca9680d4f3a78ca3f781078b6ad1ced8a90523"}, - {file = "torch-2.1.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:3cd1dedff13884d890f18eea620184fb4cd8fd3c68ce3300498f427ae93aa962"}, - {file = "torch-2.1.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:fb7bf0cc1a3db484eb5d713942a93172f3bac026fcb377a0cd107093d2eba777"}, - {file = "torch-2.1.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:761822761fffaa1c18a62c5deb13abaa780862577d3eadc428f1daa632536905"}, - {file = "torch-2.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:458a6d6d8f7d2ccc348ac4d62ea661b39a3592ad15be385bebd0a31ced7e00f4"}, - {file = "torch-2.1.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:c8bf7eaf9514465e5d9101e05195183470a6215bb50295c61b52302a04edb690"}, - {file = "torch-2.1.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05661c32ec14bc3a157193d0f19a7b19d8e61eb787b33353cad30202c295e83b"}, - {file = "torch-2.1.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:556d8dd3e0c290ed9d4d7de598a213fb9f7c59135b4fee144364a8a887016a55"}, - {file = "torch-2.1.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:de7d63c6ecece118684415a3dbd4805af4a4c1ee1490cccf7405d8c240a481b4"}, - {file = "torch-2.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:2419cf49aaf3b2336c7aa7a54a1b949fa295b1ae36f77e2aecb3a74e3a947255"}, - {file = "torch-2.1.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6ad491e70dbe4288d17fdbfc7fbfa766d66cbe219bc4871c7a8096f4a37c98df"}, - {file = "torch-2.1.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:421739685eba5e0beba42cb649740b15d44b0d565c04e6ed667b41148734a75b"}, + {file = "torch-2.1.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:16e6cdb1991ff1f15f469b12fffb5b20f00b1ecff8c1073c23c7760fba90fcbc"}, + {file = "torch-2.1.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:8e4e47aa4a004d2f5edd516b296da6fba07981ae9d1fc0da862abf651dc07d96"}, + {file = "torch-2.1.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:d83b13cb17544f9851cc31fed197865eae0c0f5d32df9d8d6d8535df7d2e5109"}, + {file = "torch-2.1.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:23be0cb945970443c97d4f9ea61ed03b27f924d835de689dd4134f30966c13f7"}, + {file = "torch-2.1.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:9399a8dfe4833bb544e7a0c332c47db1e389a06b4ae5ac9b3b167d863adc95d9"}, + {file = "torch-2.1.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:d8e19bb465e1fa15f3231ff57bbf0a673e5dcaca39eee4f58a4be7832b80c9da"}, + {file = "torch-2.1.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:b2cc98815251f8a2d102c2f8f4afe8304c2df61ce9c237198032c7903d97fdbb"}, + {file = "torch-2.1.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:b09d431c8e53511b6c3624f1d79cce9cd28f4c27ca20c5a49e4dcc1f9e7377e5"}, ] [package.dependencies] @@ -5939,8 +5965,14 @@ sympy = "*" typing-extensions = "*" [package.extras] +dynamo = ["jinja2"] opt-einsum = ["opt-einsum (>=3.3)"] +[package.source] +type = "legacy" +url = "https://download.pytorch.org/whl/cpu" +reference = "torch-cpu" + [[package]] name = "tornado" version = "6.3.3" @@ -6018,13 +6050,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.6.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.35.0" +version = "4.35.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.35.0-py3-none-any.whl", hash = "sha256:45aa9370d7d9ba1c43e6bfa04d7f8b61238497d4b646e573fd95e597fe4040ff"}, - {file = "transformers-4.35.0.tar.gz", hash = "sha256:e4b41763f651282fc979348d3aa148244387ddc9165f4b18455798c770ae23b9"}, + {file = "transformers-4.35.2-py3-none-any.whl", hash = "sha256:9dfa76f8692379544ead84d98f537be01cd1070de75c74efb13abcbc938fbe2f"}, + {file = "transformers-4.35.2.tar.gz", hash = "sha256:2d125e197d77b0cdb6c9201df9fa7e2101493272e448b9fba9341c695bee2f52"}, ] [package.dependencies] @@ -6036,21 +6068,21 @@ pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.3.1" -tokenizers = ">=0.14,<0.15" +tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.20.3)"] agents = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.10,!=1.12.0)"] -all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.15)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] +dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -6076,11 +6108,11 @@ tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] -tokenizers = ["tokenizers (>=0.14,<0.15)"] +tokenizers = ["tokenizers (>=0.14,<0.19)"] torch = ["accelerate (>=0.20.3)", "torch (>=1.10,!=1.12.0)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (<10.0.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.16.4,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch (>=1.10,!=1.12.0)", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (<10.0.0)"] @@ -7061,4 +7093,4 @@ descriptors = ["pycatch22"] [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "27d2c03b47a7211f861566b88e5949dea1dc062c511b002b18dd4bbb3554ce11" +content-hash = "98b1beb638bb2f57aa8f5884392c03b1df6bc78c0d74cbcb2804280d8eb3bb17" diff --git a/pyproject.toml b/pyproject.toml index c6264f98..979da674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,8 @@ httpx = "^0.23.0" datasets = { extras = ["audio"], version = "^2.12.0" } pydantic-settings = "^2.0.3" pycatch22 = { version = "!=0.4.4", optional = true } +transformers = "^4.35.2" +torch = {version = "^2.1.1+cpu", source = "torch-cpu"} [tool.poetry.extras] descriptors = ["pycatch22"] @@ -127,13 +129,15 @@ ruff = "^0.0.281" check-wheel-contents = "^0.6.0" [tool.poetry.group.playbook.dependencies] -datasets = "^2.12.0" -transformers = "^4.29.2" -torch = "^2.0.1" towhee = "^0.9.0" annoy = "^1.17.2" cleanlab = "^2.4.0" +[[tool.poetry.source]] +name = "torch-cpu" +url = "https://download.pytorch.org/whl/cpu" +priority = "explicit" + [tool.poetry-dynamic-versioning] enable = true vcs = "git" @@ -175,6 +179,7 @@ module = [ "datasets", "diffimg", "tests.ui._autogenerated_ui_elements", + "transformers", ] ignore_missing_imports = true diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index 02ef2e87..c5c9b3c9 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -469,6 +469,8 @@ def _update_embeddings(self) -> None: if self._data_store is None: return + logger.info("Embedding started.") + embedders = create_embedders(self._data_store, self._data_store.column_names) self._data_store.embeddings = {column: None for column in embedders} @@ -484,6 +486,7 @@ def _on_embeddings_ready(future: Future) -> None: self._data_store.embeddings = future.result() except CancelledError: return + logger.info("Embedding done.") self._broadcast(ColumnsUpdatedMessage(data=list(embedders.keys()))) task.future.add_done_callback(_on_embeddings_ready) diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index af0f5bb7..3d80d3eb 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -67,15 +67,23 @@ def data_source(self) -> DataSource: @property def dtypes(self) -> spotlight_dtypes.DTypeMap: - return { - **self._dtypes, - **{ - column: spotlight_dtypes.EmbeddingDType( - length=None if embeddings is None else embeddings.shape[1] - ) - for column, embeddings in self._embeddings.items() - }, - } + dtypes_ = self._dtypes.copy() + for column, embeddings in self._embeddings.items(): + if embeddings is None: + length = None + else: + try: + length = len( + next( + embedding + for embedding in embeddings + if embedding is not None + ) + ) + except StopIteration: + length = None + dtypes_[column] = spotlight_dtypes.EmbeddingDType(length=length) + return dtypes_ @property def embeddings(self) -> Dict[str, Optional[np.ndarray]]: @@ -83,6 +91,7 @@ def embeddings(self) -> Dict[str, Optional[np.ndarray]]: @embeddings.setter def embeddings(self, new_embeddings: Dict[str, Optional[np.ndarray]]) -> None: + print(new_embeddings) self._embeddings = new_embeddings def check_generation_id(self, generation_id: int) -> None: @@ -112,15 +121,19 @@ def get_converted_values( embeddings = self._embeddings[column_name] if embeddings is None: raise ComputedColumnNotReady(column_name) - normalized_values: Iterable = embeddings + normalized_values: Iterable = embeddings[indices] else: normalized_values = self._data_source.get_column_values( column_name, indices ) + if column_name == "image.embedding": + print(normalized_values) converted_values = [ convert_to_dtype(value, dtype, simple=simple, check=check) for value in normalized_values ] + if column_name == "image.embedding": + print(converted_values) return converted_values def get_converted_value( diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 3e75cb1c..56e9a49b 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -7,11 +7,9 @@ from typing import Any, Dict, List import numpy as np + from renumics.spotlight.embeddings.exceptions import CannotEmbed from renumics.spotlight.embeddings.typing import Embedder - -from renumics.spotlight.logging import logger - from .registry import registered_embedders from . import embedders as embedders_namespace @@ -24,9 +22,6 @@ def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder] """ Create embedding functions for the given data store. """ - - logger.info("Embedding started.") - embedders: Dict[str, Embedder] = {} for column in columns: for embedder_class in registered_embedders: @@ -36,9 +31,6 @@ def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder] continue embedders[f"{column}.embedding"] = embedder break - - logger.info("Embedding done.") - return embedders diff --git a/renumics/spotlight/embeddings/embedders/dummy.py b/renumics/spotlight/embeddings/embedders/dummy.py deleted file mode 100644 index 621a699a..00000000 --- a/renumics/spotlight/embeddings/embedders/dummy.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any -import time - -import numpy as np - -from renumics.spotlight import dtypes -from renumics.spotlight.embeddings.decorator import embedder -from renumics.spotlight.embeddings.exceptions import CannotEmbed -from renumics.spotlight.embeddings.typing import Embedder - - -@embedder -class Dummy(Embedder): - def __init__(self, data_store: Any, column: str) -> None: - if not dtypes.is_image_dtype(data_store.dtypes[column]): - raise CannotEmbed - self._data_store = data_store - self._column = column - - def __call__(self) -> np.ndarray: - time.sleep(10) - return np.random.random((len(self._data_store), 4)) diff --git a/renumics/spotlight/embeddings/embedders/vit.py b/renumics/spotlight/embeddings/embedders/vit.py new file mode 100644 index 00000000..4fe17187 --- /dev/null +++ b/renumics/spotlight/embeddings/embedders/vit.py @@ -0,0 +1,68 @@ +import io +from typing import Any, List + +from PIL import Image +import numpy as np +import transformers + +from renumics.spotlight import dtypes +from renumics.spotlight.embeddings.decorator import embedder +from renumics.spotlight.embeddings.exceptions import CannotEmbed +from renumics.spotlight.embeddings.registry import unregister_embedder +from renumics.spotlight.embeddings.typing import Embedder +from renumics.spotlight.logging import logger + +try: + import torch +except ImportError: + logger.warning("`ViTEmbedder` requires `pytorch` to be installed.") + _torch_available = False +else: + _torch_available = True + + +@embedder +class ViTEmbedder(Embedder): + def __init__(self, data_store: Any, column: str) -> None: + if not dtypes.is_image_dtype(data_store.dtypes[column]): + raise CannotEmbed + self._data_store = data_store + self._column = column + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model_name = "google/vit-base-patch16-224" + self.processor = transformers.AutoImageProcessor.from_pretrained(model_name) + self.model = transformers.ViTModel.from_pretrained(model_name).to(self.device) + + def __call__(self) -> np.ndarray: + values = self._data_store.get_converted_values( + self._column, indices=slice(None), simple=False, check=False + ) + none_mask = [sample is None for sample in values] + if all(none_mask): + return np.array([None] * len(values), dtype=np.object_) + + embeddings = self.embed_images( + [Image.open(io.BytesIO(value)) for value in values if value is not None] + ) + + if any(none_mask): + res = np.empty(len(values), dtype=np.object_) + res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) + return res + + return embeddings + + def embed_images(self, batch: List[Image.Image]) -> np.ndarray: + images = [image.convert("RGB") for image in batch] + inputs = self.processor(images=images, return_tensors="pt") + with torch.no_grad(): + outputs = self.model(**inputs.to(self.device)) + embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() + + return embeddings + + +if not _torch_available: + unregister_embedder(ViTEmbedder) diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index 84840315..b02dac34 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -145,6 +145,7 @@ async def get_table_cell( data_store.check_generation_id(generation_id) value = data_store.get_converted_value(column, row, simple=False) + print(column, value) if isinstance(value, bytes): return Response(value, media_type="application/octet-stream") From 7d1891a11c886e24f78eaa41257db23e3c7a6eb2 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Thu, 30 Nov 2023 09:39:57 +0100 Subject: [PATCH 09/24] Remove debug prints --- renumics/spotlight/data_store.py | 5 ----- renumics/spotlight_plugins/core/api/table.py | 1 - 2 files changed, 6 deletions(-) diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index 3d80d3eb..de656085 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -91,7 +91,6 @@ def embeddings(self) -> Dict[str, Optional[np.ndarray]]: @embeddings.setter def embeddings(self, new_embeddings: Dict[str, Optional[np.ndarray]]) -> None: - print(new_embeddings) self._embeddings = new_embeddings def check_generation_id(self, generation_id: int) -> None: @@ -126,14 +125,10 @@ def get_converted_values( normalized_values = self._data_source.get_column_values( column_name, indices ) - if column_name == "image.embedding": - print(normalized_values) converted_values = [ convert_to_dtype(value, dtype, simple=simple, check=check) for value in normalized_values ] - if column_name == "image.embedding": - print(converted_values) return converted_values def get_converted_value( diff --git a/renumics/spotlight_plugins/core/api/table.py b/renumics/spotlight_plugins/core/api/table.py index b02dac34..84840315 100644 --- a/renumics/spotlight_plugins/core/api/table.py +++ b/renumics/spotlight_plugins/core/api/table.py @@ -145,7 +145,6 @@ async def get_table_cell( data_store.check_generation_id(generation_id) value = data_store.get_converted_value(column, row, simple=False) - print(column, value) if isinstance(value, bytes): return Response(value, media_type="application/octet-stream") From e4eb843f11442beddda31f376ba9bf737f58d79b Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Thu, 30 Nov 2023 14:42:57 +0100 Subject: [PATCH 10/24] feat: add gte embedder for text --- poetry.lock | 163 +++++++++++++++--- pyproject.toml | 5 +- .../spotlight/embeddings/embedders/gte.py | 50 ++++++ 3 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 renumics/spotlight/embeddings/embedders/gte.py diff --git a/poetry.lock b/poetry.lock index edc589cb..0cf1feb4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2076,7 +2076,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -2710,16 +2709,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3290,6 +3279,31 @@ files = [ {file = "nh3-0.2.14.tar.gz", hash = "sha256:a0c509894fd4dccdff557068e5074999ae3b75f4c5a2d6fb5415e782e25679c4"}, ] +[[package]] +name = "nltk" +version = "3.8.1" +description = "Natural Language Toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, + {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, +] + +[package.dependencies] +click = "*" +joblib = "*" +regex = ">=2021.8.3" +tqdm = "*" + +[package.extras] +all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"] +corenlp = ["requests"] +machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"] +plot = ["matplotlib"] +tgrep = ["pyparsing"] +twitter = ["twython"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -4100,7 +4114,6 @@ description = "22 CAnonical Time-series Features" optional = true python-versions = "*" files = [ - {file = "pycatch22-0.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f22ecf340c75448ad571020c423533c242f3e577e462219951dc6c33d545d58"}, {file = "pycatch22-0.4.2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:8086ed7c223c08ea408f8bc9dd70d16691e6c5da0368d49f85a148557ad5a9c1"}, {file = "pycatch22-0.4.2.tar.gz", hash = "sha256:162bd9774a326993c564120221e779fb6e578139e9fb73ed6066a3137e61c3ad"}, ] @@ -4558,7 +4571,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4566,15 +4578,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4591,7 +4596,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4599,7 +4603,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -5447,6 +5450,82 @@ nativelib = ["pyobjc-framework-Cocoa", "pywin32"] objc = ["pyobjc-framework-Cocoa"] win32 = ["pywin32"] +[[package]] +name = "sentence-transformers" +version = "2.2.2" +description = "Multilingual text embeddings" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "sentence-transformers-2.2.2.tar.gz", hash = "sha256:dbc60163b27de21076c9a30d24b5b7b6fa05141d68cf2553fa9a77bf79a29136"}, +] + +[package.dependencies] +huggingface-hub = ">=0.4.0" +nltk = "*" +numpy = "*" +scikit-learn = "*" +scipy = "*" +sentencepiece = "*" +torch = ">=1.6.0" +torchvision = "*" +tqdm = "*" +transformers = ">=4.6.0,<5.0.0" + +[[package]] +name = "sentencepiece" +version = "0.1.99" +description = "SentencePiece python wrapper" +optional = false +python-versions = "*" +files = [ + {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73"}, + {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7"}, + {file = "sentencepiece-0.1.99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0"}, + {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baed1a26464998f9710d20e52607c29ffd4293e7c71c6a1f83f51ad0911ec12c"}, + {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9832f08bb372d4c8b567612f8eab9e36e268dff645f1c28f9f8e851be705f6d1"}, + {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:019e7535108e309dae2b253a75834fc3128240aa87c00eb80732078cdc182588"}, + {file = "sentencepiece-0.1.99-cp310-cp310-win32.whl", hash = "sha256:fa16a830416bb823fa2a52cbdd474d1f7f3bba527fd2304fb4b140dad31bb9bc"}, + {file = "sentencepiece-0.1.99-cp310-cp310-win_amd64.whl", hash = "sha256:14b0eccb7b641d4591c3e12ae44cab537d68352e4d3b6424944f0c447d2348d5"}, + {file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6d3c56f24183a1e8bd61043ff2c58dfecdc68a5dd8955dc13bab83afd5f76b81"}, + {file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed6ea1819fd612c989999e44a51bf556d0ef6abfb553080b9be3d347e18bcfb7"}, + {file = "sentencepiece-0.1.99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2a0260cd1fb7bd8b4d4f39dc2444a8d5fd4e0a0c4d5c899810ef1abf99b2d45"}, + {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a1abff4d1ff81c77cac3cc6fefa34fa4b8b371e5ee51cb7e8d1ebc996d05983"}, + {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:004e6a621d4bc88978eecb6ea7959264239a17b70f2cbc348033d8195c9808ec"}, + {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db361e03342c41680afae5807590bc88aa0e17cfd1a42696a160e4005fcda03b"}, + {file = "sentencepiece-0.1.99-cp311-cp311-win32.whl", hash = "sha256:2d95e19168875b70df62916eb55428a0cbcb834ac51d5a7e664eda74def9e1e0"}, + {file = "sentencepiece-0.1.99-cp311-cp311-win_amd64.whl", hash = "sha256:f90d73a6f81248a909f55d8e6ef56fec32d559e1e9af045f0b0322637cb8e5c7"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:62e24c81e74bd87a6e0d63c51beb6527e4c0add67e1a17bac18bcd2076afcfeb"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57efcc2d51caff20d9573567d9fd3f854d9efe613ed58a439c78c9f93101384a"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a904c46197993bd1e95b93a6e373dca2f170379d64441041e2e628ad4afb16f"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d89adf59854741c0d465f0e1525b388c0d174f611cc04af54153c5c4f36088c4"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-win32.whl", hash = "sha256:47c378146928690d1bc106fdf0da768cebd03b65dd8405aa3dd88f9c81e35dba"}, + {file = "sentencepiece-0.1.99-cp36-cp36m-win_amd64.whl", hash = "sha256:9ba142e7a90dd6d823c44f9870abdad45e6c63958eb60fe44cca6828d3b69da2"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b7b1a9ae4d7c6f1f867e63370cca25cc17b6f4886729595b885ee07a58d3cec3"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0f644c9d4d35c096a538507b2163e6191512460035bf51358794a78515b74f7"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8843d23a0f686d85e569bd6dcd0dd0e0cbc03731e63497ca6d5bacd18df8b85"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e6f690a1caebb4867a2e367afa1918ad35be257ecdb3455d2bbd787936f155"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-win32.whl", hash = "sha256:8a321866c2f85da7beac74a824b4ad6ddc2a4c9bccd9382529506d48f744a12c"}, + {file = "sentencepiece-0.1.99-cp37-cp37m-win_amd64.whl", hash = "sha256:c42f753bcfb7661c122a15b20be7f684b61fc8592c89c870adf52382ea72262d"}, + {file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:85b476406da69c70586f0bb682fcca4c9b40e5059814f2db92303ea4585c650c"}, + {file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cfbcfe13c69d3f87b7fcd5da168df7290a6d006329be71f90ba4f56bc77f8561"}, + {file = "sentencepiece-0.1.99-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:445b0ec381af1cd4eef95243e7180c63d9c384443c16c4c47a28196bd1cda937"}, + {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6890ea0f2b4703f62d0bf27932e35808b1f679bdb05c7eeb3812b935ba02001"}, + {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb71af492b0eefbf9f2501bec97bcd043b6812ab000d119eaf4bd33f9e283d03"}, + {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27b866b5bd3ddd54166bbcbf5c8d7dd2e0b397fac8537991c7f544220b1f67bc"}, + {file = "sentencepiece-0.1.99-cp38-cp38-win32.whl", hash = "sha256:b133e8a499eac49c581c3c76e9bdd08c338cc1939e441fee6f92c0ccb5f1f8be"}, + {file = "sentencepiece-0.1.99-cp38-cp38-win_amd64.whl", hash = "sha256:0eaf3591dd0690a87f44f4df129cf8d05d8a4029b5b6709b489b8e27f9a9bcff"}, + {file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38efeda9bbfb55052d482a009c6a37e52f42ebffcea9d3a98a61de7aee356a28"}, + {file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c030b081dc1e1bcc9fadc314b19b740715d3d566ad73a482da20d7d46fd444c"}, + {file = "sentencepiece-0.1.99-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:84dbe53e02e4f8a2e45d2ac3e430d5c83182142658e25edd76539b7648928727"}, + {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b0f55d0a0ee1719b4b04221fe0c9f0c3461dc3dabd77a035fa2f4788eb3ef9a"}, + {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e800f206cd235dc27dc749299e05853a4e4332e8d3dfd81bf13d0e5b9007d9"}, + {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ae1c40cda8f9d5b0423cfa98542735c0235e7597d79caf318855cdf971b2280"}, + {file = "sentencepiece-0.1.99-cp39-cp39-win32.whl", hash = "sha256:c84ce33af12ca222d14a1cdd37bd76a69401e32bc68fe61c67ef6b59402f4ab8"}, + {file = "sentencepiece-0.1.99-cp39-cp39-win_amd64.whl", hash = "sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd"}, + {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, +] + [[package]] name = "setuptools" version = "68.2.2" @@ -5973,6 +6052,44 @@ type = "legacy" url = "https://download.pytorch.org/whl/cpu" reference = "torch-cpu" +[[package]] +name = "torchvision" +version = "0.16.1" +description = "image and video datasets and models for torch deep learning" +optional = false +python-versions = ">=3.8" +files = [ + {file = "torchvision-0.16.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:987132795e5c037cb74e7be35a693999fdb2f603152266ee15b80206e83a5b0c"}, + {file = "torchvision-0.16.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:25da6a7b22ea0348f62c45ec0daf157731096babcae65d222404081af96e085c"}, + {file = "torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:c82e291e674a18b67f92ddb476ae18498fb46d7032ae914f3fda90c955e7d86f"}, + {file = "torchvision-0.16.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:632887b22e67ce32a3ede806b868bba4057601e46d680de14b32a391eac1b483"}, + {file = "torchvision-0.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:92c76a5092b4033efdb183b11fa4854a7630e23c46f4a1c3ffd70c30cb5be4fc"}, + {file = "torchvision-0.16.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:4aea5cf491c6c21b1cbdbb1bf2a3838a59d4db93ad5f49019a6564d3ca7127c7"}, + {file = "torchvision-0.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3391757167637ace3ef33a67c9d5ef86b1f8cbd93eaa5bad45eebcf266ea6089"}, + {file = "torchvision-0.16.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:4f9d5b192b336982e6dbe32c070b05606f0b53e87d722ae332a02909fbf988ed"}, + {file = "torchvision-0.16.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:3d34601614958c4e30f53ec0eb7bf3f282ee72bb747734be2d75422831a43384"}, + {file = "torchvision-0.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:e11af530585574eb5ca837b8f151bcdd57c10e35c3af56c76a10f3281d2a2f2c"}, + {file = "torchvision-0.16.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:4f2cad621fb96cf10e29af93e16c98b3226bdd53ae712b57e873c3deaf061617"}, + {file = "torchvision-0.16.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d614b3c9e2de9cd75cc0e4e1923fcfbbcd9fdb9f08a0bbbbf7e135e4a0a1cfa"}, + {file = "torchvision-0.16.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:73271e930501a008fe24ba38945b2a75b25a6098f4c2f4402e39a9d0dd305ca6"}, + {file = "torchvision-0.16.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fab67ddc4809fcc2a04610b13cac5193b9d3be2896b77538bfdff401b13022e5"}, + {file = "torchvision-0.16.1-cp38-cp38-win_amd64.whl", hash = "sha256:13782d574033efec6646d1a2f5d85f4c59fcf3f403367bb407b15df07adc87e0"}, + {file = "torchvision-0.16.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:f14d201c37176dc4106eec76b229d6585a1505266b8cea99d3366fd38897b7c0"}, + {file = "torchvision-0.16.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a15e88a93a7501cc75b761a2dcd07aaedaaf9cbfaf48c8affa8c98989ecbb19d"}, + {file = "torchvision-0.16.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:72fde5fdb462e66ebe25ae42d2ee11434cbc395f74cad0d3b22cf60524345cc5"}, + {file = "torchvision-0.16.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:153f753f14eba58969cdc86360893a57f8bf63f8136c7d1cd4388108560b5446"}, + {file = "torchvision-0.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:75e33b198b1265f61d822aa66d646ec3df67a712470ffec1e0c37ff46d4103c1"}, +] + +[package.dependencies] +numpy = "*" +pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" +requests = "*" +torch = "2.1.1" + +[package.extras] +scipy = ["scipy"] + [[package]] name = "tornado" version = "6.3.3" @@ -7093,4 +7210,4 @@ descriptors = ["pycatch22"] [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "98b1beb638bb2f57aa8f5884392c03b1df6bc78c0d74cbcb2804280d8eb3bb17" +content-hash = "ade3838a00ad06a60e0e9e262ebf891327899fc914af5aa068d657f9461513e5" diff --git a/pyproject.toml b/pyproject.toml index 979da674..fc1eefd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,8 @@ datasets = { extras = ["audio"], version = "^2.12.0" } pydantic-settings = "^2.0.3" pycatch22 = { version = "!=0.4.4", optional = true } transformers = "^4.35.2" -torch = {version = "^2.1.1+cpu", source = "torch-cpu"} +torch = { version = "^2.1.1+cpu", source = "torch-cpu" } +sentence-transformers = "^2.2.2" [tool.poetry.extras] descriptors = ["pycatch22"] @@ -180,6 +181,8 @@ module = [ "diffimg", "tests.ui._autogenerated_ui_elements", "transformers", + "zstandard", + "sentence_transformers", ] ignore_missing_imports = true diff --git a/renumics/spotlight/embeddings/embedders/gte.py b/renumics/spotlight/embeddings/embedders/gte.py new file mode 100644 index 00000000..0b6af4c7 --- /dev/null +++ b/renumics/spotlight/embeddings/embedders/gte.py @@ -0,0 +1,50 @@ +from typing import Any, List + +import numpy as np +import sentence_transformers + +from renumics.spotlight import dtypes +from renumics.spotlight.embeddings.decorator import embedder +from renumics.spotlight.embeddings.exceptions import CannotEmbed +from renumics.spotlight.embeddings.typing import Embedder +from renumics.spotlight.logging import logger + +try: + import torch +except ImportError: + logger.warning("`GTE Embedder` requires `pytorch` to be installed.") +else: + + @embedder + class GteEmbedder(Embedder): + def __init__(self, data_store: Any, column: str) -> None: + if not dtypes.is_str_dtype(data_store.dtypes[column]): + raise CannotEmbed + self._data_store = data_store + self._column = column + + def __call__(self) -> np.ndarray: + values = self._data_store.get_converted_values( + self._column, indices=slice(None), simple=False, check=False + ) + none_mask = [sample is None for sample in values] + if all(none_mask): + return np.array([None] * len(values), dtype=np.object_) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + model = sentence_transformers.SentenceTransformer( + "thenlper/gte-base", device=device + ) + + def _embed_batch(batch: List[str]) -> np.ndarray: + return model.encode(batch, normalize_embeddings=True) + + embeddings = _embed_batch([value for value in values if value is not None]) + + if any(none_mask): + res = np.empty(len(values), dtype=np.object_) + res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) + return res + + return embeddings From d4d8e11c44fd918feea8e84b8f75df9f9e6e2e6f Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Thu, 30 Nov 2023 15:30:53 +0100 Subject: [PATCH 11/24] feat: add wav2vec2 embedder for audio data --- .../spotlight/embeddings/embedders/wav2vec.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 renumics/spotlight/embeddings/embedders/wav2vec.py diff --git a/renumics/spotlight/embeddings/embedders/wav2vec.py b/renumics/spotlight/embeddings/embedders/wav2vec.py new file mode 100644 index 00000000..ffcfdba2 --- /dev/null +++ b/renumics/spotlight/embeddings/embedders/wav2vec.py @@ -0,0 +1,73 @@ +from typing import Any, List + +import numpy as np +import transformers +import av +import io + +from renumics.spotlight import dtypes +from renumics.spotlight.embeddings.decorator import embedder +from renumics.spotlight.embeddings.exceptions import CannotEmbed +from renumics.spotlight.embeddings.typing import Embedder +from renumics.spotlight.logging import logger + +try: + import torch +except ImportError: + logger.warning("`Wav2Vec Embedder` requires `pytorch` to be installed.") +else: + + @embedder + class Wav2VecEmbedder(Embedder): + def __init__(self, data_store: Any, column: str) -> None: + if not dtypes.is_audio_dtype(data_store.dtypes[column]): + raise CannotEmbed + self._data_store = data_store + self._column = column + + def __call__(self) -> np.ndarray: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model_name = "facebook/wav2vec2-base-960h" + sampling_rate = 16000 + processor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + model = transformers.AutoModel.from_pretrained(model_name).to(device) + + values = self._data_store.get_converted_values( + self._column, indices=slice(None), simple=False, check=False + ) + none_mask = [sample is None for sample in values] + if all(none_mask): + return np.array([None] * len(values), dtype=np.object_) + + def _embed_batch(batch: List[bytes]) -> np.ndarray: + resampler = av.AudioResampler(format="dbl", layout="mono", rate=16000) + resampled_batch = [] + for raw_data in batch: + with av.open(io.BytesIO(raw_data), "r") as container: + data = [] + for frame in container.decode(audio=0): + resampled_frames = resampler.resample(frame) + for resampled_frame in resampled_frames: + frame_array = resampled_frame.to_ndarray()[0] + data.append(frame_array) + resampled_batch.append(np.concatenate(data, axis=0)) + + inputs = processor( + raw_speech=resampled_batch, + sampling_rate=sampling_rate, + padding="longest", + return_tensors="pt", + ) + with torch.no_grad(): + outputs = model(**inputs.to(device)) + embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() + return embeddings + + embeddings = _embed_batch([value for value in values if value is not None]) + + if any(none_mask): + res = np.empty(len(values), dtype=np.object_) + res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) + return res + + return embeddings From 4b07dca31986ac73a35b9d80645d558e122237a9 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Thu, 30 Nov 2023 15:44:05 +0100 Subject: [PATCH 12/24] fix: use a separate resampler per audio file --- renumics/spotlight/embeddings/embedders/wav2vec.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/renumics/spotlight/embeddings/embedders/wav2vec.py b/renumics/spotlight/embeddings/embedders/wav2vec.py index ffcfdba2..22490215 100644 --- a/renumics/spotlight/embeddings/embedders/wav2vec.py +++ b/renumics/spotlight/embeddings/embedders/wav2vec.py @@ -40,10 +40,12 @@ def __call__(self) -> np.ndarray: return np.array([None] * len(values), dtype=np.object_) def _embed_batch(batch: List[bytes]) -> np.ndarray: - resampler = av.AudioResampler(format="dbl", layout="mono", rate=16000) resampled_batch = [] for raw_data in batch: with av.open(io.BytesIO(raw_data), "r") as container: + resampler = av.AudioResampler( + format="dbl", layout="mono", rate=16000 + ) data = [] for frame in container.decode(audio=0): resampled_frames = resampler.resample(frame) From 555c3efd3502bb21cca54f9cc5d91aa17749f978 Mon Sep 17 00:00:00 2001 From: Dominik Haentsch Date: Thu, 30 Nov 2023 17:09:35 +0100 Subject: [PATCH 13/24] wip: decorator for functional embedder --- renumics/spotlight/embeddings/__init__.py | 1 + renumics/spotlight/embeddings/decorator.py | 92 ++++++++++++++++++- .../spotlight/embeddings/embedders/vit.py | 69 +++----------- 3 files changed, 107 insertions(+), 55 deletions(-) diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 56e9a49b..4c8dba2a 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -26,6 +26,7 @@ def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder] for column in columns: for embedder_class in registered_embedders: try: + # embedder = FunctionalEmbedder(func, preprocessor, data_store, column) embedder = embedder_class(data_store, column) except CannotEmbed: continue diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py index 7127d1a2..00b775ae 100644 --- a/renumics/spotlight/embeddings/decorator.py +++ b/renumics/spotlight/embeddings/decorator.py @@ -2,7 +2,19 @@ A decorator for data analysis functions """ -from typing import Type +from typing import Callable, Iterable, List, Optional, Type, Any +import itertools +import io +import av +import numpy as np +import PIL.Image +from numpy._typing import DTypeLike +from numpy.lib import math +from renumics.spotlight import dtypes + +from renumics.spotlight.dtypes import create_dtype + +from renumics.spotlight.embeddings.exceptions import CannotEmbed from .typing import Embedder from .registry import register_embedder @@ -13,3 +25,81 @@ def embedder(klass: Type[Embedder]) -> Type[Embedder]: """ register_embedder(klass) return klass + + +def embed(accepts: DTypeLike, *, sampling_rate: Optional[int] = None): + dtype = create_dtype(accepts) + + if dtypes.is_image_dtype(dtype): + + def _preprocess_batch(raw_values: List[bytes]): + return [PIL.Image.open(io.BytesIO(value)) for value in raw_values] + + elif dtypes.is_audio_dtype(dtype): + if sampling_rate is None: + raise ValueError( + "No sampling rate specified, but required for `audio` embedding." + ) + + def _preprocess_batch(raw_values: Any): + resampled_batch = [] + for raw_data in raw_values: + with av.open(io.BytesIO(raw_data), "r") as container: + resampler = av.AudioResampler( + format="dbl", layout="mono", rate=16000 + ) + data = [] + for frame in container.decode(audio=0): + resampled_frames = resampler.resample(frame) + for resampled_frame in resampled_frames: + frame_array = resampled_frame.to_ndarray()[0] + data.append(frame_array) + resampled_batch.append(np.concatenate(data, axis=0)) + return resampled_batch + + else: + + def _preprocess_batch(raw_values: Any): + return raw_values + + def decorate( + func: Callable[[Iterable[list]], Iterable[List[Optional[np.ndarray]]]] + ): + class EmbedderImpl(Embedder): + def __init__(self, data_store: Any, column: str): + self.dtype = dtype + if data_store.dtypes[column].name != self.dtype.name: + raise CannotEmbed() + + self.data_store = data_store + self.column = column + self.batch_size = 16 + + self._occupied_indices = [] + + def _iter_batches(self): + self._occupied_indices = [] + batch = [] + for i in range(len(self.data_store)): + value = self.data_store.get_converted_value( + self.column, i, simple=False, check=False + ) + + if value is None: + continue + + self._occupied_indices.append(i) + batch.append(value) + if len(batch) == self.batch_size: + yield _preprocess_batch(batch) + batch = [] + + def __call__(self) -> np.ndarray: + embeddings = itertools.chain(*func(self._iter_batches())) + res = np.empty(len(self.data_store), dtype=np.object_) + res[self._occupied_indices] = list(embeddings) + return res + + register_embedder(EmbedderImpl) + + return decorate diff --git a/renumics/spotlight/embeddings/embedders/vit.py b/renumics/spotlight/embeddings/embedders/vit.py index 4fe17187..8221ed83 100644 --- a/renumics/spotlight/embeddings/embedders/vit.py +++ b/renumics/spotlight/embeddings/embedders/vit.py @@ -1,68 +1,29 @@ -import io -from typing import Any, List +from typing import Iterable, List -from PIL import Image -import numpy as np +import PIL.Image import transformers -from renumics.spotlight import dtypes -from renumics.spotlight.embeddings.decorator import embedder -from renumics.spotlight.embeddings.exceptions import CannotEmbed -from renumics.spotlight.embeddings.registry import unregister_embedder -from renumics.spotlight.embeddings.typing import Embedder +from renumics.spotlight.embeddings.decorator import embed from renumics.spotlight.logging import logger try: import torch except ImportError: logger.warning("`ViTEmbedder` requires `pytorch` to be installed.") - _torch_available = False else: - _torch_available = True - - -@embedder -class ViTEmbedder(Embedder): - def __init__(self, data_store: Any, column: str) -> None: - if not dtypes.is_image_dtype(data_store.dtypes[column]): - raise CannotEmbed - self._data_store = data_store - self._column = column - - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + @embed("image") + def vit(batches: Iterable[List[PIL.Image.Image]]): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_name = "google/vit-base-patch16-224" - self.processor = transformers.AutoImageProcessor.from_pretrained(model_name) - self.model = transformers.ViTModel.from_pretrained(model_name).to(self.device) - - def __call__(self) -> np.ndarray: - values = self._data_store.get_converted_values( - self._column, indices=slice(None), simple=False, check=False - ) - none_mask = [sample is None for sample in values] - if all(none_mask): - return np.array([None] * len(values), dtype=np.object_) - - embeddings = self.embed_images( - [Image.open(io.BytesIO(value)) for value in values if value is not None] - ) - - if any(none_mask): - res = np.empty(len(values), dtype=np.object_) - res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) - return res - - return embeddings - - def embed_images(self, batch: List[Image.Image]) -> np.ndarray: - images = [image.convert("RGB") for image in batch] - inputs = self.processor(images=images, return_tensors="pt") - with torch.no_grad(): - outputs = self.model(**inputs.to(self.device)) - embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() - - return embeddings + processor = transformers.AutoImageProcessor.from_pretrained(model_name) + model = transformers.ViTModel.from_pretrained(model_name).to(device) + for batch in batches: + images = [image.convert("RGB") for image in batch] + inputs = processor(images=images, return_tensors="pt") + with torch.no_grad(): + outputs = model(**inputs.to(device)) + embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() -if not _torch_available: - unregister_embedder(ViTEmbedder) + yield embeddings From 3e7d9d63ec53c88989058f5d670d179023292007 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Fri, 1 Dec 2023 11:39:40 +0100 Subject: [PATCH 14/24] Install and lock CPU torch into dev/playbook deps --- poetry.lock | 59 +++++++++++++++++++++++++++++++------------------- pyproject.toml | 4 +++- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0cf1feb4..8e3937c0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2076,6 +2076,7 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, + {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -2709,6 +2710,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -4114,6 +4125,7 @@ description = "22 CAnonical Time-series Features" optional = true python-versions = "*" files = [ + {file = "pycatch22-0.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4f22ecf340c75448ad571020c423533c242f3e577e462219951dc6c33d545d58"}, {file = "pycatch22-0.4.2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:8086ed7c223c08ea408f8bc9dd70d16691e6c5da0368d49f85a148557ad5a9c1"}, {file = "pycatch22-0.4.2.tar.gz", hash = "sha256:162bd9774a326993c564120221e779fb6e578139e9fb73ed6066a3137e61c3ad"}, ] @@ -4571,6 +4583,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4578,8 +4591,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4596,6 +4616,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4603,6 +4624,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -6054,31 +6076,19 @@ reference = "torch-cpu" [[package]] name = "torchvision" -version = "0.16.1" +version = "0.16.1+cpu" description = "image and video datasets and models for torch deep learning" optional = false python-versions = ">=3.8" files = [ - {file = "torchvision-0.16.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:987132795e5c037cb74e7be35a693999fdb2f603152266ee15b80206e83a5b0c"}, - {file = "torchvision-0.16.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:25da6a7b22ea0348f62c45ec0daf157731096babcae65d222404081af96e085c"}, - {file = "torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:c82e291e674a18b67f92ddb476ae18498fb46d7032ae914f3fda90c955e7d86f"}, - {file = "torchvision-0.16.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:632887b22e67ce32a3ede806b868bba4057601e46d680de14b32a391eac1b483"}, - {file = "torchvision-0.16.1-cp310-cp310-win_amd64.whl", hash = "sha256:92c76a5092b4033efdb183b11fa4854a7630e23c46f4a1c3ffd70c30cb5be4fc"}, - {file = "torchvision-0.16.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:4aea5cf491c6c21b1cbdbb1bf2a3838a59d4db93ad5f49019a6564d3ca7127c7"}, - {file = "torchvision-0.16.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3391757167637ace3ef33a67c9d5ef86b1f8cbd93eaa5bad45eebcf266ea6089"}, - {file = "torchvision-0.16.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:4f9d5b192b336982e6dbe32c070b05606f0b53e87d722ae332a02909fbf988ed"}, - {file = "torchvision-0.16.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:3d34601614958c4e30f53ec0eb7bf3f282ee72bb747734be2d75422831a43384"}, - {file = "torchvision-0.16.1-cp311-cp311-win_amd64.whl", hash = "sha256:e11af530585574eb5ca837b8f151bcdd57c10e35c3af56c76a10f3281d2a2f2c"}, - {file = "torchvision-0.16.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:4f2cad621fb96cf10e29af93e16c98b3226bdd53ae712b57e873c3deaf061617"}, - {file = "torchvision-0.16.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1d614b3c9e2de9cd75cc0e4e1923fcfbbcd9fdb9f08a0bbbbf7e135e4a0a1cfa"}, - {file = "torchvision-0.16.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:73271e930501a008fe24ba38945b2a75b25a6098f4c2f4402e39a9d0dd305ca6"}, - {file = "torchvision-0.16.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:fab67ddc4809fcc2a04610b13cac5193b9d3be2896b77538bfdff401b13022e5"}, - {file = "torchvision-0.16.1-cp38-cp38-win_amd64.whl", hash = "sha256:13782d574033efec6646d1a2f5d85f4c59fcf3f403367bb407b15df07adc87e0"}, - {file = "torchvision-0.16.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:f14d201c37176dc4106eec76b229d6585a1505266b8cea99d3366fd38897b7c0"}, - {file = "torchvision-0.16.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a15e88a93a7501cc75b761a2dcd07aaedaaf9cbfaf48c8affa8c98989ecbb19d"}, - {file = "torchvision-0.16.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:72fde5fdb462e66ebe25ae42d2ee11434cbc395f74cad0d3b22cf60524345cc5"}, - {file = "torchvision-0.16.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:153f753f14eba58969cdc86360893a57f8bf63f8136c7d1cd4388108560b5446"}, - {file = "torchvision-0.16.1-cp39-cp39-win_amd64.whl", hash = "sha256:75e33b198b1265f61d822aa66d646ec3df67a712470ffec1e0c37ff46d4103c1"}, + {file = "torchvision-0.16.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:52442b5408dab35fedc6ccd75a7df38aadb6eb5d078184ac04ba7c1a9db8ae9c"}, + {file = "torchvision-0.16.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:d77afeed3fe309f3f73943431840d40fa61a33a52076bbb780c63e6b5f79962a"}, + {file = "torchvision-0.16.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:574b58e20ca89ebd2d8d1ff72dc6e6944734ed1327849214a349c39d5454ac4f"}, + {file = "torchvision-0.16.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:f6fed6f1311d34c4751d70d60408ea8526842ede66f5aad7272d6de5c4337ebc"}, + {file = "torchvision-0.16.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:53a5f20778bb300038d1846fb111e73ae5d4babc6b9cfb6a6cbeb2eefa399c0b"}, + {file = "torchvision-0.16.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:89b9bf0e80fc4c45f114707d61e15741ce0e9591badf29aa939863be9fd9543d"}, + {file = "torchvision-0.16.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:048b029410bc4c7ce87b1f2b621fd8b88249c68350ee773c3152faa088e9fcca"}, + {file = "torchvision-0.16.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:73a4b3317eae425baabefe6b61197c75b0b766f1809c5cfdad07dd6bfad255a2"}, ] [package.dependencies] @@ -6090,6 +6100,11 @@ torch = "2.1.1" [package.extras] scipy = ["scipy"] +[package.source] +type = "legacy" +url = "https://download.pytorch.org/whl/cpu" +reference = "torch-cpu" + [[package]] name = "tornado" version = "6.3.3" @@ -7210,4 +7225,4 @@ descriptors = ["pycatch22"] [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "ade3838a00ad06a60e0e9e262ebf891327899fc914af5aa068d657f9461513e5" +content-hash = "e59f36657cc6a85482383e4d59ec7f7bb0ac6ce20c627369a9666b03ee8bbf5c" diff --git a/pyproject.toml b/pyproject.toml index fc1eefd5..708868b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,6 @@ datasets = { extras = ["audio"], version = "^2.12.0" } pydantic-settings = "^2.0.3" pycatch22 = { version = "!=0.4.4", optional = true } transformers = "^4.35.2" -torch = { version = "^2.1.1+cpu", source = "torch-cpu" } sentence-transformers = "^2.2.2" [tool.poetry.extras] @@ -128,11 +127,14 @@ types-pillow = "^10.0.0.1" pandas-stubs = "^2.0.2.230605" ruff = "^0.0.281" check-wheel-contents = "^0.6.0" +torchvision = {version = "^0.16.1+cpu", source = "torch-cpu"} +torch = {version = "^2.1.1+cpu", source = "torch-cpu"} [tool.poetry.group.playbook.dependencies] towhee = "^0.9.0" annoy = "^1.17.2" cleanlab = "^2.4.0" +torch = {version = "^2.1.1+cpu", source = "torch-cpu"} [[tool.poetry.source]] name = "torch-cpu" From 2580fd6bc4703bfd2f3e41c499d5445e1189f55a Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Fri, 1 Dec 2023 11:41:06 +0100 Subject: [PATCH 15/24] Use rewrite default embedders as functional embedders --- renumics/spotlight/data_store.py | 6 +- renumics/spotlight/embeddings/__init__.py | 13 +- renumics/spotlight/embeddings/decorator.py | 113 +++++------------- .../spotlight/embeddings/embedders/gte.py | 48 ++------ .../spotlight/embeddings/embedders/vit.py | 5 +- .../spotlight/embeddings/embedders/wav2vec.py | 75 ------------ .../embeddings/embedders/wav2vec2.py | 32 +++++ renumics/spotlight/embeddings/exceptions.py | 9 -- .../spotlight/embeddings/preprocessors.py | 33 +++++ renumics/spotlight/embeddings/registry.py | 18 ++- renumics/spotlight/embeddings/typing.py | 64 +++++++++- 11 files changed, 188 insertions(+), 228 deletions(-) delete mode 100644 renumics/spotlight/embeddings/embedders/wav2vec.py create mode 100644 renumics/spotlight/embeddings/embedders/wav2vec2.py delete mode 100644 renumics/spotlight/embeddings/exceptions.py create mode 100644 renumics/spotlight/embeddings/preprocessors.py diff --git a/renumics/spotlight/data_store.py b/renumics/spotlight/data_store.py index de656085..61f67d7b 100644 --- a/renumics/spotlight/data_store.py +++ b/renumics/spotlight/data_store.py @@ -132,9 +132,11 @@ def get_converted_values( return converted_values def get_converted_value( - self, column_name: str, index: int, simple: bool = False + self, column_name: str, index: int, simple: bool = False, check: bool = True ) -> ConvertedValue: - return self.get_converted_values(column_name, indices=[index], simple=simple)[0] + return self.get_converted_values( + column_name, indices=[index], simple=simple, check=check + )[0] def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]: """ diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 4c8dba2a..29aebf59 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -8,7 +8,6 @@ import numpy as np -from renumics.spotlight.embeddings.exceptions import CannotEmbed from renumics.spotlight.embeddings.typing import Embedder from .registry import registered_embedders from . import embedders as embedders_namespace @@ -24,14 +23,12 @@ def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder] """ embedders: Dict[str, Embedder] = {} for column in columns: - for embedder_class in registered_embedders: - try: - # embedder = FunctionalEmbedder(func, preprocessor, data_store, column) - embedder = embedder_class(data_store, column) - except CannotEmbed: + for name, (embedder_class, dtype, args, kwargs) in registered_embedders.items(): + if data_store.dtypes[column].name != dtype.name: continue - embedders[f"{column}.embedding"] = embedder - break + + embedder = embedder_class(data_store, column, *args, **kwargs) + embedders[f"{column}.{name}.embedding"] = embedder return embedders diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py index 00b775ae..62007f71 100644 --- a/renumics/spotlight/embeddings/decorator.py +++ b/renumics/spotlight/embeddings/decorator.py @@ -2,104 +2,49 @@ A decorator for data analysis functions """ -from typing import Callable, Iterable, List, Optional, Type, Any -import itertools -import io -import av -import numpy as np -import PIL.Image -from numpy._typing import DTypeLike -from numpy.lib import math +import functools +from typing import Callable, Dict, Optional, Any from renumics.spotlight import dtypes from renumics.spotlight.dtypes import create_dtype - -from renumics.spotlight.embeddings.exceptions import CannotEmbed -from .typing import Embedder +from renumics.spotlight.embeddings.preprocessors import ( + preprocess_audio_batch, + preprocess_batch, + preprocess_image_batch, +) +from .typing import EmbedFunc, FunctionalEmbedder from .registry import register_embedder -def embedder(klass: Type[Embedder]) -> Type[Embedder]: - """ - register an embedder class - """ - register_embedder(klass) - return klass - +def embed( + dtype: Any, *, name: Optional[str] = None, sampling_rate: Optional[int] = None +) -> Callable[[EmbedFunc], EmbedFunc]: + dtype = create_dtype(dtype) -def embed(accepts: DTypeLike, *, sampling_rate: Optional[int] = None): - dtype = create_dtype(accepts) + kwargs: Dict[str, Any] = {} if dtypes.is_image_dtype(dtype): - - def _preprocess_batch(raw_values: List[bytes]): - return [PIL.Image.open(io.BytesIO(value)) for value in raw_values] - + kwargs["preprocess_func"] = preprocess_image_batch elif dtypes.is_audio_dtype(dtype): if sampling_rate is None: raise ValueError( - "No sampling rate specified, but required for `audio` embedding." + "No sampling rate specified, but required for audio embedding." ) - def _preprocess_batch(raw_values: Any): - resampled_batch = [] - for raw_data in raw_values: - with av.open(io.BytesIO(raw_data), "r") as container: - resampler = av.AudioResampler( - format="dbl", layout="mono", rate=16000 - ) - data = [] - for frame in container.decode(audio=0): - resampled_frames = resampler.resample(frame) - for resampled_frame in resampled_frames: - frame_array = resampled_frame.to_ndarray()[0] - data.append(frame_array) - resampled_batch.append(np.concatenate(data, axis=0)) - return resampled_batch - + kwargs["preprocess_func"] = functools.partial( + preprocess_audio_batch, sampling_rate=sampling_rate + ) else: - - def _preprocess_batch(raw_values: Any): - return raw_values - - def decorate( - func: Callable[[Iterable[list]], Iterable[List[Optional[np.ndarray]]]] - ): - class EmbedderImpl(Embedder): - def __init__(self, data_store: Any, column: str): - self.dtype = dtype - if data_store.dtypes[column].name != self.dtype.name: - raise CannotEmbed() - - self.data_store = data_store - self.column = column - self.batch_size = 16 - - self._occupied_indices = [] - - def _iter_batches(self): - self._occupied_indices = [] - batch = [] - for i in range(len(self.data_store)): - value = self.data_store.get_converted_value( - self.column, i, simple=False, check=False - ) - - if value is None: - continue - - self._occupied_indices.append(i) - batch.append(value) - if len(batch) == self.batch_size: - yield _preprocess_batch(batch) - batch = [] - - def __call__(self) -> np.ndarray: - embeddings = itertools.chain(*func(self._iter_batches())) - res = np.empty(len(self.data_store), dtype=np.object_) - res[self._occupied_indices] = list(embeddings) - return res - - register_embedder(EmbedderImpl) + kwargs["preprocess_func"] = preprocess_batch + + def decorate(func: EmbedFunc) -> EmbedFunc: + kwargs["embed_func"] = func + register_embedder( + FunctionalEmbedder, + dtype, + func.__name__ if name is None else name, + **kwargs, + ) + return func return decorate diff --git a/renumics/spotlight/embeddings/embedders/gte.py b/renumics/spotlight/embeddings/embedders/gte.py index 0b6af4c7..48e9f7f2 100644 --- a/renumics/spotlight/embeddings/embedders/gte.py +++ b/renumics/spotlight/embeddings/embedders/gte.py @@ -1,12 +1,9 @@ -from typing import Any, List +from typing import Iterable, List import numpy as np import sentence_transformers +from renumics.spotlight.embeddings.decorator import embed -from renumics.spotlight import dtypes -from renumics.spotlight.embeddings.decorator import embedder -from renumics.spotlight.embeddings.exceptions import CannotEmbed -from renumics.spotlight.embeddings.typing import Embedder from renumics.spotlight.logging import logger try: @@ -15,36 +12,13 @@ logger.warning("`GTE Embedder` requires `pytorch` to be installed.") else: - @embedder - class GteEmbedder(Embedder): - def __init__(self, data_store: Any, column: str) -> None: - if not dtypes.is_str_dtype(data_store.dtypes[column]): - raise CannotEmbed - self._data_store = data_store - self._column = column + @embed("str") + def gte(batches: Iterable[List[str]]) -> Iterable[List[np.ndarray]]: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = sentence_transformers.SentenceTransformer( + "thenlper/gte-base", device=device + ) - def __call__(self) -> np.ndarray: - values = self._data_store.get_converted_values( - self._column, indices=slice(None), simple=False, check=False - ) - none_mask = [sample is None for sample in values] - if all(none_mask): - return np.array([None] * len(values), dtype=np.object_) - - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - model = sentence_transformers.SentenceTransformer( - "thenlper/gte-base", device=device - ) - - def _embed_batch(batch: List[str]) -> np.ndarray: - return model.encode(batch, normalize_embeddings=True) - - embeddings = _embed_batch([value for value in values if value is not None]) - - if any(none_mask): - res = np.empty(len(values), dtype=np.object_) - res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) - return res - - return embeddings + for batch in batches: + embeddings = model.encode(batch, normalize_embeddings=True) + yield list(embeddings) diff --git a/renumics/spotlight/embeddings/embedders/vit.py b/renumics/spotlight/embeddings/embedders/vit.py index 8221ed83..bcdd0cf1 100644 --- a/renumics/spotlight/embeddings/embedders/vit.py +++ b/renumics/spotlight/embeddings/embedders/vit.py @@ -1,6 +1,7 @@ from typing import Iterable, List import PIL.Image +import numpy as np import transformers from renumics.spotlight.embeddings.decorator import embed @@ -13,7 +14,7 @@ else: @embed("image") - def vit(batches: Iterable[List[PIL.Image.Image]]): + def vit(batches: Iterable[List[PIL.Image.Image]]) -> Iterable[List[np.ndarray]]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_name = "google/vit-base-patch16-224" processor = transformers.AutoImageProcessor.from_pretrained(model_name) @@ -26,4 +27,4 @@ def vit(batches: Iterable[List[PIL.Image.Image]]): outputs = model(**inputs.to(device)) embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() - yield embeddings + yield list(embeddings) diff --git a/renumics/spotlight/embeddings/embedders/wav2vec.py b/renumics/spotlight/embeddings/embedders/wav2vec.py deleted file mode 100644 index 22490215..00000000 --- a/renumics/spotlight/embeddings/embedders/wav2vec.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any, List - -import numpy as np -import transformers -import av -import io - -from renumics.spotlight import dtypes -from renumics.spotlight.embeddings.decorator import embedder -from renumics.spotlight.embeddings.exceptions import CannotEmbed -from renumics.spotlight.embeddings.typing import Embedder -from renumics.spotlight.logging import logger - -try: - import torch -except ImportError: - logger.warning("`Wav2Vec Embedder` requires `pytorch` to be installed.") -else: - - @embedder - class Wav2VecEmbedder(Embedder): - def __init__(self, data_store: Any, column: str) -> None: - if not dtypes.is_audio_dtype(data_store.dtypes[column]): - raise CannotEmbed - self._data_store = data_store - self._column = column - - def __call__(self) -> np.ndarray: - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model_name = "facebook/wav2vec2-base-960h" - sampling_rate = 16000 - processor = transformers.AutoFeatureExtractor.from_pretrained(model_name) - model = transformers.AutoModel.from_pretrained(model_name).to(device) - - values = self._data_store.get_converted_values( - self._column, indices=slice(None), simple=False, check=False - ) - none_mask = [sample is None for sample in values] - if all(none_mask): - return np.array([None] * len(values), dtype=np.object_) - - def _embed_batch(batch: List[bytes]) -> np.ndarray: - resampled_batch = [] - for raw_data in batch: - with av.open(io.BytesIO(raw_data), "r") as container: - resampler = av.AudioResampler( - format="dbl", layout="mono", rate=16000 - ) - data = [] - for frame in container.decode(audio=0): - resampled_frames = resampler.resample(frame) - for resampled_frame in resampled_frames: - frame_array = resampled_frame.to_ndarray()[0] - data.append(frame_array) - resampled_batch.append(np.concatenate(data, axis=0)) - - inputs = processor( - raw_speech=resampled_batch, - sampling_rate=sampling_rate, - padding="longest", - return_tensors="pt", - ) - with torch.no_grad(): - outputs = model(**inputs.to(device)) - embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() - return embeddings - - embeddings = _embed_batch([value for value in values if value is not None]) - - if any(none_mask): - res = np.empty(len(values), dtype=np.object_) - res[np.nonzero(~np.array(none_mask))[0]] = list(embeddings) - return res - - return embeddings diff --git a/renumics/spotlight/embeddings/embedders/wav2vec2.py b/renumics/spotlight/embeddings/embedders/wav2vec2.py new file mode 100644 index 00000000..339fc208 --- /dev/null +++ b/renumics/spotlight/embeddings/embedders/wav2vec2.py @@ -0,0 +1,32 @@ +from typing import Iterable, List + +import numpy as np +import transformers + +from renumics.spotlight.embeddings.decorator import embed +from renumics.spotlight.logging import logger + +try: + import torch +except ImportError: + logger.warning("`Wav2Vec Embedder` requires `pytorch` to be installed.") +else: + + @embed("audio", sampling_rate=16000) + def wav2vec2(batches: Iterable[List[np.ndarray]]) -> Iterable[List[np.ndarray]]: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model_name = "facebook/wav2vec2-base-960h" + processor = transformers.AutoFeatureExtractor.from_pretrained(model_name) + model = transformers.AutoModel.from_pretrained(model_name).to(device) + + for batch in batches: + inputs = processor( + raw_speech=batch, + sampling_rate=16000, + padding="longest", + return_tensors="pt", + ) + with torch.no_grad(): + outputs = model(**inputs.to(device)) + embeddings = outputs.last_hidden_state[:, 0].cpu().numpy() + yield list(embeddings) diff --git a/renumics/spotlight/embeddings/exceptions.py b/renumics/spotlight/embeddings/exceptions.py deleted file mode 100644 index 92c57bf7..00000000 --- a/renumics/spotlight/embeddings/exceptions.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -Exceptions used by embedders. -""" - - -class CannotEmbed(Exception): - """ - Raised when a column cannot be embed by an embedder. - """ diff --git a/renumics/spotlight/embeddings/preprocessors.py b/renumics/spotlight/embeddings/preprocessors.py new file mode 100644 index 00000000..2ca446fd --- /dev/null +++ b/renumics/spotlight/embeddings/preprocessors.py @@ -0,0 +1,33 @@ +import io +from typing import List + +import PIL.Image +import av +import numpy as np + + +def preprocess_batch(raw_values: list) -> list: + return raw_values + + +def preprocess_image_batch(raw_values: List[bytes]) -> List[PIL.Image.Image]: + return [PIL.Image.open(io.BytesIO(value)) for value in raw_values] + + +def preprocess_audio_batch( + raw_values: List[bytes], sampling_rate: int +) -> List[np.ndarray]: + resampled_batch = [] + for raw_data in raw_values: + with av.open(io.BytesIO(raw_data), "r") as container: + resampler = av.AudioResampler( + format="dbl", layout="mono", rate=sampling_rate + ) + data = [] + for frame in container.decode(audio=0): + resampled_frames = resampler.resample(frame) + for resampled_frame in resampled_frames: + frame_array = resampled_frame.to_ndarray()[0] + data.append(frame_array) + resampled_batch.append(np.concatenate(data, axis=0)) + return resampled_batch diff --git a/renumics/spotlight/embeddings/registry.py b/renumics/spotlight/embeddings/registry.py index eaab1f9a..ddb2003f 100644 --- a/renumics/spotlight/embeddings/registry.py +++ b/renumics/spotlight/embeddings/registry.py @@ -1,22 +1,28 @@ """ Manage data analyzers available for spotlights automatic dataset analysis. """ -from typing import Set, Type +from typing import Any, Dict, Tuple, Type + +from renumics.spotlight.dtypes import DType from .typing import Embedder -registered_embedders: Set[Type[Embedder]] = set() +registered_embedders: Dict[ + str, Tuple[Type[Embedder], DType, tuple, Dict[str, Any]] +] = {} -def register_embedder(embedder: Type[Embedder]) -> None: +def register_embedder( + embedder: Type[Embedder], dtype: DType, name: str, *args: Any, **kwargs: Any +) -> None: """ Register an embedder """ - registered_embedders.add(embedder) + registered_embedders[name] = (embedder, dtype, args, kwargs) -def unregister_embedder(embedder: Type[Embedder]) -> None: +def unregister_embedder(embedder: str) -> None: """ Unregister an embedder """ - registered_embedders.remove(embedder) + del registered_embedders[embedder] diff --git a/renumics/spotlight/embeddings/typing.py b/renumics/spotlight/embeddings/typing.py index a5dacc5f..078efdcc 100644 --- a/renumics/spotlight/embeddings/typing.py +++ b/renumics/spotlight/embeddings/typing.py @@ -2,21 +2,75 @@ Shared types for embeddings """ +import itertools from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Callable, Iterable, List, Optional import numpy as np class Embedder(ABC): - @abstractmethod + data_store: Any + column: str + def __init__(self, data_store: Any, column: str) -> None: - """ - Raise if dtype of the given column is not supported. - """ + self.data_store = data_store + self.column = column @abstractmethod def __call__(self) -> np.ndarray: """ Embed the given column. """ + + +PreprocessFunc = Callable[[list], list] +EmbedFunc = Callable[[Iterable[list]], Iterable[List[Optional[np.ndarray]]]] + + +class FunctionalEmbedder(Embedder): + preprocess_func: PreprocessFunc + embed_func: EmbedFunc + batch_size: int + _occupied_indices: List[int] + + def __init__( + self, + data_store: Any, + column: str, + preprocess_func: PreprocessFunc, + embed_func: EmbedFunc, + ) -> None: + super().__init__(data_store, column) + self.preprocess_func = preprocess_func + self.embed_func = embed_func + self.batch_size = 16 + self._occupied_indices = [] + + def _iter_batches(self) -> Iterable[list]: + """ + Yield batches with data, i.e. without the `None` values. + """ + self._occupied_indices = [] + batch = [] + for i in range(len(self.data_store)): + value = self.data_store.get_converted_value( + self.column, i, simple=False, check=False + ) + + if value is None: + continue + + self._occupied_indices.append(i) + batch.append(value) + if len(batch) == self.batch_size: + yield self.preprocess_func(batch) + batch = [] + if batch: + yield self.preprocess_func(batch) + + def __call__(self) -> np.ndarray: + embeddings = itertools.chain(*self.embed_func(self._iter_batches())) + res = np.empty(len(self.data_store), dtype=np.object_) + res[self._occupied_indices] = list(embeddings) + return res From a70ecf4f05f79ec0742286facc683df18e4beeb2 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Fri, 1 Dec 2023 11:48:41 +0100 Subject: [PATCH 16/24] Type data store for embedders properly --- renumics/spotlight/embeddings/__init__.py | 5 +++-- renumics/spotlight/embeddings/typing.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 29aebf59..51a55431 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -4,9 +4,10 @@ import importlib import pkgutil -from typing import Any, Dict, List +from typing import Dict, List import numpy as np +from renumics.spotlight.data_store import DataStore from renumics.spotlight.embeddings.typing import Embedder from .registry import registered_embedders @@ -17,7 +18,7 @@ importlib.import_module(embedders_namespace.__name__ + "." + module_info.name) -def create_embedders(data_store: Any, columns: List[str]) -> Dict[str, Embedder]: +def create_embedders(data_store: DataStore, columns: List[str]) -> Dict[str, Embedder]: """ Create embedding functions for the given data store. """ diff --git a/renumics/spotlight/embeddings/typing.py b/renumics/spotlight/embeddings/typing.py index 078efdcc..981c8dd1 100644 --- a/renumics/spotlight/embeddings/typing.py +++ b/renumics/spotlight/embeddings/typing.py @@ -4,16 +4,18 @@ import itertools from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, List, Optional +from typing import Callable, Iterable, List, Optional import numpy as np +from renumics.spotlight.data_store import DataStore + class Embedder(ABC): - data_store: Any + data_store: DataStore column: str - def __init__(self, data_store: Any, column: str) -> None: + def __init__(self, data_store: DataStore, column: str) -> None: self.data_store = data_store self.column = column @@ -36,7 +38,7 @@ class FunctionalEmbedder(Embedder): def __init__( self, - data_store: Any, + data_store: DataStore, column: str, preprocess_func: PreprocessFunc, embed_func: EmbedFunc, From ba1adfc9893e1adb6a356219feb8bfbfdc3ba570 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Fri, 1 Dec 2023 13:37:14 +0100 Subject: [PATCH 17/24] Overload embed decorator --- renumics/spotlight/embeddings/decorator.py | 71 +++++++++++++++++-- .../spotlight/embeddings/embedders/gte.py | 6 +- .../spotlight/embeddings/embedders/vit.py | 8 ++- .../embeddings/embedders/wav2vec2.py | 8 ++- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py index 62007f71..f61b9d2e 100644 --- a/renumics/spotlight/embeddings/decorator.py +++ b/renumics/spotlight/embeddings/decorator.py @@ -3,23 +3,84 @@ """ import functools -from typing import Callable, Dict, Optional, Any +from typing import ( + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Any, + Union, + overload, +) + +import PIL.Image +import numpy as np from renumics.spotlight import dtypes +from renumics.spotlight.media.audio import Audio +from renumics.spotlight.media.embedding import Embedding -from renumics.spotlight.dtypes import create_dtype -from renumics.spotlight.embeddings.preprocessors import ( +from renumics.spotlight.media.image import Image +from renumics.spotlight.media.sequence_1d import Sequence1D +from .preprocessors import ( preprocess_audio_batch, preprocess_batch, preprocess_image_batch, ) -from .typing import EmbedFunc, FunctionalEmbedder from .registry import register_embedder +from .typing import EmbedFunc, FunctionalEmbedder + + +EmbedImageFunc = Callable[ + [Iterable[List[PIL.Image.Image]]], Iterable[List[Optional[np.ndarray]]] +] +EmbedArrayFunc = Callable[ + [Iterable[List[np.ndarray]]], Iterable[List[Optional[np.ndarray]]] +] + + +@overload +def embed( + dtype: Union[Literal["image", "Image"], Image], *, name: Optional[str] = None +) -> Callable[[EmbedImageFunc], EmbedImageFunc]: + ... + + +@overload +def embed( + dtype: Union[Literal["audio", "Audio"], Audio], + *, + name: Optional[str] = None, + sampling_rate: int, +) -> Callable[[EmbedArrayFunc], EmbedArrayFunc]: + ... + + +@overload +def embed( + dtype: Union[ + Literal["embedding", "Embedding", "sequence1d", "Sequence1D"], + Embedding, + Sequence1D, + ], + *, + name: Optional[str] = None, +) -> Callable[[EmbedArrayFunc], EmbedArrayFunc]: + ... + + +@overload +def embed( + dtype: Any, *, name: Optional[str] = None, sampling_rate: Optional[int] = None +) -> Callable[[EmbedFunc], EmbedFunc]: + ... def embed( dtype: Any, *, name: Optional[str] = None, sampling_rate: Optional[int] = None ) -> Callable[[EmbedFunc], EmbedFunc]: - dtype = create_dtype(dtype) + dtype = dtypes.create_dtype(dtype) kwargs: Dict[str, Any] = {} diff --git a/renumics/spotlight/embeddings/embedders/gte.py b/renumics/spotlight/embeddings/embedders/gte.py index 48e9f7f2..c5486a86 100644 --- a/renumics/spotlight/embeddings/embedders/gte.py +++ b/renumics/spotlight/embeddings/embedders/gte.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from typing import Iterable, List, Optional import numpy as np import sentence_transformers @@ -9,11 +9,11 @@ try: import torch except ImportError: - logger.warning("`GTE Embedder` requires `pytorch` to be installed.") + logger.warning("GTE embedder requires `pytorch` to be installed.") else: @embed("str") - def gte(batches: Iterable[List[str]]) -> Iterable[List[np.ndarray]]: + def gte(batches: Iterable[List[str]]) -> Iterable[List[Optional[np.ndarray]]]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = sentence_transformers.SentenceTransformer( "thenlper/gte-base", device=device diff --git a/renumics/spotlight/embeddings/embedders/vit.py b/renumics/spotlight/embeddings/embedders/vit.py index bcdd0cf1..4995a135 100644 --- a/renumics/spotlight/embeddings/embedders/vit.py +++ b/renumics/spotlight/embeddings/embedders/vit.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from typing import Iterable, List, Optional import PIL.Image import numpy as np @@ -10,11 +10,13 @@ try: import torch except ImportError: - logger.warning("`ViTEmbedder` requires `pytorch` to be installed.") + logger.warning("ViT embedder requires `pytorch` to be installed.") else: @embed("image") - def vit(batches: Iterable[List[PIL.Image.Image]]) -> Iterable[List[np.ndarray]]: + def vit( + batches: Iterable[List[PIL.Image.Image]], + ) -> Iterable[List[Optional[np.ndarray]]]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_name = "google/vit-base-patch16-224" processor = transformers.AutoImageProcessor.from_pretrained(model_name) diff --git a/renumics/spotlight/embeddings/embedders/wav2vec2.py b/renumics/spotlight/embeddings/embedders/wav2vec2.py index 339fc208..8cc136d2 100644 --- a/renumics/spotlight/embeddings/embedders/wav2vec2.py +++ b/renumics/spotlight/embeddings/embedders/wav2vec2.py @@ -1,4 +1,4 @@ -from typing import Iterable, List +from typing import Iterable, List, Optional import numpy as np import transformers @@ -9,11 +9,13 @@ try: import torch except ImportError: - logger.warning("`Wav2Vec Embedder` requires `pytorch` to be installed.") + logger.warning("Wav2Vec embedder requires `pytorch` to be installed.") else: @embed("audio", sampling_rate=16000) - def wav2vec2(batches: Iterable[List[np.ndarray]]) -> Iterable[List[np.ndarray]]: + def wav2vec2( + batches: Iterable[List[np.ndarray]], + ) -> Iterable[List[Optional[np.ndarray]]]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_name = "facebook/wav2vec2-base-960h" processor = transformers.AutoFeatureExtractor.from_pretrained(model_name) From 8d43e82197660a3f659e37a864851140f0124983 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Fri, 1 Dec 2023 16:54:01 +0100 Subject: [PATCH 18/24] Get rid of the `sentence-transformers` since it installs `torchvision` and `torch` --- poetry.lock | 103 +----------------- pyproject.toml | 1 - .../spotlight/embeddings/embedders/gte.py | 31 +++++- 3 files changed, 26 insertions(+), 109 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8e3937c0..581e451b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3290,31 +3290,6 @@ files = [ {file = "nh3-0.2.14.tar.gz", hash = "sha256:a0c509894fd4dccdff557068e5074999ae3b75f4c5a2d6fb5415e782e25679c4"}, ] -[[package]] -name = "nltk" -version = "3.8.1" -description = "Natural Language Toolkit" -optional = false -python-versions = ">=3.7" -files = [ - {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, - {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, -] - -[package.dependencies] -click = "*" -joblib = "*" -regex = ">=2021.8.3" -tqdm = "*" - -[package.extras] -all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"] -corenlp = ["requests"] -machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"] -plot = ["matplotlib"] -tgrep = ["pyparsing"] -twitter = ["twython"] - [[package]] name = "nodeenv" version = "1.8.0" @@ -5472,82 +5447,6 @@ nativelib = ["pyobjc-framework-Cocoa", "pywin32"] objc = ["pyobjc-framework-Cocoa"] win32 = ["pywin32"] -[[package]] -name = "sentence-transformers" -version = "2.2.2" -description = "Multilingual text embeddings" -optional = false -python-versions = ">=3.6.0" -files = [ - {file = "sentence-transformers-2.2.2.tar.gz", hash = "sha256:dbc60163b27de21076c9a30d24b5b7b6fa05141d68cf2553fa9a77bf79a29136"}, -] - -[package.dependencies] -huggingface-hub = ">=0.4.0" -nltk = "*" -numpy = "*" -scikit-learn = "*" -scipy = "*" -sentencepiece = "*" -torch = ">=1.6.0" -torchvision = "*" -tqdm = "*" -transformers = ">=4.6.0,<5.0.0" - -[[package]] -name = "sentencepiece" -version = "0.1.99" -description = "SentencePiece python wrapper" -optional = false -python-versions = "*" -files = [ - {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73"}, - {file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7"}, - {file = "sentencepiece-0.1.99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0"}, - {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baed1a26464998f9710d20e52607c29ffd4293e7c71c6a1f83f51ad0911ec12c"}, - {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9832f08bb372d4c8b567612f8eab9e36e268dff645f1c28f9f8e851be705f6d1"}, - {file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:019e7535108e309dae2b253a75834fc3128240aa87c00eb80732078cdc182588"}, - {file = "sentencepiece-0.1.99-cp310-cp310-win32.whl", hash = "sha256:fa16a830416bb823fa2a52cbdd474d1f7f3bba527fd2304fb4b140dad31bb9bc"}, - {file = "sentencepiece-0.1.99-cp310-cp310-win_amd64.whl", hash = "sha256:14b0eccb7b641d4591c3e12ae44cab537d68352e4d3b6424944f0c447d2348d5"}, - {file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6d3c56f24183a1e8bd61043ff2c58dfecdc68a5dd8955dc13bab83afd5f76b81"}, - {file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed6ea1819fd612c989999e44a51bf556d0ef6abfb553080b9be3d347e18bcfb7"}, - {file = "sentencepiece-0.1.99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2a0260cd1fb7bd8b4d4f39dc2444a8d5fd4e0a0c4d5c899810ef1abf99b2d45"}, - {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a1abff4d1ff81c77cac3cc6fefa34fa4b8b371e5ee51cb7e8d1ebc996d05983"}, - {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:004e6a621d4bc88978eecb6ea7959264239a17b70f2cbc348033d8195c9808ec"}, - {file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db361e03342c41680afae5807590bc88aa0e17cfd1a42696a160e4005fcda03b"}, - {file = "sentencepiece-0.1.99-cp311-cp311-win32.whl", hash = "sha256:2d95e19168875b70df62916eb55428a0cbcb834ac51d5a7e664eda74def9e1e0"}, - {file = "sentencepiece-0.1.99-cp311-cp311-win_amd64.whl", hash = "sha256:f90d73a6f81248a909f55d8e6ef56fec32d559e1e9af045f0b0322637cb8e5c7"}, - {file = "sentencepiece-0.1.99-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:62e24c81e74bd87a6e0d63c51beb6527e4c0add67e1a17bac18bcd2076afcfeb"}, - {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57efcc2d51caff20d9573567d9fd3f854d9efe613ed58a439c78c9f93101384a"}, - {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a904c46197993bd1e95b93a6e373dca2f170379d64441041e2e628ad4afb16f"}, - {file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d89adf59854741c0d465f0e1525b388c0d174f611cc04af54153c5c4f36088c4"}, - {file = "sentencepiece-0.1.99-cp36-cp36m-win32.whl", hash = "sha256:47c378146928690d1bc106fdf0da768cebd03b65dd8405aa3dd88f9c81e35dba"}, - {file = "sentencepiece-0.1.99-cp36-cp36m-win_amd64.whl", hash = "sha256:9ba142e7a90dd6d823c44f9870abdad45e6c63958eb60fe44cca6828d3b69da2"}, - {file = "sentencepiece-0.1.99-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b7b1a9ae4d7c6f1f867e63370cca25cc17b6f4886729595b885ee07a58d3cec3"}, - {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0f644c9d4d35c096a538507b2163e6191512460035bf51358794a78515b74f7"}, - {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8843d23a0f686d85e569bd6dcd0dd0e0cbc03731e63497ca6d5bacd18df8b85"}, - {file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e6f690a1caebb4867a2e367afa1918ad35be257ecdb3455d2bbd787936f155"}, - {file = "sentencepiece-0.1.99-cp37-cp37m-win32.whl", hash = "sha256:8a321866c2f85da7beac74a824b4ad6ddc2a4c9bccd9382529506d48f744a12c"}, - {file = "sentencepiece-0.1.99-cp37-cp37m-win_amd64.whl", hash = "sha256:c42f753bcfb7661c122a15b20be7f684b61fc8592c89c870adf52382ea72262d"}, - {file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:85b476406da69c70586f0bb682fcca4c9b40e5059814f2db92303ea4585c650c"}, - {file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cfbcfe13c69d3f87b7fcd5da168df7290a6d006329be71f90ba4f56bc77f8561"}, - {file = "sentencepiece-0.1.99-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:445b0ec381af1cd4eef95243e7180c63d9c384443c16c4c47a28196bd1cda937"}, - {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6890ea0f2b4703f62d0bf27932e35808b1f679bdb05c7eeb3812b935ba02001"}, - {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb71af492b0eefbf9f2501bec97bcd043b6812ab000d119eaf4bd33f9e283d03"}, - {file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27b866b5bd3ddd54166bbcbf5c8d7dd2e0b397fac8537991c7f544220b1f67bc"}, - {file = "sentencepiece-0.1.99-cp38-cp38-win32.whl", hash = "sha256:b133e8a499eac49c581c3c76e9bdd08c338cc1939e441fee6f92c0ccb5f1f8be"}, - {file = "sentencepiece-0.1.99-cp38-cp38-win_amd64.whl", hash = "sha256:0eaf3591dd0690a87f44f4df129cf8d05d8a4029b5b6709b489b8e27f9a9bcff"}, - {file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38efeda9bbfb55052d482a009c6a37e52f42ebffcea9d3a98a61de7aee356a28"}, - {file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c030b081dc1e1bcc9fadc314b19b740715d3d566ad73a482da20d7d46fd444c"}, - {file = "sentencepiece-0.1.99-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:84dbe53e02e4f8a2e45d2ac3e430d5c83182142658e25edd76539b7648928727"}, - {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b0f55d0a0ee1719b4b04221fe0c9f0c3461dc3dabd77a035fa2f4788eb3ef9a"}, - {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e800f206cd235dc27dc749299e05853a4e4332e8d3dfd81bf13d0e5b9007d9"}, - {file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ae1c40cda8f9d5b0423cfa98542735c0235e7597d79caf318855cdf971b2280"}, - {file = "sentencepiece-0.1.99-cp39-cp39-win32.whl", hash = "sha256:c84ce33af12ca222d14a1cdd37bd76a69401e32bc68fe61c67ef6b59402f4ab8"}, - {file = "sentencepiece-0.1.99-cp39-cp39-win_amd64.whl", hash = "sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd"}, - {file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"}, -] - [[package]] name = "setuptools" version = "68.2.2" @@ -7225,4 +7124,4 @@ descriptors = ["pycatch22"] [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "e59f36657cc6a85482383e4d59ec7f7bb0ac6ce20c627369a9666b03ee8bbf5c" +content-hash = "49f1495035945dd4c6e404de9ea4512302b3ab96e86e98f9a3da370469ca7a89" diff --git a/pyproject.toml b/pyproject.toml index 708868b2..37f5193e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,6 @@ datasets = { extras = ["audio"], version = "^2.12.0" } pydantic-settings = "^2.0.3" pycatch22 = { version = "!=0.4.4", optional = true } transformers = "^4.35.2" -sentence-transformers = "^2.2.2" [tool.poetry.extras] descriptors = ["pycatch22"] diff --git a/renumics/spotlight/embeddings/embedders/gte.py b/renumics/spotlight/embeddings/embedders/gte.py index c5486a86..5f535b6a 100644 --- a/renumics/spotlight/embeddings/embedders/gte.py +++ b/renumics/spotlight/embeddings/embedders/gte.py @@ -1,24 +1,43 @@ from typing import Iterable, List, Optional import numpy as np -import sentence_transformers -from renumics.spotlight.embeddings.decorator import embed +import transformers from renumics.spotlight.logging import logger +from renumics.spotlight.embeddings.decorator import embed try: import torch + import torch.nn.functional as F except ImportError: logger.warning("GTE embedder requires `pytorch` to be installed.") else: + def average_pool( + last_hidden_state: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + masked_last_hidden_state = last_hidden_state.masked_fill( + ~attention_mask[..., None].bool(), 0.0 + ) + return ( + masked_last_hidden_state.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + ) + @embed("str") def gte(batches: Iterable[List[str]]) -> Iterable[List[Optional[np.ndarray]]]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = sentence_transformers.SentenceTransformer( - "thenlper/gte-base", device=device - ) + model_name = "thenlper/gte-base" + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + model = transformers.AutoModel.from_pretrained(model_name).to(device) for batch in batches: - embeddings = model.encode(batch, normalize_embeddings=True) + inputs = tokenizer( + batch, padding=True, truncation=True, return_tensors="pt" + ) + with torch.no_grad(): + outputs = model(**inputs) + embeddings = average_pool( + outputs.last_hidden_state, inputs["attention_mask"] + ) + embeddings = F.normalize(embeddings, p=2, dim=1).cpu().numpy() yield list(embeddings) From 9219e016a531497bd9e733205559a5123449d024 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Sat, 2 Dec 2023 09:43:38 +0100 Subject: [PATCH 19/24] Add doctrings to embedding module --- renumics/spotlight/__init__.py | 3 +- renumics/spotlight/app.py | 4 +- renumics/spotlight/embeddings/__init__.py | 5 +- renumics/spotlight/embeddings/decorator.py | 70 ++++++++++++------- .../spotlight/embeddings/embedders/gte.py | 6 +- .../spotlight/embeddings/preprocessors.py | 15 ++++ renumics/spotlight/embeddings/registry.py | 8 +-- renumics/spotlight/embeddings/typing.py | 30 +++++++- 8 files changed, 104 insertions(+), 37 deletions(-) diff --git a/renumics/spotlight/__init__.py b/renumics/spotlight/__init__.py index 5fd91702..6a1b0b7b 100644 --- a/renumics/spotlight/__init__.py +++ b/renumics/spotlight/__init__.py @@ -79,6 +79,7 @@ from .plugin_loader import load_plugins from .settings import settings from .analysis.typing import DataIssue +from .embeddings.decorator import embed # noqa: F401 from . import cache, logging if not settings.verbose: @@ -86,7 +87,7 @@ __plugins__ = load_plugins() -__all__ = ["show", "close", "viewers", "Viewer", "clear_caches", "DataIssue"] +__all__ = ["show", "close", "viewers", "Viewer", "clear_caches", "DataIssue", "embed"] def clear_caches() -> None: diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index c5c9b3c9..c6107dee 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -31,7 +31,7 @@ ResetLayoutMessage, WebsocketManager, ) -from renumics.spotlight.embeddings import create_embedders, embed +from renumics.spotlight.embeddings import create_embedders, run_embedders from renumics.spotlight.layout.nodes import Layout from renumics.spotlight.backend.config import Config from renumics.spotlight.typing import PathType @@ -476,7 +476,7 @@ def _update_embeddings(self) -> None: self._data_store.embeddings = {column: None for column in embedders} task = self.task_manager.create_task( - embed, (embedders,), name="update_embeddings" + run_embedders, (embedders,), name="update_embeddings" ) def _on_embeddings_ready(future: Future) -> None: diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 51a55431..54214906 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -11,6 +11,7 @@ from renumics.spotlight.embeddings.typing import Embedder from .registry import registered_embedders +from .decorator import embed # noqa: F401 from . import embedders as embedders_namespace # import all modules in .embedders @@ -22,6 +23,7 @@ def create_embedders(data_store: DataStore, columns: List[str]) -> Dict[str, Emb """ Create embedding functions for the given data store. """ + print(registered_embedders.keys()) embedders: Dict[str, Embedder] = {} for column in columns: for name, (embedder_class, dtype, args, kwargs) in registered_embedders.items(): @@ -30,10 +32,11 @@ def create_embedders(data_store: DataStore, columns: List[str]) -> Dict[str, Emb embedder = embedder_class(data_store, column, *args, **kwargs) embedders[f"{column}.{name}.embedding"] = embedder + print(embedders.keys()) return embedders -def embed(embedders: Dict[str, Embedder]) -> Dict[str, np.ndarray]: +def run_embedders(embedders: Dict[str, Embedder]) -> Dict[str, np.ndarray]: """ Run the given functions. """ diff --git a/renumics/spotlight/embeddings/decorator.py b/renumics/spotlight/embeddings/decorator.py index f61b9d2e..eb8adc60 100644 --- a/renumics/spotlight/embeddings/decorator.py +++ b/renumics/spotlight/embeddings/decorator.py @@ -3,41 +3,17 @@ """ import functools -from typing import ( - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - Any, - Union, - overload, -) +from typing import Callable, Dict, Literal, Optional, Any, Union, overload -import PIL.Image -import numpy as np from renumics.spotlight import dtypes -from renumics.spotlight.media.audio import Audio -from renumics.spotlight.media.embedding import Embedding - -from renumics.spotlight.media.image import Image -from renumics.spotlight.media.sequence_1d import Sequence1D +from renumics.spotlight.media import Audio, Embedding, Image, Sequence1D from .preprocessors import ( preprocess_audio_batch, preprocess_batch, preprocess_image_batch, ) from .registry import register_embedder -from .typing import EmbedFunc, FunctionalEmbedder - - -EmbedImageFunc = Callable[ - [Iterable[List[PIL.Image.Image]]], Iterable[List[Optional[np.ndarray]]] -] -EmbedArrayFunc = Callable[ - [Iterable[List[np.ndarray]]], Iterable[List[Optional[np.ndarray]]] -] +from .typing import EmbedArrayFunc, EmbedFunc, EmbedImageFunc, FunctionalEmbedder @overload @@ -80,6 +56,46 @@ def embed( def embed( dtype: Any, *, name: Optional[str] = None, sampling_rate: Optional[int] = None ) -> Callable[[EmbedFunc], EmbedFunc]: + """ + Decorator for marking an embedding function as an Spotlight embedder for the + given data type. + + The decorated function receives an iterable with preprocessed data batches + as Python lists. Batches contain preprocessed data samples: + Pillow images in case of an image embedder; + mono channel audio PCM resampled to the given sampling rate as an double + array in case of an audio embedder; + raw data from Spotlight data store otherwise. + The decorated function should return iterable with batches of embeddings. + Embeddings must be represented by arrays of the same length, or `None` if an + input sample cannot be embedded). + + Args: + dtype: Spotlight data type which can be embedded with the given + function. For more than one data type, use this decorator multiple + times. + name: Optional embedder name. If not given, the name of the decorated + function will be used. + sampling_rate: Optional sampling rate. Only relevant for audio + embedders, otherwise ignored. + + Example of the [JinaAI v2 small](https://huggingface.co/jinaai/jina-embeddings-v2-small-en) + model: + ```python + from typing import Iterable, List, Optional + + import numpy as np + import transformers + from renumics.spotlight import embed + + @embed("str", name="jina-v2-small") + def jina_v2_small(batches: Iterable[List[str]]) -> Iterable[List[Optional[np.ndarray]]]: + model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-small-en', trust_remote_code=True) + for batch in batches: + embeddings = model.encode(batch) + yield list(embeddings) + ``` + """ dtype = dtypes.create_dtype(dtype) kwargs: Dict[str, Any] = {} diff --git a/renumics/spotlight/embeddings/embedders/gte.py b/renumics/spotlight/embeddings/embedders/gte.py index 5f535b6a..8a8255cc 100644 --- a/renumics/spotlight/embeddings/embedders/gte.py +++ b/renumics/spotlight/embeddings/embedders/gte.py @@ -32,7 +32,11 @@ def gte(batches: Iterable[List[str]]) -> Iterable[List[Optional[np.ndarray]]]: for batch in batches: inputs = tokenizer( - batch, padding=True, truncation=True, return_tensors="pt" + batch, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt", ) with torch.no_grad(): outputs = model(**inputs) diff --git a/renumics/spotlight/embeddings/preprocessors.py b/renumics/spotlight/embeddings/preprocessors.py index 2ca446fd..7e07c78b 100644 --- a/renumics/spotlight/embeddings/preprocessors.py +++ b/renumics/spotlight/embeddings/preprocessors.py @@ -1,3 +1,8 @@ +""" +Data preprocessors available to embedders for conversion data batches from our +intern format to a common format. +""" + import io from typing import List @@ -7,16 +12,26 @@ def preprocess_batch(raw_values: list) -> list: + """ + Preprocess a batch of any data, returns batch as is. + """ return raw_values def preprocess_image_batch(raw_values: List[bytes]) -> List[PIL.Image.Image]: + """ + Preprocess a batch of image data, converts images to pillow image format. + """ return [PIL.Image.open(io.BytesIO(value)) for value in raw_values] def preprocess_audio_batch( raw_values: List[bytes], sampling_rate: int ) -> List[np.ndarray]: + """ + Preprocess a batch of audio data, returns mono channel audio PCM resampled + to the given sampling rate as an double array. + """ resampled_batch = [] for raw_data in raw_values: with av.open(io.BytesIO(raw_data), "r") as container: diff --git a/renumics/spotlight/embeddings/registry.py b/renumics/spotlight/embeddings/registry.py index ddb2003f..97f6923e 100644 --- a/renumics/spotlight/embeddings/registry.py +++ b/renumics/spotlight/embeddings/registry.py @@ -3,21 +3,21 @@ """ from typing import Any, Dict, Tuple, Type -from renumics.spotlight.dtypes import DType - +from renumics.spotlight import dtypes from .typing import Embedder registered_embedders: Dict[ - str, Tuple[Type[Embedder], DType, tuple, Dict[str, Any]] + str, Tuple[Type[Embedder], dtypes.DType, tuple, Dict[str, Any]] ] = {} def register_embedder( - embedder: Type[Embedder], dtype: DType, name: str, *args: Any, **kwargs: Any + embedder: Type[Embedder], dtype: dtypes.DType, name: str, *args: Any, **kwargs: Any ) -> None: """ Register an embedder """ + print(f"{name} embedder registered.") registered_embedders[name] = (embedder, dtype, args, kwargs) diff --git a/renumics/spotlight/embeddings/typing.py b/renumics/spotlight/embeddings/typing.py index 981c8dd1..831f8453 100644 --- a/renumics/spotlight/embeddings/typing.py +++ b/renumics/spotlight/embeddings/typing.py @@ -6,12 +6,21 @@ from abc import ABC, abstractmethod from typing import Callable, Iterable, List, Optional +import PIL.Image import numpy as np from renumics.spotlight.data_store import DataStore class Embedder(ABC): + """ + Base data store embedder class. + + Args: + data_store: Data store. + column: A column in the data store to embed. + """ + data_store: DataStore column: str @@ -28,9 +37,25 @@ def __call__(self) -> np.ndarray: PreprocessFunc = Callable[[list], list] EmbedFunc = Callable[[Iterable[list]], Iterable[List[Optional[np.ndarray]]]] +EmbedImageFunc = Callable[ + [Iterable[List[PIL.Image.Image]]], Iterable[List[Optional[np.ndarray]]] +] +EmbedArrayFunc = Callable[ + [Iterable[List[np.ndarray]]], Iterable[List[Optional[np.ndarray]]] +] class FunctionalEmbedder(Embedder): + """ + Wrapper for preprocessing and embedding functions. + + Attrs: + preprocess_func: Preprocessing function. Receives a batch of data in our + internal format and prepares is for embedding. + embed_func: Embedding function. Receives an iterable with preprocessed + data batches and yields batches of generated embeddings. + """ + preprocess_func: PreprocessFunc embed_func: EmbedFunc batch_size: int @@ -51,7 +76,7 @@ def __init__( def _iter_batches(self) -> Iterable[list]: """ - Yield batches with data, i.e. without the `None` values. + Yield batches with valid data, i.e. without the `None` values. """ self._occupied_indices = [] batch = [] @@ -72,6 +97,9 @@ def _iter_batches(self) -> Iterable[list]: yield self.preprocess_func(batch) def __call__(self) -> np.ndarray: + """ + Embed the given column. + """ embeddings = itertools.chain(*self.embed_func(self._iter_batches())) res = np.empty(len(self.data_store), dtype=np.object_) res[self._occupied_indices] = list(embeddings) From 6c679a73fb38c2e67c4a052295417fa4ac4a6be1 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 4 Dec 2023 09:46:59 +0100 Subject: [PATCH 20/24] Do not expose `embed` decorator --- renumics/spotlight/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/renumics/spotlight/__init__.py b/renumics/spotlight/__init__.py index 6a1b0b7b..5fd91702 100644 --- a/renumics/spotlight/__init__.py +++ b/renumics/spotlight/__init__.py @@ -79,7 +79,6 @@ from .plugin_loader import load_plugins from .settings import settings from .analysis.typing import DataIssue -from .embeddings.decorator import embed # noqa: F401 from . import cache, logging if not settings.verbose: @@ -87,7 +86,7 @@ __plugins__ = load_plugins() -__all__ = ["show", "close", "viewers", "Viewer", "clear_caches", "DataIssue", "embed"] +__all__ = ["show", "close", "viewers", "Viewer", "clear_caches", "DataIssue"] def clear_caches() -> None: From cf81c82ea9540bfe3219d856ca0090041ea6c53d Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 4 Dec 2023 09:56:15 +0100 Subject: [PATCH 21/24] Remove `torchvision` from dev deps --- poetry.lock | 33 +-------------------------------- pyproject.toml | 3 --- 2 files changed, 1 insertion(+), 35 deletions(-) diff --git a/poetry.lock b/poetry.lock index 581e451b..3149921d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5973,37 +5973,6 @@ type = "legacy" url = "https://download.pytorch.org/whl/cpu" reference = "torch-cpu" -[[package]] -name = "torchvision" -version = "0.16.1+cpu" -description = "image and video datasets and models for torch deep learning" -optional = false -python-versions = ">=3.8" -files = [ - {file = "torchvision-0.16.1+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:52442b5408dab35fedc6ccd75a7df38aadb6eb5d078184ac04ba7c1a9db8ae9c"}, - {file = "torchvision-0.16.1+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:d77afeed3fe309f3f73943431840d40fa61a33a52076bbb780c63e6b5f79962a"}, - {file = "torchvision-0.16.1+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:574b58e20ca89ebd2d8d1ff72dc6e6944734ed1327849214a349c39d5454ac4f"}, - {file = "torchvision-0.16.1+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:f6fed6f1311d34c4751d70d60408ea8526842ede66f5aad7272d6de5c4337ebc"}, - {file = "torchvision-0.16.1+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:53a5f20778bb300038d1846fb111e73ae5d4babc6b9cfb6a6cbeb2eefa399c0b"}, - {file = "torchvision-0.16.1+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:89b9bf0e80fc4c45f114707d61e15741ce0e9591badf29aa939863be9fd9543d"}, - {file = "torchvision-0.16.1+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:048b029410bc4c7ce87b1f2b621fd8b88249c68350ee773c3152faa088e9fcca"}, - {file = "torchvision-0.16.1+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:73a4b3317eae425baabefe6b61197c75b0b766f1809c5cfdad07dd6bfad255a2"}, -] - -[package.dependencies] -numpy = "*" -pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0" -requests = "*" -torch = "2.1.1" - -[package.extras] -scipy = ["scipy"] - -[package.source] -type = "legacy" -url = "https://download.pytorch.org/whl/cpu" -reference = "torch-cpu" - [[package]] name = "tornado" version = "6.3.3" @@ -7124,4 +7093,4 @@ descriptors = ["pycatch22"] [metadata] lock-version = "2.0" python-versions = ">=3.8, <3.12" -content-hash = "49f1495035945dd4c6e404de9ea4512302b3ab96e86e98f9a3da370469ca7a89" +content-hash = "b94339642a6815125a017be3386f559c1818fcbe645896203b8a7c191a940a27" diff --git a/pyproject.toml b/pyproject.toml index 37f5193e..96f0b12c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,6 @@ types-pillow = "^10.0.0.1" pandas-stubs = "^2.0.2.230605" ruff = "^0.0.281" check-wheel-contents = "^0.6.0" -torchvision = {version = "^0.16.1+cpu", source = "torch-cpu"} torch = {version = "^2.1.1+cpu", source = "torch-cpu"} [tool.poetry.group.playbook.dependencies] @@ -182,8 +181,6 @@ module = [ "diffimg", "tests.ui._autogenerated_ui_elements", "transformers", - "zstandard", - "sentence_transformers", ] ignore_missing_imports = true From 5b241de1a0ba17386901a8fc75574a7f4a57a0b9 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 4 Dec 2023 10:10:55 +0100 Subject: [PATCH 22/24] Remove debug prints --- renumics/spotlight/app.py | 6 ++---- renumics/spotlight/embeddings/__init__.py | 2 -- renumics/spotlight/embeddings/registry.py | 1 - 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index c6107dee..d0081ba1 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -53,10 +53,8 @@ from renumics.spotlight.app_config import AppConfig from renumics.spotlight.data_source import DataSource, create_datasource from renumics.spotlight import layouts - from renumics.spotlight.data_store import DataStore - -from renumics.spotlight.dtypes import DTypeMap +from renumics.spotlight import dtypes as spotlight_dtypes CURRENT_LAYOUT_KEY = "layout.current" @@ -107,7 +105,7 @@ class SpotlightApp(FastAPI): # datasource _dataset: Optional[Union[PathType, pd.DataFrame]] - _user_dtypes: DTypeMap + _user_dtypes: spotlight_dtypes.DTypeMap _data_source: Optional[DataSource] _data_store: Optional[DataStore] diff --git a/renumics/spotlight/embeddings/__init__.py b/renumics/spotlight/embeddings/__init__.py index 54214906..6dc06f7b 100644 --- a/renumics/spotlight/embeddings/__init__.py +++ b/renumics/spotlight/embeddings/__init__.py @@ -23,7 +23,6 @@ def create_embedders(data_store: DataStore, columns: List[str]) -> Dict[str, Emb """ Create embedding functions for the given data store. """ - print(registered_embedders.keys()) embedders: Dict[str, Embedder] = {} for column in columns: for name, (embedder_class, dtype, args, kwargs) in registered_embedders.items(): @@ -32,7 +31,6 @@ def create_embedders(data_store: DataStore, columns: List[str]) -> Dict[str, Emb embedder = embedder_class(data_store, column, *args, **kwargs) embedders[f"{column}.{name}.embedding"] = embedder - print(embedders.keys()) return embedders diff --git a/renumics/spotlight/embeddings/registry.py b/renumics/spotlight/embeddings/registry.py index 97f6923e..d9676f11 100644 --- a/renumics/spotlight/embeddings/registry.py +++ b/renumics/spotlight/embeddings/registry.py @@ -17,7 +17,6 @@ def register_embedder( """ Register an embedder """ - print(f"{name} embedder registered.") registered_embedders[name] = (embedder, dtype, args, kwargs) From b0d5f025408f3925e3459fc6942df3c1bafea423 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Mon, 4 Dec 2023 14:43:45 +0100 Subject: [PATCH 23/24] Add embed argument to `spotlight.show` and CLI --- renumics/spotlight/app.py | 22 ++++++++++++++++++++-- renumics/spotlight/app_config.py | 5 ++++- renumics/spotlight/cli.py | 25 ++++++++++++++++++++----- renumics/spotlight/viewer.py | 10 +++++++++- 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/renumics/spotlight/app.py b/renumics/spotlight/app.py index d0081ba1..147f0254 100644 --- a/renumics/spotlight/app.py +++ b/renumics/spotlight/app.py @@ -123,7 +123,10 @@ class SpotlightApp(FastAPI): # data issues issues: Optional[List[DataIssue]] = [] _custom_issues: List[DataIssue] = [] - analyze_columns: Union[List[str], bool] = False + analyze_columns: Union[List[str], bool] + + # embedding + embed_columns: Union[List[str], bool] def __init__(self) -> None: super().__init__() @@ -138,6 +141,7 @@ def __init__(self) -> None: self.analyze_columns = False self.issues = None self._custom_issues = [] + self.embed_columns = False self._dataset = None self._user_dtypes = {} @@ -328,6 +332,8 @@ def update(self, config: AppConfig) -> None: self.analyze_columns = config.analyze if config.custom_issues is not None: self.custom_issues = config.custom_issues + if config.embed is not None: + self.embed_columns = config.embed if config.dataset is not None: self._dataset = config.dataset self._data_source = create_datasource(self._dataset) @@ -464,12 +470,24 @@ def _update_embeddings(self) -> None: """ Update embeddings, update them in the data store and notify client about. """ + if not self.embed_columns: + return + if self._data_store is None: return logger.info("Embedding started.") - embedders = create_embedders(self._data_store, self._data_store.column_names) + if self.embed_columns is True: + embed_columns = self._data_store.column_names + else: + embed_columns = [ + column + for column in self.embed_columns + if column in self._data_store.column_names + ] + + embedders = create_embedders(self._data_store, embed_columns) self._data_store.embeddings = {column: None for column in embedders} diff --git a/renumics/spotlight/app_config.py b/renumics/spotlight/app_config.py index 15bcf217..e1364fa3 100644 --- a/renumics/spotlight/app_config.py +++ b/renumics/spotlight/app_config.py @@ -24,9 +24,12 @@ class AppConfig: project_root: Optional[Path] = None # data analysis - analyze: Optional[Union[bool, List[str]]] = None + analyze: Optional[Union[List[str], bool]] = None custom_issues: Optional[List[DataIssue]] = None + # embedding + embed: Optional[Union[List[str], bool]] = None + # frontend layout: Optional[Layout] = None filebrowsing_allowed: Optional[bool] = None diff --git a/renumics/spotlight/cli.py b/renumics/spotlight/cli.py index 7854a8c7..6736fb9b 100644 --- a/renumics/spotlight/cli.py +++ b/renumics/spotlight/cli.py @@ -6,7 +6,7 @@ import platform import signal import sys -from typing import Dict, Optional, Tuple, Union, List +from typing import Dict, Optional, Tuple, Union import click @@ -94,9 +94,21 @@ def cli_dtype_callback( ) @click.option( "--analyze", - default=[], + default=(), multiple=True, - help="Automatically analyze issues for all columns.", + help="Columns to analyze (if no --analyze-all).", +) +@click.option( + "--embed-all", + is_flag=True, + default=False, + help="Automatically embed all columns.", +) +@click.option( + "--embed", + default=(), + multiple=True, + help="Columns to embed (if no --analyze-all).", ) @click.option("-v", "--verbose", is_flag=True) @click.version_option(spotlight.__version__) @@ -109,8 +121,10 @@ def main( dtype: Optional[Dict[str, str]], no_browser: bool, filebrowsing: bool, - analyze: List[str], + analyze: Tuple[str], analyze_all: bool, + embed: Tuple[str], + embed_all: bool, verbose: bool, ) -> None: """ @@ -135,5 +149,6 @@ def main( no_browser=no_browser, allow_filebrowsing=filebrowsing, wait="forever", - analyze=True if analyze_all else analyze, + analyze=True if analyze_all else list(analyze), + embed=True if embed_all else list(embed), ) diff --git a/renumics/spotlight/viewer.py b/renumics/spotlight/viewer.py index 40a07683..94c979b9 100644 --- a/renumics/spotlight/viewer.py +++ b/renumics/spotlight/viewer.py @@ -150,8 +150,9 @@ def show( allow_filebrowsing: Union[bool, Literal["auto"]] = "auto", wait: Union[bool, Literal["auto", "forever"]] = "auto", dtype: Optional[Dict[str, Any]] = None, - analyze: Optional[Union[bool, List[str]]] = None, + analyze: Optional[Union[List[str], bool]] = None, issues: Optional[Collection[DataIssue]] = None, + embed: Optional[Union[List[str], bool]] = None, ) -> None: """ Show a dataset or folder in this spotlight viewer. @@ -173,6 +174,8 @@ def show( column types allowed by Spotlight (for dataframes only). analyze: Automatically analyze common dataset issues (disabled by default). issues: Custom dataset issues displayed in the viewer. + embed: Automatically embed all or given columns with default + embedders (disabled by default). """ if is_pathtype(dataset): @@ -206,6 +209,7 @@ def show( project_root=project_root, analyze=analyze, custom_issues=list(issues) if issues else None, + embed=embed, layout=parsed_layout, filebrowsing_allowed=filebrowsing_allowed, ) @@ -373,6 +377,7 @@ def show( dtype: Optional[Dict[str, Any]] = None, analyze: Optional[Union[bool, List[str]]] = None, issues: Optional[Collection[DataIssue]] = None, + embed: Optional[Union[List[str], bool]] = None, ) -> Viewer: """ Start a new Spotlight viewer. @@ -397,6 +402,8 @@ def show( column types allowed by Spotlight (for dataframes only). analyze: Automatically analyze common dataset issues (disabled by default). issues: Custom dataset issues displayed in the viewer. + embed: Automatically embed all or given columns with default + embedders (disabled by default). """ viewer = None @@ -419,6 +426,7 @@ def show( dtype=dtype, analyze=analyze, issues=issues, + embed=embed, ) return viewer From c36b8851db7154844adc28f963d148085d449b70 Mon Sep 17 00:00:00 2001 From: Alexander Druz Date: Tue, 5 Dec 2023 10:51:55 +0100 Subject: [PATCH 24/24] Fix help string --- renumics/spotlight/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/renumics/spotlight/cli.py b/renumics/spotlight/cli.py index 6736fb9b..e455b428 100644 --- a/renumics/spotlight/cli.py +++ b/renumics/spotlight/cli.py @@ -108,7 +108,7 @@ def cli_dtype_callback( "--embed", default=(), multiple=True, - help="Columns to embed (if no --analyze-all).", + help="Columns to embed (if no --embed-all).", ) @click.option("-v", "--verbose", is_flag=True) @click.version_option(spotlight.__version__)