This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 959669c [Rust] More Rust bindings for Attrs (#7082)
959669c is described below
commit 959669c648d63e6d71df3849bad2293a8866a3f9
Author: Andrew Liu <[email protected]>
AuthorDate: Sat Dec 26 06:15:56 2020 -0800
[Rust] More Rust bindings for Attrs (#7082)
---
include/tvm/relay/attrs/nn.h | 6 ++--
rust/tvm/src/ir/relay/attrs/nn.rs | 36 ++++++++++++++++++++++
rust/tvm/src/ir/relay/attrs/transform.rs | 52 ++++++++++++++++++++++++++++++++
3 files changed, 91 insertions(+), 3 deletions(-)
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 4d867be..c3c58e5 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -717,7 +717,7 @@ struct AvgPool2DAttrs : public
tvm::AttrsNode<AvgPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
- std::string layout;
+ tvm::String layout;
bool ceil_mode;
bool count_include_pad;
@@ -977,8 +977,8 @@ struct FIFOBufferAttrs : public
tvm::AttrsNode<FIFOBufferAttrs> {
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
double scale_h;
double scale_w;
- std::string layout;
- std::string method;
+ tvm::String layout;
+ tvm::String method;
bool align_corners;
TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs
b/rust/tvm/src/ir/relay/attrs/nn.rs
index 7ecd92f..f0137fa 100644
--- a/rust/tvm/src/ir/relay/attrs/nn.rs
+++ b/rust/tvm/src/ir/relay/attrs/nn.rs
@@ -106,3 +106,39 @@ pub struct BatchNormAttrsNode {
pub center: bool,
pub scale: bool,
}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "LeakyReluAttrs"]
+#[type_key = "relay.attrs.LeakyReluAttrs"]
+pub struct LeakyReluAttrsNode {
+ pub base: BaseAttrsNode,
+ pub alpha: f64,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "AvgPool2DAttrs"]
+#[type_key = "relay.attrs.AvgPool2DAttrs"]
+pub struct AvgPool2DAttrsNode {
+ pub base: BaseAttrsNode,
+ pub pool_size: Array<IndexExpr>,
+ pub strides: Array<IndexExpr>,
+ pub padding: Array<IndexExpr>,
+ pub layout: TString,
+ pub ceil_mode: bool,
+ pub count_include_pad: bool,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "UpSamplingAttrs"]
+#[type_key = "relay.attrs.UpSamplingAttrs"]
+pub struct UpSamplingAttrsNode {
+ pub base: BaseAttrsNode,
+ pub scale_h: f64,
+ pub scale_w: f64,
+ pub layout: TString,
+ pub method: TString,
+ pub align_corners: bool,
+}
diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs
b/rust/tvm/src/ir/relay/attrs/transform.rs
index c459f96..b5f7c20 100644
--- a/rust/tvm/src/ir/relay/attrs/transform.rs
+++ b/rust/tvm/src/ir/relay/attrs/transform.rs
@@ -18,8 +18,13 @@
*/
use crate::ir::attrs::BaseAttrsNode;
+use crate::ir::PrimExpr;
+use crate::runtime::array::Array;
+use crate::runtime::ObjectRef;
use tvm_macros::Object;
+type IndexExpr = PrimExpr;
+
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ExpandDimsAttrs"]
@@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode {
pub axis: i32,
pub num_newaxis: i32,
}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "ConcatenateAttrs"]
+#[type_key = "relay.attrs.ConcatenateAttrs"]
+pub struct ConcatenateAttrsNode {
+ pub base: BaseAttrsNode,
+ pub axis: i32,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "ReshapeAttrs"]
+#[type_key = "relay.attrs.ReshapeAttrs"]
+pub struct ReshapeAttrsNode {
+ pub base: BaseAttrsNode,
+ pub newshape: Array<IndexExpr>,
+ pub reverse: bool,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "SplitAttrs"]
+#[type_key = "relay.attrs.SplitAttrs"]
+pub struct SplitAttrsNode {
+ pub base: BaseAttrsNode,
+ pub indices_or_sections: ObjectRef,
+ pub axis: i32,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "TransposeAttrs"]
+#[type_key = "relay.attrs.TransposeAttrs"]
+pub struct TransposeAttrsNode {
+ pub base: BaseAttrsNode,
+ pub axes: Array<IndexExpr>,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "SqueezeAttrs"]
+#[type_key = "relay.attrs.SqueezeAttrs"]
+pub struct SqueezeAttrsNode {
+ pub base: BaseAttrsNode,
+ pub axis: Array<IndexExpr>,
+}