Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/208 support embedding length #375

Merged
merged 5 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions renumics/spotlight/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ def unescape_dataset_name(escaped_name: str) -> str:
np.ndarray,
list,
tuple,
range,
bool,
int,
float,
np.bool_,
np.integer,
np.floating,
),
"Window": (np.ndarray, list, tuple),
"Embedding": (spotlight_dtypes.Embedding, np.ndarray, list, tuple),
"Sequence1D": (spotlight_dtypes.Sequence1D, np.ndarray, list, tuple),
"Window": (np.ndarray, list, tuple, range),
"Embedding": (spotlight_dtypes.Embedding, np.ndarray, list, tuple, range),
"Sequence1D": (spotlight_dtypes.Sequence1D, np.ndarray, list, tuple, range),
"Audio": (spotlight_dtypes.Audio, bytes, str, os.PathLike),
"Image": (spotlight_dtypes.Image, bytes, str, os.PathLike, np.ndarray, list, tuple),
"Mesh": (spotlight_dtypes.Mesh, trimesh.Trimesh, str, os.PathLike),
Expand Down Expand Up @@ -1227,6 +1228,7 @@ def append_embedding_column(
default: EmbeddingColumnInputType = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
length: Optional[int] = None,
dtype: Union[str, np.dtype] = "float32",
) -> None:
"""
Expand Down Expand Up @@ -1257,7 +1259,7 @@ def append_embedding_column(
)
self._append_column(
name,
spotlight_dtypes.embedding_dtype,
spotlight_dtypes.EmbeddingDType(length),
values,
h5py.vlen_dtype(np_dtype),
order,
Expand Down Expand Up @@ -1707,6 +1709,13 @@ def append_column(
append_column_fn = self.append_window_column
elif spotlight_dtypes.is_embedding_dtype(dtype):
append_column_fn = self.append_embedding_column
if dtype.length is not None:
if "length" in attrs and attrs["length"] != dtype.length:
raise exceptions.InvalidAttributeError(
f"Embedding length differs between `dtype` ({dtype.length}) "
f"and `length` ({attrs['length']}) keyword argument."
)
attrs["length"] = dtype.length
elif spotlight_dtypes.is_sequence_1d_dtype(dtype):
append_column_fn = self.append_sequence_1d_column
elif spotlight_dtypes.is_audio_dtype(dtype):
Expand Down Expand Up @@ -2482,6 +2491,8 @@ def _append_column(
)
self._column_names.add(name)
column.attrs["type"] = dtype.name
if spotlight_dtypes.is_embedding_dtype(dtype) and dtype.length is not None:
column.attrs["value_shape"] = (dtype.length,)
self.set_column_attributes(
name,
order,
Expand Down Expand Up @@ -2881,10 +2892,10 @@ def _encode_simple_values(
# embedding should go through `_encode_value` element-wise.
if values.ndim == 1:
# Handle 1-dimensional input as a single embedding.
self._assert_valid_or_set_embedding_shape(values.shape, column)
self._assert_valid_or_set_value_shape(values.shape, column)
values_list = list(np.broadcast_to(values, (1, len(values))))
elif values.ndim == 2:
self._assert_valid_or_set_embedding_shape(values.shape[1:], column)
self._assert_valid_or_set_value_shape(values.shape[1:], column)
values_list = list(values)
else:
raise exceptions.InvalidShapeError(
Expand Down Expand Up @@ -3108,7 +3119,7 @@ def _encode_simple_value(
if isinstance(value, spotlight_dtypes.Embedding):
value = value.encode(attrs.get("format", None))
value = np.asarray(value, dtype=column.dtype.metadata["vlen"])
self._assert_valid_or_set_embedding_shape(value.shape, column)
self._assert_valid_or_set_value_shape(value.shape, column)
return value
if isinstance(value, np.str_):
return value.tolist()
Expand Down Expand Up @@ -3159,7 +3170,7 @@ def _encode_ref_value(
value = spotlight_dtypes.Embedding(value) # type: ignore
value = value.encode()
self._assert_valid_or_set_value_dtype(value.dtype, column)
self._assert_valid_or_set_embedding_shape(value.shape, column)
self._assert_valid_or_set_value_shape(value.shape, column)
return value
if spotlight_dtypes.is_sequence_1d_dtype(dtype):
if not isinstance(value, spotlight_dtypes.Sequence1D):
Expand Down Expand Up @@ -3475,6 +3486,12 @@ def _get_dtype(
return spotlight_dtypes.Sequence1DDType(
x.get("x_label", "x"), x.get("y_label", "y")
)
if type_name == "Embedding":
try:
length = x["value_shape"][0]
except (KeyError, IndexError):
return spotlight_dtypes.embedding_dtype
return spotlight_dtypes.EmbeddingDType(length)
return spotlight_dtypes.create_dtype(type_name)

@staticmethod
Expand Down Expand Up @@ -3585,6 +3602,9 @@ def _assert_index_exists(self, index: IndexType, check_type: bool = False) -> No
def _assert_valid_or_set_value_dtype(
self, dtype: np.dtype, column: h5py.Dataset
) -> None:
"""
Set value dtype for the whole column if not yet set, check shape otherwise.
"""
attrs = column.attrs
if "value_dtype" in attrs:
if dtype.str != attrs["value_dtype"]:
Expand All @@ -3604,27 +3624,27 @@ def _assert_valid_or_set_value_dtype(
f"received."
)

def _assert_valid_or_set_embedding_shape(
def _assert_valid_or_set_value_shape(
self, shape: Tuple[int, ...], column: h5py.Dataset
) -> None:
"""
Set value shape for the whole column if not yet set, check shape otherwise.
"""
attrs = column.attrs
if shape == (0,) and attrs.get("optional", False):
# Do not check shape if an empty array given for an optional column.
return
if "value_shape" in attrs:
if shape != attrs["value_shape"]:
column_name = self._get_column_name(column)
raise exceptions.InvalidShapeError(
f'Values for `Embedding` column "{column_name}" '
f'should have shape {attrs["value_shape"]}, but '
f"value with shape {shape} received."
)
elif len(shape) == 1 and shape[0] > 0:
try:
target_shape = attrs["value_shape"]
except KeyError:
# Target shape isn't set, set.
attrs["value_shape"] = shape
else:
column_name = self._get_column_name(column)
raise exceptions.InvalidShapeError(
f'Values for `Embedding` column "{column_name}" should '
f"have shape `(num_features,)`, `num_features > 0`, "
f"but value with shape {shape} received."
)
# Target shape is set, check.
if shape != target_shape:
name = self._get_column_name(column)
dtype = self._get_dtype(column)
raise exceptions.InvalidShapeError(
f'Values for {dtype} column "{name}" should have shape '
f"{target_shape}, but value with shape {shape} received."
)
Loading