From 57871a4c9eb4a14dde2bfc03eefa88adf0e6b298 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Wed, 27 Mar 2024 17:22:54 +0100 Subject: [PATCH 1/3] fix gather_elements --- .../tensor/math/gather_elements.cairo | 94 +++++++------------ 1 file changed, 35 insertions(+), 59 deletions(-) diff --git a/src/operators/tensor/math/gather_elements.cairo b/src/operators/tensor/math/gather_elements.cairo index cc8b9ae20..e4b624e42 100644 --- a/src/operators/tensor/math/gather_elements.cairo +++ b/src/operators/tensor/math/gather_elements.cairo @@ -1,7 +1,9 @@ +use core::option::OptionTrait; +use core::traits::TryInto; use alexandria_data_structures::array_ext::SpanTraitExt; use orion::numbers::NumberTrait; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{{TensorTrait, Tensor}, core::{unravel_index, stride}}; /// Cf: TensorTrait::gather_elements docstring fn gather_elements, impl TCopy: Copy, impl TDrop: Drop,>( @@ -19,71 +21,45 @@ fn gather_elements, impl TCopy: Copy, im }; assert(axis < (*self.shape).len(), 'axis out of dimensions'); - let axis_shape = *(*self.shape).at(axis); - - // Adjust indices that are negative - let mut adjusted_indices = array![]; - let mut indices_data = indices.data.clone(); - loop { - match indices_data.pop_front() { - Option::Some(index) => { - let adjusted_index: usize = if *index < 0 { - let val: u32 = (axis_shape.try_into().unwrap() + *index).try_into().unwrap(); - val - } else { - let val: u32 = (*index).try_into().unwrap(); - val - }; - assert(adjusted_index >= 0 && adjusted_index < axis_shape, 'Index out of bounds'); - adjusted_indices.append(adjusted_index); - }, - Option::None => { break; } - }; - }; + let data_strides = stride(*self.shape); let mut output_data = array![]; - let mut data_shape_clone = (*self.shape).clone(); - let mut multiplier = 1; - let mut looper = 1; - let mut ind = 0; - loop { - match data_shape_clone.pop_front() { - Option::Some(val) => { - if ind >= axis { - multiplier *= *val; - } - if ind > axis { - looper *= *val; - } - ind += 1; - }, - Option::None => { break; } - }; - }; + let mut i: usize = 0; + while i < indices + .data + .len() { + let indice = *indices.data.at(i); + let adjusted_indice: u32 = if indice < 0 { + ((*(*self.shape).at(axis)).try_into().unwrap() + indice).try_into().unwrap() + } else { + indice.try_into().unwrap() + }; - let inner_loop = multiplier / axis_shape; - let mut adjusted_indices_iter = adjusted_indices.clone(); + assert(adjusted_indice < (*(*self.shape).at(axis)), 'Index out of bounds'); - let mut i: usize = 0; - loop { - match adjusted_indices_iter.pop_front() { - Option::Some(indice) => { - let value = if axis == 0 { - indice * inner_loop + (i % inner_loop) - } else if axis == (*self.shape).len() - 1 { - indice + axis_shape * (i / axis_shape) - } else { - indice * looper - + (i % looper) - + (multiplier / axis_shape) * (i / (multiplier / axis_shape)) + let multidim_index = unravel_index(i, indices.shape); + let mut flat_index_for_data = 0; + + let mut j: usize = 0; + while j < multidim_index + .len() { + let dim_index = *multidim_index.at(j); + if j == axis { + flat_index_for_data += adjusted_indice * (*data_strides.at(j)); + } else { + flat_index_for_data += (dim_index * *data_strides.at(j)) + } + j += 1; }; - output_data.append(*self.data[value]); - i += 1; - }, - Option::None => { break; } + assert( + flat_index_for_data < (*self.data).len().try_into().unwrap(), + 'Flat index out of bounds' + ); + + output_data.append(*(*self.data).at(flat_index_for_data)); + i += 1; }; - }; TensorTrait::::new(indices.shape, output_data.span()) } From ad74d400650de1c3ca4245f7085700360a9153db Mon Sep 17 00:00:00 2001 From: chachaleo Date: Thu, 28 Mar 2024 20:40:39 +0100 Subject: [PATCH 2/3] feat: tree ensemble --- docgen/src/main.rs | 8 + .../machine-learning/tree-ensemble/README.md | 22 + .../tree-ensemble/tree_ensemble.predict.md | 139 ++++ src/operators/ml.cairo | 3 + src/operators/ml/tree_ensemble.cairo | 1 + .../ml/tree_ensemble/tree_ensemble.cairo | 602 ++++++++++++++++++ tests/lib.cairo | 1 + tests/ml.cairo | 1 + tests/ml/tree_ensemble_test.cairo | 300 +++++++++ 9 files changed, 1077 insertions(+) create mode 100644 docs/framework/operators/machine-learning/tree-ensemble/README.md create mode 100644 docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md create mode 100644 src/operators/ml/tree_ensemble/tree_ensemble.cairo create mode 100644 tests/ml/tree_ensemble_test.cairo diff --git a/docgen/src/main.rs b/docgen/src/main.rs index 8d1f90f4b..fb7dba8f1 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -59,6 +59,14 @@ fn main() { doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + // TREE ENSEMBLE DOC + let trait_path = "src/operators/ml/tree_ensemble/tree_ensemble.cairo"; + let doc_path = "docs/framework/operators/machine-learning/tree-ensemble"; + let label = "tree_ensemble"; + let trait_name: &str = "TreeEnsembleTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); + // LINEAR REGRESSOR DOC let trait_path = "src/operators/ml/linear/linear_regressor.cairo"; let doc_path = "docs/framework/operators/machine-learning/linear-regressor"; diff --git a/docs/framework/operators/machine-learning/tree-ensemble/README.md b/docs/framework/operators/machine-learning/tree-ensemble/README.md new file mode 100644 index 000000000..26fcfb205 --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-ensemble/README.md @@ -0,0 +1,22 @@ +# Tree Ensemble + +`TreeEnsembleTrait` provides a trait definition for tree ensemble problem. + +```rust +use orion::operators::ml::TreeEnsembleTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `TreeEnsembleTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `TreeEnsembleTrait` | + + +*** + +| function | description | +| --- | --- | +| [`tree_ensemble.predict`](tree_ensemble.predict.md) | Returns the regressed values for each input in a batch. | \ No newline at end of file diff --git a/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md b/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md new file mode 100644 index 000000000..a7f97e96d --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md @@ -0,0 +1,139 @@ +# TreeEnsemble::predict + +```rust + fn predict(X: @Tensor, + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix::; +``` + +Tree Ensemble operator. Returns the regressed values for each input in a batch. Inputs have dimensions [N, F] where N is the input batch size and F is the number of input features. Outputs have dimensions [N, num_targets] where N is the batch size and num_targets is the number of targets, which is a configurable attribute. + +## Args + +* `X`: Input 2D tensor. +* `nodes_splits`: Thresholds to do the splitting on for each node with mode that is not 'BRANCH_MEMBER'. +* `nodes_featureids`: Feature id for each node. +* `nodes_modes`: The comparison operation performed by the node. This is encoded as an enumeration of 'NODE_MODE::LEQ', 'NODE_MODE::LT', 'NODE_MODE::GTE', 'NODE_MODE::GT', 'NODE_MODE::EQ', 'NODE_MODE::NEQ', and 'NODE_MODE::MEMBER' +* `nodes_truenodeids`: If `nodes_trueleafs` is 0 (false) at an entry, this represents the position of the true branch node. +* `nodes_falsenodeids`: If `nodes_falseleafs` is 0 (false) at an entry, this represents the position of the false branch node. +* `nodes_trueleafs`: 1 if true branch is leaf for each node and 0 an interior node. +* `nodes_falseleafs`: 1 if true branch is leaf for each node and 0 an interior node. +* `leaf_targetids`: The index of the target that this leaf contributes to (this must be in range `[0, n_targets)`). +* `leaf_weights`: The weight for each leaf. +* `tree_roots`: Index into `nodes_*` for the root of each tree. The tree structure is derived from the branching of each node. +* `post_transform`: Indicates the transform to apply to the score.One of 'POST_TRANSFORM::NONE', 'POST_TRANSFORM::SOFTMAX', 'POST_TRANSFORM::LOGISTIC', 'POST_TRANSFORM::SOFTMAX_ZERO' or 'POST_TRANSFORM::PROBIT' , +* `aggregate_function`: Defines how to aggregate leaf values within a target. One of 'AGGREGATE_FUNCTION::AVERAGE', 'AGGREGATE_FUNCTION::SUM', 'AGGREGATE_FUNCTION::MIN', 'AGGREGATE_FUNCTION::MAX` defaults to 'AGGREGATE_FUNCTION::SUM' +* `nodes_hitrates`: Popularity of each node, used for performance and may be omitted. +* `nodes_missing_value_tracks_true`: For each node, define whether to follow the true branch (if attribute value is 1) or false branch (if attribute value is 0) in the presence of a NaN input feature. This attribute may be left undefined and the default value is false (0) for all nodes. +* `membership_values`: Members to test membership of for each set membership node. List all of the members to test again in the order that the 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will have the same number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes. +* `n_targets`: The total number of targets. + + +## Returns + +* Output of shape [Batch Size, Number of targets] + +## Type Constraints + +`TreeEnsembleClassifier` and `X` must be fixed points + +## Examples + +```rust +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; +use orion::operators::ml::{TreeEnsembleTrait,POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE}; +use orion::operators::matrix::{MutMatrix, MutMatrixImpl}; +use orion::numbers::NumberTrait; + +fn example_tree_ensemble_one_tree() -> MutMatrix:: { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 7864, sign: true }); + data.append(FP16x16 { mag: 108789, sign: false }); + data.append(FP16x16 { mag: 271319, sign: false }); + data.append(FP16x16 { mag: 115998, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 342753, sign: false }); + data.append(FP16x16 { mag: 794296, sign: false }); + data.append(FP16x16 { mag: 801505, sign: true }); + data.append(FP16x16 { mag: 472514, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 205783, sign: false }); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 275251, sign: false }); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let membership_values = Option::None; + + let n_targets = 2; + let aggregate_function = AGGREGATE_FUNCTION::SUM; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0].span(); + let nodes_modes: Span = array![MODE::LEQ, MODE::LEQ, MODE::LEQ].span(); + + let nodes_featureids: Span = array![0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 1].span(); + let nodes_trueleafs: Span = array![0, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 2, 3].span(); + let nodes_falseleafs: Span = array![0, 1, 1].span(); + let leaf_targetids: Span = array![0, 1, 0, 1].span(); + + return TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + membership_values, + n_targets + ); +} + +>>> [[ 5.23 0. ] + [ 5.23 0. ] + [ 0. 12.12]] +``` diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 08e9e40fb..93a0394a6 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -3,6 +3,8 @@ mod linear; mod svm; mod normalizer; +use orion::operators::ml::tree_ensemble::tree_ensemble::{TreeEnsembleTrait}; + use orion::operators::ml::tree_ensemble::core::{ TreeEnsemble, TreeEnsembleAttributes, TreeEnsembleImpl, NODE_MODES }; @@ -32,3 +34,4 @@ enum POST_TRANSFORM { SOFTMAXZERO, PROBIT, } + diff --git a/src/operators/ml/tree_ensemble.cairo b/src/operators/ml/tree_ensemble.cairo index 32c96c0bd..925c1ea7e 100644 --- a/src/operators/ml/tree_ensemble.cairo +++ b/src/operators/ml/tree_ensemble.cairo @@ -1,3 +1,4 @@ mod core; mod tree_ensemble_classifier; mod tree_ensemble_regressor; +mod tree_ensemble; diff --git a/src/operators/ml/tree_ensemble/tree_ensemble.cairo b/src/operators/ml/tree_ensemble/tree_ensemble.cairo new file mode 100644 index 000000000..51e3f9ec2 --- /dev/null +++ b/src/operators/ml/tree_ensemble/tree_ensemble.cairo @@ -0,0 +1,602 @@ +use orion::operators::tensor::{Tensor, TensorTrait}; +use orion::numbers::NumberTrait; + +use orion::operators::matrix::{MutMatrix, MutMatrixImpl, MutMatrixTrait}; + +#[derive(Copy, Drop)] +enum AGGREGATE_FUNCTION { + AVERAGE, + SUM, + MIN, + MAX, +} + +#[derive(Copy, Drop)] +enum POST_TRANSFORM { + NONE, + SOFTMAX, + LOGISTIC, + SOFTMAX_ZERO, + PROBIT, +} + +#[derive(Copy, Drop)] +enum NODE_MODE { + LEQ, + LT, + GTE, + GT, + EQ, + NEQ, + MEMBER, +} + +/// Trait +/// +/// predict - Returns the regressed values for each input in a batch. +trait TreeEnsembleTrait { + /// # TreeEnsemble::predict + /// + /// ```rust + /// fn predict(X: @Tensor, + /// nodes_splits: Tensor, + /// nodes_featureids: Span, + /// nodes_modes: Span, + /// nodes_truenodeids: Span, + /// nodes_falsenodeids: Span, + /// nodes_trueleafs: Span, + /// nodes_falseleafs: Span, + /// leaf_targetids: Span, + /// leaf_weights: Tensor, + /// tree_roots: Span, + /// post_transform: POST_TRANSFORM, + /// aggregate_function: AGGREGATE_FUNCTION, + /// nodes_hitrates: Option>, + /// nodes_missing_value_tracks_true: Option>, + /// membership_values: Option>, + /// n_targets: usize + /// ) -> MutMatrix::; + /// ``` + /// + /// Tree Ensemble operator. Returns the regressed values for each input in a batch. Inputs have dimensions [N, F] where N is the input batch size and F is the number of input features. Outputs have dimensions [N, num_targets] where N is the batch size and num_targets is the number of targets, which is a configurable attribute. + /// + /// ## Args + /// + /// * `X`: Input 2D tensor. + /// * `nodes_splits`: Thresholds to do the splitting on for each node with mode that is not 'BRANCH_MEMBER'. + /// * `nodes_featureids`: Feature id for each node. + /// * `nodes_modes`: The comparison operation performed by the node. This is encoded as an enumeration of 'NODE_MODE::LEQ', 'NODE_MODE::LT', 'NODE_MODE::GTE', 'NODE_MODE::GT', 'NODE_MODE::EQ', 'NODE_MODE::NEQ', and 'NODE_MODE::MEMBER' + /// * `nodes_truenodeids`: If `nodes_trueleafs` is 0 (false) at an entry, this represents the position of the true branch node. + /// * `nodes_falsenodeids`: If `nodes_falseleafs` is 0 (false) at an entry, this represents the position of the false branch node. + /// * `nodes_trueleafs`: 1 if true branch is leaf for each node and 0 an interior node. + /// * `nodes_falseleafs`: 1 if true branch is leaf for each node and 0 an interior node. + /// * `leaf_targetids`: The index of the target that this leaf contributes to (this must be in range `[0, n_targets)`). + /// * `leaf_weights`: The weight for each leaf. + /// * `tree_roots`: Index into `nodes_*` for the root of each tree. The tree structure is derived from the branching of each node. + /// * `post_transform`: Indicates the transform to apply to the score.One of 'POST_TRANSFORM::NONE', 'POST_TRANSFORM::SOFTMAX', 'POST_TRANSFORM::LOGISTIC', 'POST_TRANSFORM::SOFTMAX_ZERO' or 'POST_TRANSFORM::PROBIT' , + /// * `aggregate_function`: Defines how to aggregate leaf values within a target. One of 'AGGREGATE_FUNCTION::AVERAGE', 'AGGREGATE_FUNCTION::SUM', 'AGGREGATE_FUNCTION::MIN', 'AGGREGATE_FUNCTION::MAX` defaults to 'AGGREGATE_FUNCTION::SUM' + /// * `nodes_hitrates`: Popularity of each node, used for performance and may be omitted. + /// * `nodes_missing_value_tracks_true`: For each node, define whether to follow the true branch (if attribute value is 1) or false branch (if attribute value is 0) in the presence of a NaN input feature. This attribute may be left undefined and the default value is false (0) for all nodes. + /// * `membership_values`: Members to test membership of for each set membership node. List all of the members to test again in the order that the 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will have the same number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes. + /// * `n_targets`: The total number of targets. + + /// + /// ## Returns + /// + /// * Output of shape [Batch Size, Number of targets] + /// + /// ## Type Constraints + /// + /// `T` must be fixed point + /// + /// ## Examples + /// + /// ```rust + /// use orion::numbers::FP16x16; + /// use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; + /// use orion::operators::ml::{TreeEnsembleTrait,POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE}; + /// use orion::operators::matrix::{MutMatrix, MutMatrixImpl}; + /// use orion::numbers::NumberTrait; + /// + /// fn example_tree_ensemble_one_tree() -> MutMatrix:: { + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// shape.append(2); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 78643, sign: false }); + /// data.append(FP16x16 { mag: 222822, sign: false }); + /// data.append(FP16x16 { mag: 7864, sign: true }); + /// data.append(FP16x16 { mag: 108789, sign: false }); + /// data.append(FP16x16 { mag: 271319, sign: false }); + /// data.append(FP16x16 { mag: 115998, sign: false }); + /// let mut X = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(4); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 342753, sign: false }); + /// data.append(FP16x16 { mag: 794296, sign: false }); + /// data.append(FP16x16 { mag: 801505, sign: true }); + /// data.append(FP16x16 { mag: 472514, sign: false }); + /// let leaf_weights = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 205783, sign: false }); + /// data.append(FP16x16 { mag: 78643, sign: false }); + /// data.append(FP16x16 { mag: 275251, sign: false }); + /// let nodes_splits = TensorTrait::new(shape.span(), data.span()); + /// + /// let membership_values = Option::None; + /// + /// let n_targets = 2; + /// let aggregate_function = AGGREGATE_FUNCTION::SUM; + /// let nodes_missing_value_tracks_true = Option::None; + /// let nodes_hitrates = Option::None; + /// let post_transform = POST_TRANSFORM::NONE; + /// + /// let tree_roots: Span = array![0].span(); + /// let nodes_modes: Span = array![MODE::LEQ, MODE::LEQ, MODE::LEQ].span(); + /// + /// let nodes_featureids: Span = array![0, 0, 0].span(); + /// let nodes_truenodeids: Span = array![1, 0, 1].span(); + /// let nodes_trueleafs: Span = array![0, 1, 1].span(); + /// let nodes_falsenodeids: Span = array![2, 2, 3].span(); + /// let nodes_falseleafs: Span = array![0, 1, 1].span(); + /// let leaf_targetids: Span = array![0, 1, 0, 1].span(); + /// + /// return TreeEnsembleTrait::predict( + /// @X, + /// nodes_splits, + /// nodes_featureids, + /// nodes_modes, + /// nodes_truenodeids, + /// nodes_falsenodeids, + /// nodes_trueleafs, + /// nodes_falseleafs, + /// leaf_targetids, + /// leaf_weights, + /// tree_roots, + /// post_transform, + /// aggregate_function, + /// nodes_hitrates, + /// nodes_missing_value_tracks_true, + /// membership_values, + /// n_targets + /// ); + /// } + /// + /// >>> [[ 5.23 0. ] + /// [ 5.23 0. ] + /// [ 0. 12.12]] + /// ``` + /// + fn predict( + X: @Tensor, + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix::; +} + + +impl TreeEnsembleImpl< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +PartialOrd, + +PartialEq, + +Add, + +Div, + +Mul, + +Into, + +AddEq, +> of TreeEnsembleTrait { + fn predict( + X: @Tensor, + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix:: { + let batch_size = *(*X).shape.at(0); + let n_features = *(*X).shape.at(1); + let n_trees = tree_roots.len(); + + let mut set_membership_iter = array![].span(); + let mut map_member_to_nodeid = Default::default(); + + let mut res: MutMatrix = MutMatrixImpl::new(batch_size, n_targets); + + let (nodes_missing_value_tracks_true, nodes_missing_value_tracks_true_flag) = + match nodes_missing_value_tracks_true { + Option::Some(nodes_missing_value_tracks_true) => { + (nodes_missing_value_tracks_true, true) + }, + Option::None => { (array![].span(), false) } + }; + + match membership_values { + Option::Some(membership_values) => { + set_membership_iter = membership_values.data.clone(); + + let mut tree_roots_iter = tree_roots.clone(); + loop { + match tree_roots_iter.pop_front() { + Option::Some(root_index) => { + let root_index = *root_index; + let is_leaf = (*nodes_trueleafs.at(root_index) == 1 + && *nodes_falseleafs.at(root_index) == 1 + && *nodes_truenodeids + .at(root_index) == *nodes_falsenodeids + .at(root_index)); + map_members_to_nodeids( + root_index, + is_leaf, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + ref set_membership_iter, + ref map_member_to_nodeid, + ); + }, + Option::None => { break; } + } + }; + }, + Option::None => {} + }; + + match aggregate_function { + AGGREGATE_FUNCTION::AVERAGE => { res.set(batch_size, n_targets, NumberTrait::zero()); }, + AGGREGATE_FUNCTION::SUM => { res.set(batch_size, n_targets, NumberTrait::zero()); }, + AGGREGATE_FUNCTION::MIN => { + let mut i = 0; + while i != batch_size { + let mut j = 0; + while j != n_targets { + res.set(i, j, NumberTrait::min_value()); + j += 1; + }; + i += 1; + }; + }, + AGGREGATE_FUNCTION::MAX => { + let mut i = 0; + while i != batch_size { + let mut j = 0; + while j != n_targets { + res.set(i, j, NumberTrait::max_value()); + j += 1; + }; + i += 1; + }; + }, + } + + let mut weights = ArrayTrait::new(); + let mut target_ids = ArrayTrait::new(); + + let mut tree_roots_iter = tree_roots.clone(); + loop { + match tree_roots_iter.pop_front() { + Option::Some(root_index) => { + let root_index = *root_index; + let is_leaf = (*nodes_trueleafs.at(root_index) == 1 + && *nodes_falseleafs.at(root_index) == 1 + && *nodes_truenodeids.at(root_index) == *nodes_falsenodeids.at(root_index)); + + let mut batch_num = 0; + while batch_num != batch_size { + let x_batch = SpanTrait::slice( + (*X).data, batch_num * n_features, n_features + ); + + let (weight, target) = iterate_node( + x_batch, + root_index, + is_leaf, + nodes_splits.data, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + nodes_hitrates, + nodes_missing_value_tracks_true, + nodes_missing_value_tracks_true_flag, + ref map_member_to_nodeid, + ); + weights.append(weight); + target_ids.append(target); + batch_num += 1; + }; + }, + Option::None => { break; } + } + }; + + let weights = weights.span(); + let target_ids = target_ids.span(); + + let mut batch_num = 0; + while batch_num != batch_size { + match aggregate_function { + AGGREGATE_FUNCTION::AVERAGE => { + let mut i = 0; + while i != n_trees { + let index = i * batch_size + batch_num; + res + .set( + batch_num, + *target_ids.at(index), + res.at(batch_num, *target_ids.at(index)) + + *weights.at(index) + / NumberTrait::new_unscaled(n_trees.into(), false) + ); + i += 1; + }; + }, + AGGREGATE_FUNCTION::SUM => { + let mut i = 0; + while i != n_trees { + let index = i * batch_size + batch_num; + res + .set( + batch_num, + *target_ids.at(index), + res.at(batch_num, *target_ids.at(index)) + *weights.at(index) + ); + i += 1; + }; + }, + AGGREGATE_FUNCTION::MIN => { + let mut i = 0; + while i != n_targets { + let val = NumberTrait::min( + res.at(batch_num, *target_ids.at(batch_num)), *weights.at(batch_num) + ); + res.set(batch_num, *target_ids.at(batch_num), val); + i += 1; + }; + }, + AGGREGATE_FUNCTION::MAX => { + let mut i = 0; + while i != n_targets { + let val = NumberTrait::max( + res.at(batch_num, *target_ids.at(batch_num)), *weights.at(batch_num) + ); + res.set(batch_num, *target_ids.at(batch_num), val); + i += 1; + }; + } + } + + batch_num += 1; + }; + + // Post Transform + let mut res = match post_transform { + POST_TRANSFORM::NONE => res, + POST_TRANSFORM::SOFTMAX => res.softmax(1), + POST_TRANSFORM::LOGISTIC => res.sigmoid(), + POST_TRANSFORM::SOFTMAX_ZERO => res.softmax_zero(1), + POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), + }; + + return res; + } +} +fn iterate_node< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +PartialOrd, + +PartialEq, +>( + X: Span, + current_node_index: usize, + is_leaf: bool, + nodes_splits: Span, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Span, + nodes_missing_value_tracks_true_flag: bool, + ref map_member_to_nodeid: Felt252Dict>>, +) -> (T, usize) { + let mut current_node_index = current_node_index; + let mut is_leaf = is_leaf; + + while !is_leaf { + let nmvtt_flag = if nodes_missing_value_tracks_true_flag { + *nodes_missing_value_tracks_true.at(current_node_index) == 1 + } else { + nodes_missing_value_tracks_true_flag + }; + if compare( + *X.at(*nodes_featureids.at(current_node_index)), + current_node_index, + *nodes_splits.at(current_node_index), + *nodes_modes.at(current_node_index), + ref map_member_to_nodeid, + nmvtt_flag + ) { + is_leaf = *nodes_trueleafs.at(current_node_index) == 1; + current_node_index = *nodes_truenodeids.at(current_node_index); + } else { + is_leaf = *nodes_falseleafs.at(current_node_index) == 1; + current_node_index = *nodes_falsenodeids.at(current_node_index); + }; + }; + + return (*leaf_weights.data.at(current_node_index), *leaf_targetids.at(current_node_index)); +} + +fn map_members_to_nodeids< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +PartialOrd, + +PartialEq, +>( + current_node_index: usize, + is_leaf: bool, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + ref set_membership_iter: Span, + ref map_member_to_nodeid: Felt252Dict>>, +) { + let mut current_node_index = current_node_index; + let mut is_leaf = is_leaf; + + if is_leaf { + return; + } + + match *nodes_modes.at(current_node_index) { + NODE_MODE::LEQ => {}, + NODE_MODE::LT => {}, + NODE_MODE::GTE => {}, + NODE_MODE::GT => {}, + NODE_MODE::EQ => {}, + NODE_MODE::NEQ => {}, + NODE_MODE::MEMBER => { + let mut subset_members = ArrayTrait::new(); + loop { + match set_membership_iter.pop_front() { + Option::Some(v) => { + if *v == NumberTrait::NaN() { + break; + } + subset_members.append(*v) + }, + Option::None => { break; } + } + }; + map_member_to_nodeid + .insert(current_node_index.into(), NullableTrait::new(subset_members.span())); + }, + } + // true branch + map_members_to_nodeids( + *nodes_truenodeids.at(current_node_index), + *nodes_trueleafs.at(current_node_index) == 1, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + ref set_membership_iter, + ref map_member_to_nodeid, + ); + + // false branch + map_members_to_nodeids( + *nodes_falsenodeids.at(current_node_index), + *nodes_falseleafs.at(current_node_index) == 1, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + ref set_membership_iter, + ref map_member_to_nodeid, + ); +} + + +fn compare< + T, MAG, +TensorTrait, +NumberTrait, +Copy, +Drop, +PartialOrd, +PartialEq +>( + x_feat: T, + current_node_index: usize, + value: T, + mode: NODE_MODE, + ref map_members_to_nodeids: Felt252Dict>>, + nodes_missing_value_tracks_true_flag: bool, +) -> bool { + match mode { + NODE_MODE::LEQ => { + (x_feat <= value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::LT => { + (x_feat < value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::GTE => { + (x_feat >= value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::GT => { + (x_feat > value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::EQ => { + (x_feat == value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::NEQ => { + (x_feat != value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::MEMBER => { + let mut set_members = map_members_to_nodeids.get(current_node_index.into()).deref(); + loop { + match set_members.pop_front() { + Option::Some(v) => { if x_feat == *v { + break true; + } }, + Option::None => { break false; } + } + } + }, + } +} diff --git a/tests/lib.cairo b/tests/lib.cairo index f5cecb77d..c408347ef 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,3 +5,4 @@ mod nodes; mod ml; mod operators; + diff --git a/tests/ml.cairo b/tests/ml.cairo index 4e3e0781e..b92dbcd83 100644 --- a/tests/ml.cairo +++ b/tests/ml.cairo @@ -5,3 +5,4 @@ mod linear_classifier_test; mod svm_regressor_test; mod svm_classifier_test; mod normalizer_test; +mod tree_ensemble_test; diff --git a/tests/ml/tree_ensemble_test.cairo b/tests/ml/tree_ensemble_test.cairo new file mode 100644 index 000000000..59a5592f6 --- /dev/null +++ b/tests/ml/tree_ensemble_test.cairo @@ -0,0 +1,300 @@ +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; +use orion::operators::ml::tree_ensemble::tree_ensemble::{ + TreeEnsembleTrait, POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE +}; +use orion::operators::matrix::{MutMatrix, MutMatrixImpl, MutMatrixTrait}; +use orion::numbers::NumberTrait; + + +#[test] +#[available_gas(200000000000)] +fn export_tree_ensemble_two_trees() { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 32768, sign: true }); + data.append(FP16x16 { mag: 26214, sign: true }); + data.append(FP16x16 { mag: 19660, sign: true }); + data.append(FP16x16 { mag: 13107, sign: true }); + data.append(FP16x16 { mag: 6553, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 6553, sign: false }); + data.append(FP16x16 { mag: 13107, sign: false }); + data.append(FP16x16 { mag: 19660, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 5041, sign: false }); + data.append(FP16x16 { mag: 32768, sign: false }); + data.append(FP16x16 { mag: 32768, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 18724, sign: false }); + data.append(FP16x16 { mag: 32768, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 17462, sign: false }); + data.append(FP16x16 { mag: 40726, sign: false }); + data.append(FP16x16 { mag: 36652, sign: true }); + data.append(FP16x16 { mag: 47240, sign: true }); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let n_targets = 1; + let aggregate_function = AGGREGATE_FUNCTION::AVERAGE; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0, 2].span(); + let nodes_modes: Span = array![ + NODE_MODE::LEQ, NODE_MODE::LEQ, NODE_MODE::LEQ, NODE_MODE::LEQ + ] + .span(); + + let nodes_featureids: Span = array![0, 2, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 3, 4].span(); + let nodes_trueleafs: Span = array![0, 1, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 1, 3, 5].span(); + let nodes_falseleafs: Span = array![1, 1, 0, 1].span(); + let leaf_targetids: Span = array![0, 0, 0, 0, 0, 0].span(); + + let mut scores = TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + Option::None, + n_targets + ); + + // ASSERT SCOREs + assert(scores.at(0, 0) == FP16x16 { mag: 18904, sign: false }, 'scores.at(0, 0)'); + assert(scores.at(1, 0) == FP16x16 { mag: 18904, sign: false }, 'scores.at(1, 0)'); + assert(scores.at(2, 0) == FP16x16 { mag: 18904, sign: false }, 'scores.at(2, 0)'); +} + + +#[test] +#[available_gas(200000000000)] +fn export_tree_ensemble_one_tree() { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 7864, sign: true }); + data.append(FP16x16 { mag: 108789, sign: false }); + data.append(FP16x16 { mag: 271319, sign: false }); + data.append(FP16x16 { mag: 115998, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 342753, sign: false }); + data.append(FP16x16 { mag: 794296, sign: false }); + data.append(FP16x16 { mag: 801505, sign: true }); + data.append(FP16x16 { mag: 472514, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 205783, sign: false }); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 275251, sign: false }); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let membership_values = Option::None; + + let n_targets = 2; + let aggregate_function = AGGREGATE_FUNCTION::SUM; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0].span(); + let nodes_modes: Span = array![NODE_MODE::LEQ, NODE_MODE::LEQ, NODE_MODE::LEQ] + .span(); + + let nodes_featureids: Span = array![0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 1].span(); + let nodes_trueleafs: Span = array![0, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 2, 3].span(); + let nodes_falseleafs: Span = array![0, 1, 1].span(); + let leaf_targetids: Span = array![0, 1, 0, 1].span(); + + let mut scores = TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + membership_values, + n_targets + ); + + // ASSERT SCOREs + assert(scores.at(0, 0) == FP16x16 { mag: 342753, sign: false }, 'scores.at(0, 0)'); + assert(scores.at(0, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 1)'); + + assert(scores.at(1, 0) == FP16x16 { mag: 342753, sign: false }, 'scores.at(1, 0)'); + assert(scores.at(1, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 1)'); + + assert(scores.at(2, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 0)'); + assert(scores.at(2, 1) == FP16x16 { mag: 794296, sign: false }, 'scores.at(2, 1)'); +} + + +#[test] +#[available_gas(200000000000)] +fn export_tree_ensemble_set_membership() { + let mut shape = ArrayTrait::::new(); + shape.append(6); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 7864, sign: true }); + data.append(NumberTrait::::NaN()); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 65536000, sign: false }); + data.append(FP16x16 { mag: 6553600, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 1522663424, sign: false }); + data.append(NumberTrait::::NaN()); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 242483, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(NumberTrait::::NaN()); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(NumberTrait::::NaN()); + let membership_values = Option::Some(TensorTrait::new(shape.span(), data.span())); + + let n_targets = 4; + let aggregate_function = AGGREGATE_FUNCTION::SUM; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0].span(); + let nodes_modes: Span = array![NODE_MODE::LEQ, NODE_MODE::MEMBER, NODE_MODE::MEMBER] + .span(); + + let nodes_featureids: Span = array![0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 1].span(); + let nodes_trueleafs: Span = array![0, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 2, 3].span(); + let nodes_falseleafs: Span = array![1, 0, 1].span(); + let leaf_targetids: Span = array![0, 1, 2, 3].span(); + + let mut scores = TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + membership_values, + n_targets + ); + + // ASSERT SCOREs + assert(scores.at(0, 0) == FP16x16 { mag: 65536, sign: false }, 'scores.at(0, 0)'); + assert(scores.at(0, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 1)'); + assert(scores.at(0, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 2)'); + assert(scores.at(0, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 3)'); + + assert(scores.at(1, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 0)'); + assert(scores.at(1, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 1)'); + assert(scores.at(1, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 2)'); + assert(scores.at(1, 3) == FP16x16 { mag: 6553600, sign: false }, 'scores.at(1, 3)'); + + assert(scores.at(2, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 0)'); + assert(scores.at(2, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 1)'); + assert(scores.at(2, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 2)'); + assert(scores.at(2, 3) == FP16x16 { mag: 6553600, sign: false }, 'scores.at(2, 3)'); + + assert(scores.at(3, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(3, 0)'); + assert(scores.at(3, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(3, 1)'); + assert(scores.at(3, 2) == FP16x16 { mag: 65536000, sign: false }, 'scores.at(3, 2)'); + assert(scores.at(3, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(3, 3)'); + + assert(scores.at(4, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(4, 0)'); + assert(scores.at(4, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(4, 1)'); + assert(scores.at(4, 2) == FP16x16 { mag: 65536000, sign: false }, 'scores.at(4, 2)'); + assert(scores.at(4, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(4, 3)'); + + assert(scores.at(5, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(5, 0)'); + assert(scores.at(5, 1) == FP16x16 { mag: 655360, sign: false }, 'scores.at(5, 1)'); + assert(scores.at(5, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(5, 2)'); + assert(scores.at(5, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(5, 3)'); +} + From 28b82f8cd673b5afc81d6696ee742c06108e7036 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Mon, 1 Apr 2024 16:49:25 +0200 Subject: [PATCH 3/3] feat: roi_align --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + .../operators/neural-network/nn.roi_align.md | 126 ++ nodegen/node/roi_align.py | 496 ++++++++ src/operators/matrix.cairo | 319 ++--- src/operators/ml/svm/core.cairo | 9 +- src/operators/ml/svm/svm_classifier.cairo | 228 ++-- src/operators/ml/svm/svm_regressor.cairo | 63 +- src/operators/nn/core.cairo | 139 +++ src/operators/nn/functional.cairo | 1 + src/operators/nn/functional/col2im.cairo | 2 +- src/operators/nn/functional/conv.cairo | 1057 +++++++++-------- src/operators/nn/functional/grid_sample.cairo | 281 ++--- src/operators/nn/functional/roi_align.cairo | 450 +++++++ .../nn/implementations/nn_fp16x16.cairo | 26 + .../nn/implementations/nn_fp32x32.cairo | 26 + .../nn/implementations/nn_fp64x64.cairo | 26 + .../nn/implementations/nn_fp8x23.cairo | 24 + src/operators/nn/implementations/nn_i32.cairo | 16 + src/operators/nn/implementations/nn_i8.cairo | 16 + src/operators/nn/implementations/nn_u32.cairo | 16 + src/operators/tensor/core.cairo | 2 +- src/operators/tensor/helpers.cairo | 82 +- .../implementations/tensor_complex64.cairo | 5 +- .../implementations/tensor_fp16x16.cairo | 22 +- .../implementations/tensor_fp16x16wide.cairo | 23 +- .../implementations/tensor_fp32x32.cairo | 22 +- .../implementations/tensor_fp64x64.cairo | 22 +- .../implementations/tensor_fp8x23wide.cairo | 22 +- .../tensor/implementations/tensor_i32.cairo | 18 +- .../tensor/implementations/tensor_i8.cairo | 18 +- .../tensor/implementations/tensor_u32.cairo | 18 +- src/operators/tensor/linalg/transpose.cairo | 11 +- src/operators/tensor/manipulation/split.cairo | 135 ++- .../manipulation/split_to_sequence.cairo | 2 +- src/operators/tensor/math/cumsum.cairo | 146 +-- src/operators/tensor/math/gather_nd.cairo | 9 +- .../tensor/math/layer_normalization.cairo | 3 +- src/operators/tensor/math/less_equal.cairo | 7 +- src/operators/tensor/math/max.cairo | 45 +- src/operators/tensor/math/min.cairo | 45 +- src/operators/tensor/math/range.cairo | 11 +- src/operators/tensor/math/reduce_l1.cairo | 7 +- src/operators/tensor/math/resize.cairo | 296 ++--- .../tensor/quantization/qlinear_matmul.cairo | 28 +- tests/lib.cairo | 1 - tests/nodes.cairo | 3 + tests/nodes/gather_elements_axis1.cairo | 2 +- tests/nodes/gather_elements_axis2.cairo | 2 +- tests/nodes/gather_elements_default.cairo | 2 +- .../gather_elements_negative_indices.cairo | 2 +- tests/nodes/gather_fp16x16_3d_axis1.cairo | 2 +- tests/nodes/gather_fp16x16_3d_axis2.cairo | 2 +- tests/nodes/gather_fp16x16_3d_default.cairo | 2 +- tests/nodes/gather_negative_axis.cairo | 2 +- tests/nodes/gather_negative_indices.cairo | 2 +- tests/nodes/reshape_reduced_dims.cairo | 2 +- tests/nodes/reshape_reordered_all_dims.cairo | 2 +- tests/nodes/reshape_reordered_last_dims.cairo | 2 +- tests/nodes/roi_align_aligned_false.cairo | 35 + .../roi_align_aligned_false/input_0.cairo | 115 ++ .../roi_align_aligned_false/input_1.cairo | 25 + .../roi_align_aligned_false/output_0.cairo | 90 ++ tests/nodes/roi_align_aligned_true.cairo | 36 + .../roi_align_aligned_true/input_0.cairo | 115 ++ .../roi_align_aligned_true/input_1.cairo | 25 + .../roi_align_aligned_true/output_0.cairo | 90 ++ tests/nodes/roi_align_mode_max.cairo | 36 + tests/nodes/roi_align_mode_max/input_0.cairo | 115 ++ tests/nodes/roi_align_mode_max/input_1.cairo | 25 + tests/nodes/roi_align_mode_max/output_0.cairo | 90 ++ 71 files changed, 3743 insertions(+), 1404 deletions(-) create mode 100644 docs/framework/operators/neural-network/nn.roi_align.md create mode 100644 nodegen/node/roi_align.py create mode 100644 src/operators/nn/functional/roi_align.cairo create mode 100644 tests/nodes/roi_align_aligned_false.cairo create mode 100644 tests/nodes/roi_align_aligned_false/input_0.cairo create mode 100644 tests/nodes/roi_align_aligned_false/input_1.cairo create mode 100644 tests/nodes/roi_align_aligned_false/output_0.cairo create mode 100644 tests/nodes/roi_align_aligned_true.cairo create mode 100644 tests/nodes/roi_align_aligned_true/input_0.cairo create mode 100644 tests/nodes/roi_align_aligned_true/input_1.cairo create mode 100644 tests/nodes/roi_align_aligned_true/output_0.cairo create mode 100644 tests/nodes/roi_align_mode_max.cairo create mode 100644 tests/nodes/roi_align_mode_max/input_0.cairo create mode 100644 tests/nodes/roi_align_mode_max/input_1.cairo create mode 100644 tests/nodes/roi_align_mode_max/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 477601b37..60cb1807c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -178,6 +178,7 @@ * [nn.conv](framework/operators/neural-network/nn.conv.md) * [nn.depth_to_space](framework/operators/neural-network/nn.depth_to_space.md) * [nn.space_to_depth](framework/operators/neural-network/nn.space_to_depth.md) + * [nn.roi\_align](framework/operators/neural-network/nn.roi\_align.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Ensemble Classifier](framework/operators/machine-learning/tree-ensemble-classifier/README.md) * [tree\_ensemble\_classifier.predict](framework/operators/machine-learning/tree-ensemble-classifier/tree\_ensemble\_classifier.predict.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index f3f84ac3f..fd8faf0bf 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -47,6 +47,7 @@ You can see below the list of current supported ONNX Operators: | [Col2im](operators/neural-network/nn.col2im\_sigmoid.md) | :white\_check\_mark: | | [ConvTranspose](operators/neural-network/nn.conv\_transpose_.md) | :white\_check\_mark: | | [Conv](operators/neural-network/nn.conv.md) | :white\_check\_mark: | +| [RoiAlign](operators/neural-network/nn.roi\_align.md) | :white\_check\_mark: | | [Sinh](operators/tensor/tensor.sinh.md) | :white\_check\_mark: | | [Asinh](operators/tensor/tensor.asinh.md) | :white\_check\_mark: | | [Atanh](operators/tensor/tensor.atanh.md) | :white\_check\_mark: | diff --git a/docs/framework/operators/neural-network/nn.roi_align.md b/docs/framework/operators/neural-network/nn.roi_align.md new file mode 100644 index 000000000..2021f2317 --- /dev/null +++ b/docs/framework/operators/neural-network/nn.roi_align.md @@ -0,0 +1,126 @@ +# NNTrait::roi_align + +```rust + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, +) -> Tensor; +``` + + RoiAlign consumes an input tensor X and region of interests (rois) to apply pooling across each RoI; it produces a 4-D tensor of shape (num_rois, C, output_height, output_width). + +## Args + +* `X`(`@Tensor`) - Input data tensor from the previous operator; 4-D feature map of shape (N, C, H, W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data. +* `rois`(`@Tensor`) - RoIs (Regions of Interest) to pool over; rois is 2-D input of shape (num_rois, 4) given as [[x1, y1, x2, y2], ...]. +* `batch_indices`(`@Tensor`) - 1-D tensor of shape (num_rois,) with each element denoting the index of the corresponding image in the batch. +* `coordinate_transformation_mode`(`Option`) - Allowed values are 'half_pixel' and 'output_half_pixel'. Use the value 'half_pixel' to pixel shift the input coordinates by -0.5 (default behavior). Use the value 'output_half_pixel' to omit the pixel shift for the input +* `mode`(`Option`) -The pooling method. Two modes are supported: 'avg' and 'max'. Default is 'avg'. +* `output_height`(`Option`) - default 1; Pooled output Y's height. +* `output_width`(`Option`) - default 1; Pooled output Y's width. +* `sampling_ratio`(`Option`) - Number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. If > 0, then exactly sampling_ratio x sampling_ratio grid points are used. If == 0, then an adaptive number of grid points are used (computed as ceil(roi_width / output_width), and likewise for height). Default is 0. +* `spatial_scale`(`Option`) - Multiplicative spatial scale factor to translate ROI coordinates from their input spatial scale to the scale used when pooling, i.e., spatial scale of the input feature map X relative to the input image. Default is 1.0. + +## Returns + +A `Tensor` RoI pooled output, 4-D tensor of shape (num_rois, C, output_height, output_width). The r-th batch element Y[r-1] is a pooled feature map corresponding to the r-th RoI X[r-1]. + +## Example + +```rust +use orion::operators::nn::NNTrait; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor}; +use orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE; + +fn example_roi_align() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(5); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 18114, sign: false }); + data.append(FP16x16 { mag: 46858, sign: false }); + data.append(FP16x16 { mag: 12831, sign: false }); + data.append(FP16x16 { mag: 22387, sign: false }); + data.append(FP16x16 { mag: 30395, sign: false }); + data.append(FP16x16 { mag: 63157, sign: false }); + data.append(FP16x16 { mag: 5865, sign: false }); + data.append(FP16x16 { mag: 19129, sign: false }); + data.append(FP16x16 { mag: 44256, sign: false }); + data.append(FP16x16 { mag: 1533, sign: false }); + data.append(FP16x16 { mag: 21397, sign: false }); + data.append(FP16x16 { mag: 55567, sign: false }); + data.append(FP16x16 { mag: 63556, sign: false }); + data.append(FP16x16 { mag: 16193, sign: false }); + data.append(FP16x16 { mag: 61184, sign: false }); + data.append(FP16x16 { mag: 1350, sign: false }); + data.append(FP16x16 { mag: 11272, sign: false }); + data.append(FP16x16 { mag: 14123, sign: false }); + data.append(FP16x16 { mag: 28796, sign: false }); + data.append(FP16x16 { mag: 4279, sign: false }); + data.append(FP16x16 { mag: 26620, sign: false }); + data.append(FP16x16 { mag: 14378, sign: false }); + data.append(FP16x16 { mag: 29314, sign: false }); + data.append(FP16x16 { mag: 30716, sign: false }); + data.append(FP16x16 { mag: 46589, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let batch_indices = TensorTrait::new(array![3].span(), array![0, 0, 0].span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + let rois = TensorTrait::new(shape.span(), data.span()); + + return roi_align( + @X, + @rois, + @batch_indices, + Option::Some(TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL), + Option::None, + Option::Some(2), + Option::Some(2), + Option::Some(FP16x16 { mag: 65536, sign: false }), + Option::Some(FP16x16 { mag: 32768, sign: false }), + ); +} +>>> [[[[0.2083422 , 0.44005 ], + [0.20385626, 0.39676717]]], + + + [[[0.09630001, 0.19375 ], + [0.3128 , 0.33335 ]]], + + + [[[0.4394 , 0.0653 ], + [0.4687 , 0.7109 ]]]] + +```` \ No newline at end of file diff --git a/nodegen/node/roi_align.py b/nodegen/node/roi_align.py new file mode 100644 index 000000000..a9f070bad --- /dev/null +++ b/nodegen/node/roi_align.py @@ -0,0 +1,496 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait +from typing import Tuple + +import numpy as np + + +class PreCalc: + def __init__(self, pos1=0, pos2=0, pos3=0, pos4=0, w1=0, w2=0, w3=0, w4=0): # type: ignore + self.pos1 = pos1 + self.pos2 = pos2 + self.pos3 = pos3 + self.pos4 = pos4 + self.w1 = w1 + self.w2 = w2 + self.w3 = w3 + self.w4 = w4 + + def __repr__(self) -> str: + return f"PreCalc({self.pos1},{self.pos2},{self.pos3},{self.pos4},{self.w1},{self.w2},{self.w3},{self.w4})" + + + +def pre_calc_for_bilinear_interpolate( # type: ignore + height: int, + width: int, + pooled_height: int, + pooled_width: int, + iy_upper: int, + ix_upper: int, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h: int, + roi_bin_grid_w: int, + pre_calc, +): + pre_calc_index = 0 + for ph in range(pooled_height): + for pw in range(pooled_width): + for iy in range(iy_upper): + yy = ( + roi_start_h + + ph * bin_size_h + + (iy + 0.5) * bin_size_h / roi_bin_grid_h + ) + for ix in range(ix_upper): + xx = ( + roi_start_w + + pw * bin_size_w + + (ix + 0.5) * bin_size_w / roi_bin_grid_w + ) + x = xx + y = yy + # deal with: inverse elements are out of feature map boundary + if y < -1.0 or y > height or x < -1.0 or x > width: + pc = pre_calc[pre_calc_index] + pc.pos1 = 0 + pc.pos2 = 0 + pc.pos3 = 0 + pc.pos4 = 0 + pc.w1 = 0 + pc.w2 = 0 + pc.w3 = 0 + pc.w4 = 0 + pre_calc_index += 1 + continue + y = max(y, 0) + x = max(x, 0) + y_low = int(y) + x_low = int(x) + if y_low >= height - 1: + y_high = y_low = height - 1 + y = y_low + else: + y_high = y_low + 1 + if x_low >= width - 1: + x_high = x_low = width - 1 + x = x_low + else: + x_high = x_low + 1 + ly = y - y_low + lx = x - x_low + hy = 1.0 - ly + hx = 1.0 - lx + w1 = hy * hx + w2 = hy * lx + w3 = ly * hx + w4 = ly * lx + # save weights and indeces + pc = PreCalc() + pc.pos1 = y_low * width + x_low + pc.pos2 = y_low * width + x_high + pc.pos3 = y_high * width + x_low + pc.pos4 = y_high * width + x_high + pc.w1 = w1 + pc.w2 = w2 + pc.w3 = w3 + pc.w4 = w4 + pre_calc[pre_calc_index] = pc + pre_calc_index += 1 + + +def roi_align_forward( # type: ignore + output_shape: Tuple[int, int, int, int], + bottom_data, + spatial_scale, + height: int, + width: int, + sampling_ratio, + bottom_rois, + num_roi_cols: int, + top_data, + mode, + half_pixel: bool, + batch_indices_ptr, +): + n_rois = output_shape[0] + channels = output_shape[1] + pooled_height = output_shape[2] + pooled_width = output_shape[3] + # 100 is a random chosed value, need be tuned + for n in range(n_rois): + index_n = n * channels * pooled_width * pooled_height + # bottom_rois + offset_bottom_rois = n * num_roi_cols + roi_batch_ind = batch_indices_ptr[n] + # Do not using rounding; this implementation detail is critical. + offset = 0.5 if half_pixel else 0.0 + roi_start_w = bottom_rois[offset_bottom_rois + 0] * spatial_scale - offset + roi_start_h = bottom_rois[offset_bottom_rois + 1] * spatial_scale - offset + roi_end_w = bottom_rois[offset_bottom_rois + 2] * spatial_scale - offset + roi_end_h = bottom_rois[offset_bottom_rois + 3] * spatial_scale - offset + roi_width = roi_end_w - roi_start_w + roi_height = roi_end_h - roi_start_h + if not half_pixel: + # Force malformed ROIs to be 1x1 + roi_width = max(roi_width, 1.0) + roi_height = max(roi_height, 1.0) + bin_size_h = roi_height / pooled_height + bin_size_w = roi_width / pooled_width + # We use roi_bin_grid to sample the grid and mimic integral + roi_bin_grid_h = ( + int(sampling_ratio) + if sampling_ratio > 0 + else int(np.ceil(roi_height / pooled_height)) + ) + roi_bin_grid_w = ( + int(sampling_ratio) + if sampling_ratio > 0 + else int(np.ceil(roi_width / pooled_width)) + ) + # We do average (integral) pooling inside a bin + count = int(max(roi_bin_grid_h * roi_bin_grid_w, 1)) + # we want to precalculate indices and weights shared by all channels, + # this is the key point of optimization + pre_calc = [ + PreCalc() + for i in range( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height + ) + ] + pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc, + ) + for c in range(channels): + index_n_c = index_n + c * pooled_width * pooled_height + # bottom_data + offset_bottom_data = int( + (roi_batch_ind * channels + c) * height * width + ) + + pre_calc_index = 0 + for ph in range(pooled_height): + for pw in range(pooled_width): + index = index_n_c + ph * pooled_width + pw + output_val = 0.0 + if mode == "avg": # avg pooling + for _iy in range(roi_bin_grid_h): + for _ix in range(roi_bin_grid_w): + pc = pre_calc[pre_calc_index] + output_val += ( + pc.w1 + * bottom_data[offset_bottom_data + pc.pos1] + + pc.w2 + * bottom_data[offset_bottom_data + pc.pos2] + + pc.w3 + * bottom_data[offset_bottom_data + pc.pos3] + + pc.w4 + * bottom_data[offset_bottom_data + pc.pos4] + ) + pre_calc_index += 1 + output_val /= count + else: # max pooling + max_flag = False + for _iy in range(roi_bin_grid_h): + for _ix in range(roi_bin_grid_w): + pc = pre_calc[pre_calc_index] + val = max( + pc.w1 + * bottom_data[offset_bottom_data + pc.pos1], + pc.w2 + * bottom_data[offset_bottom_data + pc.pos2], + pc.w3 + * bottom_data[offset_bottom_data + pc.pos3], + pc.w4 + * bottom_data[offset_bottom_data + pc.pos4], + ) + if not max_flag: + output_val = val + max_flag = True + else: + output_val = max(output_val, val) + pre_calc_index += 1 + top_data[index] = output_val + +def roi_align( + X, + rois, + batch_indices, + coordinate_transformation_mode=None, + mode=None, + output_height=None, + output_width=None, + sampling_ratio=None, + spatial_scale=None, +): + num_channels = X.shape[1] + num_rois = batch_indices.shape[0] + num_roi_cols = rois.shape[1] + y_dims = (num_rois, num_channels, output_height, output_width) + Y = np.empty(y_dims, dtype=X.dtype).flatten() + roi_align_forward( + y_dims, + X.flatten(), + spatial_scale, + X.shape[2], # height, 3 + X.shape[3], # width, 4 + sampling_ratio, + rois.flatten(), + num_roi_cols, + Y, + mode.lower(), + coordinate_transformation_mode.lower() == "half_pixel", + batch_indices.flatten(), + ) + return (Y.reshape(y_dims).astype(X.dtype),) + +def get_roi_align_input_values(): # type: ignore + X = np.array( + [ + [ + [ + [ + 0.2764, + 0.7150, + 0.1958, + 0.3416, + 0.4638, + 0.0259, + 0.2963, + 0.6518, + 0.4856, + 0.7250, + ], + [ + 0.9637, + 0.0895, + 0.2919, + 0.6753, + 0.0234, + 0.6132, + 0.8085, + 0.5324, + 0.8992, + 0.4467, + ], + [ + 0.3265, + 0.8479, + 0.9698, + 0.2471, + 0.9336, + 0.1878, + 0.4766, + 0.4308, + 0.3400, + 0.2162, + ], + [ + 0.0206, + 0.1720, + 0.2155, + 0.4394, + 0.0653, + 0.3406, + 0.7724, + 0.3921, + 0.2541, + 0.5799, + ], + [ + 0.4062, + 0.2194, + 0.4473, + 0.4687, + 0.7109, + 0.9327, + 0.9815, + 0.6320, + 0.1728, + 0.6119, + ], + [ + 0.3097, + 0.1283, + 0.4984, + 0.5068, + 0.4279, + 0.0173, + 0.4388, + 0.0430, + 0.4671, + 0.7119, + ], + [ + 0.1011, + 0.8477, + 0.4726, + 0.1777, + 0.9923, + 0.4042, + 0.1869, + 0.7795, + 0.9946, + 0.9689, + ], + [ + 0.1366, + 0.3671, + 0.7011, + 0.6234, + 0.9867, + 0.5585, + 0.6985, + 0.5609, + 0.8788, + 0.9928, + ], + [ + 0.5697, + 0.8511, + 0.6711, + 0.9406, + 0.8751, + 0.7496, + 0.1650, + 0.1049, + 0.1559, + 0.2514, + ], + [ + 0.7012, + 0.4056, + 0.7879, + 0.3461, + 0.0415, + 0.2998, + 0.5094, + 0.3727, + 0.5482, + 0.0502, + ], + ] + ] + ], + dtype=np.float32, + ) + batch_indices = np.array([0, 0, 0], dtype=np.int64) + rois = np.array([[0, 0, 9, 9], [0, 5, 4, 9], [5, 5, 9, 9]], dtype=np.float32) + return X, batch_indices, rois + + +class Roi_align(RunAll): + + @staticmethod + def export_roialign_aligned_false() -> None: + + + + x, batch_indices, rois = get_roi_align_input_values() + y = roi_align(X=x, rois=rois, batch_indices=batch_indices, spatial_scale=1.0, + output_height=5, + output_width=5, + sampling_ratio=2, + coordinate_transformation_mode="output_half_pixel",mode ='avg') + + y = np.array(y) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + rois = Tensor(Dtype.FP16x16, rois.shape, to_fp(rois.flatten(), FixedImpl.FP16x16)) + + name = "roi_align_aligned_false" + func_sig = "NNTrait::roi_align(" + func_sig += "@input_0," + func_sig += "@input_1," + func_sig += "@TensorTrait::new(array![3].span(), array![0, 0, 0].span())," + func_sig += "Option::Some(TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL)," + func_sig += "Option::None," + func_sig += "Option::Some(5)," + func_sig += "Option::Some(5)," + func_sig += "Option::Some(FP16x16 { mag: 131072, sign: false })," + func_sig += "Option::Some(FP16x16 { mag: 65536, sign: false }))" + make_test( + [x, rois], y, func_sig, name, Trait.NN) + + @staticmethod + def export_roialign_aligned_true() -> None: + + x, batch_indices, rois = get_roi_align_input_values() + # (num_rois, C, output_height, output_width) + y = roi_align(X=x, rois=rois, batch_indices=batch_indices, spatial_scale=1.0, + output_height=5, + output_width=5, + sampling_ratio=2, + coordinate_transformation_mode="half_pixel",mode ='avg') + + + y = np.array(y) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + rois = Tensor(Dtype.FP16x16, rois.shape, to_fp(rois.flatten(), FixedImpl.FP16x16)) + + + name = "roi_align_aligned_true" + func_sig = "NNTrait::roi_align(" + func_sig += "@input_0," + func_sig += "@input_1," + func_sig += "@TensorTrait::new(array![3].span(), array![0, 0, 0].span())," + func_sig += "Option::Some(TRANSFORMATION_MODE::HALF_PIXEL)," + func_sig += "Option::None," + func_sig += "Option::Some(5)," + func_sig += "Option::Some(5)," + func_sig += "Option::Some(FP16x16 { mag: 131072, sign: false })," + func_sig += "Option::Some(FP16x16 { mag: 65536, sign: false }))" + make_test( + [x, rois], y, func_sig, name, Trait.NN) + + @staticmethod + def export_roialign_mode_max() -> None: + x, batch_indices, rois = get_roi_align_input_values() + y = roi_align(X=x, rois=rois, batch_indices=batch_indices, spatial_scale=1.0, + output_height=5, + output_width=5, + sampling_ratio=2, + coordinate_transformation_mode="output_half_pixel",mode ='max') + + y = np.array(y) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp(x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + rois = Tensor(Dtype.FP16x16, rois.shape, to_fp(rois.flatten(), FixedImpl.FP16x16)) + + + name = "roi_align_mode_max" + func_sig = "NNTrait::roi_align(" + func_sig += "@input_0," + func_sig += "@input_1," + func_sig += "@TensorTrait::new(array![3].span(), array![0, 0, 0].span())," + func_sig += "Option::Some(TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL)," + func_sig += "Option::Some(MODE::MAX)," + func_sig += "Option::Some(5)," + func_sig += "Option::Some(5)," + func_sig += "Option::Some(FP16x16 { mag: 131072, sign: false })," + func_sig += "Option::Some(FP16x16 { mag: 65536, sign: false }))" + make_test( + [x, rois], y, func_sig, name, Trait.NN) + + + diff --git a/src/operators/matrix.cairo b/src/operators/matrix.cairo index 5e7564d11..efdee2e3a 100644 --- a/src/operators/matrix.cairo +++ b/src/operators/matrix.cairo @@ -90,66 +90,70 @@ impl MutMatrixImpl< if axis == 0 { let mut col: usize = 0; - while col != self.cols { - let mut max_value = self.get(0, col); - let mut max_value = match max_value { - Option::Some => { max_value.unwrap() }, - Option::None => { NumberTrait::min_value() } - }; - let mut max_index = 0; - - let mut row: usize = 1; - while row != self.rows { - let mut value = self.get(row, col); - let mut value = match value { - Option::Some => { value.unwrap() }, + while col != self + .cols { + let mut max_value = self.get(0, col); + let mut max_value = match max_value { + Option::Some => { max_value.unwrap() }, Option::None => { NumberTrait::min_value() } }; - - if value > max_value { - max_value = value; - max_index = row; - } - - row += 1; + let mut max_index = 0; + + let mut row: usize = 1; + while row != self + .rows { + let mut value = self.get(row, col); + let mut value = match value { + Option::Some => { value.unwrap() }, + Option::None => { NumberTrait::min_value() } + }; + + if value > max_value { + max_value = value; + max_index = row; + } + + row += 1; + }; + + result.append(max_index); + col += 1; }; - result.append(max_index); - col += 1; - }; - return result.span(); } let mut row: usize = 0; - while row != self.rows { - let mut max_value = self.get(row, 0); - let mut max_value = match max_value { - Option::Some => { max_value.unwrap() }, - Option::None => { NumberTrait::min_value() } - }; - let mut max_index = 0; - - let mut col: usize = 1; - while col != self.cols { - let mut value = self.get(row, col); - let mut value = match value { - Option::Some => { value.unwrap() }, + while row != self + .rows { + let mut max_value = self.get(row, 0); + let mut max_value = match max_value { + Option::Some => { max_value.unwrap() }, Option::None => { NumberTrait::min_value() } }; + let mut max_index = 0; - if value > max_value { - max_value = value; - max_index = col; - } + let mut col: usize = 1; + while col != self + .cols { + let mut value = self.get(row, col); + let mut value = match value { + Option::Some => { value.unwrap() }, + Option::None => { NumberTrait::min_value() } + }; + + if value > max_value { + max_value = value; + max_index = col; + } + + col += 1; + }; - col += 1; + result.append(max_index); + row += 1; }; - result.append(max_index); - row += 1; - }; - result.span() } @@ -161,50 +165,56 @@ impl MutMatrixImpl< if axis == 0 { let mut col: usize = 0; - while col != self.cols { - let mut sum_exp = NumberTrait::zero(); - let mut row: usize = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - sum_exp += value.exp(); - - row += 1; - }; - - row = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); + while col != self + .cols { + let mut sum_exp = NumberTrait::zero(); + let mut row: usize = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + sum_exp += value.exp(); + + row += 1; + }; + + row = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + + row += 1; + }; - row += 1; + col += 1; }; - - col += 1; - }; } else { let mut row: usize = 0; - while row != self.rows { - let mut sum_exp = NumberTrait::zero(); - let mut col: usize = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - sum_exp += value.exp(); - - col += 1; - }; + while row != self + .rows { + let mut sum_exp = NumberTrait::zero(); + let mut col: usize = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + sum_exp += value.exp(); + + col += 1; + }; + + col = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + + col += 1; + }; - col = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); - - col += 1; + row += 1; }; - - row += 1; - }; } result @@ -220,65 +230,71 @@ impl MutMatrixImpl< if axis == 0 { let mut col: usize = 0; - while col != self.cols { - let mut sum_exp = NumberTrait::zero(); - let mut row: usize = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - - if value != NumberTrait::zero() { - sum_exp += value.exp(); - } + while col != self + .cols { + let mut sum_exp = NumberTrait::zero(); + let mut row: usize = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + + if value != NumberTrait::zero() { + sum_exp += value.exp(); + } + + row += 1; + }; + + row = 0; + while row != self + .rows { + let value = self.get(row, col).unwrap().into(); + + if value != NumberTrait::zero() { + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + } else { + result.set(row, col, NumberTrait::zero()); + } + + row += 1; + }; - row += 1; - }; - - row = 0; - while row != self.rows { - let value = self.get(row, col).unwrap().into(); - - if value != NumberTrait::zero() { - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); - } else { - result.set(row, col, NumberTrait::zero()); - } - - row += 1; + col += 1; }; - - col += 1; - }; } else { let mut row: usize = 0; - while row != self.rows { - let mut sum_exp = NumberTrait::zero(); - let mut col: usize = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - if value != NumberTrait::zero() { - sum_exp += value.exp(); - } + while row != self + .rows { + let mut sum_exp = NumberTrait::zero(); + let mut col: usize = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + if value != NumberTrait::zero() { + sum_exp += value.exp(); + } + + col += 1; + }; + + col = 0; + while col != self + .cols { + let value = self.get(row, col).unwrap().into(); + + if value != NumberTrait::zero() { + let softmax_value = (value.exp() / sum_exp).into(); + result.set(row, col, softmax_value); + } else { + result.set(row, col, NumberTrait::zero()); + } + + col += 1; + }; - col += 1; - }; - - col = 0; - while col != self.cols { - let value = self.get(row, col).unwrap().into(); - - if value != NumberTrait::zero() { - let softmax_value = (value.exp() / sum_exp).into(); - result.set(row, col, softmax_value); - } else { - result.set(row, col, NumberTrait::zero()); - } - - col += 1; + row += 1; }; - - row += 1; - }; } result @@ -289,23 +305,26 @@ impl MutMatrixImpl< let mut result = MutMatrixImpl::new(self.rows, self.cols); let mut row: usize = 0; - while row != self.rows { - let mut col: usize = 0; - while col != self.cols { - let value = self.get(row, col); + while row != self + .rows { + let mut col: usize = 0; + while col != self + .cols { + let value = self.get(row, col); - if value.is_some() { - let value = NumberTrait::one() - / (NumberTrait::one() + (value.unwrap() * NumberTrait::neg_one()).exp()); + if value.is_some() { + let value = NumberTrait::one() + / (NumberTrait::one() + + (value.unwrap() * NumberTrait::neg_one()).exp()); - result.set(row, col, value); - } + result.set(row, col, value); + } - col += 1; - }; + col += 1; + }; - row += 1; - }; + row += 1; + }; result } diff --git a/src/operators/ml/svm/core.cairo b/src/operators/ml/svm/core.cairo index 365cb0c1b..64c853077 100644 --- a/src/operators/ml/svm/core.cairo +++ b/src/operators/ml/svm/core.cairo @@ -81,10 +81,11 @@ fn squared_diff< ) -> T { let mut i = 0; let mut sum = NumberTrait::zero(); - while i != pA.len() { - sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); - i += 1; - }; + while i != pA + .len() { + sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); + i += 1; + }; sum } diff --git a/src/operators/ml/svm/svm_classifier.cairo b/src/operators/ml/svm/svm_classifier.cairo index 4df3d63f6..8d1f8b90b 100644 --- a/src/operators/ml/svm/svm_classifier.cairo +++ b/src/operators/ml/svm/svm_classifier.cairo @@ -266,11 +266,12 @@ impl SVMClassifierImpl< let (vectors_per_class_, starting_vector_) = match self.vectors_per_class { Option::Some(vectors_per_class) => { let mut i = 0; - while i != vectors_per_class.len() { - starting_vector_.append(vector_count_); - vector_count_ += *vectors_per_class.at(i); - i += 1; - }; + while i != vectors_per_class + .len() { + starting_vector_.append(vector_count_); + vector_count_ += *vectors_per_class.at(i); + i += 1; + }; (vectors_per_class, starting_vector_.span()) }, @@ -309,17 +310,19 @@ impl SVMClassifierImpl< MODE::SVM_LINEAR => { let mut res: Array = array![]; let mut n = 0; - while n != *X.shape.at(0) { - let mut x_n = get_row(@X, n); - let scores = run_linear(ref self, x_n, coefs, class_count_, kernel_type_); - let mut i = 0; - while i != scores.len() { - res.append(*scores.at(i)); - i += 1; - }; + while n != *X + .shape + .at(0) { + let mut x_n = get_row(@X, n); + let scores = run_linear(ref self, x_n, coefs, class_count_, kernel_type_); + let mut i = 0; + while i != scores.len() { + res.append(*scores.at(i)); + i += 1; + }; - n += 1; - }; + n += 1; + }; ( TensorTrait::new(array![*X.shape.at(0), class_count_].span(), res.span()), @@ -330,33 +333,35 @@ impl SVMClassifierImpl< let mut res: Array = array![]; let mut votes: Array = array![]; let mut n = 0; - while n != *X.shape.at(0) { - let mut x_n = get_row(@X, n); - let (scores, mut vote) = run_svm( - ref self, - x_n, - sv, - vector_count_, - kernel_type_, - class_count_, - starting_vector_, - coefs, - vectors_per_class_ - ); - let mut i = 0; - while i != scores.len() { - res.append(*scores.at(i)); - i += 1; - }; + while n != *X + .shape + .at(0) { + let mut x_n = get_row(@X, n); + let (scores, mut vote) = run_svm( + ref self, + x_n, + sv, + vector_count_, + kernel_type_, + class_count_, + starting_vector_, + coefs, + vectors_per_class_ + ); + let mut i = 0; + while i != scores.len() { + res.append(*scores.at(i)); + i += 1; + }; - let mut i = 0; - while i != vote.len() { - votes.append(vote.at(i)); - i += 1; - }; + let mut i = 0; + while i != vote.len() { + votes.append(vote.at(i)); + i += 1; + }; - n += 1; - }; + n += 1; + }; ( TensorTrait::new( @@ -377,18 +382,20 @@ impl SVMClassifierImpl< let (scores, has_proba) = if self.prob_a.len() > 0 { let mut scores: Array = array![]; let mut n = 0; - while n != *res.shape.at(0) { - let res_n = get_row(@res, n); - let mut s = probablities(ref self, res_n, class_count_); - - let mut i = 0; - while i != s.len() { - scores.append(s.at(i)); - i += 1; + while n != *res + .shape + .at(0) { + let res_n = get_row(@res, n); + let mut s = probablities(ref self, res_n, class_count_); + + let mut i = 0; + while i != s.len() { + scores.append(s.at(i)); + i += 1; + }; + + n += 1; }; - - n += 1; - }; ( TensorTrait::new( array![*res.shape.at(0), scores.len() / *res.shape.at(0)].span(), @@ -409,50 +416,56 @@ impl SVMClassifierImpl< let mut final_scores: Array = array![]; let mut n = 0; - while n != *scores.shape.at(0) { - let mut scores_n = get_row(@scores, n); - match votes { - Option::Some(votes) => { - let mut votes_n = get_row(@votes, n); - let (label, new_scores) = compute_final_scores( - ref self, - votes_n, - scores_n, - weights_are_all_positive_, - has_proba, - self.classlabels - ); - - let mut i = 0; - while i != new_scores.data.len() { - final_scores.append(*new_scores.data.at(i)); - i += 1; - }; + while n != *scores + .shape + .at(0) { + let mut scores_n = get_row(@scores, n); + match votes { + Option::Some(votes) => { + let mut votes_n = get_row(@votes, n); + let (label, new_scores) = compute_final_scores( + ref self, + votes_n, + scores_n, + weights_are_all_positive_, + has_proba, + self.classlabels + ); - labels.append(label); - }, - Option::None => { - let (label, new_scores) = compute_final_scores( - ref self, - array![].span(), - scores_n, - weights_are_all_positive_, - has_proba, - self.classlabels - ); - - let mut i = 0; - while i != new_scores.data.len() { - final_scores.append(*new_scores.data.at(i)); - i += 1; - }; - - labels.append(label); - }, - } + let mut i = 0; + while i != new_scores + .data + .len() { + final_scores.append(*new_scores.data.at(i)); + i += 1; + }; + + labels.append(label); + }, + Option::None => { + let (label, new_scores) = compute_final_scores( + ref self, + array![].span(), + scores_n, + weights_are_all_positive_, + has_proba, + self.classlabels + ); - n += 1; - }; + let mut i = 0; + while i != new_scores + .data + .len() { + final_scores.append(*new_scores.data.at(i)); + i += 1; + }; + + labels.append(label); + }, + } + + n += 1; + }; let labels = labels.span(); @@ -460,10 +473,11 @@ impl SVMClassifierImpl< if self.classlabels.len() > 0 { let mut class_labels: Array = array![]; let mut i = 0; - while i != labels.len() { - class_labels.append(*self.classlabels.at(*labels.at(i))); - i += 1; - }; + while i != labels + .len() { + class_labels.append(*self.classlabels.at(*labels.at(i))); + i += 1; + }; return ( class_labels.span(), @@ -1070,11 +1084,12 @@ fn dot_start_end< let mut sum = NumberTrait::zero(); let mut index_a = a_start; let mut index_b = b_start; - while index_a != a_end && index_b != b_end { - sum = sum + *pA.at(index_a) * *pB.at(index_b); - index_a += 1; - index_b += 1; - }; + while index_a != a_end + && index_b != b_end { + sum = sum + *pA.at(index_a) * *pB.at(index_b); + index_a += 1; + index_b += 1; + }; sum } @@ -1110,10 +1125,11 @@ fn squared_diff< ) -> T { let mut i = 0; let mut sum = NumberTrait::zero(); - while i != pA.len() { - sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); - i += 1; - }; + while i != pA + .len() { + sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); + i += 1; + }; sum } diff --git a/src/operators/ml/svm/svm_regressor.cairo b/src/operators/ml/svm/svm_regressor.cairo index 1d5858a2f..286729ff4 100644 --- a/src/operators/ml/svm/svm_regressor.cairo +++ b/src/operators/ml/svm/svm_regressor.cairo @@ -189,40 +189,43 @@ impl SVMRegressorImpl< let mut z: Array = array![]; let mut n = 0; - while n != *X.shape.at(0) { - let mut s = NumberTrait::zero(); - match mode_ { - MODE::SVM_LINEAR => { - let mut x_n = get_row(@X, n); - s = kernel_dot(self.kernel_params, x_n, self.coefficients, kernel_type_); - s += *self.rho.at(0); - }, - MODE::SVM_SVC => { - let mut x_n = get_row(@X, n); - let mut j = 0; - while j != self.n_supports { - let mut sv_j = get_row(@sv, j); - let d = kernel_dot(self.kernel_params, x_n, sv_j, kernel_type_); - s += *self.coefficients.at(j) * d; - j += 1; - }; + while n != *X + .shape + .at(0) { + let mut s = NumberTrait::zero(); + match mode_ { + MODE::SVM_LINEAR => { + let mut x_n = get_row(@X, n); + s = kernel_dot(self.kernel_params, x_n, self.coefficients, kernel_type_); + s += *self.rho.at(0); + }, + MODE::SVM_SVC => { + let mut x_n = get_row(@X, n); + let mut j = 0; + while j != self + .n_supports { + let mut sv_j = get_row(@sv, j); + let d = kernel_dot(self.kernel_params, x_n, sv_j, kernel_type_); + s += *self.coefficients.at(j) * d; + j += 1; + }; - s += *self.rho.at(0); - }, - } - if self.one_class == 1 { - let elem = if s > NumberTrait::zero() { - NumberTrait::one() + s += *self.rho.at(0); + }, + } + if self.one_class == 1 { + let elem = if s > NumberTrait::zero() { + NumberTrait::one() + } else { + -NumberTrait::one() + }; + z.append(elem); } else { - -NumberTrait::one() + z.append(s); }; - z.append(elem); - } else { - z.append(s); - }; - n += 1; - }; + n += 1; + }; // Post Transform let mut score = TensorTrait::new(array![*X.shape.at(0)].span(), z.span()); diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index ad66197fd..e916b47c4 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -1304,4 +1304,143 @@ trait NNTrait { mode: Option, padding_mode: Option, ) -> Tensor; + /// # NNTrait::roi_align + /// + /// ```rust + /// fn roi_align( + /// X: @Tensor, + /// roi: @Tensor, + /// batch_indices: @Tensor, + /// coordinate_transformation_mode: Option< + /// orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + /// >, + /// mode: Option, + /// output_height: Option, + /// output_width: Option, + /// sampling_ratio: Option, + /// spatial_scale: Option, + /// ) -> Tensor; + /// ``` + /// + /// RoiAlign consumes an input tensor X and region of interests (rois) to apply pooling across each RoI; it produces a 4-D tensor of shape (num_rois, C, output_height, output_width). + /// + /// ## Args + /// + /// * `X`(`@Tensor`) - Input data tensor from the previous operator; 4-D feature map of shape (N, C, H, W), where N is the batch size, C is the number of channels, and H and W are the height and the width of the data. + /// * `rois`(`@Tensor`) - RoIs (Regions of Interest) to pool over; rois is 2-D input of shape (num_rois, 4) given as [[x1, y1, x2, y2], ...]. + /// * `batch_indices`(`@Tensor`) - 1-D tensor of shape (num_rois,) with each element denoting the index of the corresponding image in the batch. + /// * `coordinate_transformation_mode`(`Option`) - Allowed values are 'half_pixel' and 'output_half_pixel'. Use the value 'half_pixel' to pixel shift the input coordinates by -0.5 (default behavior). Use the value 'output_half_pixel' to omit the pixel shift for the input + /// * `mode`(`Option`) -The pooling method. Two modes are supported: 'avg' and 'max'. Default is 'avg'. + /// * `output_height`(`Option`) - default 1; Pooled output Y's height. + /// * `output_width`(`Option`) - default 1; Pooled output Y's width. + /// * `sampling_ratio`(`Option`) - Number of sampling points in the interpolation grid used to compute the output value of each pooled output bin. If > 0, then exactly sampling_ratio x sampling_ratio grid points are used. If == 0, then an adaptive number of grid points are used (computed as ceil(roi_width / output_width), and likewise for height). Default is 0. + /// * `spatial_scale`(`Option`) - Multiplicative spatial scale factor to translate ROI coordinates from their input spatial scale to the scale used when pooling, i.e., spatial scale of the input feature map X relative to the input image. Default is 1.0. + /// + /// ## Returns + /// + /// A `Tensor` RoI pooled output, 4-D tensor of shape (num_rois, C, output_height, output_width). The r-th batch element Y[r-1] is a pooled feature map corresponding to the r-th RoI X[r-1]. + /// + /// ## Example + /// + /// ```rust + /// use orion::operators::nn::NNTrait; + /// use orion::numbers::FixedTrait; + /// use orion::operators::nn::FP16x16NN; + /// use orion::numbers::FP16x16; + /// use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor}; + /// use orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE; + /// + /// fn example_roi_align() -> Tensor { + /// let mut shape = ArrayTrait::::new(); + /// shape.append(1); + /// shape.append(1); + /// shape.append(5); + /// shape.append(5); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 18114, sign: false }); + /// data.append(FP16x16 { mag: 46858, sign: false }); + /// data.append(FP16x16 { mag: 12831, sign: false }); + /// data.append(FP16x16 { mag: 22387, sign: false }); + /// data.append(FP16x16 { mag: 30395, sign: false }); + /// data.append(FP16x16 { mag: 63157, sign: false }); + /// data.append(FP16x16 { mag: 5865, sign: false }); + /// data.append(FP16x16 { mag: 19129, sign: false }); + /// data.append(FP16x16 { mag: 44256, sign: false }); + /// data.append(FP16x16 { mag: 1533, sign: false }); + /// data.append(FP16x16 { mag: 21397, sign: false }); + /// data.append(FP16x16 { mag: 55567, sign: false }); + /// data.append(FP16x16 { mag: 63556, sign: false }); + /// data.append(FP16x16 { mag: 16193, sign: false }); + /// data.append(FP16x16 { mag: 61184, sign: false }); + /// data.append(FP16x16 { mag: 1350, sign: false }); + /// data.append(FP16x16 { mag: 11272, sign: false }); + /// data.append(FP16x16 { mag: 14123, sign: false }); + /// data.append(FP16x16 { mag: 28796, sign: false }); + /// data.append(FP16x16 { mag: 4279, sign: false }); + /// data.append(FP16x16 { mag: 26620, sign: false }); + /// data.append(FP16x16 { mag: 14378, sign: false }); + /// data.append(FP16x16 { mag: 29314, sign: false }); + /// data.append(FP16x16 { mag: 30716, sign: false }); + /// data.append(FP16x16 { mag: 46589, sign: false }); + /// let mut X = TensorTrait::new(shape.span(), data.span()); + /// + /// let batch_indices = TensorTrait::new(array![3].span(), array![0, 0, 0].span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// shape.append(4); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 0, sign: false }); + /// data.append(FP16x16 { mag: 0, sign: false }); + /// data.append(FP16x16 { mag: 589824, sign: false }); + /// data.append(FP16x16 { mag: 589824, sign: false }); + /// data.append(FP16x16 { mag: 0, sign: false }); + /// data.append(FP16x16 { mag: 327680, sign: false }); + /// data.append(FP16x16 { mag: 262144, sign: false }); + /// data.append(FP16x16 { mag: 589824, sign: false }); + /// data.append(FP16x16 { mag: 327680, sign: false }); + /// data.append(FP16x16 { mag: 327680, sign: false }); + /// data.append(FP16x16 { mag: 589824, sign: false }); + /// data.append(FP16x16 { mag: 589824, sign: false }); + /// let rois = TensorTrait::new(shape.span(), data.span()); + /// + /// return NNTrait::roi_align( + /// @X, + /// @rois, + /// @batch_indices, + /// Option::Some(TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL), + /// Option::None, + /// Option::Some(2), + /// Option::Some(2), + /// Option::Some(FP16x16 { mag: 65536, sign: false }), + /// Option::Some(FP16x16 { mag: 32768, sign: false }), + /// ); + /// } + /// >>> [[[[0.2083422 , 0.44005 ], + /// [0.20385626, 0.39676717]]], + /// + /// + /// [[[0.09630001, 0.19375 ], + /// [0.3128 , 0.33335 ]]], + /// + /// + /// [[[0.4394 , 0.0653 ], + /// [0.4687 , 0.7109 ]]]] + /// + /// ```` + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor; } diff --git a/src/operators/nn/functional.cairo b/src/operators/nn/functional.cairo index 45e1c1ec9..7384b8064 100644 --- a/src/operators/nn/functional.cairo +++ b/src/operators/nn/functional.cairo @@ -16,3 +16,4 @@ mod conv_transpose; mod depth_to_space; mod space_to_depth; mod conv; +mod roi_align; diff --git a/src/operators/nn/functional/col2im.cairo b/src/operators/nn/functional/col2im.cairo index b08d9f650..465f65cfb 100644 --- a/src/operators/nn/functional/col2im.cairo +++ b/src/operators/nn/functional/col2im.cairo @@ -299,4 +299,4 @@ fn prod, +Copy, +NumberTrait, +TensorTrait, +Mul< }; prod -} \ No newline at end of file +} diff --git a/src/operators/nn/functional/conv.cairo b/src/operators/nn/functional/conv.cairo index ac72c336d..2000b0845 100644 --- a/src/operators/nn/functional/conv.cairo +++ b/src/operators/nn/functional/conv.cairo @@ -193,22 +193,23 @@ fn conv< let mut p = 0; let mut i = 0; - while i != res_b.len() { - let cv = *res_cv.at(i); - - let mut n = 0; - while n != cv.data.len() { - final.append(*cv.data.at(n)); - n += 1; - }; + while i != res_b + .len() { + let cv = *res_cv.at(i); + + let mut n = 0; + while n != cv.data.len() { + final.append(*cv.data.at(n)); + n += 1; + }; - p += *cv.shape.at(1); - if p >= td { - p = 0; - } + p += *cv.shape.at(1); + if p >= td { + p = 0; + } - i += 1; - }; + i += 1; + }; let final = final.span(); @@ -217,24 +218,32 @@ fn conv< let mut final_b: Array = array![]; let final_stride = stride(final_shape); let mut i = 0; - while i != *final_shape.at(0) { - let mut j = 0; - while j != B.len() { - let mut k = 0; - while k != *final_stride.at(1) { - final_b - .append( - *final.at(i * *final_stride.at(0) + j * *final_stride.at(1) + k) - + *B.at(j) - ); - k += 1; - }; + while i != *final_shape + .at(0) { + let mut j = 0; + while j != B + .len() { + let mut k = 0; + while k != *final_stride + .at(1) { + final_b + .append( + *final + .at( + i * *final_stride.at(0) + + j * *final_stride.at(1) + + k + ) + + *B.at(j) + ); + k += 1; + }; - j += 1; - }; + j += 1; + }; - i += 1; - }; + i += 1; + }; final_b.span() }, @@ -253,13 +262,14 @@ fn conv< new_shape.append_span(SpanTrait::slice((*W).shape, 0, (*W).shape.len() - nd)); let mut i = 0; - while i != dilations.len() { - let d = *dilations.at(i); - let di = (*W).shape.len() - nd + i; - new_shape.append(*(*W).shape.at(di) + (*(*W).shape.at(di) - 1) * (d - 1)); - new_kernel_shape.append(*kernel_shape.at(i) + (*kernel_shape.at(i) - 1) * (d - 1)); - i += 1; - }; + while i != dilations + .len() { + let d = *dilations.at(i); + let di = (*W).shape.len() - nd + i; + new_shape.append(*(*W).shape.at(di) + (*(*W).shape.at(di) - 1) * (d - 1)); + new_kernel_shape.append(*kernel_shape.at(i) + (*kernel_shape.at(i) - 1) * (d - 1)); + i += 1; + }; let new_shape = new_shape.span(); let new_w_strides = stride(new_shape); @@ -273,12 +283,13 @@ fn conv< indices.append(arange(0, *new_shape.at(1), 1)); let mut i = 0; - while i != dilations.len() { - let d = *dilations.at(i); - let di = (*W).shape.len() - nd + i; - indices.append(arange(0, *new_shape.at(di), d)); - i += 1; - }; + while i != dilations + .len() { + let d = *dilations.at(i); + let di = (*W).shape.len() - nd + i; + indices.append(arange(0, *new_shape.at(di), d)); + i += 1; + }; let set_of_all_indices = cartesian(indices.span()); @@ -286,29 +297,32 @@ fn conv< let mut i = 0; let mut prev = 0; - while i != (*W).data.len() { - let nd_index = *set_of_all_indices.at(i); - let mut flatten_index = 0; - let mut j = 0; - while j != nd_index.len() { - flatten_index += *nd_index.at(j) * *new_w_strides.at(j); - j += 1; - }; + while i != (*W) + .data + .len() { + let nd_index = *set_of_all_indices.at(i); + let mut flatten_index = 0; + let mut j = 0; + while j != nd_index + .len() { + flatten_index += *nd_index.at(j) * *new_w_strides.at(j); + j += 1; + }; - if flatten_index > prev + 1 { - let mut j = prev + 1; - while j != flatten_index { - new_w_arr.append(NumberTrait::zero()); - }; + if flatten_index > prev + 1 { + let mut j = prev + 1; + while j != flatten_index { + new_w_arr.append(NumberTrait::zero()); + }; - j += 1; - } + j += 1; + } - new_w_arr.append(*(*W).data.at(i)); - new_w.set(flatten_index, *(*W).data.at(i)); - prev = flatten_index; - i += 1; - }; + new_w_arr.append(*(*W).data.at(i)); + new_w.set(flatten_index, *(*W).data.at(i)); + prev = flatten_index; + i += 1; + }; } let pads = match auto_pad { @@ -425,42 +439,51 @@ fn conv< let w = SpanTrait::slice((*W).data, nw * sC * kh + c * kh, kh); let mut io = bh; - while io < eh.into() { - let hr = (io - bh) / sth.into(); - if hr < h_out.into() { - let i = io + (kh % 2).into(); - - let ih1 = I32Number::max(0, i + oh).into(); - let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); - let img = SpanTrait::slice((*X).data, n * sN + c * sC + ih1, ih2 - ih1); - - let s = if w.len() != img.len() { - let jh1 = I32Number::max(0, -i - oh).into(); - let jh2 = I32Number::min(sH.into() - (i + oh), kh.into()).into(); - - let w_ = SpanTrait::slice(w, jh1, jh2 - jh1); - assert(w_.len() == img.len(), 'unexpected w and img len'); - dot(img, w_) - } else { - dot(img, w) - }; + while io < eh + .into() { + let hr = (io - bh) / sth.into(); + if hr < h_out.into() { + let i = io + (kh % 2).into(); + + let ih1 = I32Number::max(0, i + oh).into(); + let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); + let img = SpanTrait::slice( + (*X).data, n * sN + c * sC + ih1, ih2 - ih1 + ); - let hr = if hr < 0 { - *res_strides.at(1) - hr.into() - } else { - hr.into() - }; + let s = if w.len() != img.len() { + let jh1 = I32Number::max(0, -i - oh).into(); + let jh2 = I32Number::min(sH.into() - (i + oh), kh.into()) + .into(); - res - .set( - n * *res_strides.at(0) + nw * *res_strides.at(1) + hr, - res.at(n * *res_strides.at(0) + nw * *res_strides.at(1) + hr) - + s - ); - } + let w_ = SpanTrait::slice(w, jh1, jh2 - jh1); + assert(w_.len() == img.len(), 'unexpected w and img len'); + dot(img, w_) + } else { + dot(img, w) + }; - io += sth.into(); - }; + let hr = if hr < 0 { + *res_strides.at(1) - hr.into() + } else { + hr.into() + }; + + res + .set( + n * *res_strides.at(0) + nw * *res_strides.at(1) + hr, + res + .at( + n * *res_strides.at(0) + + nw * *res_strides.at(1) + + hr + ) + + s + ); + } + + io += sth.into(); + }; c += 1; }; @@ -558,102 +581,114 @@ fn conv< ); let mut io = bh; - while io < eh.into() { - let hr = (io - bh) / sth.into(); - if hr < h_out.into() { - let i = io + (kh % 2).into(); - let ih1 = I32Number::max(0, i + oh).into(); - let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); - - let mut jo = bw; - while jo < ew.into() { - let wr = (jo - bw) / stw.into(); - if wr < w_out.into() { - let j = jo + (kw % 2).into(); - let iw1 = I32Number::max(0, j + ow).into(); - let iw2 = I32Number::min(j + ow + kw.into(), sW.into()).into(); - - let mut img: Array = array![]; - let mut ihi = ih1; - while ihi != ih2 { - img - .append_span( - SpanTrait::slice( - (*X).data, - n * (sC * sH * sW) - + c * (sH * sW) - + ihi * sW - + iw1, - iw2 - iw1 - ) - ); - ihi += 1; - }; + while io < eh + .into() { + let hr = (io - bh) / sth.into(); + if hr < h_out.into() { + let i = io + (kh % 2).into(); + let ih1 = I32Number::max(0, i + oh).into(); + let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); + + let mut jo = bw; + while jo < ew + .into() { + let wr = (jo - bw) / stw.into(); + if wr < w_out.into() { + let j = jo + (kw % 2).into(); + let iw1 = I32Number::max(0, j + ow).into(); + let iw2 = I32Number::min(j + ow + kw.into(), sW.into()) + .into(); - let img = img.span(); + let mut img: Array = array![]; + let mut ihi = ih1; + while ihi != ih2 { + img + .append_span( + SpanTrait::slice( + (*X).data, + n * (sC * sH * sW) + + c * (sH * sW) + + ihi * sW + + iw1, + iw2 - iw1 + ) + ); + ihi += 1; + }; - let s = if w.len() != img.len() { - let jh1 = I32Number::max(0, -i - oh).into(); - let jh2 = I32Number::min(sH.into() - (i + oh), kh.into()) - .into(); + let img = img.span(); - let jw1 = I32Number::max(0, -j - ow).into(); - let jw2 = I32Number::min(sW.into() - (j + ow), kw.into()) - .into(); + let s = if w.len() != img.len() { + let jh1 = I32Number::max(0, -i - oh).into(); + let jh2 = I32Number::min( + sH.into() - (i + oh), kh.into() + ) + .into(); - let mut w_: Array = array![]; - let mut jhj = jh1; - while jhj != jh2 { - w_ - .append_span( - SpanTrait::slice(w, jhj * kw + jw1, jw2 - jw1) - ); - jhj += 1; - }; + let jw1 = I32Number::max(0, -j - ow).into(); + let jw2 = I32Number::min( + sW.into() - (j + ow), kw.into() + ) + .into(); - let w_ = w_.span(); + let mut w_: Array = array![]; + let mut jhj = jh1; + while jhj != jh2 { + w_ + .append_span( + SpanTrait::slice( + w, jhj * kw + jw1, jw2 - jw1 + ) + ); + jhj += 1; + }; - assert(w_.len() == img.len(), 'unexpected w and img len'); - dot(img, w_) - } else { - dot(img, w) - }; + let w_ = w_.span(); - let hr = if hr < 0 { - h_out - hr.into() - } else { - hr.into() - }; + assert( + w_.len() == img.len(), + 'unexpected w and img len' + ); + dot(img, w_) + } else { + dot(img, w) + }; - let wr = if wr < 0 { - w_out - wr.into() - } else { - wr.into() - }; + let hr = if hr < 0 { + h_out - hr.into() + } else { + hr.into() + }; + + let wr = if wr < 0 { + w_out - wr.into() + } else { + wr.into() + }; - res - .set( - n * *res_strides.at(0) - + nw * *res_strides.at(1) - + hr * *res_strides.at(2) - + wr, res - .at( + .set( n * *res_strides.at(0) + nw * *res_strides.at(1) + hr * *res_strides.at(2) - + wr - ) - + s - ); - } + + wr, + res + .at( + n * *res_strides.at(0) + + nw * *res_strides.at(1) + + hr * *res_strides.at(2) + + wr + ) + + s + ); + } - jo += stw.into(); - }; - } + jo += stw.into(); + }; + } - io += sth.into(); - }; + io += sth.into(); + }; c += 1; }; @@ -767,151 +802,165 @@ fn conv< ); let mut io = bh; - while io < eh.into() { - let hr = (io - bh) / sth.into(); - if hr < h_out.into() { - let i = io + (kh % 2).into(); - let ih1 = I32Number::max(0, i + oh).into(); - let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); - - let mut jo = bw; - while jo < ew.into() { - let wr = (jo - bw) / stw.into(); - if wr < w_out.into() { - let j = jo + (kw % 2).into(); - let iw1 = I32Number::max(0, j + ow).into(); - let iw2 = I32Number::min(j + ow + kw.into(), sW.into()).into(); - - let mut zo = bz; - while zo < ez.into() { - let zr = (zo - bz) / stz.into(); - if zr < z_out.into() { - let z = zo + (kz % 2).into(); - let iz1 = I32Number::max(0, z + oz).into(); - let iz2 = I32Number::min(z + oz + kz.into(), sW.into()) + while io < eh + .into() { + let hr = (io - bh) / sth.into(); + if hr < h_out.into() { + let i = io + (kh % 2).into(); + let ih1 = I32Number::max(0, i + oh).into(); + let ih2 = I32Number::min(i + oh + kh.into(), sH.into()).into(); + + let mut jo = bw; + while jo < ew + .into() { + let wr = (jo - bw) / stw.into(); + if wr < w_out.into() { + let j = jo + (kw % 2).into(); + let iw1 = I32Number::max(0, j + ow).into(); + let iw2 = I32Number::min(j + ow + kw.into(), sW.into()) .into(); - let mut img: Array = array![]; - let mut ihi = ih1; - while ihi != ih2 { - let mut iwi = iw1; - while iwi != iw2 { - img - .append_span( - SpanTrait::slice( - (*X).data, - n * (sC * sH * sW * sZ) - + c * (sH * sW * sZ) - + ihi * (sW * sZ) - + iwi * sZ - + iz1, - iz2 - iz1 + let mut zo = bz; + while zo < ez + .into() { + let zr = (zo - bz) / stz.into(); + if zr < z_out.into() { + let z = zo + (kz % 2).into(); + let iz1 = I32Number::max(0, z + oz).into(); + let iz2 = I32Number::min( + z + oz + kz.into(), sW.into() + ) + .into(); + + let mut img: Array = array![]; + let mut ihi = ih1; + while ihi != ih2 { + let mut iwi = iw1; + while iwi != iw2 { + img + .append_span( + SpanTrait::slice( + (*X).data, + n * (sC * sH * sW * sZ) + + c * (sH * sW * sZ) + + ihi * (sW * sZ) + + iwi * sZ + + iz1, + iz2 - iz1 + ) + ); + iwi += 1; + }; + + ihi += 1; + }; + + let img = img.span(); + + let s = if w.len() != img.len() { + let jh1 = I32Number::max(0, -i - oh) + .into(); + let jh2 = I32Number::min( + sH.into() - (i + oh), kh.into() ) - ); - iwi += 1; - }; + .into(); - ihi += 1; - }; - - let img = img.span(); - - let s = if w.len() != img.len() { - let jh1 = I32Number::max(0, -i - oh).into(); - let jh2 = I32Number::min( - sH.into() - (i + oh), kh.into() - ) - .into(); - - let jw1 = I32Number::max(0, -j - ow).into(); - let jw2 = I32Number::min( - sW.into() - (j + ow), kw.into() - ) - .into(); - - let jz1 = I32Number::max(0, -z - oz).into(); - let jz2 = I32Number::min( - sZ.into() - (z + oz), kz.into() - ) - .into(); + let jw1 = I32Number::max(0, -j - ow) + .into(); + let jw2 = I32Number::min( + sW.into() - (j + ow), kw.into() + ) + .into(); - let mut w_: Array = array![]; - let mut jhj = jh1; - while jhj != jh2 { - let mut jwj = jw1; - while jwj != jw2 { - w_ - .append_span( - SpanTrait::slice( - w, - jhj * kw * kz + jwj * kz + jz1, - jz2 - jz1 - ) + let jz1 = I32Number::max(0, -z - oz) + .into(); + let jz2 = I32Number::min( + sZ.into() - (z + oz), kz.into() + ) + .into(); + + let mut w_: Array = array![]; + let mut jhj = jh1; + while jhj != jh2 { + let mut jwj = jw1; + while jwj != jw2 { + w_ + .append_span( + SpanTrait::slice( + w, + jhj * kw * kz + + jwj * kz + + jz1, + jz2 - jz1 + ) + ); + jwj += 1; + }; + + jhj += 1; + }; + + let w_ = w_.span(); + + assert( + w_.len() == img.len(), + 'unexpected w and img len' ); - jwj += 1; - }; + dot(img, w_) + } else { + dot(img, w) + }; + + let hr = if hr < 0 { + h_out - hr.into() + } else { + hr.into() + }; + + let wr = if wr < 0 { + w_out - wr.into() + } else { + wr.into() + }; + + let zr = if zr < 0 { + z_out - zr.into() + } else { + zr.into() + }; + + res + .set( + n * *res_strides.at(0) + + nw * *res_strides.at(1) + + hr * *res_strides.at(2) + + wr * *res_strides.at(3) + + zr, + res + .at( + n * *res_strides.at(0) + + nw + * *res_strides.at(1) + + hr + * *res_strides.at(2) + + wr + * *res_strides.at(3) + + zr + ) + + s + ); + } - jhj += 1; + zo += stz.into(); }; - - let w_ = w_.span(); - - assert( - w_.len() == img.len(), - 'unexpected w and img len' - ); - dot(img, w_) - } else { - dot(img, w) - }; - - let hr = if hr < 0 { - h_out - hr.into() - } else { - hr.into() - }; - - let wr = if wr < 0 { - w_out - wr.into() - } else { - wr.into() - }; - - let zr = if zr < 0 { - z_out - zr.into() - } else { - zr.into() - }; - - res - .set( - n * *res_strides.at(0) - + nw * *res_strides.at(1) - + hr * *res_strides.at(2) - + wr * *res_strides.at(3) - + zr, - res - .at( - n * *res_strides.at(0) - + nw * *res_strides.at(1) - + hr * *res_strides.at(2) - + wr * *res_strides.at(3) - + zr - ) - + s - ); } - zo += stz.into(); + jo += stw.into(); }; - } - - jo += stw.into(); - }; - } + } - io += sth.into(); - }; + io += sth.into(); + }; c += 1; }; @@ -990,10 +1039,11 @@ fn conv< while j != sM { let b_j = *B.at(j); let mut k = 0; - while k != *res_strides.at(1) { - res.set(i * *res_strides.at(0) + j * *res_strides.at(1) + k, b_j); - k += 1; - }; + while k != *res_strides + .at(1) { + res.set(i * *res_strides.at(0) + j * *res_strides.at(1) + k, b_j); + k += 1; + }; j += 1; }; @@ -1014,185 +1064,211 @@ fn conv< (*W).data, nw * *w_stride.at(0) + c * *w_stride.at(1), *w_stride.at(1) ); let mut i = 0; - while i != *range_len.at(0) * *range_stride.at(0) { - let mut io_index: Array = array![]; - let mut r_index: Array = array![]; - let mut flatten_index = i; - - let mut nx = 0; - while nx != nd { - let (n_index, rem) = DivRem::div_rem( - flatten_index, (*range_stride.at(nx)).try_into().unwrap() - ); - - flatten_index = rem; - io_index - .append(n_index.into() * (*strides.at(nx)).into() + *b_index.at(nx)); - r_index.append(n_index.into()); - nx += 1; - }; + while i != *range_len.at(0) + * *range_stride + .at(0) { + let mut io_index: Array = array![]; + let mut r_index: Array = array![]; + let mut flatten_index = i; - if r_index_check(r_index.span(), shape_out) { - let mut indices: Array = array![]; - let mut i1_index: Array = array![]; - let mut i2_index: Array = array![]; - let mut idiff_index: Array = array![]; - - let mut nx = 0; - while nx != nd { - indices.append(*io_index.at(nx) + (*kernel_shape.at(nx) % 2).into()); - i1_index - .append( - I32Number::max(0, *indices.at(nx) + *o_index.at(nx)).into() - ); - i2_index - .append( - I32Number::min( - (*(*X).shape.at(nx + 2)).into(), - *indices.at(nx) - + *o_index.at(nx) - + (*kernel_shape.at(nx)).into() - ) - .into() + let mut nx = 0; + while nx != nd { + let (n_index, rem) = DivRem::div_rem( + flatten_index, (*range_stride.at(nx)).try_into().unwrap() ); - if nx != nd - 1 { - idiff_index.append(*i2_index.at(nx) - *i1_index.at(nx)); - } - nx += 1; - }; - - let i1_index = i1_index.span(); - let mut img: Array = array![]; - - let img = if nx == 1 { - let img = SpanTrait::slice( - (*X).data, - n * sN + c * sC + *i1_index.at(nd - 1), - *i2_index.at(nd - 1) - *i1_index.at(nd - 1) - ); - img - } else { - let i_stride = stride(idiff_index.span()); + flatten_index = rem; + io_index + .append( + n_index.into() * (*strides.at(nx)).into() + *b_index.at(nx) + ); + r_index.append(n_index.into()); + nx += 1; + }; - let mut ii = 0; - while ii != *i_stride.at(0) * *idiff_index.at(0) { - let mut flatten_index = ii; - let mut start = n * *x_stride.at(0) + c * *x_stride.at(1); + if r_index_check(r_index.span(), shape_out) { + let mut indices: Array = array![]; + let mut i1_index: Array = array![]; + let mut i2_index: Array = array![]; + let mut idiff_index: Array = array![]; let mut nx = 0; - while nx != nd - 1 { - let (ii_index, rem) = DivRem::div_rem( - flatten_index, (*i_stride.at(nx)).try_into().unwrap() - ); - flatten_index = rem; + while nx != nd { + indices + .append( + *io_index.at(nx) + (*kernel_shape.at(nx) % 2).into() + ); + i1_index + .append( + I32Number::max(0, *indices.at(nx) + *o_index.at(nx)) + .into() + ); + i2_index + .append( + I32Number::min( + (*(*X).shape.at(nx + 2)).into(), + *indices.at(nx) + + *o_index.at(nx) + + (*kernel_shape.at(nx)).into() + ) + .into() + ); - start += (*i1_index.at(nx) + ii_index) * *x_stride.at(2 + nx); + if nx != nd - 1 { + idiff_index.append(*i2_index.at(nx) - *i1_index.at(nx)); + } nx += 1; }; - img - .append_span( - SpanTrait::slice( - (*X).data, - start + *i1_index.at(nd - 1), - *i2_index.at(nd - 1) - *i1_index.at(nd - 1) - ) - ); - ii += 1; - }; + let i1_index = i1_index.span(); + let mut img: Array = array![]; - img.span() - }; - - let s = if w.len() != img.len() { - let mut j1_index: Array = array![]; - let mut j2_index: Array = array![]; - let mut jdiff_index: Array = array![]; - - let mut nx = 0; - while nx != nd { - j1_index - .append( - I32Number::max(0, -*indices.at(nx) - *o_index.at(nx)).into() - ); - j2_index - .append( - I32Number::min( - (*(*X).shape.at(nx + 2)).into() - - *indices.at(nx) - - *o_index.at(nx), - (*kernel_shape.at(nx)).into() - ) - .into() + let img = if nx == 1 { + let img = SpanTrait::slice( + (*X).data, + n * sN + c * sC + *i1_index.at(nd - 1), + *i2_index.at(nd - 1) - *i1_index.at(nd - 1) ); - if nx != nd - 1 { - jdiff_index.append(*j2_index.at(nx) - *j1_index.at(nx)); - } - nx += 1; - }; + img + } else { + let i_stride = stride(idiff_index.span()); + + let mut ii = 0; + while ii != *i_stride.at(0) + * *idiff_index + .at(0) { + let mut flatten_index = ii; + let mut start = n * *x_stride.at(0) + + c * *x_stride.at(1); + + let mut nx = 0; + while nx != nd + - 1 { + let (ii_index, rem) = DivRem::div_rem( + flatten_index, + (*i_stride.at(nx)).try_into().unwrap() + ); + flatten_index = rem; - let j1_index = j1_index.span(); + start += (*i1_index.at(nx) + ii_index) + * *x_stride.at(2 + nx); + nx += 1; + }; - let mut w_: Array = array![]; + img + .append_span( + SpanTrait::slice( + (*X).data, + start + *i1_index.at(nd - 1), + *i2_index.at(nd - 1) + - *i1_index.at(nd - 1) + ) + ); + ii += 1; + }; - let w_ = if nx == 1 { - let w_ = SpanTrait::slice( - w, - *j1_index.at(nd - 1), - *j2_index.at(nd - 1) - *j1_index.at(nd - 1) - ); - w_ - } else { - let j_stride = stride(jdiff_index.span()); + img.span() + }; - let mut jj = 0; - while jj != *j_stride.at(0) * *jdiff_index.at(0) { - let mut flatten_index = jj; - let mut start = 0; + let s = if w.len() != img.len() { + let mut j1_index: Array = array![]; + let mut j2_index: Array = array![]; + let mut jdiff_index: Array = array![]; let mut nx = 0; - while nx != nd - 1 { - let (jj_index, rem) = DivRem::div_rem( - flatten_index, (*j_stride.at(nx)).try_into().unwrap() - ); - flatten_index = rem; - start += (*j1_index.at(nx) + jj_index) - * *kernel_shape.at(nx); + while nx != nd { + j1_index + .append( + I32Number::max( + 0, -*indices.at(nx) - *o_index.at(nx) + ) + .into() + ); + j2_index + .append( + I32Number::min( + (*(*X).shape.at(nx + 2)).into() + - *indices.at(nx) + - *o_index.at(nx), + (*kernel_shape.at(nx)).into() + ) + .into() + ); + if nx != nd - 1 { + jdiff_index.append(*j2_index.at(nx) - *j1_index.at(nx)); + } nx += 1; }; - w_ - .append_span( - SpanTrait::slice( - w, - start + *j1_index.at(nd - 1), - *j2_index.at(nd - 1) - *j1_index.at(nd - 1) - ) + + let j1_index = j1_index.span(); + + let mut w_: Array = array![]; + + let w_ = if nx == 1 { + let w_ = SpanTrait::slice( + w, + *j1_index.at(nd - 1), + *j2_index.at(nd - 1) - *j1_index.at(nd - 1) ); - jj += 1; - }; + w_ + } else { + let j_stride = stride(jdiff_index.span()); + + let mut jj = 0; + while jj != *j_stride.at(0) + * *jdiff_index + .at(0) { + let mut flatten_index = jj; + let mut start = 0; + + let mut nx = 0; + while nx != nd + - 1 { + let (jj_index, rem) = DivRem::div_rem( + flatten_index, + (*j_stride.at(nx)) + .try_into() + .unwrap() + ); + flatten_index = rem; + start += (*j1_index.at(nx) + jj_index) + * *kernel_shape.at(nx); + nx += 1; + }; + w_ + .append_span( + SpanTrait::slice( + w, + start + *j1_index.at(nd - 1), + *j2_index.at(nd - 1) + - *j1_index.at(nd - 1) + ) + ); + jj += 1; + }; - w_.span() - }; + w_.span() + }; - dot(img, w_) - } else { - dot(img, w) - }; + dot(img, w_) + } else { + dot(img, w) + }; - let mut res_index = n * *res_strides.at(0) + nw * *res_strides.at(1); + let mut res_index = n * *res_strides.at(0) + + nw * *res_strides.at(1); - let mut nx = 0; - while nx != nd { - res_index += (*r_index.at(nx)).into() * *res_strides.at(2 + nx); - nx += 1; - }; + let mut nx = 0; + while nx != nd { + res_index += (*r_index.at(nx)).into() * *res_strides.at(2 + nx); + nx += 1; + }; - res.set(res_index, res.at(res_index) + s); - }; + res.set(res_index, res.at(res_index) + s); + }; - i += 1 - }; + i += 1 + }; c += 1; }; @@ -1306,14 +1382,15 @@ fn cartesian(mut arrays: Span>,) -> Span> { let mut m = n; let mut i = 0; - while i != arrays.len() { - m = m / (*(arrays.at(i))).len(); - let mut out = repeat(*(arrays.at(i)), m); - out = repeat_2(out, size_arrays, i); + while i != arrays + .len() { + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); - output_arrays.append(out); - i += 1; - }; + output_arrays.append(out); + i += 1; + }; let output_arrays = output_arrays.span(); @@ -1339,15 +1416,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A let mut i = 0; while i != index { let mut j = 1; - while j != *size_array.at(index - 1 - i) { - let mut k = 0; - while k != size { - array.append(*array.at(k)); - k += 1; - }; + while j != *size_array + .at(index - 1 - i) { + let mut k = 0; + while k != size { + array.append(*array.at(k)); + k += 1; + }; - j += 1; - }; + j += 1; + }; size = size * *size_array.at(index - 1 - i); i += 1; @@ -1359,15 +1437,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A fn repeat(array: Span, m: usize,) -> Array { let mut out: Array = array![]; let mut j = 0; - while j != array.len() { - let mut k = 0; - while k != m { - out.append(*array.at(j)); - k += 1; - }; + while j != array + .len() { + let mut k = 0; + while k != m { + out.append(*array.at(j)); + k += 1; + }; - j += 1; - }; + j += 1; + }; out } diff --git a/src/operators/nn/functional/grid_sample.cairo b/src/operators/nn/functional/grid_sample.cairo index aed560e37..909065bfa 100644 --- a/src/operators/nn/functional/grid_sample.cairo +++ b/src/operators/nn/functional/grid_sample.cairo @@ -99,68 +99,70 @@ fn grid_sample< let all_coords = get_all_coords(SpanTrait::slice(grid_dims, 1, grid_dims.len() - 2)); let mut ix = 0; - while ix != all_coords.len() { - let ox = *all_coords.at(ix); - let nx = get_sub(grid_data, grid_data_stride, ox); - let nx = reverse(nx); - let x = gs_denormalize_coordinates(nx, dims, align_corner); - - let x = match mode { - MODE::NEAREST => { rint(x) }, - MODE::LINEAR => { x }, - MODE::CUBIC => { x }, - }; + while ix != all_coords + .len() { + let ox = *all_coords.at(ix); + let nx = get_sub(grid_data, grid_data_stride, ox); + let nx = reverse(nx); + let x = gs_denormalize_coordinates(nx, dims, align_corner); + + let x = match mode { + MODE::NEAREST => { rint(x) }, + MODE::LINEAR => { x }, + MODE::CUBIC => { x }, + }; - let mut new_x: Array = array![]; - let mut i = 0; - while i != x.len() { - let v = *x.at(i); - let mut x_min = *border.at(i); - let mut x_max = *border.at(i + num_dims); - let new_v = if v < x_min || v > x_max { - let v = match padding_mode { - PADDING_MODE::ZEROS => { v }, - PADDING_MODE::BORDER => { - clamp( - v, - NumberTrait::zero(), - NumberTrait::new_unscaled((*dims.at(i)).into(), false) - - NumberTrait::one() - ) - }, - PADDING_MODE::REFLECTION => { gs_reflect(v, x_min, x_max) }, + let mut new_x: Array = array![]; + let mut i = 0; + while i != x + .len() { + let v = *x.at(i); + let mut x_min = *border.at(i); + let mut x_max = *border.at(i + num_dims); + let new_v = if v < x_min || v > x_max { + let v = match padding_mode { + PADDING_MODE::ZEROS => { v }, + PADDING_MODE::BORDER => { + clamp( + v, + NumberTrait::zero(), + NumberTrait::new_unscaled((*dims.at(i)).into(), false) + - NumberTrait::one() + ) + }, + PADDING_MODE::REFLECTION => { gs_reflect(v, x_min, x_max) }, + }; + v + } else { + v + }; + + new_x.append(new_v); + i += 1; }; - v - } else { - v - }; - new_x.append(new_v); - i += 1; - }; + let x = new_x.span(); + + let y = match mode { + MODE::NEAREST => { + pixel_at_ndarray(X_data, dims, X_data_stride, x, border, padding_mode) + }, + MODE::LINEAR => { + gs_linear_interpolation_nd_with_x( + X_data, dims, X_data_stride, x, border, padding_mode + ) + }, + MODE::CUBIC => { + gs_cubic_interpolation_nd_with_x( + X_data, dims, X_data_stride, x, border, padding_mode + ) + }, + }; - let x = new_x.span(); - - let y = match mode { - MODE::NEAREST => { - pixel_at_ndarray(X_data, dims, X_data_stride, x, border, padding_mode) - }, - MODE::LINEAR => { - gs_linear_interpolation_nd_with_x( - X_data, dims, X_data_stride, x, border, padding_mode - ) - }, - MODE::CUBIC => { - gs_cubic_interpolation_nd_with_x( - X_data, dims, X_data_stride, x, border, padding_mode - ) - }, + Y.append(y); + ix += 1; }; - Y.append(y); - ix += 1; - }; - c += 1; }; @@ -288,26 +290,27 @@ fn gs_cubic_interpolation_nd_with_x< let mut res1d: Array = array![]; let mut i = 0; - while i != *data_dims.at(0) { - let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); - let sub_x = SpanTrait::slice(x, 1, x.len() - 1); - - let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); - let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); - - let border1 = SpanTrait::slice(border, 1, num_dims - 1); - let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); - let mut border = ArrayTrait::new(); - border.append_span(border1); - border.append_span(border2); - - let r = gs_cubic_interpolation_nd_with_x( - sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode - ); + while i != *data_dims + .at(0) { + let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); + let sub_x = SpanTrait::slice(x, 1, x.len() - 1); + + let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); + let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); + + let border1 = SpanTrait::slice(border, 1, num_dims - 1); + let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); + let mut border = ArrayTrait::new(); + border.append_span(border1); + border.append_span(border2); + + let r = gs_cubic_interpolation_nd_with_x( + sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode + ); - res1d.append(r); - i += 1; - }; + res1d.append(r); + i += 1; + }; gs_cubic_interpolation_1d_with_x( res1d.span(), *x.at(0), array![*border.at(0), *border.at(num_dims)].span(), padding_mode @@ -408,26 +411,27 @@ fn gs_linear_interpolation_nd_with_x< let mut res1d: Array = array![]; let mut i = 0; - while i != *data_dims.at(0) { - let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); - let sub_x = SpanTrait::slice(x, 1, x.len() - 1); - - let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); - let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); - - let border1 = SpanTrait::slice(border, 1, num_dims - 1); - let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); - let mut border = ArrayTrait::new(); - border.append_span(border1); - border.append_span(border2); - - let r = gs_linear_interpolation_nd_with_x( - sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode - ); + while i != *data_dims + .at(0) { + let sub_data = SpanTrait::slice(data, i * *data_stride.at(0), *data_stride.at(0)); + let sub_x = SpanTrait::slice(x, 1, x.len() - 1); + + let data_dims_sub = SpanTrait::slice(data_dims, 1, data_dims.len() - 1); + let data_stride_sub = SpanTrait::slice(data_stride, 1, data_stride.len() - 1); + + let border1 = SpanTrait::slice(border, 1, num_dims - 1); + let border2 = SpanTrait::slice(border, num_dims + 1, num_dims - 1); + let mut border = ArrayTrait::new(); + border.append_span(border1); + border.append_span(border2); + + let r = gs_linear_interpolation_nd_with_x( + sub_data, data_dims_sub, data_stride_sub, sub_x, border.span(), padding_mode + ); - res1d.append(r); - i += 1; - }; + res1d.append(r); + i += 1; + }; gs_linear_interpolation_1d_with_x( res1d.span(), *x.at(0), array![*border.at(0), *border.at(num_dims)].span(), padding_mode @@ -586,20 +590,21 @@ fn rint< let two: T = NumberTrait::one() + NumberTrait::one(); let mut i = 0; - while i != data.len() { - let x = *data.at(i); - let mut round = NumberTrait::round(x); - - let diff = round - x; - if diff == NumberTrait::half() { - if round % two != NumberTrait::zero() { - round -= NumberTrait::one() + while i != data + .len() { + let x = *data.at(i); + let mut round = NumberTrait::round(x); + + let diff = round - x; + if diff == NumberTrait::half() { + if round % two != NumberTrait::zero() { + round -= NumberTrait::one() + } } - } - rint.append(round); - i += 1; - }; + rint.append(round); + i += 1; + }; rint.span() } @@ -779,12 +784,13 @@ fn gs_denormalize_coordinates< let mut x: Array = array![]; let mut i = 0; - while i != n.len() { - let v = *n.at(i); - let dim = *dims.at(i); - x.append(gs_denormalize(v, dim, align_corner)); - i += 1; - }; + while i != n + .len() { + let v = *n.at(i); + let dim = *dims.at(i); + x.append(gs_denormalize(v, dim, align_corner)); + i += 1; + }; x.span() } @@ -851,14 +857,15 @@ fn cartesian(mut arrays: Span>,) -> Span> { let mut m = n; let mut i = 0; - while i != arrays.len() { - m = m / (*(arrays.at(i))).len(); - let mut out = repeat(*(arrays.at(i)), m); - out = repeat_2(out, size_arrays, i); - - output_arrays.append(out); - i += 1; - }; + while i != arrays + .len() { + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); + + output_arrays.append(out); + i += 1; + }; let output_arrays = output_arrays.span(); @@ -884,15 +891,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A let mut i = 0; while i != index { let mut j = 1; - while j != *size_array.at(index - 1 - i) { - let mut k = 0; - while k != size { - array.append(*array.at(k)); - k += 1; - }; + while j != *size_array + .at(index - 1 - i) { + let mut k = 0; + while k != size { + array.append(*array.at(k)); + k += 1; + }; - j += 1; - }; + j += 1; + }; size = size * *size_array.at(index - 1 - i); i += 1; @@ -904,15 +912,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A fn repeat(array: Span, m: usize,) -> Array { let mut out: Array = array![]; let mut j = 0; - while j != array.len() { - let mut k = 0; - while k != m { - out.append(*array.at(j)); - k += 1; - }; + while j != array + .len() { + let mut k = 0; + while k != m { + out.append(*array.at(j)); + k += 1; + }; - j += 1; - }; + j += 1; + }; out } diff --git a/src/operators/nn/functional/roi_align.cairo b/src/operators/nn/functional/roi_align.cairo new file mode 100644 index 000000000..1f82e7c99 --- /dev/null +++ b/src/operators/nn/functional/roi_align.cairo @@ -0,0 +1,450 @@ +use orion::numbers::NumberTrait; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[derive(Copy, Drop, Destruct)] +struct PreCalc { + pos1: usize, + pos2: usize, + pos3: usize, + pos4: usize, + w1: T, + w2: T, + w3: T, + w4: T, +} + +#[derive(Copy, Drop)] +enum MODE { + AVG, + MAX, +} + +#[derive(Copy, Drop)] +enum TRANSFORMATION_MODE { + HALF_PIXEL, + OUTPUT_HALF_PIXEL, +} + +/// Cf: TensorTrait::roi_align docstring +fn roi_align< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, + +Neg, + +DivEq, +>( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, +) -> Tensor { + let num_channels = *(*X).shape.at(1); + let num_rois = *(*batch_indices).shape.at(0); + let num_roi_cols = *(*roi).shape.at(1); + + let output_height = match output_height { + Option::Some(output_height) => output_height, + Option::None => 1, + }; + + let output_width = match output_width { + Option::Some(output_width) => output_width, + Option::None => 1, + }; + + let coordinate_transformation_mode = match coordinate_transformation_mode { + Option::Some(coordinate_transformation_mode) => coordinate_transformation_mode, + Option::None => TRANSFORMATION_MODE::HALF_PIXEL, + }; + + let half_pixel = match coordinate_transformation_mode { + TRANSFORMATION_MODE::HALF_PIXEL => true, + TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL => false, + }; + + let y_dims = array![num_rois, num_channels, output_height, output_width].span(); + + let y_data = roi_align_forward( + y_dims, + (*X).data, + spatial_scale, + *(*X).shape.at(2), + *(*X).shape.at(3), + sampling_ratio, + (*roi).data, + num_roi_cols, + mode, + half_pixel, + (*batch_indices).data + ); + + return TensorTrait::new(y_dims, y_data); +} + + +fn roi_align_forward< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, + +Neg, + +DivEq, +>( + output_shape: Span, + bottom_data: Span, + spatial_scale: Option, + height: usize, + width: usize, + sampling_ratio: Option, + bottom_rois: Span, + num_roi_cols: usize, + mode: Option, + half_pixel: bool, + batch_indices_ptr: Span, +) -> Span { + let n_rois = *output_shape.at(0); + let channels = *output_shape.at(1); + let pooled_height = *output_shape.at(2); + let pooled_width = *output_shape.at(3); + + let mut top_data = ArrayTrait::new(); + + let spatial_scale = match spatial_scale { + Option::Some(spatial_scale) => spatial_scale, + Option::None => NumberTrait::one(), + }; + + let sampling_ratio = match sampling_ratio { + Option::Some(sampling_ratio) => sampling_ratio, + Option::None => NumberTrait::zero(), + }; + + let mode = match mode { + Option::Some(mode) => mode, + Option::None => MODE::AVG, + }; + + let mut n = 0; + + while n != n_rois { + let offset_bottom_rois = n * num_roi_cols; + let roi_batch_ind = *batch_indices_ptr.at(n); + + let offset: T = if half_pixel { + NumberTrait::half() + } else { + NumberTrait::zero() + }; + + let roi_start_w = *bottom_rois.at(offset_bottom_rois + 0) * spatial_scale - offset; + let roi_start_h = *bottom_rois.at(offset_bottom_rois + 1) * spatial_scale - offset; + let roi_end_w = *bottom_rois.at(offset_bottom_rois + 2) * spatial_scale - offset; + let roi_end_h = *bottom_rois.at(offset_bottom_rois + 3) * spatial_scale - offset; + + let mut roi_width = roi_end_w - roi_start_w; + let mut roi_height = roi_end_h - roi_start_h; + + if !half_pixel { + roi_width = NumberTrait::max(roi_width, NumberTrait::one()); + roi_height = NumberTrait::max(roi_height, NumberTrait::one()); + } + + let bin_size_h = roi_height / NumberTrait::new_unscaled(pooled_height.into(), false); + let bin_size_w = roi_width / NumberTrait::new_unscaled(pooled_width.into(), false); + + let roi_bin_grid_h: usize = if sampling_ratio > NumberTrait::zero() { + sampling_ratio.try_into().unwrap() + } else { + NumberTrait::< + T + >::ceil(roi_height / NumberTrait::new_unscaled(pooled_height.into(), false)) + .try_into() + .unwrap() + }; + + let roi_bin_grid_w: usize = if sampling_ratio > NumberTrait::zero() { + sampling_ratio.try_into().unwrap() + } else { + NumberTrait::< + T + >::ceil(roi_width / NumberTrait::new_unscaled(pooled_width.into(), false)) + .try_into() + .unwrap() + }; + + let count = NumberTrait::max(roi_bin_grid_h * roi_bin_grid_w, 1); + let mut pre_calc = ArrayTrait::new(); + let pre_calc = pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + ref pre_calc, + ); + + let mut c = 0; + while c != channels { + let offset_bottom_data = (roi_batch_ind * channels + c) * height * width; + + let mut pre_calc_index = 0; + let mut ph = 0; + while ph != pooled_height { + let mut pw = 0; + while pw != pooled_width { + let mut output_val = NumberTrait::zero(); + + match mode { + MODE::AVG => { + let mut _iy = 0; + while _iy != roi_bin_grid_h { + let mut _ix = 0; + while _ix != roi_bin_grid_w { + let pc = *pre_calc.at(pre_calc_index); + + output_val += + (pc.w1 * *bottom_data.at(offset_bottom_data + pc.pos1) + + pc.w2 * *bottom_data.at(offset_bottom_data + pc.pos2) + + pc.w3 * *bottom_data.at(offset_bottom_data + pc.pos3) + + pc.w4 + * *bottom_data.at(offset_bottom_data + pc.pos4)); + pre_calc_index += 1; + _ix += 1; + }; + + _iy += 1; + }; + + output_val /= NumberTrait::new_unscaled(count.into(), false); + top_data.append(output_val); + }, + MODE::MAX => { + let mut max_flag = false; + let mut _iy = 0; + while _iy != roi_bin_grid_h { + let mut _ix = 0; + let mut val = NumberTrait::zero(); + + while _ix != roi_bin_grid_w { + let pc = *pre_calc.at(pre_calc_index); + + val = + NumberTrait::max( + NumberTrait::max( + pc.w1 + * *bottom_data.at(offset_bottom_data + pc.pos1), + pc.w2 + * *bottom_data.at(offset_bottom_data + pc.pos2) + ), + NumberTrait::max( + pc.w3 + * *bottom_data.at(offset_bottom_data + pc.pos3), + pc.w4 + * *bottom_data.at(offset_bottom_data + pc.pos4) + ) + ); + if !max_flag { + output_val = val; + max_flag = true; + } else { + output_val = NumberTrait::max(output_val, val); + } + pre_calc_index += 1; + + _ix += 1; + }; + _iy += 1; + }; + top_data.append(output_val); + } + } + pw += 1; + }; + ph += 1; + }; + + c += 1; + }; + n += 1; + }; + + return top_data.span(); +} + + +fn pre_calc_for_bilinear_interpolate< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, + +Neg, +>( + height: usize, + width: usize, + pooled_height: usize, + pooled_width: usize, + iy_upper: usize, + ix_upper: usize, + roi_start_h: T, + roi_start_w: T, + bin_size_h: T, + bin_size_w: T, + roi_bin_grid_h: usize, + roi_bin_grid_w: usize, + ref pre_calc: Array> +) -> Span> { + let mut pre_calc_index = 0; + + let roi_bin_grid_h = NumberTrait::new_unscaled(roi_bin_grid_h.into(), false); + let roi_bin_grid_w = NumberTrait::new_unscaled(roi_bin_grid_w.into(), false); + + let height = NumberTrait::new_unscaled(height.into(), false); + let width = NumberTrait::new_unscaled(width.into(), false); + + let mut ph = 0; + while ph != pooled_height { + let mut pw = 0; + while pw != pooled_width { + let mut iy: usize = 0; + while iy != iy_upper { + let yy = roi_start_h + + NumberTrait::new_unscaled(ph.into(), false) * bin_size_h + + (NumberTrait::new_unscaled(iy.into(), false) + NumberTrait::half()) + * bin_size_h + / roi_bin_grid_h; + + let mut ix = 0; + while ix != ix_upper { + let xx = roi_start_w + + NumberTrait::new_unscaled(pw.into(), false) * bin_size_w + + (NumberTrait::new_unscaled(ix.into(), false) + NumberTrait::half()) + * bin_size_w + / roi_bin_grid_w; + let mut x: T = xx; + let mut y: T = yy; + + if y < -NumberTrait::one() + || y > height + || x < -NumberTrait::one() + || x > width { + let pc: PreCalc = PreCalc { + pos1: 0, + pos2: 0, + pos3: 0, + pos4: 0, + w1: NumberTrait::zero(), + w2: NumberTrait::zero(), + w3: NumberTrait::zero(), + w4: NumberTrait::zero() + }; + pre_calc.append(pc); + pre_calc_index += 1; + } else { + y = NumberTrait::max(y, NumberTrait::zero()); + x = NumberTrait::max(x, NumberTrait::zero()); + + let mut y_low = NumberTrait::floor(y); + let mut x_low = NumberTrait::floor(x); + + let mut y_high = y_low; + let mut x_high = x_low; + + if y_low >= height - NumberTrait::one() { + y_low = height - NumberTrait::one(); + y_high = y_low; + y = y_low; + } else { + y_high += NumberTrait::one(); + }; + + if x_low >= width - NumberTrait::one() { + x_low = width - NumberTrait::one(); + x_high = x_low; + x = x_low; + } else { + x_high += NumberTrait::one(); + }; + + let ly = y - y_low; + let lx = x - x_low; + let hy = NumberTrait::one() - ly; + let hx = NumberTrait::one() - lx; + let w1 = hy * hx; + let w2 = hy * lx; + let w3 = ly * hx; + let w4 = ly * lx; + + let pc: PreCalc = PreCalc { + pos1: (y_low * width + x_low).try_into().unwrap(), + pos2: (y_low * width + x_high).try_into().unwrap(), + pos3: (y_high * width + x_low).try_into().unwrap(), + pos4: (y_high * width + x_high).try_into().unwrap(), + w1: w1, + w2: w2, + w3: w3, + w4: w4 + }; + pre_calc.append(pc); + + pre_calc_index += 1; + } + + ix += 1; + }; + + iy += 1; + }; + pw += 1; + }; + ph += 1; + }; + + return pre_calc.span(); +} + diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index cd65d0fd0..444a9a102 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -143,4 +143,30 @@ impl FP16x16NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + functional::roi_align::roi_align( + X, + roi, + batch_indices, + coordinate_transformation_mode, + mode, + output_height, + output_width, + sampling_ratio, + spatial_scale, + ) + } } diff --git a/src/operators/nn/implementations/nn_fp32x32.cairo b/src/operators/nn/implementations/nn_fp32x32.cairo index 4baa425b7..247d82a63 100644 --- a/src/operators/nn/implementations/nn_fp32x32.cairo +++ b/src/operators/nn/implementations/nn_fp32x32.cairo @@ -137,4 +137,30 @@ impl FP32x32NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + functional::roi_align::roi_align( + X, + roi, + batch_indices, + coordinate_transformation_mode, + mode, + output_height, + output_width, + sampling_ratio, + spatial_scale, + ) + } } diff --git a/src/operators/nn/implementations/nn_fp64x64.cairo b/src/operators/nn/implementations/nn_fp64x64.cairo index 588387700..5dba37875 100644 --- a/src/operators/nn/implementations/nn_fp64x64.cairo +++ b/src/operators/nn/implementations/nn_fp64x64.cairo @@ -137,4 +137,30 @@ impl FP64x64NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + functional::roi_align::roi_align( + X, + roi, + batch_indices, + coordinate_transformation_mode, + mode, + output_height, + output_width, + sampling_ratio, + spatial_scale, + ) + } } diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index bb8d73436..d4e7f624e 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -139,4 +139,28 @@ impl FP8x23NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + functional::roi_align::roi_align( + X, + roi, + batch_indices, + coordinate_transformation_mode, + mode, + output_height, + output_width, + sampling_ratio, + spatial_scale, + ) + } } diff --git a/src/operators/nn/implementations/nn_i32.cairo b/src/operators/nn/implementations/nn_i32.cairo index 49200fe0b..81e887d87 100644 --- a/src/operators/nn/implementations/nn_i32.cairo +++ b/src/operators/nn/implementations/nn_i32.cairo @@ -130,4 +130,20 @@ impl I32NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_i8.cairo b/src/operators/nn/implementations/nn_i8.cairo index f481fdc80..b50b9b3d7 100644 --- a/src/operators/nn/implementations/nn_i8.cairo +++ b/src/operators/nn/implementations/nn_i8.cairo @@ -130,4 +130,20 @@ impl I8NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_u32.cairo b/src/operators/nn/implementations/nn_u32.cairo index ec0d6dedc..7ec5501ad 100644 --- a/src/operators/nn/implementations/nn_u32.cairo +++ b/src/operators/nn/implementations/nn_u32.cairo @@ -130,4 +130,20 @@ impl U32NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn roi_align( + X: @Tensor, + roi: @Tensor, + batch_indices: @Tensor, + coordinate_transformation_mode: Option< + orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE + >, + mode: Option, + output_height: Option, + output_width: Option, + sampling_ratio: Option, + spatial_scale: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 02f9cc6e4..6f2eef605 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -683,7 +683,7 @@ trait TensorTrait { axes: Option>, keepdims: Option, noop_with_empty_axes: Option - ) -> Tensor; + ) -> Tensor; /// # tensor.argmax /// /// ```rust diff --git a/src/operators/tensor/helpers.cairo b/src/operators/tensor/helpers.cairo index 550ff45c5..8482e5cdf 100644 --- a/src/operators/tensor/helpers.cairo +++ b/src/operators/tensor/helpers.cairo @@ -52,32 +52,33 @@ fn check_compatibility(mut shape_1: Span, mut shape_2: Span) { let mut iter_2 = shape_2.len(); // Iterate while there are dimensions left in either shape - while iter_1 > 0 || iter_2 > 0 { - // Get the current dimension for each shape, defaulting to 1 if we've run out of dimensions - let dim_1 = if iter_1 > 0 { - *shape_1[iter_1 - 1] - } else { - 1 - }; - let dim_2 = if iter_2 > 0 { - *shape_2[iter_2 - 1] - } else { - 1 - }; + while iter_1 > 0 + || iter_2 > 0 { + // Get the current dimension for each shape, defaulting to 1 if we've run out of dimensions + let dim_1 = if iter_1 > 0 { + *shape_1[iter_1 - 1] + } else { + 1 + }; + let dim_2 = if iter_2 > 0 { + *shape_2[iter_2 - 1] + } else { + 1 + }; - // Check the broadcasting rule for the current dimension - if dim_1 != dim_2 && dim_1 != 1 && dim_2 != 1 { - panic(array!['tensors shape must match']); - } + // Check the broadcasting rule for the current dimension + if dim_1 != dim_2 && dim_1 != 1 && dim_2 != 1 { + panic(array!['tensors shape must match']); + } - // Move to the next dimension - if iter_1 > 0 { - iter_1 -= 1; - } - if iter_2 > 0 { - iter_2 -= 1; + // Move to the next dimension + if iter_1 > 0 { + iter_1 -= 1; + } + if iter_2 > 0 { + iter_2 -= 1; + } } - } } /// Computes the index in the broadcasted tensor corresponding to the given indices and shape. @@ -250,17 +251,18 @@ fn combine_indices(mut output_indices: Span, axis_index: usize, axis: usi let mut result: Array = array![]; let mut n: usize = 0; - while n != output_indices.len() + 1 { - if n == axis { - result.append(axis_index); - } else if n > axis { - result.append(*output_indices[n - 1_usize]); - } else { - result.append(*output_indices[n]); - } + while n != output_indices.len() + + 1 { + if n == axis { + result.append(axis_index); + } else if n > axis { + result.append(*output_indices[n - 1_usize]); + } else { + result.append(*output_indices[n]); + } - n += 1; - }; + n += 1; + }; result.span() } @@ -313,13 +315,15 @@ fn broadcast_shape(mut shape1: Span, mut shape2: Span) -> Span = array![]; - while !shape1.is_empty() || !shape2.is_empty() { - let dim1 = *shape1.pop_back().unwrap_or(@1); - let dim2 = *shape2.pop_back().unwrap_or(@1); + while !shape1.is_empty() + || !shape2 + .is_empty() { + let dim1 = *shape1.pop_back().unwrap_or(@1); + let dim2 = *shape2.pop_back().unwrap_or(@1); - let broadcasted_dim = u32_max(dim1, dim2); - result.append(broadcasted_dim); - }; + let broadcasted_dim = u32_max(dim1, dim2); + result.append(broadcasted_dim); + }; result.reverse().span() } diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 8acb0891e..24e64c923 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -89,10 +89,7 @@ impl Complex64Tensor of TensorTrait { } fn argmax( - self: @Tensor, - axis: i32, - keepdims: Option, - select_last_index: Option + self: @Tensor, axis: i32, keepdims: Option, select_last_index: Option ) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 27f853df5..58b36a4bc 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -71,7 +71,9 @@ impl FP16x16Tensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -353,9 +355,7 @@ impl FP16x16Tensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -765,17 +765,19 @@ fn relative_eq(lhs: @FP16x16, rhs: @FP16x16) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 61485bae6..89eb2d09b 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -75,7 +75,9 @@ impl FP16x16WTensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -93,10 +95,7 @@ impl FP16x16WTensor of TensorTrait { } fn argmax( - self: @Tensor, - axis: i32, - keepdims: Option, - select_last_index: Option + self: @Tensor, axis: i32, keepdims: Option, select_last_index: Option ) -> Tensor { math::argmax::argmax(self, axis, keepdims, select_last_index) } @@ -724,17 +723,19 @@ fn relative_eq(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 6ea3c7d94..942680b81 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -68,7 +68,9 @@ impl FP32x32Tensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -350,9 +352,7 @@ impl FP32x32Tensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -771,17 +771,19 @@ fn relative_eq(lhs: @FP32x32, rhs: @FP32x32) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index af955fff1..be1cd04c3 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -68,7 +68,9 @@ impl FP64x64Tensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -350,9 +352,7 @@ impl FP64x64Tensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -771,17 +771,19 @@ fn relative_eq(lhs: @FP64x64, rhs: @FP64x64) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.shape.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index ef65871d4..bee949f2c 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -71,7 +71,9 @@ impl FP8x23WTensor of TensorTrait { unravel_index(index, *self.shape) } - fn reshape(self: @Tensor, target_shape: Span, allowzero: bool) -> Tensor { + fn reshape( + self: @Tensor, target_shape: Span, allowzero: bool + ) -> Tensor { reshape(self, target_shape, allowzero) } @@ -304,9 +306,7 @@ impl FP8x23WTensor of TensorTrait { core_tensor::slice::(self, starts, ends, axes, steps) } - fn gather( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather::gather(self, indices, axis) } @@ -725,17 +725,19 @@ fn relative_eq(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 924a6b1fd..4caa2ad24 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -435,9 +435,7 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_elements( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather_elements(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather_elements::gather_elements(self, indices, axis) } @@ -716,17 +714,19 @@ impl I32TensorPartialOrd of PartialOrd> { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index f523c47b2..7c7770a77 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -438,9 +438,7 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_elements( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather_elements(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather_elements::gather_elements(self, indices, axis) } @@ -707,17 +705,19 @@ impl I8TensorPartialOrd of PartialOrd> { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() == 0 && !is_eq { - is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); - }; + while lhs.data.len() == 0 + && !is_eq { + is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); + }; is_eq } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 7aa2ade26..1ad0ef8de 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -382,9 +382,7 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_elements( - self: @Tensor, indices: Tensor, axis: Option - ) -> Tensor { + fn gather_elements(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { math::gather_elements::gather_elements(self, indices, axis) } @@ -661,17 +659,19 @@ impl U32TensorPartialOrd of PartialOrd> { fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { let mut is_eq = true; - while lhs.shape.len() != 0 && is_eq { - is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); - }; + while lhs.shape.len() != 0 + && is_eq { + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; if !is_eq { return false; } - while lhs.data.len() != 0 && is_eq { - is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); - }; + while lhs.data.len() != 0 + && is_eq { + is_eq = lhs.data.pop_front().unwrap() == rhs.data.pop_front().unwrap(); + }; is_eq } diff --git a/src/operators/tensor/linalg/transpose.cairo b/src/operators/tensor/linalg/transpose.cairo index 97ad240b4..b1f4381c4 100644 --- a/src/operators/tensor/linalg/transpose.cairo +++ b/src/operators/tensor/linalg/transpose.cairo @@ -29,12 +29,13 @@ fn transpose, impl TCopy: Copy, impl TDrop: D let mut input_indices: Array = array![]; let mut output_axis: usize = 0; - while output_axis != axes.len() { - let input_axis = find_axis(axes, output_axis); - input_indices.append(*output_indices[input_axis]); + while output_axis != axes + .len() { + let input_axis = find_axis(axes, output_axis); + input_indices.append(*output_indices[input_axis]); - output_axis += 1; - }; + output_axis += 1; + }; let input_index = ravel_index(*self.shape, input_indices.span()); output_data.append(*(*self.data)[input_index]); diff --git a/src/operators/tensor/manipulation/split.cairo b/src/operators/tensor/manipulation/split.cairo index a8036f219..fb8621614 100644 --- a/src/operators/tensor/manipulation/split.cairo +++ b/src/operators/tensor/manipulation/split.cairo @@ -69,42 +69,45 @@ fn split_num_outputs, +Drop, +TensorTrait,>( let mut sli: MutMatrix = MutMatrixImpl::new((*t).shape.len(), 2); let mut pos: usize = 0; let mut i = 0; - while i != (*t).shape.len() { - let s: usize = *(*t).shape.at(i); - sli.set(i, 0, 0); - sli.set(i, 1, s); - i += 1; - }; + while i != (*t) + .shape + .len() { + let s: usize = *(*t).shape.at(i); + sli.set(i, 0, 0); + sli.set(i, 1, s); + i += 1; + }; let mut i: usize = 0; - while i != split.len() { - let spl = *split.at(i); - sli.set(axis, 0, pos); - pos += spl; - sli.set(axis, 1, pos); + while i != split + .len() { + let spl = *split.at(i); + sli.set(axis, 0, pos); + pos += spl; + sli.set(axis, 1, pos); - let end_ele_0 = match sli.get(axis, 0) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, - }; - let end_ele_1 = match sli.get(axis, 1) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + let end_ele_0 = match sli.get(axis, 0) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let end_ele_1 = match sli.get(axis, 1) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); + let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); + let axes: Option> = Option::None(()); + let steps: Option> = Option::None(()); + let sub_t: Tensor = t.slice(starts, ends, axes, steps); + splited_t.append(sub_t); + i += 1; }; - let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); - let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); - let axes: Option> = Option::None(()); - let steps: Option> = Option::None(()); - let sub_t: Tensor = t.slice(starts, ends, axes, steps); - splited_t.append(sub_t); - i += 1; - }; splited_t } @@ -118,42 +121,46 @@ fn split_has_split, +Drop, +TensorTrait,>( let mut sli: MutMatrix = MutMatrixImpl::new((*t).shape.len(), 2); let mut pos: usize = 0; let mut i = 0; - while i != (*t).shape.len() { - let s: usize = *(*t).shape.at(i); - sli.set(i, 0, 0); - sli.set(i, 1, s); - i += 1; - }; + while i != (*t) + .shape + .len() { + let s: usize = *(*t).shape.at(i); + sli.set(i, 0, 0); + sli.set(i, 1, s); + i += 1; + }; let mut i: usize = 0; - while i != split.data.len() { - let spl: usize = split.at(indices: array![i].span()); - sli.set(axis, 0, pos); - pos += spl; - sli.set(axis, 1, pos); + while i != split + .data + .len() { + let spl: usize = split.at(indices: array![i].span()); + sli.set(axis, 0, pos); + pos += spl; + sli.set(axis, 1, pos); - let end_ele_0 = match sli.get(axis, 0) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, - }; - let end_ele_1 = match sli.get(axis, 1) { - Option::Some(res) => res, - Option::None => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + let end_ele_0 = match sli.get(axis, 0) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let end_ele_1 = match sli.get(axis, 1) { + Option::Some(res) => res, + Option::None => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); + let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); + let axes: Option> = Option::None(()); + let steps: Option> = Option::None(()); + let sub_t: Tensor = t.slice(starts, ends, axes, steps); + splited_t.append(sub_t); + i += 1; }; - let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); - let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); - let axes: Option> = Option::None(()); - let steps: Option> = Option::None(()); - let sub_t: Tensor = t.slice(starts, ends, axes, steps); - splited_t.append(sub_t); - i += 1; - }; splited_t } diff --git a/src/operators/tensor/manipulation/split_to_sequence.cairo b/src/operators/tensor/manipulation/split_to_sequence.cairo index 2e8e4704c..51c3662e8 100644 --- a/src/operators/tensor/manipulation/split_to_sequence.cairo +++ b/src/operators/tensor/manipulation/split_to_sequence.cairo @@ -202,4 +202,4 @@ fn split_has_split, +Drop, +TensorTrait,>( }; splited_t -} \ No newline at end of file +} diff --git a/src/operators/tensor/math/cumsum.cairo b/src/operators/tensor/math/cumsum.cairo index 6fef885d2..247494617 100644 --- a/src/operators/tensor/math/cumsum.cairo +++ b/src/operators/tensor/math/cumsum.cairo @@ -51,37 +51,38 @@ fn cumsum_forward< let mut index: usize = 0; - while index != data.len() { - let current_indices = unravel_index(index, *self.shape); - let axis_value = *current_indices[axis]; - - if axis_value == 0 { - if exclusive { - output_data.append(zero); + while index != data + .len() { + let current_indices = unravel_index(index, *self.shape); + let axis_value = *current_indices[axis]; + + if axis_value == 0 { + if exclusive { + output_data.append(zero); + } else { + output_data.append(*(data)[index]); + } } else { - output_data.append(*(data)[index]); + let previous_axis_element_indices = replace_index( + current_indices, axis, axis_value - 1 + ); + let previous_axis_element_index = ravel_index( + *self.shape, previous_axis_element_indices + ); + + if exclusive { + output_data + .append( + *output_data[previous_axis_element_index] + + *(data)[previous_axis_element_index] + ); + } else { + output_data.append(*output_data[previous_axis_element_index] + *(data)[index]); + }; } - } else { - let previous_axis_element_indices = replace_index( - current_indices, axis, axis_value - 1 - ); - let previous_axis_element_index = ravel_index( - *self.shape, previous_axis_element_indices - ); - - if exclusive { - output_data - .append( - *output_data[previous_axis_element_index] - + *(data)[previous_axis_element_index] - ); - } else { - output_data.append(*output_data[previous_axis_element_index] + *(data)[index]); - }; - } - index += 1; - }; + index += 1; + }; TensorTrait::::new(*self.shape, output_data.span()) } @@ -106,54 +107,59 @@ fn cumsum_reverse< let data = *self.data; let mut output_data = array![]; let mut index: usize = 0; - while index != data.len() { - let current_indices = unravel_index(index, *self.shape); - let mut axis_value = *current_indices[axis]; - - if axis_value == 0 { - // If the axis value is 0, we need to sum all the elements - // in the axis. - let mut sum = *(data)[index]; - if exclusive { - sum = zero; - } + while index != data + .len() { + let current_indices = unravel_index(index, *self.shape); + let mut axis_value = *current_indices[axis]; + + if axis_value == 0 { + // If the axis value is 0, we need to sum all the elements + // in the axis. + let mut sum = *(data)[index]; + if exclusive { + sum = zero; + } - let end_index = *(*self.shape)[axis] - 1; + let end_index = *(*self.shape)[axis] - 1; - loop { - axis_value += 1; - if axis_value > end_index { - break (); - } + loop { + axis_value += 1; + if axis_value > end_index { + break (); + } - let next_axis_element_indices = replace_index(current_indices, axis, axis_value); - let next_axis_element_index = ravel_index(*self.shape, next_axis_element_indices); - sum += *data[next_axis_element_index]; - }; - - output_data.append(sum); - } else { - // If the axis value is not 0, we only need to do a subtraction - let previous_axis_element_indices = replace_index( - current_indices, axis, axis_value - 1 - ); - let previous_axis_element_index = ravel_index( - *self.shape, previous_axis_element_indices - ); - - if exclusive { - output_data.append(*output_data[previous_axis_element_index] - *(data)[index]); - } else { - output_data - .append( - *output_data[previous_axis_element_index] - - *(data)[previous_axis_element_index] + let next_axis_element_indices = replace_index( + current_indices, axis, axis_value ); + let next_axis_element_index = ravel_index( + *self.shape, next_axis_element_indices + ); + sum += *data[next_axis_element_index]; + }; + + output_data.append(sum); + } else { + // If the axis value is not 0, we only need to do a subtraction + let previous_axis_element_indices = replace_index( + current_indices, axis, axis_value - 1 + ); + let previous_axis_element_index = ravel_index( + *self.shape, previous_axis_element_indices + ); + + if exclusive { + output_data.append(*output_data[previous_axis_element_index] - *(data)[index]); + } else { + output_data + .append( + *output_data[previous_axis_element_index] + - *(data)[previous_axis_element_index] + ); + } } - } - index += 1; - }; + index += 1; + }; TensorTrait::::new(*self.shape, output_data.span()) } diff --git a/src/operators/tensor/math/gather_nd.cairo b/src/operators/tensor/math/gather_nd.cairo index e5f340487..99564de6f 100644 --- a/src/operators/tensor/math/gather_nd.cairo +++ b/src/operators/tensor/math/gather_nd.cairo @@ -127,10 +127,11 @@ fn gather_nd, impl TCopy: Copy, impl TDr if (index == *indices_shape_last - 1) { let mut data_ind: usize = result; - while data_ind != result + incrementer { - index_data.append(data_ind + incr); - data_ind += 1; - }; + while data_ind != result + + incrementer { + index_data.append(data_ind + incr); + data_ind += 1; + }; result = 0; }; diff --git a/src/operators/tensor/math/layer_normalization.cairo b/src/operators/tensor/math/layer_normalization.cairo index b6aa33ec0..6473ce15a 100644 --- a/src/operators/tensor/math/layer_normalization.cairo +++ b/src/operators/tensor/math/layer_normalization.cairo @@ -99,7 +99,8 @@ fn layer_normalization< let x_diff = x_mat - x_mean; let x_squared_diff = x_diff * x_diff; - let variance = x_squared_diff.reduce_sum(Option::Some(array![1].span()), Option::Some(true), Option::Some(false)) + let variance = x_squared_diff + .reduce_sum(Option::Some(array![1].span()), Option::Some(true), Option::Some(false)) / TensorTrait::new(shape_one.span(), col_number_tensor.span()); let variance_eps = variance + TensorTrait::new(shape_one.span(), epsilon_tensor.span()); diff --git a/src/operators/tensor/math/less_equal.cairo b/src/operators/tensor/math/less_equal.cairo index dea786878..dd54a0a41 100644 --- a/src/operators/tensor/math/less_equal.cairo +++ b/src/operators/tensor/math/less_equal.cairo @@ -4,12 +4,7 @@ use orion::operators::tensor::helpers::{ }; /// Cf: TensorTrait::less_equal docstring -fn less_equal< - T, - impl TPartialOrd: PartialOrd, - impl TCopy: Copy, - impl TDrop: Drop ->( +fn less_equal, impl TCopy: Copy, impl TDrop: Drop>( y: @Tensor, z: @Tensor ) -> Tensor { let broadcasted_shape = broadcast_shape(*y.shape, *z.shape); diff --git a/src/operators/tensor/math/max.cairo b/src/operators/tensor/math/max.cairo index 3ce6d4919..c6a55576a 100644 --- a/src/operators/tensor/math/max.cairo +++ b/src/operators/tensor/math/max.cairo @@ -28,35 +28,36 @@ fn max< let mut tensor_counter: usize = 1; - while tensor_counter != tensors.len() { - let mut new_max_data: Array = array![]; + while tensor_counter != tensors + .len() { + let mut new_max_data: Array = array![]; - let mut current_tensor = *tensors.at(tensor_counter); + let mut current_tensor = *tensors.at(tensor_counter); - let mut broadcasted_shape = broadcast_shape(max_shape, current_tensor.shape); + let mut broadcasted_shape = broadcast_shape(max_shape, current_tensor.shape); - let num_elements = len_from_shape(broadcasted_shape); - let mut n: usize = 0; - while n != num_elements { - let mut indices_broadcasted = unravel_index(n, broadcasted_shape); + let num_elements = len_from_shape(broadcasted_shape); + let mut n: usize = 0; + while n != num_elements { + let mut indices_broadcasted = unravel_index(n, broadcasted_shape); - let mut indices_self = broadcast_index_mapping(max_shape, indices_broadcasted); - let mut indices_other = broadcast_index_mapping( - current_tensor.shape, indices_broadcasted - ); + let mut indices_self = broadcast_index_mapping(max_shape, indices_broadcasted); + let mut indices_other = broadcast_index_mapping( + current_tensor.shape, indices_broadcasted + ); - let mut max_value = NumberTrait::max( - *(max_data)[indices_self], *(current_tensor.data)[indices_other] - ); - new_max_data.append(max_value); + let mut max_value = NumberTrait::max( + *(max_data)[indices_self], *(current_tensor.data)[indices_other] + ); + new_max_data.append(max_value); - n += 1; - }; + n += 1; + }; - max_shape = broadcasted_shape; - max_data = new_max_data.span(); - tensor_counter += 1; - }; + max_shape = broadcasted_shape; + max_data = new_max_data.span(); + tensor_counter += 1; + }; TensorTrait::::new(max_shape, max_data) } diff --git a/src/operators/tensor/math/min.cairo b/src/operators/tensor/math/min.cairo index 2e7acadab..9f0de2dfd 100644 --- a/src/operators/tensor/math/min.cairo +++ b/src/operators/tensor/math/min.cairo @@ -28,35 +28,36 @@ fn min< let mut tensor_counter: usize = 1; - while tensor_counter != tensors.len() { - let mut new_min_data: Array = array![]; + while tensor_counter != tensors + .len() { + let mut new_min_data: Array = array![]; - let mut current_tensor = *tensors.at(tensor_counter); + let mut current_tensor = *tensors.at(tensor_counter); - let mut broadcasted_shape = broadcast_shape(min_shape, current_tensor.shape); + let mut broadcasted_shape = broadcast_shape(min_shape, current_tensor.shape); - let num_elements = len_from_shape(broadcasted_shape); - let mut n: usize = 0; - while n != num_elements { - let mut indices_broadcasted = unravel_index(n, broadcasted_shape); + let num_elements = len_from_shape(broadcasted_shape); + let mut n: usize = 0; + while n != num_elements { + let mut indices_broadcasted = unravel_index(n, broadcasted_shape); - let mut indices_self = broadcast_index_mapping(min_shape, indices_broadcasted); - let mut indices_other = broadcast_index_mapping( - current_tensor.shape, indices_broadcasted - ); + let mut indices_self = broadcast_index_mapping(min_shape, indices_broadcasted); + let mut indices_other = broadcast_index_mapping( + current_tensor.shape, indices_broadcasted + ); - let mut min_value = NumberTrait::min( - *(min_data)[indices_self], *(current_tensor.data)[indices_other] - ); - new_min_data.append(min_value); + let mut min_value = NumberTrait::min( + *(min_data)[indices_self], *(current_tensor.data)[indices_other] + ); + new_min_data.append(min_value); - n += 1; - }; + n += 1; + }; - min_shape = broadcasted_shape; - min_data = new_min_data.span(); - tensor_counter += 1; - }; + min_shape = broadcasted_shape; + min_data = new_min_data.span(); + tensor_counter += 1; + }; TensorTrait::::new(min_shape, min_data) } diff --git a/src/operators/tensor/math/range.cairo b/src/operators/tensor/math/range.cairo index 1edc0f628..b1950f2b9 100644 --- a/src/operators/tensor/math/range.cairo +++ b/src/operators/tensor/math/range.cairo @@ -18,11 +18,12 @@ fn range< ) -> Tensor { let mut result: Array = array![]; let zero: T = NumberTrait::zero(); - while !(step >= zero && start >= end) && !(step <= zero && start <= end) { - let v = start; - result.append(v); - start += step; - }; + while !(step >= zero && start >= end) + && !(step <= zero && start <= end) { + let v = start; + result.append(v); + start += step; + }; let shape = array![result.len()]; diff --git a/src/operators/tensor/math/reduce_l1.cairo b/src/operators/tensor/math/reduce_l1.cairo index 29b83b69d..422af1a6d 100644 --- a/src/operators/tensor/math/reduce_l1.cairo +++ b/src/operators/tensor/math/reduce_l1.cairo @@ -16,5 +16,10 @@ fn reduce_l1< ) -> Tensor { let data_abs = self.abs(); - data_abs.reduce_sum(Option::Some(array![axis.try_into().unwrap()].span()), Option::Some(keepdims), Option::Some(false)) + data_abs + .reduce_sum( + Option::Some(array![axis.try_into().unwrap()].span()), + Option::Some(keepdims), + Option::Some(false) + ) } diff --git a/src/operators/tensor/math/resize.cairo b/src/operators/tensor/math/resize.cairo index ab0ef86f7..5b93e5497 100644 --- a/src/operators/tensor/math/resize.cairo +++ b/src/operators/tensor/math/resize.cairo @@ -283,13 +283,14 @@ fn interpolate_nd< KEEP_ASPECT_RATIO_POLICY::NOT_LARGER => { let mut scale = *scale_factors.at(*axes.at(0)); let mut i = 1; - while i != axes.len() { - if scale > *scale_factors.at(*axes.at(i)) { - scale = *scale_factors.at(*axes.at(i)); - } + while i != axes + .len() { + if scale > *scale_factors.at(*axes.at(i)) { + scale = *scale_factors.at(*axes.at(i)); + } - i += 1; - }; + i += 1; + }; let mut scale_factors: Array = array![]; let mut d = 0; @@ -341,13 +342,14 @@ fn interpolate_nd< KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER => { let mut scale = *scale_factors.at(*axes.at(0)); let mut i = 1; - while i != axes.len() { - if scale < *scale_factors.at(*axes.at(i)) { - scale = *scale_factors.at(*axes.at(i)); - } + while i != axes + .len() { + if scale < *scale_factors.at(*axes.at(i)) { + scale = *scale_factors.at(*axes.at(i)); + } - i += 1; - }; + i += 1; + }; let mut scale_factors: Array = array![]; let mut d = 0; @@ -409,12 +411,13 @@ fn interpolate_nd< }; let mut i = 0; - while i != scale_factors.len() { - let item = *scale_factors.at(i) - * NumberTrait::new_unscaled((*(*(data).shape).at(i)).into(), false); - output_size.append(item.try_into().unwrap()); - i += 1; - }; + while i != scale_factors + .len() { + let item = *scale_factors.at(i) + * NumberTrait::new_unscaled((*(*(data).shape).at(i)).into(), false); + output_size.append(item.try_into().unwrap()); + i += 1; + }; (output_size.span(), scale_factors) }, @@ -422,17 +425,18 @@ fn interpolate_nd< let mut ret: Array> = array![]; let mut i = 0; - while i != output_size.len() { - let mut temp = ArrayTrait::::new(); - let mut j = 0; - while j != *output_size.at(i) { - temp.append(j); - j += 1; - }; + while i != output_size + .len() { + let mut temp = ArrayTrait::::new(); + let mut j = 0; + while j != *output_size.at(i) { + temp.append(j); + j += 1; + }; - ret.append(temp.span()); - i += 1; - }; + ret.append(temp.span()); + i += 1; + }; let mut ret = cartesian(ret.span()); let mut ret_data = array![]; @@ -442,10 +446,11 @@ fn interpolate_nd< Option::Some(X) => { let mut x: Array = array![]; let mut i = 0; - while i != X.len() { - x.append(NumberTrait::new_unscaled((*X.at(i)).into(), false)); - i += 1; - }; + while i != X + .len() { + x.append(NumberTrait::new_unscaled((*X.at(i)).into(), false)); + i += 1; + }; let mut x = x.span(); let item = interpolate_nd_with_x( @@ -499,14 +504,15 @@ fn cartesian(mut arrays: Span>,) -> Array> { let mut m = n; let mut i = 0; - while i != arrays.len() { - m = m / (*(arrays.at(i))).len(); - let mut out = repeat(*(arrays.at(i)), m); - out = repeat_2(out, size_arrays, i); + while i != arrays + .len() { + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); - output_arrays.append(out); - i += 1; - }; + output_arrays.append(out); + i += 1; + }; let output_arrays = output_arrays.span(); @@ -532,15 +538,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A let mut i = 0; while i != index { let mut j = 1; - while j != *size_array.at(index - 1 - i) { - let mut k = 0; - while k != size { - array.append(*array.at(k)); - k += 1; - }; + while j != *size_array + .at(index - 1 - i) { + let mut k = 0; + while k != size { + array.append(*array.at(k)); + k += 1; + }; - j += 1; - }; + j += 1; + }; size = size * *size_array.at(index - 1 - i); i += 1; @@ -552,15 +559,16 @@ fn repeat_2(mut array: Array, size_array: Span, index: usize) -> A fn repeat(array: Span, m: usize,) -> Array { let mut out = array![]; let mut j = 0; - while j != array.len() { - let mut k = 0; - while k != m { - out.append(*array.at(j)); - k += 1; - }; + while j != array + .len() { + let mut k = 0; + while k != m { + out.append(*array.at(j)); + k += 1; + }; - j += 1; - }; + j += 1; + }; out } @@ -648,35 +656,37 @@ fn interpolate_nd_with_x< }; let mut i = 0; - while i != *(*data).shape.at(0) { - let data = get_row_n(data, i); - - let mut r = interpolate_nd_with_x( - @data, - n - 1, - scale_factor, - output_size, - x, - antialias, - mode, - nearest_mode, - reduced_roi, - extrapolation_value, - coordinate_transformation_mode, - exclude_outside, - cubic_coeff_a - ); + while i != *(*data) + .shape + .at(0) { + let data = get_row_n(data, i); + + let mut r = interpolate_nd_with_x( + @data, + n - 1, + scale_factor, + output_size, + x, + antialias, + mode, + nearest_mode, + reduced_roi, + extrapolation_value, + coordinate_transformation_mode, + exclude_outside, + cubic_coeff_a + ); + + loop { + match r.data.pop_front() { + Option::Some(item) => { res1d.append(*item); }, + Option::None => { break; } + } + }; - loop { - match r.data.pop_front() { - Option::Some(item) => { res1d.append(*item); }, - Option::None => { break; } - } + i += 1; }; - i += 1; - }; - let mut shape = array![]; shape.append(res1d.len()); @@ -727,14 +737,16 @@ fn get_row_n, +Copy, +Drop,>( let mut stride_output = 1; let mut i = 0; - while i != (*data).shape.len() { - if i != 0 { - output_shape.append(*(*data).shape.at(i)); - stride_output = stride_output * *(*data).shape.at(i); - } + while i != (*data) + .shape + .len() { + if i != 0 { + output_shape.append(*(*data).shape.at(i)); + stride_output = stride_output * *(*data).shape.at(i); + } - i += 1; - }; + i += 1; + }; let mut i = 0; while i != stride_output { @@ -897,17 +909,19 @@ fn interpolate_1d_with_x< let mut coeffs_exclude_outside: Array = array![]; let mut sum = NumberTrait::zero(); let mut i = 0; - while i != idxes.data.len() { - if *idxes.data.at(i) { - coeffs_exclude_outside.append(NumberTrait::zero()); - sum += NumberTrait::zero(); - } else { - coeffs_exclude_outside.append(*coeffs.data.at(i)); - sum += *coeffs.data.at(i); - } + while i != idxes + .data + .len() { + if *idxes.data.at(i) { + coeffs_exclude_outside.append(NumberTrait::zero()); + sum += NumberTrait::zero(); + } else { + coeffs_exclude_outside.append(*coeffs.data.at(i)); + sum += *coeffs.data.at(i); + } - i += 1; - }; + i += 1; + }; let mut coeff_div: Array = array![]; let mut i = 0; @@ -974,21 +988,23 @@ fn get_neighbor< let mut idxes_centered = array![]; let mut ret = array![]; let mut i = 0; - while i != idxes.data.len() { - ret.append(*padded.at(*idxes.data.at(i))); - - if *idxes.data.at(i) >= pad_width { - if (*idxes.data.at(i) - pad_width) >= (*data).data.len() { - idxes_centered.append(true); + while i != idxes + .data + .len() { + ret.append(*padded.at(*idxes.data.at(i))); + + if *idxes.data.at(i) >= pad_width { + if (*idxes.data.at(i) - pad_width) >= (*data).data.len() { + idxes_centered.append(true); + } else { + idxes_centered.append(false); + } } else { - idxes_centered.append(false); + idxes_centered.append(true); } - } else { - idxes_centered.append(true); - } - i += 1; - }; + i += 1; + }; let mut shape = array![]; shape.append(idxes.data.len()); @@ -1049,22 +1065,23 @@ fn get_neighbor_idxes< } let mut i = 0; - while i != n / 2 { - if i_low - i < 0 { - idxes.append(i_high + i); - i_high += 1; - } else { - idxes.append(i_low - i); - } - if i_high + i >= limit { - i_low -= 1; - idxes.append(i_low - i); - } else { - idxes.append(i_high + i); - } + while i != n + / 2 { + if i_low - i < 0 { + idxes.append(i_high + i); + i_high += 1; + } else { + idxes.append(i_low - i); + } + if i_high + i >= limit { + i_low -= 1; + idxes.append(i_low - i); + } else { + idxes.append(i_high + i); + } - i += 1; - } + i += 1; + } } else { core::panic_with_felt252('MUST BE EVEN'); } @@ -1129,21 +1146,22 @@ fn linear_coeffs_antialias< // arange and clip + compute sum let mut i = start; - while i != start + footprint { - let value = NumberTrait::one() - NumberTrait::abs((i - ratio) * scale); - - if value < NumberTrait::zero() { - coeffs.append(NumberTrait::zero()); - } else if value > NumberTrait::one() { - coeffs.append(NumberTrait::one()); - sum += NumberTrait::one(); - } else { - coeffs.append(value); - sum += value; - } + while i != start + + footprint { + let value = NumberTrait::one() - NumberTrait::abs((i - ratio) * scale); + + if value < NumberTrait::zero() { + coeffs.append(NumberTrait::zero()); + } else if value > NumberTrait::one() { + coeffs.append(NumberTrait::one()); + sum += NumberTrait::one(); + } else { + coeffs.append(value); + sum += value; + } - i += NumberTrait::one(); - }; + i += NumberTrait::one(); + }; let n = coeffs.len(); diff --git a/src/operators/tensor/quantization/qlinear_matmul.cairo b/src/operators/tensor/quantization/qlinear_matmul.cairo index 325f4fd30..f869ae6d0 100644 --- a/src/operators/tensor/quantization/qlinear_matmul.cairo +++ b/src/operators/tensor/quantization/qlinear_matmul.cairo @@ -78,14 +78,15 @@ fn qlinear_matmul< b_shape_reduced.append(n); let mut i = 0; - while i != stride(a_shape) / (m * k) { - result_updates( - @subtensor(@dequantized_a, i * (m * k), a_shape_reduced.span()), - @subtensor(@dequantized_b, i * (k * n), b_shape_reduced.span()), - ref x_data - ); - i += 1; - }; + while i != stride(a_shape) + / (m * k) { + result_updates( + @subtensor(@dequantized_a, i * (m * k), a_shape_reduced.span()), + @subtensor(@dequantized_b, i * (k * n), b_shape_reduced.span()), + ref x_data + ); + i += 1; + }; x_shape(ref x_shape, a_shape, m, n); let x = TensorTrait::new(x_shape.span(), x_data.span()); @@ -94,12 +95,13 @@ fn qlinear_matmul< } fn x_shape(ref x_data: Array, mut shape: Span, m: usize, n: usize) { - while shape.len() != 2 { - match shape.pop_front() { - Option::Some(elem) => { x_data.append(*elem); }, - Option::None => { break; } + while shape + .len() != 2 { + match shape.pop_front() { + Option::Some(elem) => { x_data.append(*elem); }, + Option::None => { break; } + }; }; - }; x_data.append(m); x_data.append(n); diff --git a/tests/lib.cairo b/tests/lib.cairo index c408347ef..f5cecb77d 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,4 +5,3 @@ mod nodes; mod ml; mod operators; - diff --git a/tests/nodes.cairo b/tests/nodes.cairo index 244d8b0c9..11461e004 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -985,3 +985,6 @@ mod argmax_negative_axis_keepdims; mod argmax_negative_axis_keepdims_select_last_index; mod argmax_no_keepdims; mod argmax_no_keepdims_select_last_index; +mod roi_align_mode_max; +mod roi_align_aligned_false; +mod roi_align_aligned_true; diff --git a/tests/nodes/gather_elements_axis1.cairo b/tests/nodes/gather_elements_axis1.cairo index 82b08e271..638f2d60b 100644 --- a/tests/nodes/gather_elements_axis1.cairo +++ b/tests/nodes/gather_elements_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(1)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_elements_axis2.cairo b/tests/nodes/gather_elements_axis2.cairo index 0e0b7caea..dc7b0c373 100644 --- a/tests/nodes/gather_elements_axis2.cairo +++ b/tests/nodes/gather_elements_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(2)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_elements_default.cairo b/tests/nodes/gather_elements_default.cairo index 9d1a099c1..98d66dd58 100644 --- a/tests/nodes/gather_elements_default.cairo +++ b/tests/nodes/gather_elements_default.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_elements_negative_indices.cairo b/tests/nodes/gather_elements_negative_indices.cairo index 0aff55566..8980eff88 100644 --- a/tests/nodes/gather_elements_negative_indices.cairo +++ b/tests/nodes/gather_elements_negative_indices.cairo @@ -18,7 +18,7 @@ fn test_gather_elements_negative_indices() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_elements(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather_elements(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_fp16x16_3d_axis1.cairo b/tests/nodes/gather_fp16x16_3d_axis1.cairo index d10ab5245..643aacd78 100644 --- a/tests/nodes/gather_fp16x16_3d_axis1.cairo +++ b/tests/nodes/gather_fp16x16_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_gather_fp16x16_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(1)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_fp16x16_3d_axis2.cairo b/tests/nodes/gather_fp16x16_3d_axis2.cairo index 40ef5691d..e256664b6 100644 --- a/tests/nodes/gather_fp16x16_3d_axis2.cairo +++ b/tests/nodes/gather_fp16x16_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_gather_fp16x16_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(2)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_fp16x16_3d_default.cairo b/tests/nodes/gather_fp16x16_3d_default.cairo index 2003b0838..affd608a4 100644 --- a/tests/nodes/gather_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_negative_axis.cairo b/tests/nodes/gather_negative_axis.cairo index 27c511614..b5b4854e5 100644 --- a/tests/nodes/gather_negative_axis.cairo +++ b/tests/nodes/gather_negative_axis.cairo @@ -18,7 +18,7 @@ fn test_gather_negative_axis() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(-1)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(-1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_negative_indices.cairo b/tests/nodes/gather_negative_indices.cairo index 559a276ea..03dfbaa74 100644 --- a/tests/nodes/gather_negative_indices.cairo +++ b/tests/nodes/gather_negative_indices.cairo @@ -18,7 +18,7 @@ fn test_gather_negative_indices() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather(indices:input_1, axis:Option::Some(0)); + let y_0 = input_0.gather(indices: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/reshape_reduced_dims.cairo b/tests/nodes/reshape_reduced_dims.cairo index 7952505d1..4d42db34e 100644 --- a/tests/nodes/reshape_reduced_dims.cairo +++ b/tests/nodes/reshape_reduced_dims.cairo @@ -14,7 +14,7 @@ fn test_reshape_reduced_dims() { let input_0 = input_0::input_0(); let z_0 = output_0::output_0(); - let y_0 = input_0.reshape(array![2,12].span(), false); + let y_0 = input_0.reshape(array![2, 12].span(), false); assert_eq(y_0, z_0); } diff --git a/tests/nodes/reshape_reordered_all_dims.cairo b/tests/nodes/reshape_reordered_all_dims.cairo index 237c867c2..b9d1f456e 100644 --- a/tests/nodes/reshape_reordered_all_dims.cairo +++ b/tests/nodes/reshape_reordered_all_dims.cairo @@ -14,7 +14,7 @@ fn test_reshape_reordered_all_dims() { let input_0 = input_0::input_0(); let z_0 = output_0::output_0(); - let y_0 = input_0.reshape(array![4,2,3].span(), false); + let y_0 = input_0.reshape(array![4, 2, 3].span(), false); assert_eq(y_0, z_0); } diff --git a/tests/nodes/reshape_reordered_last_dims.cairo b/tests/nodes/reshape_reordered_last_dims.cairo index 5c5f4fd7e..79a82982a 100644 --- a/tests/nodes/reshape_reordered_last_dims.cairo +++ b/tests/nodes/reshape_reordered_last_dims.cairo @@ -14,7 +14,7 @@ fn test_reshape_reordered_last_dims() { let input_0 = input_0::input_0(); let z_0 = output_0::output_0(); - let y_0 = input_0.reshape(array![2,4,3].span(), false); + let y_0 = input_0.reshape(array![2, 4, 3].span(), false); assert_eq(y_0, z_0); } diff --git a/tests/nodes/roi_align_aligned_false.cairo b/tests/nodes/roi_align_aligned_false.cairo new file mode 100644 index 000000000..39bf86488 --- /dev/null +++ b/tests/nodes/roi_align_aligned_false.cairo @@ -0,0 +1,35 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::nn::FP16x16NN; +use orion::numbers::FixedTrait; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::nn::NNTrait; +use orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE; +use orion::numbers::FP16x16; +use orion::operators::tensor::{TensorTrait, U32Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_roi_align_aligned_false() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = NNTrait::roi_align( + @input_0, + @input_1, + @TensorTrait::new(array![3].span(), array![0, 0, 0].span()), + Option::Some(TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL), + Option::None, + Option::Some(5), + Option::Some(5), + Option::Some(FP16x16 { mag: 131072, sign: false }), + Option::Some(FP16x16 { mag: 65536, sign: false }) + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/roi_align_aligned_false/input_0.cairo b/tests/nodes/roi_align_aligned_false/input_0.cairo new file mode 100644 index 000000000..91e6fffbd --- /dev/null +++ b/tests/nodes/roi_align_aligned_false/input_0.cairo @@ -0,0 +1,115 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(10); + shape.append(10); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 18114, sign: false }); + data.append(FP16x16 { mag: 46858, sign: false }); + data.append(FP16x16 { mag: 12831, sign: false }); + data.append(FP16x16 { mag: 22387, sign: false }); + data.append(FP16x16 { mag: 30395, sign: false }); + data.append(FP16x16 { mag: 1697, sign: false }); + data.append(FP16x16 { mag: 19418, sign: false }); + data.append(FP16x16 { mag: 42716, sign: false }); + data.append(FP16x16 { mag: 31824, sign: false }); + data.append(FP16x16 { mag: 47513, sign: false }); + data.append(FP16x16 { mag: 63157, sign: false }); + data.append(FP16x16 { mag: 5865, sign: false }); + data.append(FP16x16 { mag: 19129, sign: false }); + data.append(FP16x16 { mag: 44256, sign: false }); + data.append(FP16x16 { mag: 1533, sign: false }); + data.append(FP16x16 { mag: 40186, sign: false }); + data.append(FP16x16 { mag: 52985, sign: false }); + data.append(FP16x16 { mag: 34891, sign: false }); + data.append(FP16x16 { mag: 58929, sign: false }); + data.append(FP16x16 { mag: 29274, sign: false }); + data.append(FP16x16 { mag: 21397, sign: false }); + data.append(FP16x16 { mag: 55567, sign: false }); + data.append(FP16x16 { mag: 63556, sign: false }); + data.append(FP16x16 { mag: 16193, sign: false }); + data.append(FP16x16 { mag: 61184, sign: false }); + data.append(FP16x16 { mag: 12307, sign: false }); + data.append(FP16x16 { mag: 31234, sign: false }); + data.append(FP16x16 { mag: 28232, sign: false }); + data.append(FP16x16 { mag: 22282, sign: false }); + data.append(FP16x16 { mag: 14168, sign: false }); + data.append(FP16x16 { mag: 1350, sign: false }); + data.append(FP16x16 { mag: 11272, sign: false }); + data.append(FP16x16 { mag: 14123, sign: false }); + data.append(FP16x16 { mag: 28796, sign: false }); + data.append(FP16x16 { mag: 4279, sign: false }); + data.append(FP16x16 { mag: 22321, sign: false }); + data.append(FP16x16 { mag: 50620, sign: false }); + data.append(FP16x16 { mag: 25696, sign: false }); + data.append(FP16x16 { mag: 16652, sign: false }); + data.append(FP16x16 { mag: 38004, sign: false }); + data.append(FP16x16 { mag: 26620, sign: false }); + data.append(FP16x16 { mag: 14378, sign: false }); + data.append(FP16x16 { mag: 29314, sign: false }); + data.append(FP16x16 { mag: 30716, sign: false }); + data.append(FP16x16 { mag: 46589, sign: false }); + data.append(FP16x16 { mag: 61125, sign: false }); + data.append(FP16x16 { mag: 64323, sign: false }); + data.append(FP16x16 { mag: 41418, sign: false }); + data.append(FP16x16 { mag: 11324, sign: false }); + data.append(FP16x16 { mag: 40101, sign: false }); + data.append(FP16x16 { mag: 20296, sign: false }); + data.append(FP16x16 { mag: 8408, sign: false }); + data.append(FP16x16 { mag: 32663, sign: false }); + data.append(FP16x16 { mag: 33213, sign: false }); + data.append(FP16x16 { mag: 28042, sign: false }); + data.append(FP16x16 { mag: 1133, sign: false }); + data.append(FP16x16 { mag: 28757, sign: false }); + data.append(FP16x16 { mag: 2818, sign: false }); + data.append(FP16x16 { mag: 30611, sign: false }); + data.append(FP16x16 { mag: 46655, sign: false }); + data.append(FP16x16 { mag: 6625, sign: false }); + data.append(FP16x16 { mag: 55554, sign: false }); + data.append(FP16x16 { mag: 30972, sign: false }); + data.append(FP16x16 { mag: 11645, sign: false }); + data.append(FP16x16 { mag: 65031, sign: false }); + data.append(FP16x16 { mag: 26489, sign: false }); + data.append(FP16x16 { mag: 12248, sign: false }); + data.append(FP16x16 { mag: 51085, sign: false }); + data.append(FP16x16 { mag: 65182, sign: false }); + data.append(FP16x16 { mag: 63497, sign: false }); + data.append(FP16x16 { mag: 8952, sign: false }); + data.append(FP16x16 { mag: 24058, sign: false }); + data.append(FP16x16 { mag: 45947, sign: false }); + data.append(FP16x16 { mag: 40855, sign: false }); + data.append(FP16x16 { mag: 64664, sign: false }); + data.append(FP16x16 { mag: 36601, sign: false }); + data.append(FP16x16 { mag: 45776, sign: false }); + data.append(FP16x16 { mag: 36759, sign: false }); + data.append(FP16x16 { mag: 57593, sign: false }); + data.append(FP16x16 { mag: 65064, sign: false }); + data.append(FP16x16 { mag: 37335, sign: false }); + data.append(FP16x16 { mag: 55777, sign: false }); + data.append(FP16x16 { mag: 43981, sign: false }); + data.append(FP16x16 { mag: 61643, sign: false }); + data.append(FP16x16 { mag: 57350, sign: false }); + data.append(FP16x16 { mag: 49125, sign: false }); + data.append(FP16x16 { mag: 10813, sign: false }); + data.append(FP16x16 { mag: 6874, sign: false }); + data.append(FP16x16 { mag: 10217, sign: false }); + data.append(FP16x16 { mag: 16475, sign: false }); + data.append(FP16x16 { mag: 45953, sign: false }); + data.append(FP16x16 { mag: 26581, sign: false }); + data.append(FP16x16 { mag: 51635, sign: false }); + data.append(FP16x16 { mag: 22682, sign: false }); + data.append(FP16x16 { mag: 2719, sign: false }); + data.append(FP16x16 { mag: 19647, sign: false }); + data.append(FP16x16 { mag: 33384, sign: false }); + data.append(FP16x16 { mag: 24425, sign: false }); + data.append(FP16x16 { mag: 35926, sign: false }); + data.append(FP16x16 { mag: 3289, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_aligned_false/input_1.cairo b/tests/nodes/roi_align_aligned_false/input_1.cairo new file mode 100644 index 000000000..58887c027 --- /dev/null +++ b/tests/nodes/roi_align_aligned_false/input_1.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_aligned_false/output_0.cairo b/tests/nodes/roi_align_aligned_false/output_0.cairo new file mode 100644 index 000000000..951678f12 --- /dev/null +++ b/tests/nodes/roi_align_aligned_false/output_0.cairo @@ -0,0 +1,90 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(1); + shape.append(5); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 30567, sign: false }); + data.append(FP16x16 { mag: 29265, sign: false }); + data.append(FP16x16 { mag: 22316, sign: false }); + data.append(FP16x16 { mag: 37280, sign: false }); + data.append(FP16x16 { mag: 39765, sign: false }); + data.append(FP16x16 { mag: 24338, sign: false }); + data.append(FP16x16 { mag: 28152, sign: false }); + data.append(FP16x16 { mag: 25134, sign: false }); + data.append(FP16x16 { mag: 36453, sign: false }); + data.append(FP16x16 { mag: 23006, sign: false }); + data.append(FP16x16 { mag: 18140, sign: false }); + data.append(FP16x16 { mag: 32000, sign: false }); + data.append(FP16x16 { mag: 34222, sign: false }); + data.append(FP16x16 { mag: 36226, sign: false }); + data.append(FP16x16 { mag: 27332, sign: false }); + data.append(FP16x16 { mag: 30883, sign: false }); + data.append(FP16x16 { mag: 31746, sign: false }); + data.append(FP16x16 { mag: 45249, sign: false }); + data.append(FP16x16 { mag: 32246, sign: false }); + data.append(FP16x16 { mag: 57501, sign: false }); + data.append(FP16x16 { mag: 40887, sign: false }); + data.append(FP16x16 { mag: 46691, sign: false }); + data.append(FP16x16 { mag: 41217, sign: false }); + data.append(FP16x16 { mag: 21987, sign: false }); + data.append(FP16x16 { mag: 22902, sign: false }); + data.append(FP16x16 { mag: 19803, sign: false }); + data.append(FP16x16 { mag: 28210, sign: false }); + data.append(FP16x16 { mag: 30774, sign: false }); + data.append(FP16x16 { mag: 26066, sign: false }); + data.append(FP16x16 { mag: 35537, sign: false }); + data.append(FP16x16 { mag: 23956, sign: false }); + data.append(FP16x16 { mag: 46197, sign: false }); + data.append(FP16x16 { mag: 33848, sign: false }); + data.append(FP16x16 { mag: 20783, sign: false }); + data.append(FP16x16 { mag: 45969, sign: false }); + data.append(FP16x16 { mag: 19086, sign: false }); + data.append(FP16x16 { mag: 33154, sign: false }); + data.append(FP16x16 { mag: 42441, sign: false }); + data.append(FP16x16 { mag: 40861, sign: false }); + data.append(FP16x16 { mag: 54387, sign: false }); + data.append(FP16x16 { mag: 38769, sign: false }); + data.append(FP16x16 { mag: 48421, sign: false }); + data.append(FP16x16 { mag: 46191, sign: false }); + data.append(FP16x16 { mag: 54863, sign: false }); + data.append(FP16x16 { mag: 58282, sign: false }); + data.append(FP16x16 { mag: 40807, sign: false }); + data.append(FP16x16 { mag: 40322, sign: false }); + data.append(FP16x16 { mag: 46511, sign: false }); + data.append(FP16x16 { mag: 40327, sign: false }); + data.append(FP16x16 { mag: 30049, sign: false }); + data.append(FP16x16 { mag: 15627, sign: false }); + data.append(FP16x16 { mag: 22148, sign: false }); + data.append(FP16x16 { mag: 24359, sign: false }); + data.append(FP16x16 { mag: 39976, sign: false }); + data.append(FP16x16 { mag: 49811, sign: false }); + data.append(FP16x16 { mag: 24688, sign: false }); + data.append(FP16x16 { mag: 24807, sign: false }); + data.append(FP16x16 { mag: 46837, sign: false }); + data.append(FP16x16 { mag: 60575, sign: false }); + data.append(FP16x16 { mag: 63752, sign: false }); + data.append(FP16x16 { mag: 37676, sign: false }); + data.append(FP16x16 { mag: 38182, sign: false }); + data.append(FP16x16 { mag: 37416, sign: false }); + data.append(FP16x16 { mag: 49932, sign: false }); + data.append(FP16x16 { mag: 57474, sign: false }); + data.append(FP16x16 { mag: 35095, sign: false }); + data.append(FP16x16 { mag: 16815, sign: false }); + data.append(FP16x16 { mag: 14031, sign: false }); + data.append(FP16x16 { mag: 18324, sign: false }); + data.append(FP16x16 { mag: 23592, sign: false }); + data.append(FP16x16 { mag: 28605, sign: false }); + data.append(FP16x16 { mag: 22965, sign: false }); + data.append(FP16x16 { mag: 18923, sign: false }); + data.append(FP16x16 { mag: 23995, sign: false }); + data.append(FP16x16 { mag: 15395, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_aligned_true.cairo b/tests/nodes/roi_align_aligned_true.cairo new file mode 100644 index 000000000..178e6949b --- /dev/null +++ b/tests/nodes/roi_align_aligned_true.cairo @@ -0,0 +1,36 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::nn::FP16x16NN; +use orion::numbers::FixedTrait; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::nn::NNTrait; +use orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE; +use orion::numbers::FP16x16; +use orion::operators::tensor::{TensorTrait, U32Tensor}; + + +#[test] +#[available_gas(2000000000)] +fn test_roi_align_aligned_true() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = NNTrait::roi_align( + @input_0, + @input_1, + @TensorTrait::new(array![3].span(), array![0, 0, 0].span()), + Option::Some(TRANSFORMATION_MODE::HALF_PIXEL), + Option::None, + Option::Some(5), + Option::Some(5), + Option::Some(FP16x16 { mag: 131072, sign: false }), + Option::Some(FP16x16 { mag: 65536, sign: false }) + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/roi_align_aligned_true/input_0.cairo b/tests/nodes/roi_align_aligned_true/input_0.cairo new file mode 100644 index 000000000..91e6fffbd --- /dev/null +++ b/tests/nodes/roi_align_aligned_true/input_0.cairo @@ -0,0 +1,115 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(10); + shape.append(10); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 18114, sign: false }); + data.append(FP16x16 { mag: 46858, sign: false }); + data.append(FP16x16 { mag: 12831, sign: false }); + data.append(FP16x16 { mag: 22387, sign: false }); + data.append(FP16x16 { mag: 30395, sign: false }); + data.append(FP16x16 { mag: 1697, sign: false }); + data.append(FP16x16 { mag: 19418, sign: false }); + data.append(FP16x16 { mag: 42716, sign: false }); + data.append(FP16x16 { mag: 31824, sign: false }); + data.append(FP16x16 { mag: 47513, sign: false }); + data.append(FP16x16 { mag: 63157, sign: false }); + data.append(FP16x16 { mag: 5865, sign: false }); + data.append(FP16x16 { mag: 19129, sign: false }); + data.append(FP16x16 { mag: 44256, sign: false }); + data.append(FP16x16 { mag: 1533, sign: false }); + data.append(FP16x16 { mag: 40186, sign: false }); + data.append(FP16x16 { mag: 52985, sign: false }); + data.append(FP16x16 { mag: 34891, sign: false }); + data.append(FP16x16 { mag: 58929, sign: false }); + data.append(FP16x16 { mag: 29274, sign: false }); + data.append(FP16x16 { mag: 21397, sign: false }); + data.append(FP16x16 { mag: 55567, sign: false }); + data.append(FP16x16 { mag: 63556, sign: false }); + data.append(FP16x16 { mag: 16193, sign: false }); + data.append(FP16x16 { mag: 61184, sign: false }); + data.append(FP16x16 { mag: 12307, sign: false }); + data.append(FP16x16 { mag: 31234, sign: false }); + data.append(FP16x16 { mag: 28232, sign: false }); + data.append(FP16x16 { mag: 22282, sign: false }); + data.append(FP16x16 { mag: 14168, sign: false }); + data.append(FP16x16 { mag: 1350, sign: false }); + data.append(FP16x16 { mag: 11272, sign: false }); + data.append(FP16x16 { mag: 14123, sign: false }); + data.append(FP16x16 { mag: 28796, sign: false }); + data.append(FP16x16 { mag: 4279, sign: false }); + data.append(FP16x16 { mag: 22321, sign: false }); + data.append(FP16x16 { mag: 50620, sign: false }); + data.append(FP16x16 { mag: 25696, sign: false }); + data.append(FP16x16 { mag: 16652, sign: false }); + data.append(FP16x16 { mag: 38004, sign: false }); + data.append(FP16x16 { mag: 26620, sign: false }); + data.append(FP16x16 { mag: 14378, sign: false }); + data.append(FP16x16 { mag: 29314, sign: false }); + data.append(FP16x16 { mag: 30716, sign: false }); + data.append(FP16x16 { mag: 46589, sign: false }); + data.append(FP16x16 { mag: 61125, sign: false }); + data.append(FP16x16 { mag: 64323, sign: false }); + data.append(FP16x16 { mag: 41418, sign: false }); + data.append(FP16x16 { mag: 11324, sign: false }); + data.append(FP16x16 { mag: 40101, sign: false }); + data.append(FP16x16 { mag: 20296, sign: false }); + data.append(FP16x16 { mag: 8408, sign: false }); + data.append(FP16x16 { mag: 32663, sign: false }); + data.append(FP16x16 { mag: 33213, sign: false }); + data.append(FP16x16 { mag: 28042, sign: false }); + data.append(FP16x16 { mag: 1133, sign: false }); + data.append(FP16x16 { mag: 28757, sign: false }); + data.append(FP16x16 { mag: 2818, sign: false }); + data.append(FP16x16 { mag: 30611, sign: false }); + data.append(FP16x16 { mag: 46655, sign: false }); + data.append(FP16x16 { mag: 6625, sign: false }); + data.append(FP16x16 { mag: 55554, sign: false }); + data.append(FP16x16 { mag: 30972, sign: false }); + data.append(FP16x16 { mag: 11645, sign: false }); + data.append(FP16x16 { mag: 65031, sign: false }); + data.append(FP16x16 { mag: 26489, sign: false }); + data.append(FP16x16 { mag: 12248, sign: false }); + data.append(FP16x16 { mag: 51085, sign: false }); + data.append(FP16x16 { mag: 65182, sign: false }); + data.append(FP16x16 { mag: 63497, sign: false }); + data.append(FP16x16 { mag: 8952, sign: false }); + data.append(FP16x16 { mag: 24058, sign: false }); + data.append(FP16x16 { mag: 45947, sign: false }); + data.append(FP16x16 { mag: 40855, sign: false }); + data.append(FP16x16 { mag: 64664, sign: false }); + data.append(FP16x16 { mag: 36601, sign: false }); + data.append(FP16x16 { mag: 45776, sign: false }); + data.append(FP16x16 { mag: 36759, sign: false }); + data.append(FP16x16 { mag: 57593, sign: false }); + data.append(FP16x16 { mag: 65064, sign: false }); + data.append(FP16x16 { mag: 37335, sign: false }); + data.append(FP16x16 { mag: 55777, sign: false }); + data.append(FP16x16 { mag: 43981, sign: false }); + data.append(FP16x16 { mag: 61643, sign: false }); + data.append(FP16x16 { mag: 57350, sign: false }); + data.append(FP16x16 { mag: 49125, sign: false }); + data.append(FP16x16 { mag: 10813, sign: false }); + data.append(FP16x16 { mag: 6874, sign: false }); + data.append(FP16x16 { mag: 10217, sign: false }); + data.append(FP16x16 { mag: 16475, sign: false }); + data.append(FP16x16 { mag: 45953, sign: false }); + data.append(FP16x16 { mag: 26581, sign: false }); + data.append(FP16x16 { mag: 51635, sign: false }); + data.append(FP16x16 { mag: 22682, sign: false }); + data.append(FP16x16 { mag: 2719, sign: false }); + data.append(FP16x16 { mag: 19647, sign: false }); + data.append(FP16x16 { mag: 33384, sign: false }); + data.append(FP16x16 { mag: 24425, sign: false }); + data.append(FP16x16 { mag: 35926, sign: false }); + data.append(FP16x16 { mag: 3289, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_aligned_true/input_1.cairo b/tests/nodes/roi_align_aligned_true/input_1.cairo new file mode 100644 index 000000000..58887c027 --- /dev/null +++ b/tests/nodes/roi_align_aligned_true/input_1.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_aligned_true/output_0.cairo b/tests/nodes/roi_align_aligned_true/output_0.cairo new file mode 100644 index 000000000..cb00c2000 --- /dev/null +++ b/tests/nodes/roi_align_aligned_true/output_0.cairo @@ -0,0 +1,90 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(1); + shape.append(5); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 33933, sign: false }); + data.append(FP16x16 { mag: 22505, sign: false }); + data.append(FP16x16 { mag: 21161, sign: false }); + data.append(FP16x16 { mag: 29318, sign: false }); + data.append(FP16x16 { mag: 41574, sign: false }); + data.append(FP16x16 { mag: 26416, sign: false }); + data.append(FP16x16 { mag: 35169, sign: false }); + data.append(FP16x16 { mag: 29018, sign: false }); + data.append(FP16x16 { mag: 31859, sign: false }); + data.append(FP16x16 { mag: 26365, sign: false }); + data.append(FP16x16 { mag: 16462, sign: false }); + data.append(FP16x16 { mag: 26224, sign: false }); + data.append(FP16x16 { mag: 33785, sign: false }); + data.append(FP16x16 { mag: 45571, sign: false }); + data.append(FP16x16 { mag: 22710, sign: false }); + data.append(FP16x16 { mag: 21957, sign: false }); + data.append(FP16x16 { mag: 30153, sign: false }); + data.append(FP16x16 { mag: 38539, sign: false }); + data.append(FP16x16 { mag: 22535, sign: false }); + data.append(FP16x16 { mag: 44887, sign: false }); + data.append(FP16x16 { mag: 32321, sign: false }); + data.append(FP16x16 { mag: 46796, sign: false }); + data.append(FP16x16 { mag: 53853, sign: false }); + data.append(FP16x16 { mag: 30928, sign: false }); + data.append(FP16x16 { mag: 26473, sign: false }); + data.append(FP16x16 { mag: 20116, sign: false }); + data.append(FP16x16 { mag: 14331, sign: false }); + data.append(FP16x16 { mag: 21868, sign: false }); + data.append(FP16x16 { mag: 31981, sign: false }); + data.append(FP16x16 { mag: 31913, sign: false }); + data.append(FP16x16 { mag: 12261, sign: false }); + data.append(FP16x16 { mag: 32205, sign: false }); + data.append(FP16x16 { mag: 36445, sign: false }); + data.append(FP16x16 { mag: 27470, sign: false }); + data.append(FP16x16 { mag: 24157, sign: false }); + data.append(FP16x16 { mag: 9389, sign: false }); + data.append(FP16x16 { mag: 30201, sign: false }); + data.append(FP16x16 { mag: 39133, sign: false }); + data.append(FP16x16 { mag: 34796, sign: false }); + data.append(FP16x16 { mag: 32650, sign: false }); + data.append(FP16x16 { mag: 18272, sign: false }); + data.append(FP16x16 { mag: 28742, sign: false }); + data.append(FP16x16 { mag: 39465, sign: false }); + data.append(FP16x16 { mag: 45877, sign: false }); + data.append(FP16x16 { mag: 49311, sign: false }); + data.append(FP16x16 { mag: 37839, sign: false }); + data.append(FP16x16 { mag: 46031, sign: false }); + data.append(FP16x16 { mag: 47519, sign: false }); + data.append(FP16x16 { mag: 48087, sign: false }); + data.append(FP16x16 { mag: 53497, sign: false }); + data.append(FP16x16 { mag: 15684, sign: false }); + data.append(FP16x16 { mag: 26706, sign: false }); + data.append(FP16x16 { mag: 22144, sign: false }); + data.append(FP16x16 { mag: 16549, sign: false }); + data.append(FP16x16 { mag: 31086, sign: false }); + data.append(FP16x16 { mag: 24056, sign: false }); + data.append(FP16x16 { mag: 17705, sign: false }); + data.append(FP16x16 { mag: 26903, sign: false }); + data.append(FP16x16 { mag: 42066, sign: false }); + data.append(FP16x16 { mag: 54445, sign: false }); + data.append(FP16x16 { mag: 36414, sign: false }); + data.append(FP16x16 { mag: 29772, sign: false }); + data.append(FP16x16 { mag: 36467, sign: false }); + data.append(FP16x16 { mag: 49161, sign: false }); + data.append(FP16x16 { mag: 60948, sign: false }); + data.append(FP16x16 { mag: 43422, sign: false }); + data.append(FP16x16 { mag: 36809, sign: false }); + data.append(FP16x16 { mag: 31540, sign: false }); + data.append(FP16x16 { mag: 32469, sign: false }); + data.append(FP16x16 { mag: 43667, sign: false }); + data.append(FP16x16 { mag: 43487, sign: false }); + data.append(FP16x16 { mag: 24386, sign: false }); + data.append(FP16x16 { mag: 13474, sign: false }); + data.append(FP16x16 { mag: 12633, sign: false }); + data.append(FP16x16 { mag: 16243, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_mode_max.cairo b/tests/nodes/roi_align_mode_max.cairo new file mode 100644 index 000000000..a9f7eb06a --- /dev/null +++ b/tests/nodes/roi_align_mode_max.cairo @@ -0,0 +1,36 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::nn::FP16x16NN; +use orion::numbers::FixedTrait; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::nn::NNTrait; +use orion::operators::nn::functional::roi_align::TRANSFORMATION_MODE; +use orion::operators::nn::functional::roi_align::MODE; +use orion::numbers::FP16x16; +use orion::operators::tensor::{TensorTrait, U32Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_roi_align_mode_max() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = NNTrait::roi_align( + @input_0, + @input_1, + @TensorTrait::new(array![3].span(), array![0, 0, 0].span()), + Option::Some(TRANSFORMATION_MODE::OUTPUT_HALF_PIXEL), + Option::Some(MODE::MAX), + Option::Some(5), + Option::Some(5), + Option::Some(FP16x16 { mag: 131072, sign: false }), + Option::Some(FP16x16 { mag: 65536, sign: false }) + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/roi_align_mode_max/input_0.cairo b/tests/nodes/roi_align_mode_max/input_0.cairo new file mode 100644 index 000000000..91e6fffbd --- /dev/null +++ b/tests/nodes/roi_align_mode_max/input_0.cairo @@ -0,0 +1,115 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(10); + shape.append(10); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 18114, sign: false }); + data.append(FP16x16 { mag: 46858, sign: false }); + data.append(FP16x16 { mag: 12831, sign: false }); + data.append(FP16x16 { mag: 22387, sign: false }); + data.append(FP16x16 { mag: 30395, sign: false }); + data.append(FP16x16 { mag: 1697, sign: false }); + data.append(FP16x16 { mag: 19418, sign: false }); + data.append(FP16x16 { mag: 42716, sign: false }); + data.append(FP16x16 { mag: 31824, sign: false }); + data.append(FP16x16 { mag: 47513, sign: false }); + data.append(FP16x16 { mag: 63157, sign: false }); + data.append(FP16x16 { mag: 5865, sign: false }); + data.append(FP16x16 { mag: 19129, sign: false }); + data.append(FP16x16 { mag: 44256, sign: false }); + data.append(FP16x16 { mag: 1533, sign: false }); + data.append(FP16x16 { mag: 40186, sign: false }); + data.append(FP16x16 { mag: 52985, sign: false }); + data.append(FP16x16 { mag: 34891, sign: false }); + data.append(FP16x16 { mag: 58929, sign: false }); + data.append(FP16x16 { mag: 29274, sign: false }); + data.append(FP16x16 { mag: 21397, sign: false }); + data.append(FP16x16 { mag: 55567, sign: false }); + data.append(FP16x16 { mag: 63556, sign: false }); + data.append(FP16x16 { mag: 16193, sign: false }); + data.append(FP16x16 { mag: 61184, sign: false }); + data.append(FP16x16 { mag: 12307, sign: false }); + data.append(FP16x16 { mag: 31234, sign: false }); + data.append(FP16x16 { mag: 28232, sign: false }); + data.append(FP16x16 { mag: 22282, sign: false }); + data.append(FP16x16 { mag: 14168, sign: false }); + data.append(FP16x16 { mag: 1350, sign: false }); + data.append(FP16x16 { mag: 11272, sign: false }); + data.append(FP16x16 { mag: 14123, sign: false }); + data.append(FP16x16 { mag: 28796, sign: false }); + data.append(FP16x16 { mag: 4279, sign: false }); + data.append(FP16x16 { mag: 22321, sign: false }); + data.append(FP16x16 { mag: 50620, sign: false }); + data.append(FP16x16 { mag: 25696, sign: false }); + data.append(FP16x16 { mag: 16652, sign: false }); + data.append(FP16x16 { mag: 38004, sign: false }); + data.append(FP16x16 { mag: 26620, sign: false }); + data.append(FP16x16 { mag: 14378, sign: false }); + data.append(FP16x16 { mag: 29314, sign: false }); + data.append(FP16x16 { mag: 30716, sign: false }); + data.append(FP16x16 { mag: 46589, sign: false }); + data.append(FP16x16 { mag: 61125, sign: false }); + data.append(FP16x16 { mag: 64323, sign: false }); + data.append(FP16x16 { mag: 41418, sign: false }); + data.append(FP16x16 { mag: 11324, sign: false }); + data.append(FP16x16 { mag: 40101, sign: false }); + data.append(FP16x16 { mag: 20296, sign: false }); + data.append(FP16x16 { mag: 8408, sign: false }); + data.append(FP16x16 { mag: 32663, sign: false }); + data.append(FP16x16 { mag: 33213, sign: false }); + data.append(FP16x16 { mag: 28042, sign: false }); + data.append(FP16x16 { mag: 1133, sign: false }); + data.append(FP16x16 { mag: 28757, sign: false }); + data.append(FP16x16 { mag: 2818, sign: false }); + data.append(FP16x16 { mag: 30611, sign: false }); + data.append(FP16x16 { mag: 46655, sign: false }); + data.append(FP16x16 { mag: 6625, sign: false }); + data.append(FP16x16 { mag: 55554, sign: false }); + data.append(FP16x16 { mag: 30972, sign: false }); + data.append(FP16x16 { mag: 11645, sign: false }); + data.append(FP16x16 { mag: 65031, sign: false }); + data.append(FP16x16 { mag: 26489, sign: false }); + data.append(FP16x16 { mag: 12248, sign: false }); + data.append(FP16x16 { mag: 51085, sign: false }); + data.append(FP16x16 { mag: 65182, sign: false }); + data.append(FP16x16 { mag: 63497, sign: false }); + data.append(FP16x16 { mag: 8952, sign: false }); + data.append(FP16x16 { mag: 24058, sign: false }); + data.append(FP16x16 { mag: 45947, sign: false }); + data.append(FP16x16 { mag: 40855, sign: false }); + data.append(FP16x16 { mag: 64664, sign: false }); + data.append(FP16x16 { mag: 36601, sign: false }); + data.append(FP16x16 { mag: 45776, sign: false }); + data.append(FP16x16 { mag: 36759, sign: false }); + data.append(FP16x16 { mag: 57593, sign: false }); + data.append(FP16x16 { mag: 65064, sign: false }); + data.append(FP16x16 { mag: 37335, sign: false }); + data.append(FP16x16 { mag: 55777, sign: false }); + data.append(FP16x16 { mag: 43981, sign: false }); + data.append(FP16x16 { mag: 61643, sign: false }); + data.append(FP16x16 { mag: 57350, sign: false }); + data.append(FP16x16 { mag: 49125, sign: false }); + data.append(FP16x16 { mag: 10813, sign: false }); + data.append(FP16x16 { mag: 6874, sign: false }); + data.append(FP16x16 { mag: 10217, sign: false }); + data.append(FP16x16 { mag: 16475, sign: false }); + data.append(FP16x16 { mag: 45953, sign: false }); + data.append(FP16x16 { mag: 26581, sign: false }); + data.append(FP16x16 { mag: 51635, sign: false }); + data.append(FP16x16 { mag: 22682, sign: false }); + data.append(FP16x16 { mag: 2719, sign: false }); + data.append(FP16x16 { mag: 19647, sign: false }); + data.append(FP16x16 { mag: 33384, sign: false }); + data.append(FP16x16 { mag: 24425, sign: false }); + data.append(FP16x16 { mag: 35926, sign: false }); + data.append(FP16x16 { mag: 3289, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_mode_max/input_1.cairo b/tests/nodes/roi_align_mode_max/input_1.cairo new file mode 100644 index 000000000..58887c027 --- /dev/null +++ b/tests/nodes/roi_align_mode_max/input_1.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/roi_align_mode_max/output_0.cairo b/tests/nodes/roi_align_mode_max/output_0.cairo new file mode 100644 index 000000000..5c0f6623b --- /dev/null +++ b/tests/nodes/roi_align_mode_max/output_0.cairo @@ -0,0 +1,90 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(1); + shape.append(5); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 22578, sign: false }); + data.append(FP16x16 { mag: 24451, sign: false }); + data.append(FP16x16 { mag: 24815, sign: false }); + data.append(FP16x16 { mag: 29274, sign: false }); + data.append(FP16x16 { mag: 24897, sign: false }); + data.append(FP16x16 { mag: 27089, sign: false }); + data.append(FP16x16 { mag: 35750, sign: false }); + data.append(FP16x16 { mag: 43593, sign: false }); + data.append(FP16x16 { mag: 36572, sign: false }); + data.append(FP16x16 { mag: 17767, sign: false }); + data.append(FP16x16 { mag: 13909, sign: false }); + data.append(FP16x16 { mag: 26820, sign: false }); + data.append(FP16x16 { mag: 55165, sign: false }); + data.append(FP16x16 { mag: 51941, sign: false }); + data.append(FP16x16 { mag: 24377, sign: false }); + data.append(FP16x16 { mag: 30694, sign: false }); + data.append(FP16x16 { mag: 26045, sign: false }); + data.append(FP16x16 { mag: 52512, sign: false }); + data.append(FP16x16 { mag: 32566, sign: false }); + data.append(FP16x16 { mag: 36013, sign: false }); + data.append(FP16x16 { mag: 23566, sign: false }); + data.append(FP16x16 { mag: 34057, sign: false }); + data.append(FP16x16 { mag: 35413, sign: false }); + data.append(FP16x16 { mag: 15607, sign: false }); + data.append(FP16x16 { mag: 13102, sign: false }); + data.append(FP16x16 { mag: 19999, sign: false }); + data.append(FP16x16 { mag: 33332, sign: false }); + data.append(FP16x16 { mag: 20904, sign: false }); + data.append(FP16x16 { mag: 26570, sign: false }); + data.append(FP16x16 { mag: 31215, sign: false }); + data.append(FP16x16 { mag: 33332, sign: false }); + data.append(FP16x16 { mag: 55554, sign: false }); + data.append(FP16x16 { mag: 24777, sign: false }); + data.append(FP16x16 { mag: 16342, sign: false }); + data.append(FP16x16 { mag: 52025, sign: false }); + data.append(FP16x16 { mag: 11547, sign: false }); + data.append(FP16x16 { mag: 19246, sign: false }); + data.append(FP16x16 { mag: 29406, sign: false }); + data.append(FP16x16 { mag: 32684, sign: false }); + data.append(FP16x16 { mag: 41385, sign: false }); + data.append(FP16x16 { mag: 33466, sign: false }); + data.append(FP16x16 { mag: 55777, sign: false }); + data.append(FP16x16 { mag: 35184, sign: false }); + data.append(FP16x16 { mag: 61643, sign: false }); + data.append(FP16x16 { mag: 45880, sign: false }); + data.append(FP16x16 { mag: 29410, sign: false }); + data.append(FP16x16 { mag: 33466, sign: false }); + data.append(FP16x16 { mag: 33046, sign: false }); + data.append(FP16x16 { mag: 36985, sign: false }); + data.append(FP16x16 { mag: 27528, sign: false }); + data.append(FP16x16 { mag: 13803, sign: false }); + data.append(FP16x16 { mag: 23005, sign: false }); + data.append(FP16x16 { mag: 24520, sign: false }); + data.append(FP16x16 { mag: 39109, sign: false }); + data.append(FP16x16 { mag: 30478, sign: false }); + data.append(FP16x16 { mag: 21191, sign: false }); + data.append(FP16x16 { mag: 20434, sign: false }); + data.append(FP16x16 { mag: 40868, sign: false }); + data.append(FP16x16 { mag: 65182, sign: false }); + data.append(FP16x16 { mag: 50798, sign: false }); + data.append(FP16x16 { mag: 23425, sign: false }); + data.append(FP16x16 { mag: 36621, sign: false }); + data.append(FP16x16 { mag: 23525, sign: false }); + data.append(FP16x16 { mag: 46074, sign: false }); + data.append(FP16x16 { mag: 41641, sign: false }); + data.append(FP16x16 { mag: 39300, sign: false }); + data.append(FP16x16 { mag: 18310, sign: false }); + data.append(FP16x16 { mag: 11762, sign: false }); + data.append(FP16x16 { mag: 23037, sign: false }); + data.append(FP16x16 { mag: 20820, sign: false }); + data.append(FP16x16 { mag: 23580, sign: false }); + data.append(FP16x16 { mag: 26707, sign: false }); + data.append(FP16x16 { mag: 15632, sign: false }); + data.append(FP16x16 { mag: 28741, sign: false }); + data.append(FP16x16 { mag: 17244, sign: false }); + TensorTrait::new(shape.span(), data.span()) +}