Skip to content

Commit

Permalink
Make TensorSchema a dataclass (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsakkis authored Jun 30, 2022
1 parent 6d25994 commit 233d1ee
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 23 deletions.
6 changes: 4 additions & 2 deletions tests/readers/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def parametrize_fields(*fields):
def test_max_partition_weight_dense(dense_uri, fields, key_dim_index, memory_budget):
config = {"py.max_incomplete_retries": 0, "sm.memory_budget": memory_budget}
with tiledb.open(dense_uri, config=config) as a:
schema = DenseTensorSchema(ArrayParams(a, key_dim_index, fields))
params = ArrayParams(a, key_dim_index, fields)
schema = DenseTensorSchema.from_array_params(params)
max_weight = schema.max_partition_weight
for key_range in schema.key_range.partition_by_weight(max_weight):
# query succeeds without incomplete retries
Expand All @@ -118,7 +119,8 @@ def test_max_partition_weight_sparse(sparse_uri, fields, key_dim_index, memory_b
}
with tiledb.open(sparse_uri, config=config) as a:
key_dim = a.dim(key_dim_index)
schema = SparseTensorSchema(ArrayParams(a, key_dim_index, fields))
params = ArrayParams(a, key_dim_index, fields)
schema = SparseTensorSchema.from_array_params(params)
max_weight = schema.max_partition_weight
for key_range in schema.key_range.partition_by_weight(max_weight):
# query succeeds without incomplete retries
Expand Down
39 changes: 23 additions & 16 deletions tiledb/ml/readers/_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass
from math import ceil
from operator import itemgetter
from typing import (
Expand All @@ -11,6 +12,7 @@
Iterable,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
Expand All @@ -27,23 +29,33 @@
Tensor = TypeVar("Tensor")


@dataclass(frozen=True) # type: ignore
class TensorSchema(ABC):
"""
A class to encapsulate the information needed for mapping a TileDB array to tensors.
"""

def __init__(self, array_params: ArrayParams):
self._array = array_params._tensor_schema_kwargs["array"]
self._fields = array_params._tensor_schema_kwargs["fields"]
self._key_dim_index = array_params._tensor_schema_kwargs["key_dim_index"]
self._ned = array_params._tensor_schema_kwargs["ned"]
self._all_dims = array_params._tensor_schema_kwargs["all_dims"]
self._query_kwargs = array_params._tensor_schema_kwargs["query_kwargs"]
_array: tiledb.Array
_key_dim_index: int
_fields: Sequence[str]
_all_dims: Sequence[str]
_ned: Sequence[Tuple[Any, Any]]
_query_kwargs: Dict[str, Any]
_transform: Optional[Callable[[Tensor], Tensor]]

@classmethod
def from_array_params(
cls,
array_params: ArrayParams,
transform: Optional[Callable[[Tensor], Tensor]] = None,
) -> TensorSchema:
kwargs = {"_" + k: v for k, v in array_params._tensor_schema_kwargs.items()}
return cls(_transform=transform, **kwargs)

@property
def fields(self) -> Sequence[str]:
"""Names of attributes and dimensions to read."""
return cast(Sequence[str], self._fields)
return self._fields

@property
def field_dtypes(self) -> Sequence[np.dtype]:
Expand Down Expand Up @@ -182,20 +194,15 @@ def max_partition_weight(self) -> int:
class SparseTensorSchema(TensorSchema):
sparse = True

def __init__(
self,
array_params: ArrayParams,
transform: Optional[Callable[[Tensor], Tensor]] = None,
):
super().__init__(array_params)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._query_kwargs["dims"] = self._all_dims
key_counter: Counter[Any] = Counter()
key_dim = self._all_dims[0]
query = self._array.query(dims=(key_dim,), attrs=(), return_incomplete=True)
for result in query.multi_index[:]:
key_counter.update(result[key_dim])
self._key_range = InclusiveRange.factory(key_counter)
self._query_kwargs["dims"] = self._all_dims
self._transform = transform

@property
def key_range(self) -> InclusiveRange[Any, int]:
Expand Down
6 changes: 3 additions & 3 deletions tiledb/ml/readers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def _worker_init(worker_id: int) -> None:

def _get_tensor_schema(array_params: ArrayParams) -> TensorSchema:
if not array_params.array.schema.sparse:
return DenseTensorSchema(array_params)
return DenseTensorSchema.from_array_params(array_params)
elif array_params.array.ndim == 2:
return SparseTensorSchema(array_params, methodcaller("tocsr"))
return SparseTensorSchema.from_array_params(array_params, methodcaller("tocsr"))
else:
return SparseTensorSchema(array_params)
return SparseTensorSchema.from_array_params(array_params)


_SingleCollator = Callable[[TensorLikeSequence], torch.Tensor]
Expand Down
4 changes: 2 additions & 2 deletions tiledb/ml/readers/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def _get_tensor_specs(schema: TensorSchema) -> Union[TensorSpec, Sequence[Tensor

def _get_tensor_schema(array_params: ArrayParams) -> TensorSchema:
if not array_params.array.schema.sparse:
return DenseTensorSchema(array_params)
return DenseTensorSchema.from_array_params(array_params)
else:
return SparseTensorSchema(array_params, _coo_to_sparse_tensor)
return SparseTensorSchema.from_array_params(array_params, _coo_to_sparse_tensor)


def _coo_to_sparse_tensor(coo: sparse.COO) -> tf.SparseTensor:
Expand Down

0 comments on commit 233d1ee

Please sign in to comment.