Skip to content

Commit

Permalink
Merge pull request #382 from Renumics/feature/embed-columns
Browse files Browse the repository at this point in the history
Feature/embed columns
  • Loading branch information
druzsan authored Dec 5, 2023
2 parents 6cb470c + c36b885 commit cbc3826
Show file tree
Hide file tree
Showing 30 changed files with 927 additions and 80 deletions.
100 changes: 66 additions & 34 deletions poetry.lock

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ httpx = "^0.23.0"
datasets = { extras = ["audio"], version = "^2.12.0" }
pydantic-settings = "^2.0.3"
pycatch22 = { version = "!=0.4.4", optional = true }
transformers = "^4.35.2"

[tool.poetry.extras]
descriptors = ["pycatch22"]
Expand Down Expand Up @@ -125,14 +126,18 @@ types-pillow = "^10.0.0.1"
pandas-stubs = "^2.0.2.230605"
ruff = "^0.0.281"
check-wheel-contents = "^0.6.0"
torch = {version = "^2.1.1+cpu", source = "torch-cpu"}

[tool.poetry.group.playbook.dependencies]
datasets = "^2.12.0"
transformers = "^4.29.2"
torch = "^2.0.1"
towhee = "^0.9.0"
annoy = "^1.17.2"
cleanlab = "^2.4.0"
torch = {version = "^2.1.1+cpu", source = "torch-cpu"}

[[tool.poetry.source]]
name = "torch-cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "explicit"

[tool.poetry-dynamic-versioning]
enable = true
Expand Down Expand Up @@ -175,6 +180,7 @@ module = [
"datasets",
"diffimg",
"tests.ui._autogenerated_ui_elements",
"transformers",
]
ignore_missing_imports = true

Expand Down
3 changes: 2 additions & 1 deletion renumics/spotlight/analysis/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""
Manage data analyzers available for spotlights automatic dataset analysis.
"""
from typing import Set

from .typing import DataAnalyzer

registered_analyzers = set()
registered_analyzers: Set[DataAnalyzer] = set()


def register_analyzer(analyzer: DataAnalyzer) -> None:
Expand Down
66 changes: 61 additions & 5 deletions renumics/spotlight/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ResetLayoutMessage,
WebsocketManager,
)
from renumics.spotlight.embeddings import create_embedders, run_embedders
from renumics.spotlight.layout.nodes import Layout
from renumics.spotlight.backend.config import Config
from renumics.spotlight.typing import PathType
Expand All @@ -52,10 +53,8 @@
from renumics.spotlight.app_config import AppConfig
from renumics.spotlight.data_source import DataSource, create_datasource
from renumics.spotlight import layouts

from renumics.spotlight.data_store import DataStore

from renumics.spotlight.dtypes import DTypeMap
from renumics.spotlight import dtypes as spotlight_dtypes

CURRENT_LAYOUT_KEY = "layout.current"

Expand Down Expand Up @@ -83,6 +82,15 @@ class IssuesUpdatedMessage(Message):
data: Any = None


class ColumnsUpdatedMessage(Message):
"""
Notify about updated embeddings.
"""

type: Literal["columnsUpdated"] = "columnsUpdated"
data: List[str]


class SpotlightApp(FastAPI):
"""
Spotlight wsgi application
Expand All @@ -97,7 +105,7 @@ class SpotlightApp(FastAPI):

# datasource
_dataset: Optional[Union[PathType, pd.DataFrame]]
_user_dtypes: DTypeMap
_user_dtypes: spotlight_dtypes.DTypeMap
_data_source: Optional[DataSource]
_data_store: Optional[DataStore]

Expand All @@ -115,7 +123,10 @@ class SpotlightApp(FastAPI):
# data issues
issues: Optional[List[DataIssue]] = []
_custom_issues: List[DataIssue] = []
analyze_columns: Union[List[str], bool] = False
analyze_columns: Union[List[str], bool]

# embedding
embed_columns: Union[List[str], bool]

def __init__(self) -> None:
super().__init__()
Expand All @@ -130,6 +141,7 @@ def __init__(self) -> None:
self.analyze_columns = False
self.issues = None
self._custom_issues = []
self.embed_columns = False

