zackcquic commented on a change in pull request #8079:
URL: https://github.com/apache/tvm/pull/8079#discussion_r637993294
##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,76 @@ def reset_attr(self, attr_name):
"""
_ffi_api.OpResetAttr(self, attr_name)
+ def add_type_rel(self, rel_name, type_rel_func=None):
Review comment:
Since we are now in python world, what will happened if type_rel_func
raises exception?
Could add testcase for this?
##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,76 @@ def reset_attr(self, attr_name):
"""
_ffi_api.OpResetAttr(self, attr_name)
+ def add_type_rel(self, rel_name, type_rel_func=None):
+ """Attach the type function corresponding to the return type.
-def register_op(op_name):
+ Parameters
+ ----------
+ rel_name : str
+ The type relation name to register.
+
+ type_rel_func: function (args: List[Type], num_inputs:int, attrs:
Attrs) -> Type
Review comment:
From the testcase you provided and ``ir.OpAddTypeRel```, I think the
function signature should be function (args : List[Type], attrs) -> Type and
first argument only contains input types while return value has special meaning
(None vs defined)
The callback function is quite different with ```type_rel_func``` in C++, it
will be great to document it here.
##########
File path: include/tvm/ir/op.h
##########
@@ -250,6 +250,12 @@ class OpRegEntry {
*/
template <typename AttrsType>
inline OpRegEntry& set_attrs_type();
+ /*!
+ * \brief Set the the attrs type key and index to be AttrsType.
+ * \tparam key the attribute type key to be set.
Review comment:
Will be great to use upper case after param name (the attirbute -> The
attribute) to be consistent with other comments
##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,61 @@
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
reg.reset_attr(attr_name);
});
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
- const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
- ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is
registered before";
- OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String
descr) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+ reg.describe(descr);
+});
+
+TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
+ .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
+ if (value.type_code() == kTVMPackedFuncHandle) {
+ // do an eager copy of the PackedFunc to avoid deleting function from
frontend.
+ PackedFunc* fcopy = new PackedFunc(value.operator
tvm::runtime::PackedFunc());
+ auto f = [=](const Array<Type>& args, int num_inputs, const Attrs&
attrs,
+ const TypeReporter& reporter) -> bool {
+ Array<Type> input_types(args.begin(), args.end() - 1);
+ // call customized relation functions
+ Type ret_type = (*fcopy)(input_types, attrs);
+ if (ret_type.defined()) {
+ // the last argument is output
+ reporter->Assign(args[args.size() - 1], ret_type);
+ return true;
+ }
+ return false;
+ };
+ // adjust function call to relay type system with TypeReporter
+ auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int,
const Attrs&,
+ const TypeReporter&)>(f);
+ reg.add_type_rel(rel_name, type_rel);
+ } else if (value.type_code() == kTVMNullptr) {
+ // Call relation functions of relay
+ auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
Review comment:
Document this behavior in python world.
Also it will be great to add more test cases for these:
1. search for builtin relation, eg BroadcastRel. (I saw there are some in
test cases.)
2. search for registered custom type relation.
##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,76 @@ def reset_attr(self, attr_name):
"""
_ffi_api.OpResetAttr(self, attr_name)
+ def add_type_rel(self, rel_name, type_rel_func=None):
+ """Attach the type function corresponding to the return type.
-def register_op(op_name):
+ Parameters
+ ----------
+ rel_name : str
+ The type relation name to register.
+
+ type_rel_func: function (args: List[Type], num_inputs:int, attrs:
Attrs) -> Type
Review comment:
Just found in ```ir.OpAddTypeRel```,```type_rel_fun``` could be None,
then it will looking for registered relations.
It will be great to document this behavior.
##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,61 @@
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
reg.reset_attr(attr_name);
});
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
- const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
- ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is
registered before";
- OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String
descr) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+ reg.describe(descr);
+});
+
+TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
+ .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
+ if (value.type_code() == kTVMPackedFuncHandle) {
+ // do an eager copy of the PackedFunc to avoid deleting function from
frontend.
+ PackedFunc* fcopy = new PackedFunc(value.operator
tvm::runtime::PackedFunc());
Review comment:
It will be great to document ```fcopy```'s function signature and it's
special meaning of return value.
(Because in before, ```TypeReporter``` takes care of these.)
Also, consider some error handling (exception raised?) or limitation
documented.
##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,61 @@
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
reg.reset_attr(attr_name);
});
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
- const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
- ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is
registered before";
- OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String
descr) {
+ auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+ reg.describe(descr);
Review comment:
Will be great to check if ```op_name``` is registered before or not?
##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,76 @@ def reset_attr(self, attr_name):
"""
_ffi_api.OpResetAttr(self, attr_name)
+ def add_type_rel(self, rel_name, type_rel_func=None):
+ """Attach the type function corresponding to the return type.
-def register_op(op_name):
+ Parameters
+ ----------
+ rel_name : str
+ The type relation name to register.
+
+ type_rel_func: function (args: List[Type], num_inputs:int, attrs:
Attrs) -> Type
+ The backing relation function which can solve an arbitrary
relation on variables.
+ """
+ _ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)
+
+ def add_argument(self, name, type, description): # pylint:
disable=redefined-builtin
+ """Add arguments information to the function.
+
+ Parameters
+ ----------
+ name : str
+ The argument name.
+ type : str
+ The argument type.
+ description : str
+ The argument description.
+ """
+ _ffi_api.OpAddArgument(self, name, type, description)
+
+ def set_support_level(self, level):
+ """Set the support level of op.
+
+ Parameters
+ ----------
+ level : int
+ The support level.
+ """
+ _ffi_api.OpSetSupportLevel(self, level)
+
+ def set_num_inputs(self, n):
+ """Set the support level of op.
+
+ Parameters
+ ----------
+ n : int
+ The input number.
+ """
+ _ffi_api.OpSetNumInputs(self, n)
+
+ def set_attrs_type_key(self, key):
+ """Set the attribute type key of op.
+
+ Parameters
+ ----------
+ key : str
+ The type key.
+ """
+ _ffi_api.OpSetAttrsTypeKey(self, key)
+
+
+def register_op(op_name, describe=""):
"""Register an operator by name
Parameters
----------
op_name : str
The name of new operator
+ describe : str
+ The detail describe of new operator
"""
- _ffi_api.RegisterOp(op_name)
+ return _ffi_api.RegisterOp(op_name, describe)
Review comment:
I considered ```return``` op reference before, but I removed it finally.
How do you think? Will it be common use case?
##########
File path: src/relay/op/tensor/reduce.cc
##########
@@ -315,6 +315,8 @@ bool ReduceRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
return true;
}
+TVM_REGISTER_GLOBAL("tvm.relay.type_relation.ReduceRel").set_body_typed(ReduceRel);
Review comment:
Maybe we don't need to expose these (which we need to expose all the c++
relations)
I just found in test case
[test_ir_type.py:94](https://github.com/apache/tvm/blob/0b2f30aef2c1c1ed4ec504157b54ceaab182e9ab/tests/python/unittest/test_ir_type.py#L94),
may be we can use this way.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]