This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
commit b228037a291bbc5d97fbeadceb75efc8fb94f2c7 Author: Eric Lunderberg <[email protected]> AuthorDate: Wed Apr 5 09:31:21 2023 -0500 Expose attrs argument of "ir.IRModule" to Rust bindings --- rust/tvm/src/ir/module.rs | 16 +++++++++++----- src/ir/module.cc | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 8f71a8be2c..4cdca826ec 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -28,7 +28,7 @@ use crate::runtime::array::Array; use crate::runtime::function::Result; use crate::runtime::map::Map; use crate::runtime::string::String as TVMString; -use crate::runtime::{external, IsObjectRef, Object}; +use crate::runtime::{external, IsObjectRef, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; @@ -62,7 +62,7 @@ external! { #[name("relay.parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; #[name("ir.IRModule")] - fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule; + fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>, attrs: Map<TVMString, ObjectRef>) -> IRModule; // Module methods #[name("ir.Module_Add")] fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule; @@ -99,18 +99,24 @@ external! { // Note: we don't expose update here as update is going to be removed. impl IRModule { - pub fn new<'a, F, T>(funcs: F, types: T) -> Result<IRModule> + pub fn new<'a, F, T, A>(funcs: F, types: T, attrs: A) -> Result<IRModule> where F: IntoIterator<Item = (&'a GlobalVar, &'a BaseFunc)>, T: IntoIterator<Item = (&'a GlobalTypeVar, &'a TypeData)>, + A: IntoIterator<Item = (&'a TVMString, &'a ObjectRef)>, { - module_new(Map::from_iter(funcs), Map::from_iter(types)) + module_new( + Map::from_iter(funcs), + Map::from_iter(types), + Map::from_iter(attrs), + ) } pub fn empty() -> Result<IRModule> { let funcs = HashMap::<GlobalVar, BaseFunc>::new(); let types = HashMap::<GlobalTypeVar, TypeData>::new(); - IRModule::new(funcs.iter(), types.iter()) + let attrs = HashMap::<TVMString, ObjectRef>::new(); + IRModule::new(funcs.iter(), types.iter(), attrs.iter()) } pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule> diff --git a/src/ir/module.cc b/src/ir/module.cc index ba66a66894..77316f55ed 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -383,7 +383,21 @@ TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types, - tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); + tvm::ObjectRef attrs) { + auto dict_attrs = [&attrs]() { + if (!attrs.defined()) { + return DictAttrs(); + } else if (auto* as_dict_attrs = attrs.as<tvm::DictAttrsNode>()) { + return GetRef<tvm::DictAttrs>(as_dict_attrs); + } else if (attrs.as<tvm::MapNode>()) { + return tvm::DictAttrs(Downcast<Map<String, ObjectRef>>(attrs)); + } else { + LOG(FATAL) << "Expected attrs argument to be either DictAttrs or Map<String,ObjectRef>"; + } + }(); + + return IRModule(funcs, types, {}, {}, dict_attrs); + }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
