Skip to content

Commit

Permalink
Merge pull request #375 from Renumics/feature/208-support-embedding-l…
Browse files Browse the repository at this point in the history
…ength

Feature/208 support embedding length
  • Loading branch information
druzsan authored Nov 20, 2023
2 parents 5fa0b61 + 0189eee commit 528e3c9
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 24 deletions.
68 changes: 44 additions & 24 deletions renumics/spotlight/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ def unescape_dataset_name(escaped_name: str) -> str:
np.ndarray,
list,
tuple,
range,
bool,
int,
float,
np.bool_,
np.integer,
np.floating,
),
"Window": (np.ndarray, list, tuple),
"Embedding": (spotlight_dtypes.Embedding, np.ndarray, list, tuple),
"Sequence1D": (spotlight_dtypes.Sequence1D, np.ndarray, list, tuple),
"Window": (np.ndarray, list, tuple, range),
"Embedding": (spotlight_dtypes.Embedding, np.ndarray, list, tuple, range),
"Sequence1D": (spotlight_dtypes.Sequence1D, np.ndarray, list, tuple, range),
"Audio": (spotlight_dtypes.Audio, bytes, str, os.PathLike),
"Image": (spotlight_dtypes.Image, bytes, str, os.PathLike, np.ndarray, list, tuple),
"Mesh": (spotlight_dtypes.Mesh, trimesh.Trimesh, str, os.PathLike),
Expand Down Expand Up @@ -1227,6 +1228,7 @@ def append_embedding_column(
default: EmbeddingColumnInputType = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
length: Optional[int] = None,
dtype: Union[str, np.dtype] = "float32",
) -> None:
"""
Expand Down Expand Up @@ -1257,7 +1259,7 @@ def append_embedding_column(
)
self._append_column(
name,
spotlight_dtypes.embedding_dtype,
spotlight_dtypes.EmbeddingDType(length),
values,
h5py.vlen_dtype(np_dtype),
order,
Expand Down Expand Up @@ -1707,6 +1709,13 @@ def append_column(
append_column_fn = self.append_window_column
elif spotlight_dtypes.is_embedding_dtype(dtype):
append_column_fn = self.append_embedding_column
if dtype.length is not None:
if "length" in attrs and attrs["length"] != dtype.length:
raise exceptions.InvalidAttributeError(
f"Embedding length differs between `dtype` ({dtype.length}) "
f"and `length` ({attrs['length']}) keyword argument."
)
attrs["length"] = dtype.length
elif spotlight_dtypes.is_sequence_1d_dtype(dtype):
append_column_fn = self.append_sequence_1d_column
elif spotlight_dtypes.is_audio_dtype(dtype):
Expand Down Expand Up @@ -2482,6 +2491,8 @@ def _append_column(
)
self._column_names.add(name)
column.attrs["type"] = dtype.name
if spotlight_dtypes.is_embedding_dtype(dtype) and dtype.length is not None:
column.attrs["value_shape"] = (dtype.length,)
self.set_column_attributes(
name,
order,
Expand Down Expand Up @@ -2881,10 +2892,10 @@ def _encode_simple_values(
# embedding should go through `_encode_value` element-wise.
if values.ndim == 1:
# Handle 1-dimensional input as a single embedding.
self._assert_valid_or_set_embedding_shape(values.shape, column)
self._assert_valid_or_set_value_shape(values.shape, column)
values_list = list(np.broadcast_to(values, (1, len(values))))
elif values.ndim == 2:
self._assert_valid_or_set_embedding_shape(values.shape[1:], column)
self._assert_valid_or_set_value_shape(values.shape[1:], column)
values_list = list(values)
else:
raise exceptions.InvalidShapeError(
Expand Down Expand Up @@ -3108,7 +3119,7 @@ def _encode_simple_value(
if isinstance(value, spotlight_dtypes.Embedding):
value = value.encode(attrs.get("format", None))
value = np.asarray(value, dtype=column.dtype.metadata["vlen"])
self._assert_valid_or_set_embedding_shape(value.shape, column)
self._assert_valid_or_set_value_shape(value.shape, column)
return value
if isinstance(value, np.str_):
return value.tolist()
Expand Down Expand Up @@ -3159,7 +3170,7 @@ def _encode_ref_value(
value = spotlight_dtypes.Embedding(value) # type: ignore
value = value.encode()
self._assert_valid_or_set_value_dtype(value.dtype, column)
self._assert_valid_or_set_embedding_shape(value.shape, column)
self._assert_valid_or_set_value_shape(value.shape, column)
return value
if spotlight_dtypes.is_sequence_1d_dtype(dtype):
if not isinstance(value, spotlight_dtypes.Sequence1D):
Expand Down Expand Up @@ -3475,6 +3486,12 @@ def _get_dtype(
return spotlight_dtypes.Sequence1DDType(
x.get("x_label", "x"), x.get("y_label", "y")
)
if type_name == "Embedding":
try:
length = x["value_shape"][0]
except (KeyError, IndexError):
return spotlight_dtypes.embedding_dtype
return spotlight_dtypes.EmbeddingDType(length)
return spotlight_dtypes.create_dtype(type_name)

@staticmethod
Expand Down Expand Up @@ -3585,6 +3602,9 @@ def _assert_index_exists(self, index: IndexType, check_type: bool = False) -> No
def _assert_valid_or_set_value_dtype(
self, dtype: np.dtype, column: h5py.Dataset
) -> None:
"""
Set value dtype for the whole column if not yet set, check shape otherwise.
"""
attrs = column.attrs
if "value_dtype" in attrs:
if dtype.str != attrs["value_dtype"]:
Expand All @@ -3604,27 +3624,27 @@ def _assert_valid_or_set_value_dtype(
f"received."
)

def _assert_valid_or_set_embedding_shape(
def _assert_valid_or_set_value_shape(
self, shape: Tuple[int, ...], column: h5py.Dataset
) -> None:
"""
Set value shape for the whole column if not yet set, check shape otherwise.
"""
attrs = column.attrs
if shape == (0,) and attrs.get("optional", False):
# Do not check shape if an empty array given for an optional column.
return
if "value_shape" in attrs:
if shape != attrs["value_shape"]:
column_name = self._get_column_name(column)
raise exceptions.InvalidShapeError(
f'Values for `Embedding` column "{column_name}" '
f'should have shape {attrs["value_shape"]}, but '
f"value with shape {shape} received."
)
elif len(shape) == 1 and shape[0] > 0:
try:
target_shape = attrs["value_shape"]
except KeyError:
# Target shape isn't set, set.
attrs["value_shape"] = shape
else:
column_name = self._get_column_name(column)
raise exceptions.InvalidShapeError(
f'Values for `Embedding` column "{column_name}" should '
f"have shape `(num_features,)`, `num_features > 0`, "
f"but value with shape {shape} received."
)
# Target shape is set, check.
if shape != target_shape:
name = self._get_column_name(column)
dtype = self._get_dtype(column)
raise exceptions.InvalidShapeError(
f'Values for {dtype} column "{name}" should have shape '
f"{target_shape}, but value with shape {shape} received."
)
Loading

0 comments on commit 528e3c9

Please sign in to comment.