From 81f6aef7fb42faff5999c6b319d8fdd4ba48dda2 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 11 Mar 2024 17:36:55 +0100 Subject: [PATCH 1/4] refactor osiris deserializer --- osiris/cairo/serde/deserialize.py | 341 ++++++++++++++++++------------ tests/test_deserialize.py | 36 ++-- 2 files changed, 223 insertions(+), 154 deletions(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index bc7ea7b..f86dcf7 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -1,172 +1,241 @@ -import json - import numpy as np - -from .utils import felt_to_int, from_fp - - -def deserializer(serialized: str, dtype: str): - # Check if the serialized data is a string and needs conversion - if isinstance(serialized, str): - serialized = convert_data(serialized) - - # Function to deserialize individual elements within a tuple - def deserialize_element(element, element_type): - if element_type in ("u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"): - return deserialize_int(element) - elif element_type.startswith("FP"): - return deserialize_fixed_point(element, element_type) - elif element_type.startswith("Span<") and element_type.endswith(">"): - inner_type = element_type[5:-1] - if inner_type.startswith("FP"): - return deserialize_arr_fixed_point(element, inner_type) - else: - return deserialize_arr_int(element) - elif element_type.startswith("Tensor<") and element_type.endswith(">"): - inner_type = element_type[7:-1] - if inner_type.startswith("FP"): - return deserialize_tensor_fixed_point(element, inner_type) - else: - return deserialize_tensor_int(element) - elif element_type.startswith("(") and element_type.endswith(")"): - # Recursive call for nested tuples - return deserializer(element, element_type) +from math import isclose + + +def deserializer(serialized, dtype): + if dtype in ['u32', 'i32']: + return int(serialized) + + elif dtype == 'FP16x16': + parts = serialized.split() + value = int(parts[0]) / 2**16 + if len(parts) > 1 and parts[1] == '1': # Check for negative sign + value = -value + return value + + elif dtype.startswith('Span<'): + inner_type = dtype[5:-1] + if 'FP16x16' in inner_type: + # For FP16x16, elements consist of two parts (value and sign) + elements = serialized[1:-1].split() + deserialized_elements = [] + for i in range(0, len(elements), 2): + element = ' '.join(elements[i:i+2]) + deserialized_elements.append(deserializer(element, inner_type)) + return np.array(deserialized_elements, dtype=np.float64) else: - raise ValueError(f"Unsupported data type: {element_type}") - - # Handle tuple data type - if dtype.startswith("(") and dtype.endswith(")"): - types = dtype[1:-1].split(", ") - deserialized_elements = [] - i = 0 # Initialize loop counter - - while i < len(serialized): - ele_type = types[len(deserialized_elements)] - - if ele_type.startswith("Tensor<"): - # For Tensors, take two elements from serialized (shape and data) - ele = serialized[i:i+2] - i += 2 - else: - # For other types, take one element - ele = serialized[i] - i += 1 - - if ele_type.startswith("Tensor<"): - deserialized_elements.append( - deserialize_element(ele, ele_type)) - else: - deserialized_elements.append( - deserialize_element([ele], ele_type)) - - if len(deserialized_elements) != len(types): - raise ValueError( - "Serialized data length does not match tuple length") - - return tuple(deserialized_elements) - - else: - return deserialize_element(serialized, dtype) - - -def parse_return_value(return_value): - """ - Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues). - """ - if 'Int' in return_value: - # Convert hexadecimal string to integer - return int(return_value['Int'], 16) - elif 'Array' in return_value: - # Recursively parse each item in the array - return [parse_return_value(item) for item in return_value['Array']] - else: - raise ValueError("Invalid ReturnValue format") - - -def convert_data(data): - """ - Convert the given JSON-like data structure to the desired format. - """ - parsed_data = json.loads(data) - result = [] - for item in parsed_data: - # Parse each item based on its keys - if 'Array' in item: - # Process array items - result.append(parse_return_value(item)) - elif 'Int' in item: - # Process single int items - result.append(parse_return_value(item)) + elements = serialized[1:-1].split() + return np.array([deserializer(e, inner_type) for e in elements], dtype=np.int64) + + elif dtype.startswith('Tensor<'): + inner_type = dtype[7:-1] + parts = serialized.split('] [') + dims = [int(d) for d in parts[0][1:].split()] + if 'FP16x16' in inner_type: + values = parts[1][:-1].split() # Split the values normally first + # Now, process every two items (value and sign) as one FP16x16 element + tensor_data = np.array([deserializer( + ' '.join(values[i:i+2]), inner_type) for i in range(0, len(values), 2)]) else: - raise ValueError("Invalid data format") - return result + values = parts[1][:-1].split() + tensor_data = np.array( + [deserializer(v, inner_type) for v in values]) + return tensor_data.reshape(dims) + + elif dtype.startswith('('): # Tuple + types = dtype[1:-1].split(', ') + if 'Tensor' in types[0]: # Handling Tensor as the first element in the tuple + tensor_end = serialized.find(']') + 2 # Find the end of the Tensor definition + # Handle cases where there might be nested arrays or tensors + depth = 1 + for i in range(tensor_end, len(serialized)): + if serialized[i] == '[': + depth += 1 + elif serialized[i] == ']': + depth -= 1 + if depth == 0: + tensor_end = i + 1 + break + part1 = deserializer(serialized[:tensor_end].strip(), types[0]) + part2 = deserializer(serialized[tensor_end:].strip(), types[1]) + else: + split_index = serialized.find(']') + 2 + part1 = deserializer(serialized[:split_index].strip(), types[0]) + part2 = deserializer(serialized[split_index:].strip(), types[1]) + return part1, part2 + + else: + raise ValueError(f"Unknown data type: {dtype}") + +# import json + +# import numpy as np + +# from .utils import felt_to_int, from_fp + + +# def deserializer(serialized: str, dtype: str): +# # Check if the serialized data is a string and needs conversion +# if isinstance(serialized, str): +# serialized = convert_data(serialized) + +# # Function to deserialize individual elements within a tuple +# def deserialize_element(element, element_type): +# if element_type in ("u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"): +# return deserialize_int(element) +# elif element_type.startswith("FP"): +# return deserialize_fixed_point(element, element_type) +# elif element_type.startswith("Span<") and element_type.endswith(">"): +# inner_type = element_type[5:-1] +# if inner_type.startswith("FP"): +# return deserialize_arr_fixed_point(element, inner_type) +# else: +# return deserialize_arr_int(element) +# elif element_type.startswith("Tensor<") and element_type.endswith(">"): +# inner_type = element_type[7:-1] +# if inner_type.startswith("FP"): +# return deserialize_tensor_fixed_point(element, inner_type) +# else: +# return deserialize_tensor_int(element) +# elif element_type.startswith("(") and element_type.endswith(")"): +# # Recursive call for nested tuples +# return deserializer(element, element_type) +# else: +# raise ValueError(f"Unsupported data type: {element_type}") + +# # Handle tuple data type +# if dtype.startswith("(") and dtype.endswith(")"): +# types = dtype[1:-1].split(", ") +# deserialized_elements = [] +# i = 0 # Initialize loop counter + +# while i < len(serialized): +# ele_type = types[len(deserialized_elements)] + +# if ele_type.startswith("Tensor<"): +# # For Tensors, take two elements from serialized (shape and data) +# ele = serialized[i:i+2] +# i += 2 +# else: +# # For other types, take one element +# ele = serialized[i] +# i += 1 + +# if ele_type.startswith("Tensor<"): +# deserialized_elements.append( +# deserialize_element(ele, ele_type)) +# else: +# deserialized_elements.append( +# deserialize_element([ele], ele_type)) + +# if len(deserialized_elements) != len(types): +# raise ValueError( +# "Serialized data length does not match tuple length") + +# return tuple(deserialized_elements) + +# else: +# return deserialize_element(serialized, dtype) + + +# def parse_return_value(return_value): +# """ +# Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues). +# """ +# if 'Int' in return_value: +# # Convert hexadecimal string to integer +# return int(return_value['Int'], 16) +# elif 'Array' in return_value: +# # Recursively parse each item in the array +# return [parse_return_value(item) for item in return_value['Array']] +# else: +# raise ValueError("Invalid ReturnValue format") + + +# def convert_data(data): +# """ +# Convert the given JSON-like data structure to the desired format. +# """ +# parsed_data = json.loads(data) +# result = [] +# for item in parsed_data: +# # Parse each item based on its keys +# if 'Array' in item: +# # Process array items +# result.append(parse_return_value(item)) +# elif 'Int' in item: +# # Process single int items +# result.append(parse_return_value(item)) +# else: +# raise ValueError("Invalid data format") +# return result -# ================= INT ================= +# # ================= INT ================= -def deserialize_int(serialized: list) -> np.int64: - return np.int64(felt_to_int(serialized[0])) +# def deserialize_int(serialized: list) -> np.int64: +# return np.int64(felt_to_int(serialized[0])) -# ================= FIXED POINT ================= +# # ================= FIXED POINT ================= -def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64: - serialized_mag = from_fp(serialized[0], impl) - serialized_sign = serialized[1] +# def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64: +# serialized_mag = from_fp(serialized[0], impl) +# serialized_sign = serialized[1] - deserialized = serialized_mag if serialized_sign == 0 else -serialized_mag - return np.float64(deserialized) +# deserialized = serialized_mag if serialized_sign == 0 else -serialized_mag +# return np.float64(deserialized) -# ================= ARRAY INT ================= +# # ================= ARRAY INT ================= -def deserialize_arr_int(serialized): +# def deserialize_arr_int(serialized): - serialized = serialized[0] +# serialized = serialized[0] - deserialized = [] - for ele in serialized: - deserialized.append(felt_to_int(ele)) +# deserialized = [] +# for ele in serialized: +# deserialized.append(felt_to_int(ele)) - return np.array(deserialized) +# return np.array(deserialized) -# ================= ARRAY FIXED POINT ================= +# # ================= ARRAY FIXED POINT ================= -def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'): +# def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'): - serialized = serialized[0] +# serialized = serialized[0] - if len(serialized) % 2 != 0: - raise ValueError("Array length must be even") +# if len(serialized) % 2 != 0: +# raise ValueError("Array length must be even") - deserialized = [] - for i in range(0, len(serialized), 2): - mag = serialized[i] - sign = serialized[i + 1] +# deserialized = [] +# for i in range(0, len(serialized), 2): +# mag = serialized[i] +# sign = serialized[i + 1] - deserialized.append(deserialize_fixed_point([mag, sign], impl)) +# deserialized.append(deserialize_fixed_point([mag, sign], impl)) - return np.array(deserialized) +# return np.array(deserialized) -# ================= TENSOR INT ================= +# # ================= TENSOR INT ================= -def deserialize_tensor_int(serialized: list) -> np.array: - shape = serialized[0] - data = deserialize_arr_int([serialized[1]]) +# def deserialize_tensor_int(serialized: list) -> np.array: +# shape = serialized[0] +# data = deserialize_arr_int([serialized[1]]) - return np.array(data, dtype=np.int64).reshape(shape) +# return np.array(data, dtype=np.int64).reshape(shape) -# ================= TENSOR FIXED POINT ================= +# # ================= TENSOR FIXED POINT ================= -def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array: - shape = serialized[0] - data = deserialize_arr_fixed_point([serialized[1]], impl) +# def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array: +# shape = serialized[0] +# data = deserialize_arr_fixed_point([serialized[1]], impl) - return np.array(data, dtype=np.float64).reshape(shape) +# return np.array(data, dtype=np.float64).reshape(shape) diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index e4f14c5..8734397 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -7,68 +7,68 @@ def test_deserialize_int(): - serialized = '[{"Int":"2A"}]' + serialized = '42' deserialized = deserializer(serialized, 'u32') assert deserialized == 42 - serialized = '[{"Int":"800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]' + serialized = '-42' deserialized = deserializer(serialized, 'i32') assert deserialized == -42 def test_deserialize_fp(): - serialized = '[{"Int":"2A6B85"}, {"Int":"0"}]' + serialized = '2780037 0' deserialized = deserializer(serialized, 'FP16x16') assert isclose(deserialized, 42.42, rel_tol=1e-7) - serialized = '[{"Int":"2A6B85"}, {"Int":"1"}]' + serialized = '2780037 1' deserialized = deserializer(serialized, 'FP16x16') assert isclose(deserialized, -42.42, rel_tol=1e-7) def test_deserialize_array_int(): - serialized = '[{"Array": [{"Int": "0x1"}, {"Int": "0x2"}]}]' + serialized = '[1 2]' deserialized = deserializer(serialized, 'Span') assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64)) - serialized = '[{"Array": [{"Int": "2A"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]' + serialized = '[42 -42]' deserialized = deserializer(serialized, 'Span') assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64)) def test_deserialize_arr_fixed_point(): - serialized = '[{"Array": [{"Int": "2A6B85"}, {"Int": "0"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]' + serialized = '[2780037 0 2780037 1]' deserialized = deserializer(serialized, 'Span') expected = np.array([42.42, -42.42], dtype=np.float64) assert np.all(np.isclose(deserialized, expected, atol=1e-7)) def test_deserialize_tensor_int(): - serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "0x1"}, {"Int": "0x2"}, {"Int": "0x3"}, {"Int": "0x4"}]}]' + serialized = '[2 2] [1 2 3 4]' deserialized = deserializer(serialized, 'Tensor') assert np.array_equal(deserialized, np.array( ([1, 2], [3, 4]), dtype=np.int64)) - serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A"}, {"Int": "2A"},{"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}, {"Int": "800000000000010FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFD7"}]}]' + serialized = '[2 2] [42 42 -42 -42]' deserialized = deserializer(serialized, 'Tensor') assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]])) def test_deserialize_tensor_fixed_point(): - serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]' + serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1]' expected_array = np.array([[42.42, 42.42], [-42.42, -42.42]]) deserialized = deserializer(serialized, 'Tensor') assert np.allclose(deserialized, expected_array, atol=1e-7) def test_deserialize_tuple_int(): - serialized = '[{"Int":"0x1"},{"Int":"0x3"}]' + serialized = '1 3' deserialized = deserializer(serialized, '(u32, u32)') assert deserialized == (1, 3) def test_deserialize_tuple_span(): - serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Int":"0x3"}]' + serialized = '[1 2] 3' deserialized = deserializer(serialized, '(Span, u32)') expected = (np.array([1, 2]), 3) npt.assert_array_equal(deserialized[0], expected[0]) @@ -76,14 +76,14 @@ def test_deserialize_tuple_span(): def test_deserialize_tuple_span_tensor_fp(): - serialized = '[{"Array":[{"Int":"0x1"},{"Int":"0x2"}]},{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}]' + serialized = '[1 2] [2 2] [2780037 0 2780037 0 2780037 1 2780037 1]' deserialized = deserializer(serialized, '(Span, Tensor)') expected = (np.array([1, 2]), np.array([[42.42, 42.42], [-42.42, -42.42]])) npt.assert_array_equal(deserialized[0], expected[0]) assert np.allclose(deserialized[1], expected[1], atol=1e-7) - serialized = '[{"Array": [{"Int": "0x2"}, {"Int": "0x2"}]}, {"Array": [{"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x0"}, {"Int": "2A6B85"}, {"Int": "0x1"}, {"Int": "2A6B85"}, {"Int": "0x1"}]}, {"Array":[{"Int":"0x1"},{"Int":"0x2"}]}]' - deserialized = deserializer(serialized, '(Tensor, Span)') - expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) - assert np.allclose(deserialized[0], expected[0], atol=1e-7) - npt.assert_array_equal(deserialized[1], expected[1]) + # serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]' + # deserialized = deserializer(serialized, '(Tensor, Span)') + # expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) + # assert np.allclose(deserialized[0], expected[0], atol=1e-7) + # npt.assert_array_equal(deserialized[1], expected[1]) From 6526e81c3a1f78ce8227c6849aa4404b0a9aa010 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 09:40:42 +0100 Subject: [PATCH 2/4] fix complex tuples --- osiris/cairo/serde/deserialize.py | 183 ++---------------------------- tests/test_deserialize.py | 11 +- 2 files changed, 15 insertions(+), 179 deletions(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index f86dcf7..1d5ca39 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -45,7 +45,9 @@ def deserializer(serialized, dtype): elif dtype.startswith('('): # Tuple types = dtype[1:-1].split(', ') if 'Tensor' in types[0]: # Handling Tensor as the first element in the tuple - tensor_end = serialized.find(']') + 2 # Find the end of the Tensor definition + # Find the end of the Tensor definition + tensor_end = find_nth_occurrence(serialized, ']', 2) + # Handle cases where there might be nested arrays or tensors depth = 1 for i in range(tensor_end, len(serialized)): @@ -63,179 +65,14 @@ def deserializer(serialized, dtype): part1 = deserializer(serialized[:split_index].strip(), types[0]) part2 = deserializer(serialized[split_index:].strip(), types[1]) return part1, part2 - + else: raise ValueError(f"Unknown data type: {dtype}") -# import json - -# import numpy as np - -# from .utils import felt_to_int, from_fp - - -# def deserializer(serialized: str, dtype: str): -# # Check if the serialized data is a string and needs conversion -# if isinstance(serialized, str): -# serialized = convert_data(serialized) - -# # Function to deserialize individual elements within a tuple -# def deserialize_element(element, element_type): -# if element_type in ("u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"): -# return deserialize_int(element) -# elif element_type.startswith("FP"): -# return deserialize_fixed_point(element, element_type) -# elif element_type.startswith("Span<") and element_type.endswith(">"): -# inner_type = element_type[5:-1] -# if inner_type.startswith("FP"): -# return deserialize_arr_fixed_point(element, inner_type) -# else: -# return deserialize_arr_int(element) -# elif element_type.startswith("Tensor<") and element_type.endswith(">"): -# inner_type = element_type[7:-1] -# if inner_type.startswith("FP"): -# return deserialize_tensor_fixed_point(element, inner_type) -# else: -# return deserialize_tensor_int(element) -# elif element_type.startswith("(") and element_type.endswith(")"): -# # Recursive call for nested tuples -# return deserializer(element, element_type) -# else: -# raise ValueError(f"Unsupported data type: {element_type}") - -# # Handle tuple data type -# if dtype.startswith("(") and dtype.endswith(")"): -# types = dtype[1:-1].split(", ") -# deserialized_elements = [] -# i = 0 # Initialize loop counter - -# while i < len(serialized): -# ele_type = types[len(deserialized_elements)] - -# if ele_type.startswith("Tensor<"): -# # For Tensors, take two elements from serialized (shape and data) -# ele = serialized[i:i+2] -# i += 2 -# else: -# # For other types, take one element -# ele = serialized[i] -# i += 1 - -# if ele_type.startswith("Tensor<"): -# deserialized_elements.append( -# deserialize_element(ele, ele_type)) -# else: -# deserialized_elements.append( -# deserialize_element([ele], ele_type)) - -# if len(deserialized_elements) != len(types): -# raise ValueError( -# "Serialized data length does not match tuple length") - -# return tuple(deserialized_elements) - -# else: -# return deserialize_element(serialized, dtype) - - -# def parse_return_value(return_value): -# """ -# Parse a ReturnValue dictionary to extract the integer value or recursively parse an array of ReturnValues (cf: OrionRunner ReturnValues). -# """ -# if 'Int' in return_value: -# # Convert hexadecimal string to integer -# return int(return_value['Int'], 16) -# elif 'Array' in return_value: -# # Recursively parse each item in the array -# return [parse_return_value(item) for item in return_value['Array']] -# else: -# raise ValueError("Invalid ReturnValue format") - - -# def convert_data(data): -# """ -# Convert the given JSON-like data structure to the desired format. -# """ -# parsed_data = json.loads(data) -# result = [] -# for item in parsed_data: -# # Parse each item based on its keys -# if 'Array' in item: -# # Process array items -# result.append(parse_return_value(item)) -# elif 'Int' in item: -# # Process single int items -# result.append(parse_return_value(item)) -# else: -# raise ValueError("Invalid data format") -# return result - - -# # ================= INT ================= - - -# def deserialize_int(serialized: list) -> np.int64: -# return np.int64(felt_to_int(serialized[0])) - - -# # ================= FIXED POINT ================= - - -# def deserialize_fixed_point(serialized: list, impl='FP16x16') -> np.float64: -# serialized_mag = from_fp(serialized[0], impl) -# serialized_sign = serialized[1] - -# deserialized = serialized_mag if serialized_sign == 0 else -serialized_mag -# return np.float64(deserialized) - - -# # ================= ARRAY INT ================= - - -# def deserialize_arr_int(serialized): - -# serialized = serialized[0] - -# deserialized = [] -# for ele in serialized: -# deserialized.append(felt_to_int(ele)) - -# return np.array(deserialized) - -# # ================= ARRAY FIXED POINT ================= - - -# def deserialize_arr_fixed_point(serialized: list, impl='FP16x16'): - -# serialized = serialized[0] - -# if len(serialized) % 2 != 0: -# raise ValueError("Array length must be even") - -# deserialized = [] -# for i in range(0, len(serialized), 2): -# mag = serialized[i] -# sign = serialized[i + 1] - -# deserialized.append(deserialize_fixed_point([mag, sign], impl)) - -# return np.array(deserialized) - - -# # ================= TENSOR INT ================= - - -# def deserialize_tensor_int(serialized: list) -> np.array: -# shape = serialized[0] -# data = deserialize_arr_int([serialized[1]]) - -# return np.array(data, dtype=np.int64).reshape(shape) - - -# # ================= TENSOR FIXED POINT ================= - -# def deserialize_tensor_fixed_point(serialized: list, impl='FP16x16') -> np.array: -# shape = serialized[0] -# data = deserialize_arr_fixed_point([serialized[1]], impl) -# return np.array(data, dtype=np.float64).reshape(shape) +def find_nth_occurrence(string, sub_string, n): + start_index = string.find(sub_string) + while start_index >= 0 and n > 1: + start_index = string.find(sub_string, start_index + 1) + n -= 1 + return start_index diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index 8734397..9b71b17 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -1,6 +1,5 @@ import numpy as np import numpy.testing as npt -import pytest from math import isclose from osiris.cairo.serde.deserialize import * @@ -82,8 +81,8 @@ def test_deserialize_tuple_span_tensor_fp(): npt.assert_array_equal(deserialized[0], expected[0]) assert np.allclose(deserialized[1], expected[1], atol=1e-7) - # serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]' - # deserialized = deserializer(serialized, '(Tensor, Span)') - # expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) - # assert np.allclose(deserialized[0], expected[0], atol=1e-7) - # npt.assert_array_equal(deserialized[1], expected[1]) + serialized = '[2 2] [2780037 0 2780037 0 2780037 1 2780037 1] [1 2]' + deserialized = deserializer(serialized, '(Tensor, Span)') + expected = (np.array([[42.42, 42.42], [-42.42, -42.42]]), np.array([1, 2])) + assert np.allclose(deserialized[0], expected[0], atol=1e-7) + npt.assert_array_equal(deserialized[1], expected[1]) From 8ac0cdb8d684f6871afd7d13a9e2ba21285dce3f Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 10:01:13 +0100 Subject: [PATCH 3/4] handle negativity and non default fp --- osiris/cairo/serde/deserialize.py | 18 +++++++++--------- tests/test_deserialize.py | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index 1d5ca39..c624018 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -1,22 +1,22 @@ import numpy as np -from math import isclose +from osiris.cairo.serde.utils import felt_to_int, from_fp def deserializer(serialized, dtype): - if dtype in ['u32', 'i32']: - return int(serialized) + if dtype in ["u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]: + return felt_to_int(int(serialized)) - elif dtype == 'FP16x16': + elif dtype.startswith("FP"): parts = serialized.split() - value = int(parts[0]) / 2**16 + value = from_fp(int(parts[0])) if len(parts) > 1 and parts[1] == '1': # Check for negative sign value = -value return value elif dtype.startswith('Span<'): inner_type = dtype[5:-1] - if 'FP16x16' in inner_type: - # For FP16x16, elements consist of two parts (value and sign) + if inner_type.startswith("FP"): + # For fixed point, elements consist of two parts (value and sign) elements = serialized[1:-1].split() deserialized_elements = [] for i in range(0, len(elements), 2): @@ -31,9 +31,9 @@ def deserializer(serialized, dtype): inner_type = dtype[7:-1] parts = serialized.split('] [') dims = [int(d) for d in parts[0][1:].split()] - if 'FP16x16' in inner_type: + if inner_type.startswith("FP"): values = parts[1][:-1].split() # Split the values normally first - # Now, process every two items (value and sign) as one FP16x16 element + # Now, process every two items (value and sign) as one fixed point element tensor_data = np.array([deserializer( ' '.join(values[i:i+2]), inner_type) for i in range(0, len(values), 2)]) else: diff --git a/tests/test_deserialize.py b/tests/test_deserialize.py index 9b71b17..e78502e 100644 --- a/tests/test_deserialize.py +++ b/tests/test_deserialize.py @@ -10,7 +10,7 @@ def test_deserialize_int(): deserialized = deserializer(serialized, 'u32') assert deserialized == 42 - serialized = '-42' + serialized = '3618502788666131213697322783095070105623107215331596699973092056135872020439' deserialized = deserializer(serialized, 'i32') assert deserialized == -42 @@ -30,7 +30,7 @@ def test_deserialize_array_int(): deserialized = deserializer(serialized, 'Span') assert np.array_equal(deserialized, np.array([1, 2], dtype=np.int64)) - serialized = '[42 -42]' + serialized = '[42 3618502788666131213697322783095070105623107215331596699973092056135872020439]' deserialized = deserializer(serialized, 'Span') assert np.array_equal(deserialized, np.array([42, -42], dtype=np.int64)) @@ -48,7 +48,7 @@ def test_deserialize_tensor_int(): assert np.array_equal(deserialized, np.array( ([1, 2], [3, 4]), dtype=np.int64)) - serialized = '[2 2] [42 42 -42 -42]' + serialized = '[2 2] [42 42 3618502788666131213697322783095070105623107215331596699973092056135872020439 3618502788666131213697322783095070105623107215331596699973092056135872020439]' deserialized = deserializer(serialized, 'Tensor') assert np.array_equal(deserialized, np.array([[42, 42], [-42, -42]])) From a238162cf60a4a8080f1485f38c5d75781d6862d Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 14 Mar 2024 10:04:08 +0100 Subject: [PATCH 4/4] refactor code --- osiris/cairo/serde/deserialize.py | 114 ++++++++++++++++-------------- 1 file changed, 62 insertions(+), 52 deletions(-) diff --git a/osiris/cairo/serde/deserialize.py b/osiris/cairo/serde/deserialize.py index c624018..f7158e7 100644 --- a/osiris/cairo/serde/deserialize.py +++ b/osiris/cairo/serde/deserialize.py @@ -2,74 +2,84 @@ from osiris.cairo.serde.utils import felt_to_int, from_fp + def deserializer(serialized, dtype): + if dtype in ["u8", "u16", "u32", "u64", "u128", "i8", "i16", "i32", "i64", "i128"]: return felt_to_int(int(serialized)) elif dtype.startswith("FP"): - parts = serialized.split() - value = from_fp(int(parts[0])) - if len(parts) > 1 and parts[1] == '1': # Check for negative sign - value = -value - return value + return deserialize_fp(serialized) elif dtype.startswith('Span<'): - inner_type = dtype[5:-1] - if inner_type.startswith("FP"): - # For fixed point, elements consist of two parts (value and sign) - elements = serialized[1:-1].split() - deserialized_elements = [] - for i in range(0, len(elements), 2): - element = ' '.join(elements[i:i+2]) - deserialized_elements.append(deserializer(element, inner_type)) - return np.array(deserialized_elements, dtype=np.float64) - else: - elements = serialized[1:-1].split() - return np.array([deserializer(e, inner_type) for e in elements], dtype=np.int64) + return deserialize_span(serialized, dtype) elif dtype.startswith('Tensor<'): - inner_type = dtype[7:-1] - parts = serialized.split('] [') - dims = [int(d) for d in parts[0][1:].split()] - if inner_type.startswith("FP"): - values = parts[1][:-1].split() # Split the values normally first - # Now, process every two items (value and sign) as one fixed point element - tensor_data = np.array([deserializer( - ' '.join(values[i:i+2]), inner_type) for i in range(0, len(values), 2)]) - else: - values = parts[1][:-1].split() - tensor_data = np.array( - [deserializer(v, inner_type) for v in values]) - return tensor_data.reshape(dims) + return deserialize_tensor(serialized, dtype) elif dtype.startswith('('): # Tuple - types = dtype[1:-1].split(', ') - if 'Tensor' in types[0]: # Handling Tensor as the first element in the tuple - # Find the end of the Tensor definition - tensor_end = find_nth_occurrence(serialized, ']', 2) - - # Handle cases where there might be nested arrays or tensors - depth = 1 - for i in range(tensor_end, len(serialized)): - if serialized[i] == '[': - depth += 1 - elif serialized[i] == ']': - depth -= 1 - if depth == 0: - tensor_end = i + 1 - break - part1 = deserializer(serialized[:tensor_end].strip(), types[0]) - part2 = deserializer(serialized[tensor_end:].strip(), types[1]) - else: - split_index = serialized.find(']') + 2 - part1 = deserializer(serialized[:split_index].strip(), types[0]) - part2 = deserializer(serialized[split_index:].strip(), types[1]) - return part1, part2 + return deserialize_tuple(serialized, dtype) else: raise ValueError(f"Unknown data type: {dtype}") +def deserialize_fp(serialized): + parts = serialized.split() + value = from_fp(int(parts[0])) + if len(parts) > 1 and parts[1] == '1': # Check for negative sign + value = -value + return value + + +def deserialize_span(serialized, dtype): + inner_type = dtype[5:-1] + elements = serialized[1:-1].split() + if inner_type.startswith("FP"): + # For fixed point, elements consist of two parts (value and sign) + deserialized_elements = [deserializer(' '.join(elements[i:i + 2]), inner_type) + for i in range(0, len(elements), 2)] + return np.array(deserialized_elements, dtype=np.float64) + else: + return np.array([deserializer(e, inner_type) for e in elements], dtype=np.int64) + + +def deserialize_tensor(serialized, dtype): + inner_type = dtype[7:-1] + parts = serialized.split('] [') + dims = [int(d) for d in parts[0][1:].split()] + values = parts[1][:-1].split() + if inner_type.startswith("FP"): + tensor_data = np.array([deserializer(' '.join(values[i:i + 2]), inner_type) + for i in range(0, len(values), 2)]) + else: + tensor_data = np.array( + [deserializer(v, inner_type) for v in values]) + return tensor_data.reshape(dims) + + +def deserialize_tuple(serialized, dtype): + types = dtype[1:-1].split(', ') + if 'Tensor' in types[0]: + tensor_end = find_nth_occurrence(serialized, ']', 2) + depth = 1 + for i in range(tensor_end, len(serialized)): + if serialized[i] == '[': + depth += 1 + elif serialized[i] == ']': + depth -= 1 + if depth == 0: + tensor_end = i + 1 + break + part1 = deserializer(serialized[:tensor_end].strip(), types[0]) + part2 = deserializer(serialized[tensor_end:].strip(), types[1]) + else: + split_index = serialized.find(']') + 2 + part1 = deserializer(serialized[:split_index].strip(), types[0]) + part2 = deserializer(serialized[split_index:].strip(), types[1]) + return part1, part2 + + def find_nth_occurrence(string, sub_string, n): start_index = string.find(sub_string) while start_index >= 0 and n > 1: