This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b228037a291bbc5d97fbeadceb75efc8fb94f2c7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 5 09:31:21 2023 -0500

    Expose attrs argument of "ir.IRModule" to Rust bindings
---
 rust/tvm/src/ir/module.rs | 16 +++++++++++-----
 src/ir/module.cc          | 16 +++++++++++++++-
 2 files changed, 26 insertions(+), 6 deletions(-)

diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 8f71a8be2c..4cdca826ec 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -28,7 +28,7 @@ use crate::runtime::array::Array;
 use crate::runtime::function::Result;
 use crate::runtime::map::Map;
 use crate::runtime::string::String as TVMString;
-use crate::runtime::{external, IsObjectRef, Object};
+use crate::runtime::{external, IsObjectRef, Object, ObjectRef};
 
 use super::expr::GlobalVar;
 use super::function::BaseFunc;
@@ -62,7 +62,7 @@ external! {
     #[name("relay.parser.ParseExpr")]
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
     #[name("ir.IRModule")]
-    fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, 
TypeData>) -> IRModule;
+    fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, 
TypeData>, attrs: Map<TVMString, ObjectRef>) -> IRModule;
     // Module methods
     #[name("ir.Module_Add")]
     fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, 
update: bool) -> IRModule;
@@ -99,18 +99,24 @@ external! {
 // Note: we don't expose update here as update is going to be removed.
 
 impl IRModule {
-    pub fn new<'a, F, T>(funcs: F, types: T) -> Result<IRModule>
+    pub fn new<'a, F, T, A>(funcs: F, types: T, attrs: A) -> Result<IRModule>
     where
         F: IntoIterator<Item = (&'a GlobalVar, &'a BaseFunc)>,
         T: IntoIterator<Item = (&'a GlobalTypeVar, &'a TypeData)>,
+        A: IntoIterator<Item = (&'a TVMString, &'a ObjectRef)>,
     {
-        module_new(Map::from_iter(funcs), Map::from_iter(types))
+        module_new(
+            Map::from_iter(funcs),
+            Map::from_iter(types),
+            Map::from_iter(attrs),
+        )
     }
 
     pub fn empty() -> Result<IRModule> {
         let funcs = HashMap::<GlobalVar, BaseFunc>::new();
         let types = HashMap::<GlobalTypeVar, TypeData>::new();
-        IRModule::new(funcs.iter(), types.iter())
+        let attrs = HashMap::<TVMString, ObjectRef>::new();
+        IRModule::new(funcs.iter(), types.iter(), attrs.iter())
     }
 
     pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
diff --git a/src/ir/module.cc b/src/ir/module.cc
index ba66a66894..77316f55ed 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -383,7 +383,21 @@ TVM_REGISTER_NODE_TYPE(IRModuleNode);
 
 TVM_REGISTER_GLOBAL("ir.IRModule")
     .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, 
tvm::Map<GlobalTypeVar, TypeData> types,
-                       tvm::DictAttrs attrs) { return IRModule(funcs, types, 
{}, {}, attrs); });
+                       tvm::ObjectRef attrs) {
+      auto dict_attrs = [&attrs]() {
+        if (!attrs.defined()) {
+          return DictAttrs();
+        } else if (auto* as_dict_attrs = attrs.as<tvm::DictAttrsNode>()) {
+          return GetRef<tvm::DictAttrs>(as_dict_attrs);
+        } else if (attrs.as<tvm::MapNode>()) {
+          return tvm::DictAttrs(Downcast<Map<String, ObjectRef>>(attrs));
+        } else {
+          LOG(FATAL) << "Expected attrs argument to be either DictAttrs or 
Map<String,ObjectRef>";
+        }
+      }();
+
+      return IRModule(funcs, types, {}, {}, dict_attrs);
+    });
 
 TVM_REGISTER_GLOBAL("ir.Module_Add")
     .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool 
update) -> IRModule {

Reply via email to