Skip to content

Commit

Permalink
feat.tensor-complex
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Dec 22, 2023
1 parent 63c0a8c commit cb8de91
Show file tree
Hide file tree
Showing 51 changed files with 939 additions and 118 deletions.
16 changes: 16 additions & 0 deletions nodegen/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
class FixedImpl(Enum):
FP8x23 = 'FP8x23'
FP16x16 = 'FP16x16'
FP64x64 = 'FP64x64'



def to_fp(x: np.ndarray, fp_impl: FixedImpl):
Expand All @@ -18,15 +20,19 @@ def to_fp(x: np.ndarray, fp_impl: FixedImpl):
return (x * 2**23).astype(np.int64)
case FixedImpl.FP16x16:
return (x * 2**16).astype(np.int64)
case FixedImpl.FP64x64:
return (x * 2**64)


class Dtype(Enum):
FP8x23 = 'FP8x23'
FP16x16 = 'FP16x16'
FP64x64 = 'FP64x64'
I8 = 'i8'
I32 = 'i32'
U32 = 'u32'
BOOL = 'bool'
COMPLEX64 = 'complex64'


class Tensor:
Expand Down Expand Up @@ -166,8 +172,15 @@ def get_data_statement(data: np.ndarray, dtype: Dtype) -> list[str]:
return ["FP8x23 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.FP16x16:
return ["FP16x16 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.FP64x64:
return ["FP64x64 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()]
case Dtype.BOOL:
return [str(x).lower() for x in data.flatten()]
case Dtype.COMPLEX64:
return ["complex64 { "+"real: FP64x64 { "+f"mag: {abs(int(np.real(x)))}, sign: {str(np.real(x) < 0).lower()} "+"} , img: FP64x64 { "+f"mag: {abs(int(np.imag(x)))}, sign: {str(np.imag(x) < 0).lower()} "+"} }" for x in data.flatten()]





def get_data_statement_for_sequences(data: Sequence, dtype: Dtype) -> list[list[str]]:
Expand Down Expand Up @@ -227,6 +240,7 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]:
Dtype.FP8x23: ["orion::operators::tensor::FP8x23Tensor",],
Dtype.FP16x16: ["orion::operators::tensor::FP16x16Tensor",],
Dtype.BOOL: ["orion::operators::tensor::BoolTensor",],
Dtype.COMPLEX64: ["orion::operators::tensor::Complex64Tensor",],
}


Expand All @@ -246,6 +260,7 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]:
Dtype.FP8x23: ["orion::operators::tensor::FP8x23TensorPartialEq",],
Dtype.FP16x16: ["orion::operators::tensor::FP16x16TensorPartialEq",],
Dtype.BOOL: ["orion::operators::tensor::BoolTensorPartialEq",],
Dtype.COMPLEX64: ["orion::operators::tensor::Complex64TensorPartialEq",],
}


Expand All @@ -256,4 +271,5 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]:
Dtype.FP8x23: ["orion::numbers::{FixedTrait, FP8x23}",],
Dtype.FP16x16: ["orion::numbers::{FixedTrait, FP16x16}",],
Dtype.BOOL: [],
Dtype.COMPLEX64: ["orion::numbers::{NumberTrait, complex64}",],
}
26 changes: 26 additions & 0 deletions nodegen/node/reduce_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np



class Reduce_l2(RunAll):
@staticmethod
def reduce_l2_fp8x23():
Expand Down Expand Up @@ -107,4 +108,29 @@ def reduce_l2_axis_0():

reduce_l2_export_do_not_keepdims()
reduce_l2_export_keepdims()
reduce_l2_axis_0()

@staticmethod
def reduce_l2_complex64():



def reduce_l2_axis_0():
shape = [2, 3]
axes = np.array([0], dtype=np.int64)
keepdims = True
x = np.reshape(np.array([1.+2.j, 2.-1.j, 3.-3.j, 3.-2.j, 3.+5.j, 4.- 1.j]), shape)
y = np.sqrt(np.sum(a=np.square(abs(x)), axis=tuple(axes), keepdims=True))
print(to_fp(x.flatten(), FixedImpl.FP64x64))

x = Tensor(Dtype.COMPLEX64, x.shape, to_fp(
x.flatten(), FixedImpl.FP64x64))

y = Tensor(Dtype.COMPLEX64, y.shape, to_fp(
y.flatten(), FixedImpl.FP64x64))

name = "reduce_l2_complex64_axis_0"
make_test(
[x], y, "input_0.reduce_l2(0, true)", name)

reduce_l2_axis_0()
7 changes: 6 additions & 1 deletion src/numbers/complex_number/complex64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ impl Complex64Impl of ComplexTrait<complex64, FP64x64> {
let y = self.img;
let two = FP64x64Impl::new(TWO, false);
let real = (((x.pow(two) + y.pow(two)).sqrt() + x) / two).sqrt();
let img = (((x.pow(two) + y.pow(two)).sqrt() - x) / two).sqrt();
let img = if y == FP64x64Impl::ZERO() {
FP64x64Impl::ZERO()
} else {
(((x.pow(two) + y.pow(two)).sqrt() - x) / two).sqrt()
};

let img = FP64x64Impl::new(img.mag, y.sign);
complex64 { real, img }
}
Expand Down
4 changes: 3 additions & 1 deletion src/operators/ml.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{
use orion::operators::ml::tree_ensemble::tree_ensemble_regressor::{
TreeEnsembleRegressor, TreeEnsembleRegressorImpl, TreeEnsembleRegressorTrait, AGGREGATE_FUNCTION
};
use orion::operators::ml::linear::linear_regressor::{LinearRegressorTrait, LinearRegressorImpl, LinearRegressor};
use orion::operators::ml::linear::linear_regressor::{
LinearRegressorTrait, LinearRegressorImpl, LinearRegressor
};
4 changes: 4 additions & 0 deletions src/operators/tensor.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ use orion::operators::tensor::implementations::tensor_u32::{

use orion::operators::tensor::implementations::tensor_bool::{BoolTensor, BoolTensorPartialEq};

use orion::operators::tensor::implementations::tensor_complex64::{
Complex64Tensor, Complex64TensorAdd, Complex64TensorSub, Complex64TensorMul, Complex64TensorDiv,
Complex64TensorPartialEq,
};
1 change: 1 addition & 0 deletions src/operators/tensor/implementations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ mod tensor_fp64x64;
mod tensor_fp32x32;
mod tensor_fp16x16wide;
mod tensor_fp8x23wide;
mod tensor_complex64;
4 changes: 3 additions & 1 deletion src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ impl BoolTensor of TensorTrait<bool> {
panic(array!['not supported!'])
}

fn gather_nd(self: @Tensor<bool>, indices: Tensor<usize>, batch_dims: Option<usize>) -> Tensor<bool> {
fn gather_nd(
self: @Tensor<bool>, indices: Tensor<usize>, batch_dims: Option<usize>
) -> Tensor<bool> {
math::gather_nd::gather_nd(self, indices, batch_dims)
}
}
Expand Down
Loading

0 comments on commit cb8de91

Please sign in to comment.