self._dataset = None
self._user_dtypes = {}
Expand Down Expand Up @@ -320,6 +332,8 @@ def update(self, config: AppConfig) -> None:
self.analyze_columns = config.analyze
if config.custom_issues is not None:
self.custom_issues = config.custom_issues
if config.embed is not None:
self.embed_columns = config.embed
if config.dataset is not None:
self._dataset = config.dataset
self._data_source = create_datasource(self._dataset)
Expand All @@ -334,6 +348,7 @@ def update(self, config: AppConfig) -> None:
self._data_store = DataStore(data_source, self._user_dtypes)
self._broadcast(RefreshMessage())
self._update_issues()
self._update_embeddings()
if config.layout is not None:
if self._data_store is not None:
dataset_uid = self._data_store.uid
Expand Down Expand Up @@ -451,6 +466,47 @@ def _on_issues_ready(future: Future) -> None:

task.future.add_done_callback(_on_issues_ready)

def _update_embeddings(self) -> None:
"""
Update embeddings, update them in the data store and notify client about.
"""
if not self.embed_columns:
return

if self._data_store is None:
return

logger.info("Embedding started.")

if self.embed_columns is True:
embed_columns = self._data_store.column_names
else:
embed_columns = [
column
for column in self.embed_columns
if column in self._data_store.column_names
]

embedders = create_embedders(self._data_store, embed_columns)

self._data_store.embeddings = {column: None for column in embedders}

task = self.task_manager.create_task(
run_embedders, (embedders,), name="update_embeddings"
)

def _on_embeddings_ready(future: Future) -> None:
if self._data_store is None:
return
try:
self._data_store.embeddings = future.result()
except CancelledError:
return
logger.info("Embedding done.")
self._broadcast(ColumnsUpdatedMessage(data=list(embedders.keys())))

task.future.add_done_callback(_on_embeddings_ready)

def _broadcast(self, message: Message) -> None:
"""
Broadcast a message to all connected clients via websocket
Expand Down
5 changes: 4 additions & 1 deletion renumics/spotlight/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ class AppConfig:
project_root: Optional[Path] = None

# data analysis
analyze: Optional[Union[bool, List[str]]] = None
analyze: Optional[Union[List[str], bool]] = None
custom_issues: Optional[List[DataIssue]] = None

# embedding
embed: Optional[Union[List[str], bool]] = None

# frontend
layout: Optional[Layout] = None
filebrowsing_allowed: Optional[bool] = None
11 changes: 11 additions & 0 deletions renumics/spotlight/backend/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,14 @@ def __init__(self) -> None:
"dataset using `dataset.rebuild()`.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)


class ComputedColumnNotReady(Problem):
"""A computed column is not yet ready"""

def __init__(self, name: str) -> None:
super().__init__(
"Computed column not ready",
f"Computed column {name} is not yet ready.",
status.HTTP_404_NOT_FOUND,
)
25 changes: 20 additions & 5 deletions renumics/spotlight/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import platform
import signal
import sys
from typing import Dict, Optional, Tuple, Union, List
from typing import Dict, Optional, Tuple, Union

import click

Expand Down Expand Up @@ -94,9 +94,21 @@ def cli_dtype_callback(
)
@click.option(
"--analyze",
default=[],
default=(),
multiple=True,
help="Automatically analyze issues for all columns.",
help="Columns to analyze (if no --analyze-all).",
)
@click.option(
"--embed-all",
is_flag=True,
default=False,
help="Automatically embed all columns.",
)
@click.option(
"--embed",
default=(),
multiple=True,
help="Columns to embed (if no --embed-all).",
)
@click.option("-v", "--verbose", is_flag=True)
@click.version_option(spotlight.__version__)
Expand All @@ -109,8 +121,10 @@ def main(
dtype: Optional[Dict[str, str]],
no_browser: bool,
filebrowsing: bool,
analyze: List[str],
analyze: Tuple[str],
analyze_all: bool,
embed: Tuple[str],
embed_all: bool,
verbose: bool,
) -> None:
"""
Expand All @@ -135,5 +149,6 @@ def main(
no_browser=no_browser,
allow_filebrowsing=filebrowsing,
wait="forever",
analyze=True if analyze_all else analyze,
analyze=True if analyze_all else list(analyze),
embed=True if embed_all else list(embed),
)
1 change: 1 addition & 0 deletions renumics/spotlight/data_source/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ColumnMetadata:
hidden: bool = False
description: Optional[str] = None
tags: List[str] = dataclasses.field(default_factory=list)
computed: bool = False


