xqdan commented on a change in pull request #8079:
URL: https://github.com/apache/tvm/pull/8079#discussion_r638529020
##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,61 @@
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
reg.reset_attr(attr_name);
});
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
- const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
- ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is
registered before";
- OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String
descr) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+ reg.describe(descr);
+});
+
+TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
+ .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
+ if (value.type_code() == kTVMPackedFuncHandle) {
+ // do an eager copy of the PackedFunc to avoid deleting function from
frontend.
+ PackedFunc* fcopy = new PackedFunc(value.operator
tvm::runtime::PackedFunc());
+ auto f = [=](const Array<Type>& args, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) -> bool {
+ Array<Type> input_types(args.begin(), args.end() - 1);
+ // call customized relation functions
+ Type ret_type = (*fcopy)(input_types, attrs);
+ if (ret_type.defined()) {
+ // the last argument is output
+ reporter->Assign(args[args.size() - 1], ret_type);
+ return true;
+ }
+ return false;
+ };
+ // adjust function call to relay type system with TypeReporter
+ auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int,
const Attrs&,
+ const TypeReporter&)>(f);
+ reg.add_type_rel(rel_name, type_rel);
+ } else if (value.type_code() == kTVMNullptr) {
+ // Call relation functions of relay
+ auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
Review comment:
I can add case for 1. custom type relation goes `if (value.type_code()
== kTVMPackedFuncHandle)` branch, no need to search by
"tvm.relay.type_relation.") + rel_name;
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]