Skip to content

Commit

Permalink
fix docs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
0x73e committed Oct 23, 2023
1 parent 6462c61 commit 19266ef
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NNTrait::thresholded_relu

```rust
fn thresholded_relu(inputs: @Tensor<T>, alpha: @T) -> Tensor<T>
fn thresholded_relu(tensor: @Tensor<T>, alpha: @T) -> Tensor<T>
```

Applies the thresholded rectified linear unit (Thresholded ReLU) activation function element-wise to a given tensor.
Expand Down Expand Up @@ -35,13 +35,13 @@ fn thresholded_relu_example() -> Tensor<FP8x23> {
FixedTrait::new(0, false),
FixedTrait::new(256, false),
FixedTrait::new(512, false),
FixedTrait::new(513, false),
FixedTrait::new(257, false),
]
.span(),
);
let alpha = FixedTrait::from_felt(256); // 1.0

return NNTrait::leaky_relu(@tensor, @alpha);
}
>>> [[0, 0], [512, 513]]
>>> [[0, 0], [512, 257]]
```
4 changes: 2 additions & 2 deletions nodegen/node/thresholded_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Thresholded_relu(RunAll):

@staticmethod
def leaky_thresholded_fp8x23():
def thresholded_relu_fp8x23():

alpha = 1.0

Expand All @@ -25,7 +25,7 @@ def leaky_thresholded_fp8x23():
name, Trait.NN)

@staticmethod
def leaky_thresholded_fp16x16():
def thresholded_relu_fp16x16():

alpha = 1.0

Expand Down
6 changes: 3 additions & 3 deletions src/operators/nn/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ trait NNTrait<T> {
/// # NNTrait::thresholded_relu
///
/// ```rust
/// fn thresholded_relu(inputs: @Tensor<T>, alpha: @T) -> Tensor<T>
/// fn thresholded_relu(tensor: @Tensor<T>, alpha: @T) -> Tensor<T>
/// ```
///
/// Applies the thresholded rectified linear unit (Thresholded ReLU) activation function element-wise to a given tensor.
Expand Down Expand Up @@ -488,15 +488,15 @@ trait NNTrait<T> {
/// FixedTrait::new(0, false),
/// FixedTrait::new(256, false),
/// FixedTrait::new(512, false),
/// FixedTrait::new(513, false),
/// FixedTrait::new(257, false),
/// ]
/// .span(),
/// );
/// let alpha = FixedTrait::from_felt(256); // 1.0
///
/// return NNTrait::leaky_relu(@tensor, @alpha);
/// }
/// >>> [[0, 0], [512, 513]]
/// >>> [[0, 0], [512, 257]]
/// ```
///
fn thresholded_relu(tensor: @Tensor<T>, alpha: @T) -> Tensor<T>;
Expand Down

0 comments on commit 19266ef

Please sign in to comment.