This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 959669c  [Rust] More Rust bindings for Attrs (#7082)
959669c is described below

commit 959669c648d63e6d71df3849bad2293a8866a3f9
Author: Andrew Liu <[email protected]>
AuthorDate: Sat Dec 26 06:15:56 2020 -0800

    [Rust] More Rust bindings for Attrs (#7082)
---
 include/tvm/relay/attrs/nn.h             |  6 ++--
 rust/tvm/src/ir/relay/attrs/nn.rs        | 36 ++++++++++++++++++++++
 rust/tvm/src/ir/relay/attrs/transform.rs | 52 ++++++++++++++++++++++++++++++++
 3 files changed, 91 insertions(+), 3 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index 4d867be..c3c58e5 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -717,7 +717,7 @@ struct AvgPool2DAttrs : public 
tvm::AttrsNode<AvgPool2DAttrs> {
   Array<IndexExpr> pool_size;
   Array<IndexExpr> strides;
   Array<IndexExpr> padding;
-  std::string layout;
+  tvm::String layout;
   bool ceil_mode;
   bool count_include_pad;
 
@@ -977,8 +977,8 @@ struct FIFOBufferAttrs : public 
tvm::AttrsNode<FIFOBufferAttrs> {
 struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
   double scale_h;
   double scale_w;
-  std::string layout;
-  std::string method;
+  tvm::String layout;
+  tvm::String method;
   bool align_corners;
 
   TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
diff --git a/rust/tvm/src/ir/relay/attrs/nn.rs 
b/rust/tvm/src/ir/relay/attrs/nn.rs
index 7ecd92f..f0137fa 100644
--- a/rust/tvm/src/ir/relay/attrs/nn.rs
+++ b/rust/tvm/src/ir/relay/attrs/nn.rs
@@ -106,3 +106,39 @@ pub struct BatchNormAttrsNode {
     pub center: bool,
     pub scale: bool,
 }
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "LeakyReluAttrs"]
+#[type_key = "relay.attrs.LeakyReluAttrs"]
+pub struct LeakyReluAttrsNode {
+    pub base: BaseAttrsNode,
+    pub alpha: f64,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "AvgPool2DAttrs"]
+#[type_key = "relay.attrs.AvgPool2DAttrs"]
+pub struct AvgPool2DAttrsNode {
+    pub base: BaseAttrsNode,
+    pub pool_size: Array<IndexExpr>,
+    pub strides: Array<IndexExpr>,
+    pub padding: Array<IndexExpr>,
+    pub layout: TString,
+    pub ceil_mode: bool,
+    pub count_include_pad: bool,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "UpSamplingAttrs"]
+#[type_key = "relay.attrs.UpSamplingAttrs"]
+pub struct UpSamplingAttrsNode {
+    pub base: BaseAttrsNode,
+    pub scale_h: f64,
+    pub scale_w: f64,
+    pub layout: TString,
+    pub method: TString,
+    pub align_corners: bool,
+}
diff --git a/rust/tvm/src/ir/relay/attrs/transform.rs 
b/rust/tvm/src/ir/relay/attrs/transform.rs
index c459f96..b5f7c20 100644
--- a/rust/tvm/src/ir/relay/attrs/transform.rs
+++ b/rust/tvm/src/ir/relay/attrs/transform.rs
@@ -18,8 +18,13 @@
  */
 
 use crate::ir::attrs::BaseAttrsNode;
+use crate::ir::PrimExpr;
+use crate::runtime::array::Array;
+use crate::runtime::ObjectRef;
 use tvm_macros::Object;
 
+type IndexExpr = PrimExpr;
+
 #[repr(C)]
 #[derive(Object, Debug)]
 #[ref_name = "ExpandDimsAttrs"]
@@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode {
     pub axis: i32,
     pub num_newaxis: i32,
 }
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "ConcatenateAttrs"]
+#[type_key = "relay.attrs.ConcatenateAttrs"]
+pub struct ConcatenateAttrsNode {
+    pub base: BaseAttrsNode,
+    pub axis: i32,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "ReshapeAttrs"]
+#[type_key = "relay.attrs.ReshapeAttrs"]
+pub struct ReshapeAttrsNode {
+    pub base: BaseAttrsNode,
+    pub newshape: Array<IndexExpr>,
+    pub reverse: bool,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "SplitAttrs"]
+#[type_key = "relay.attrs.SplitAttrs"]
+pub struct SplitAttrsNode {
+    pub base: BaseAttrsNode,
+    pub indices_or_sections: ObjectRef,
+    pub axis: i32,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "TransposeAttrs"]
+#[type_key = "relay.attrs.TransposeAttrs"]
+pub struct TransposeAttrsNode {
+    pub base: BaseAttrsNode,
+    pub axes: Array<IndexExpr>,
+}
+
+#[repr(C)]
+#[derive(Object, Debug)]
+#[ref_name = "SqueezeAttrs"]
+#[type_key = "relay.attrs.SqueezeAttrs"]
+pub struct SqueezeAttrsNode {
+    pub base: BaseAttrsNode,
+    pub axis: Array<IndexExpr>,
+}

Reply via email to