From d7732a3e96b75a20545c9a0185e46b45b1acc3a3 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Fri, 28 May 2021 15:05:27 +0800 Subject: [PATCH] Fix bug: infer binary_vector failed (#134) * Fix bug: infer binary_vector failed Signed-off-by: zhenshan.cao * Fix bug: use fpath Signed-off-by: zhenshan.cao --- examples/collection.py | 10 +++++----- examples/hello_milvus.py | 2 +- pymilvus_orm/types.py | 4 ++-- setup.py | 2 +- tests/test_types.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/collection.py b/examples/collection.py index cd0fe2f..fd41504 100644 --- a/examples/collection.py +++ b/examples/collection.py @@ -64,7 +64,7 @@ def gen_default_fields(): default_fields = [ FieldSchema(name="int64", dtype=DataType.INT64, is_primary=False), - FieldSchema(name="float", dtype=DataType.FLOAT), + FieldSchema(name="double", dtype=DataType.DOUBLE), FieldSchema(name=default_float_vec_field_name, dtype=DataType.FLOAT_VECTOR, dim=default_dim) ] default_schema = CollectionSchema(fields=default_fields, description="test collection") @@ -74,7 +74,7 @@ def gen_default_fields(): def gen_default_fields_with_primary_key_1(): default_fields = [ FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True), - FieldSchema(name="float", dtype=DataType.FLOAT), + FieldSchema(name="double", dtype=DataType.DOUBLE), FieldSchema(name=default_float_vec_field_name, dtype=DataType.FLOAT_VECTOR, dim=default_dim) ] default_schema = CollectionSchema(fields=default_fields, description="test collection") @@ -84,7 +84,7 @@ def gen_default_fields_with_primary_key_1(): def gen_default_fields_with_primary_key_2(): default_fields = [ FieldSchema(name="int64", dtype=DataType.INT64), - FieldSchema(name="float", dtype=DataType.FLOAT), + FieldSchema(name="double", dtype=DataType.DOUBLE), FieldSchema(name=default_float_vec_field_name, dtype=DataType.FLOAT_VECTOR, dim=default_dim) ] default_schema = CollectionSchema(fields=default_fields, description="test collection", primary_field="int64") @@ -94,7 +94,7 @@ def gen_default_fields_with_primary_key_2(): def gen_binary_schema(): binary_fields = [ FieldSchema(name="int64", dtype=DataType.INT64, is_primary=False), - FieldSchema(name="float", dtype=DataType.FLOAT), + FieldSchema(name="double", dtype=DataType.DOUBLE), FieldSchema(name=default_binary_vec_field_name, dtype=DataType.BINARY_VECTOR, dim=default_dim) ] default_schema = CollectionSchema(fields=binary_fields, description="test collection") @@ -226,4 +226,4 @@ def test_specify_primary_key(): test_collection_with_data() test_create_index_float_vector() test_create_index_binary_vector() -# test_specify_primary_key() +test_specify_primary_key() diff --git a/examples/hello_milvus.py b/examples/hello_milvus.py index eb92caa..d5ac73a 100644 --- a/examples/hello_milvus.py +++ b/examples/hello_milvus.py @@ -23,7 +23,7 @@ def hello_milvus(): dim = 128 default_fields = [ schema.FieldSchema(name="count", dtype=DataType.INT64, is_primary=False), - schema.FieldSchema(name="score", dtype=DataType.FLOAT), + schema.FieldSchema(name="score", dtype=DataType.DOUBLE), schema.FieldSchema(name="float_vector", dtype=DataType.FLOAT_VECTOR, dim=dim) ] default_schema = schema.CollectionSchema(fields=default_fields, description="test collection") diff --git a/pymilvus_orm/types.py b/pymilvus_orm/types.py index bab6d9e..b1ed84d 100644 --- a/pymilvus_orm/types.py +++ b/pymilvus_orm/types.py @@ -109,6 +109,8 @@ def infer_dtype_by_scaladata(data): return DataType.BOOL if isinstance(data, np.bool_): return DataType.BOOL + if isinstance(data, bytes): + return DataType.BINARY_VECTOR if is_float(data): return DataType.DOUBLE @@ -131,8 +133,6 @@ def infer_dtype_bydata(data): d_type = dtype_str_map.get(type_str, DataType.UNKNOWN) if is_numeric_datatype(d_type): d_type = DataType.FLOAT_VECTOR - elif type_str in ("bytes",): - d_type = DataType.BINARY_VECTOR else: d_type = DataType.UNKNOWN diff --git a/setup.py b/setup.py index ca15df3..bf58482 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ def simple_parse_requirements(fpath): requirements = [] - with open("requirements.txt", 'r') as f: + with open(fpath, 'r') as f: for line in f.readlines(): if line.startswith("--"): continue diff --git a/tests/test_types.py b/tests/test_types.py index 241313b..75d2216 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -67,7 +67,7 @@ def test_infer_dtype_bydata(self): [True], [1.0, 2.0], ["abc"], - [bytes("abc", encoding='ascii'), bytes("def", encoding='ascii')], + bytes("abc", encoding='ascii'), 1, True, "abc",