Skip to content

Commit

Permalink
Add python support for Array (#2492)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Support insert and select Array data in python clients

Issue link:#2455

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Breaking Change (fix or feature that could cause existing
functionality not to work as expected)
- [x] Refactoring
- [x] Test cases
- [x] Python SDK impacted, Need to update PyPI
  • Loading branch information
yangzq50 authored Jan 23, 2025
1 parent 939d180 commit 5c97af6
Show file tree
Hide file tree
Showing 18 changed files with 3,198 additions and 2,253 deletions.
21 changes: 19 additions & 2 deletions python/infinity_embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from pathlib import Path
from typing import Union
from dataclasses import dataclass
from dataclasses import dataclass, field
import numpy as np


Expand Down Expand Up @@ -60,9 +60,26 @@ def __repr__(self):
return str(self)


@dataclass
class Array:
elements: list = field(default_factory=list)

def append(self, element):
self.elements.append(element)

def __init__(self, *args):
self.elements = list(args)

def __str__(self):
return f"Array({', '.join(str(e) for e in self.elements)})"

def __repr__(self):
return str(self)


URI = Union[NetworkAddress, Path]
VEC = Union[list, np.ndarray]
INSERT_DATA = dict[str, Union[str, int, float, list[Union[int, float]]], SparseVector, dict]
INSERT_DATA = dict[str, Union[str, int, float, list[Union[int, float]]], SparseVector, dict, Array]

LOCAL_HOST = NetworkAddress("127.0.0.1", 23817)

Expand Down
311 changes: 215 additions & 96 deletions python/infinity_embedded/local_infinity/types.py

Large diffs are not rendered by default.

23 changes: 19 additions & 4 deletions python/infinity_embedded/local_infinity/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sqlglot.expressions as exp
import numpy as np
from infinity_embedded.errors import ErrorCode
from infinity_embedded.common import InfinityException, SparseVector
from infinity_embedded.common import InfinityException, SparseVector, Array
from infinity_embedded.local_infinity.types import build_result, logic_type_to_dtype
from infinity_embedded.utils import binary_exp_to_paser_exp
from infinity_embedded.embedded_infinity_ext import WrapInExpr, WrapParsedExpr, WrapFunctionExpr, \
Expand Down Expand Up @@ -297,7 +297,10 @@ def get_local_constant_expr_from_python_value(value) -> WrapConstantExpr:
case _:
raise InfinityException(ErrorCode.INVALID_EXPRESSION,
f"Invalid sparse vector value type: {type(next(iter(value.values())))}")

case Array():
constant_expression.literal_type = LiteralType.kCurlyBracketsArray
constant_expression.curly_brackets_array = [get_local_constant_expr_from_python_value(child) for child in
value.elements]
case _:
raise InfinityException(ErrorCode.INVALID_EXPRESSION, f"Invalid constant type: {type(value)}")
return constant_expression
Expand Down Expand Up @@ -506,6 +509,10 @@ def get_data_type(column_info: dict) -> WrapDataType:
raise InfinityException(ErrorCode.NO_COLUMN_DEFINED, f"Column definition without data type")
datatype = column_info["type"].lower()
column_big_info = [item.strip() for item in datatype.split(",")]
return get_data_type_from_column_big_info(column_big_info)


def get_data_type_from_column_big_info(column_big_info: list) -> WrapDataType:
column_big_info_first_str = column_big_info[0]
match column_big_info_first_str:
case "vector" | "multivector" | "tensor" | "tensorarray":
Expand All @@ -516,9 +523,17 @@ def get_data_type(column_info: dict) -> WrapDataType:
sparse_type = get_sparse_type(column_big_info)
return sparse_type
# return get_sparse_info(column_info, column_defs, column_name, index)
case "array":
proto_column_type = WrapDataType()
proto_column_type.logical_type = LogicalType.kArray
proto_column_type.array_type = get_data_type_from_column_big_info(column_big_info[1:])
return proto_column_type
case _:
if len(column_big_info) > 1:
raise InfinityException(ErrorCode.INVALID_DATA_TYPE,
f"Unknown datatype: {column_big_info}, too many arguments")
proto_column_type = WrapDataType()
match datatype:
match column_big_info_first_str:
case "int8":
proto_column_type.logical_type = LogicalType.kTinyInt
case "int16":
Expand Down Expand Up @@ -550,7 +565,7 @@ def get_data_type(column_info: dict) -> WrapDataType:
case "timestamp":
proto_column_type.logical_type = LogicalType.kTimestamp
case _:
raise InfinityException(ErrorCode.INVALID_DATA_TYPE, f"Unknown datatype: {datatype}")
raise InfinityException(ErrorCode.INVALID_DATA_TYPE, f"Unknown datatype: {column_big_info_first_str}")
return proto_column_type


Expand Down
21 changes: 19 additions & 2 deletions python/infinity_sdk/infinity/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from pathlib import Path
from typing import Union
from dataclasses import dataclass
from dataclasses import dataclass, field
import numpy as np


Expand Down Expand Up @@ -59,9 +59,26 @@ def __repr__(self):
return str(self)


@dataclass
class Array:
elements: list = field(default_factory=list)

def append(self, element):
self.elements.append(element)

def __init__(self, *args):
self.elements = list(args)

def __str__(self):
return f"Array({', '.join(str(e) for e in self.elements)})"

def __repr__(self):
return str(self)


URI = Union[NetworkAddress, Path]
VEC = Union[list, np.ndarray]
INSERT_DATA = dict[str, Union[str, int, float, list[Union[int, float]]], SparseVector, dict]
INSERT_DATA = dict[str, Union[str, int, float, list[Union[int, float]]], SparseVector, dict, Array]

LOCAL_HOST = NetworkAddress("127.0.0.1", 23817)

Expand Down
Loading

0 comments on commit 5c97af6

Please sign in to comment.