class DataSource(ABC):
Expand Down
60 changes: 53 additions & 7 deletions renumics/spotlight/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import os
import statistics
from typing import Any, Iterable, List, Optional, Set, Union, cast
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import numpy as np
import filetype
Expand All @@ -25,16 +25,19 @@
from renumics.spotlight.media.image import Image
from renumics.spotlight.media.sequence_1d import Sequence1D
from renumics.spotlight.media.embedding import Embedding
from renumics.spotlight.backend.exceptions import ComputedColumnNotReady


class DataStore:
_data_source: DataSource
_user_dtypes: spotlight_dtypes.DTypeMap
_dtypes: spotlight_dtypes.DTypeMap
_embeddings: Dict[str, Optional[np.ndarray]]

def __init__(
self, data_source: DataSource, user_dtypes: spotlight_dtypes.DTypeMap
) -> None:
self._embeddings = {}
self._data_source = data_source
self._user_dtypes = user_dtypes
self._update_dtypes()
Expand All @@ -56,20 +59,53 @@ def generation_id(self) -> int:

@property
def column_names(self) -> List[str]:
return self._data_source.column_names
return self._data_source.column_names + list(self._embeddings)

@property
def data_source(self) -> DataSource:
return self._data_source

@property
def dtypes(self) -> spotlight_dtypes.DTypeMap:
return self._dtypes
dtypes_ = self._dtypes.copy()
for column, embeddings in self._embeddings.items():
if embeddings is None:
length = None
else:
try:
length = len(
next(
embedding
for embedding in embeddings
if embedding is not None
)
)
except StopIteration:
length = None
dtypes_[column] = spotlight_dtypes.EmbeddingDType(length=length)
return dtypes_

@property
def embeddings(self) -> Dict[str, Optional[np.ndarray]]:
return self._embeddings

@embeddings.setter
def embeddings(self, new_embeddings: Dict[str, Optional[np.ndarray]]) -> None:
self._embeddings = new_embeddings

def check_generation_id(self, generation_id: int) -> None:
self._data_source.check_generation_id(generation_id)

def get_column_metadata(self, column_name: str) -> ColumnMetadata:
if column_name in self._embeddings:
return ColumnMetadata(
nullable=True,
editable=False,
hidden=True,
description=None,
tags=[],
computed=True,
)
return self._data_source.get_column_metadata(column_name)

def get_converted_values(
Expand All @@ -79,18 +115,28 @@ def get_converted_values(
simple: bool = False,
check: bool = True,
) -> List[ConvertedValue]:
dtype = self._dtypes[column_name]
normalized_values = self._data_source.get_column_values(column_name, indices)
dtype = self.dtypes[column_name]
if column_name in self._embeddings:
embeddings = self._embeddings[column_name]
if embeddings is None:
raise ComputedColumnNotReady(column_name)
normalized_values: Iterable = embeddings[indices]
else:
normalized_values = self._data_source.get_column_values(
column_name, indices
)
converted_values = [
convert_to_dtype(value, dtype, simple=simple, check=check)
for value in normalized_values
]
return converted_values

def get_converted_value(
self, column_name: str, index: int, simple: bool = False
self, column_name: str, index: int, simple: bool = False, check: bool = True
) -> ConvertedValue:
return self.get_converted_values(column_name, indices=[index], simple=simple)[0]
return self.get_converted_values(
column_name, indices=[index], simple=simple, check=check
)[0]

def get_waveform(self, column_name: str, index: int) -> Optional[np.ndarray]:
"""
Expand Down
Loading

0 comments on commit cbc3826

Please sign in to comment.