Skip to content

Commit

Permalink
Merge pull request gizatechxyz#624 from chachaleo/feat/TreeEnsemble
Browse files Browse the repository at this point in the history
feat: tree ensemble
  • Loading branch information
raphaelDkhn authored Mar 31, 2024
2 parents 57871a4 + ad74d40 commit b02b601
Show file tree
Hide file tree
Showing 9 changed files with 1,077 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docgen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ fn main() {
doc_trait(trait_path, doc_path, label);
doc_functions(trait_path, doc_path, trait_name, label);

// TREE ENSEMBLE DOC
let trait_path = "src/operators/ml/tree_ensemble/tree_ensemble.cairo";
let doc_path = "docs/framework/operators/machine-learning/tree-ensemble";
let label = "tree_ensemble";
let trait_name: &str = "TreeEnsembleTrait";
doc_trait(trait_path, doc_path, label);
doc_functions(trait_path, doc_path, trait_name, label);

// LINEAR REGRESSOR DOC
let trait_path = "src/operators/ml/linear/linear_regressor.cairo";
let doc_path = "docs/framework/operators/machine-learning/linear-regressor";
Expand Down
22 changes: 22 additions & 0 deletions docs/framework/operators/machine-learning/tree-ensemble/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Tree Ensemble

`TreeEnsembleTrait` provides a trait definition for tree ensemble problem.

```rust
use orion::operators::ml::TreeEnsembleTrait;
```

### Data types

Orion supports currently only fixed point data types for `TreeEnsembleTrait`.

| Data type | dtype |
| -------------------- | ------------------------------------------------------------- |
| Fixed point (signed) | `TreeEnsembleTrait<FP8x23 \| FP16x16 \| FP64x64 \| FP32x32>` |


***

| function | description |
| --- | --- |
| [`tree_ensemble.predict`](tree_ensemble.predict.md) | Returns the regressed values for each input in a batch. |
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# TreeEnsemble::predict

```rust
fn predict(X: @Tensor<T>,
nodes_splits: Tensor<T>,
nodes_featureids: Span<usize>,
nodes_modes: Span<MODE>,
nodes_truenodeids: Span<usize>,
nodes_falsenodeids: Span<usize>,
nodes_trueleafs: Span<usize>,
nodes_falseleafs: Span<usize>,
leaf_targetids: Span<usize>,
leaf_weights: Tensor<T>,
tree_roots: Span<usize>,
post_transform: POST_TRANSFORM,
aggregate_function: AGGREGATE_FUNCTION,
nodes_hitrates: Option<Tensor<T>>,
nodes_missing_value_tracks_true: Option<Span<usize>>,
membership_values: Option<Tensor<T>>,
n_targets: usize
) -> MutMatrix::<T>;
```

Tree Ensemble operator. Returns the regressed values for each input in a batch. Inputs have dimensions [N, F] where N is the input batch size and F is the number of input features. Outputs have dimensions [N, num_targets] where N is the batch size and num_targets is the number of targets, which is a configurable attribute.

## Args

* `X`: Input 2D tensor.
* `nodes_splits`: Thresholds to do the splitting on for each node with mode that is not 'BRANCH_MEMBER'.
* `nodes_featureids`: Feature id for each node.
* `nodes_modes`: The comparison operation performed by the node. This is encoded as an enumeration of 'NODE_MODE::LEQ', 'NODE_MODE::LT', 'NODE_MODE::GTE', 'NODE_MODE::GT', 'NODE_MODE::EQ', 'NODE_MODE::NEQ', and 'NODE_MODE::MEMBER'
* `nodes_truenodeids`: If `nodes_trueleafs` is 0 (false) at an entry, this represents the position of the true branch node.
* `nodes_falsenodeids`: If `nodes_falseleafs` is 0 (false) at an entry, this represents the position of the false branch node.
* `nodes_trueleafs`: 1 if true branch is leaf for each node and 0 an interior node.
* `nodes_falseleafs`: 1 if true branch is leaf for each node and 0 an interior node.
* `leaf_targetids`: The index of the target that this leaf contributes to (this must be in range `[0, n_targets)`).
* `leaf_weights`: The weight for each leaf.
* `tree_roots`: Index into `nodes_*` for the root of each tree. The tree structure is derived from the branching of each node.
* `post_transform`: Indicates the transform to apply to the score.One of 'POST_TRANSFORM::NONE', 'POST_TRANSFORM::SOFTMAX', 'POST_TRANSFORM::LOGISTIC', 'POST_TRANSFORM::SOFTMAX_ZERO' or 'POST_TRANSFORM::PROBIT' ,
* `aggregate_function`: Defines how to aggregate leaf values within a target. One of 'AGGREGATE_FUNCTION::AVERAGE', 'AGGREGATE_FUNCTION::SUM', 'AGGREGATE_FUNCTION::MIN', 'AGGREGATE_FUNCTION::MAX` defaults to 'AGGREGATE_FUNCTION::SUM'
* `nodes_hitrates`: Popularity of each node, used for performance and may be omitted.
* `nodes_missing_value_tracks_true`: For each node, define whether to follow the true branch (if attribute value is 1) or false branch (if attribute value is 0) in the presence of a NaN input feature. This attribute may be left undefined and the default value is false (0) for all nodes.
* `membership_values`: Members to test membership of for each set membership node. List all of the members to test again in the order that the 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will have the same number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes.
* `n_targets`: The total number of targets.


## Returns

* Output of shape [Batch Size, Number of targets]

## Type Constraints

`TreeEnsembleClassifier` and `X` must be fixed points

## Examples

```rust
use orion::numbers::FP16x16;
use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor};
use orion::operators::ml::{TreeEnsembleTrait,POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE};
use orion::operators::matrix::{MutMatrix, MutMatrixImpl};
use orion::numbers::NumberTrait;

fn example_tree_ensemble_one_tree() -> MutMatrix::<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(3);
shape.append(2);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 78643, sign: false });
data.append(FP16x16 { mag: 222822, sign: false });
data.append(FP16x16 { mag: 7864, sign: true });
data.append(FP16x16 { mag: 108789, sign: false });
data.append(FP16x16 { mag: 271319, sign: false });
data.append(FP16x16 { mag: 115998, sign: false });
let mut X = TensorTrait::new(shape.span(), data.span());

let mut shape = ArrayTrait::<usize>::new();
shape.append(4);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 342753, sign: false });
data.append(FP16x16 { mag: 794296, sign: false });
data.append(FP16x16 { mag: 801505, sign: true });
data.append(FP16x16 { mag: 472514, sign: false });
let leaf_weights = TensorTrait::new(shape.span(), data.span());

let mut shape = ArrayTrait::<usize>::new();
shape.append(3);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 205783, sign: false });
data.append(FP16x16 { mag: 78643, sign: false });
data.append(FP16x16 { mag: 275251, sign: false });
let nodes_splits = TensorTrait::new(shape.span(), data.span());

let membership_values = Option::None;

let n_targets = 2;
let aggregate_function = AGGREGATE_FUNCTION::SUM;
let nodes_missing_value_tracks_true = Option::None;
let nodes_hitrates = Option::None;
let post_transform = POST_TRANSFORM::NONE;

let tree_roots: Span<usize> = array![0].span();
let nodes_modes: Span<MODE> = array![MODE::LEQ, MODE::LEQ, MODE::LEQ].span();

let nodes_featureids: Span<usize> = array![0, 0, 0].span();
let nodes_truenodeids: Span<usize> = array![1, 0, 1].span();
let nodes_trueleafs: Span<usize> = array![0, 1, 1].span();
let nodes_falsenodeids: Span<usize> = array![2, 2, 3].span();
let nodes_falseleafs: Span<usize> = array![0, 1, 1].span();
let leaf_targetids: Span<usize> = array![0, 1, 0, 1].span();

return TreeEnsembleTrait::predict(
@X,
nodes_splits,
nodes_featureids,
nodes_modes,
nodes_truenodeids,
nodes_falsenodeids,
nodes_trueleafs,
nodes_falseleafs,
leaf_targetids,
leaf_weights,
tree_roots,
post_transform,
aggregate_function,
nodes_hitrates,
nodes_missing_value_tracks_true,
membership_values,
n_targets
);
}

>>> [[ 5.23 0. ]
[ 5.23 0. ]
[ 0. 12.12]]
```
3 changes: 3 additions & 0 deletions src/operators/ml.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod linear;
mod svm;
mod normalizer;

use orion::operators::ml::tree_ensemble::tree_ensemble::{TreeEnsembleTrait};

use orion::operators::ml::tree_ensemble::core::{
TreeEnsemble, TreeEnsembleAttributes, TreeEnsembleImpl, NODE_MODES
};
Expand Down Expand Up @@ -32,3 +34,4 @@ enum POST_TRANSFORM {
SOFTMAXZERO,
PROBIT,
}

1 change: 1 addition & 0 deletions src/operators/ml/tree_ensemble.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod core;
mod tree_ensemble_classifier;
mod tree_ensemble_regressor;
mod tree_ensemble;
Loading

0 comments on commit b02b601

Please sign in to comment.