zackcquic commented on a change in pull request #8079:
URL: https://github.com/apache/tvm/pull/8079#discussion_r642494514



##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,87 @@ def reset_attr(self, attr_name):
         """
         _ffi_api.OpResetAttr(self, attr_name)
 
+    def add_type_rel(self, rel_name, type_rel_func=None):
+        """Attach the type function corresponding to the return type.
 
-def register_op(op_name):
-    """Register an operator by name
+        Parameters
+        ----------
+        rel_name : str
+            The type relation name to register.
+
+        type_rel_func: function (args: List[Type], attrs: Attrs) -> Type

Review comment:
       type_rel_func: Optional[function (args: List[Type], attrs: Attrs) -> 
Type]

##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,87 @@ def reset_attr(self, attr_name):
         """
         _ffi_api.OpResetAttr(self, attr_name)
 
+    def add_type_rel(self, rel_name, type_rel_func=None):
+        """Attach the type function corresponding to the return type.
 
-def register_op(op_name):
-    """Register an operator by name
+        Parameters
+        ----------
+        rel_name : str
+            The type relation name to register.
+
+        type_rel_func: function (args: List[Type], attrs: Attrs) -> Type
+            The backing relation function which can solve an arbitrary 
relation on variables.
+            Differences with type_rel_func in C++:
+            1, when type_rel_func is not None:
+               1) OpAddTypeRel on C++ side will adjust type_rel_func with 
TypeReporter to
+                  calling convention of relay type system.
+               2) type_rel_func returns output argument's type, return None 
means can't
+                  infer output's type.
+               3) only support single output operators for now, the last 
argument is output tensor.

Review comment:
       Have you encountered output with multiple operators?
   (Just curious, I can't come up with an example so far)

##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,66 @@ 
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
   reg.reset_attr(attr_name);
 });
 
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String 
descr) {
   const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
   ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is 
registered before";
-  OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+  auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+  op.describe(descr);
+});
+
+TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
+    .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
+      auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
+      if (value.type_code() == kTVMPackedFuncHandle) {
+        // do an eager copy of the PackedFunc to avoid deleting function from 
frontend.
+        PackedFunc* fcopy = new PackedFunc(value.operator 
tvm::runtime::PackedFunc());
+        auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) -> bool {
+          Array<Type> input_types(args.begin(), args.end() - 1);
+          // call customized relation functions
+          // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> 
Type
+          Type ret_type = (*fcopy)(input_types, attrs);
+          // when defined ret_type, inference of output type is ok, do type 
assign
+          // otherwise, inference failure happens
+          if (ret_type.defined()) {
+            // the last argument is output
+            reporter->Assign(args[args.size() - 1], ret_type);
+            return true;
+          }
+          return false;
+        };
+        // adjust function call to call conventions of relay type system with 
TypeReporter
+        auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int, 
const Attrs&,
+                                                      const TypeReporter&)>(f);
+        reg.add_type_rel(rel_name, type_rel);
+      } else if (value.type_code() == kTVMNullptr) {
+        // Call relation functions of relay
+        auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
+        auto* f = runtime::Registry::Get(func_name);
+        CHECK(f != nullptr) << "AddTypeRel error: no type_relation 
registered.";

Review comment:
       use ICHECK here.

##########
File path: tests/python/relay/test_type_infer.py
##########
@@ -416,6 +417,127 @@ def test_dynamic_function():
     assert mod["main"].params[0].checked_type == s_tt
 
 
+def test_custom_op_infer():
+    """" Tests infer type for custom_op """
+    op_name = "custom_log"
+    _op.register(op_name, r"code(cal log of a tensor.)code")
+    _op.get(op_name).set_num_inputs(1)
+    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
+    # call default relation functions
+    _op.get(op_name).add_type_rel("Identity")
+    _op.get(op_name).set_support_level(1)
+    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
+    _op.register_stateful(op_name, False)
+
+    def clog(x):
+        return relay.Call(_op.get(op_name), [x])
+
+    tp = relay.TensorType((10, 10), "float32")
+    x = relay.var("x", tp)
+    sb = relay.ScopeBuilder()
+    t1 = sb.let("t1", clog(x))
+    t2 = sb.let("t2", relay.add(t1, x))
+    sb.ret(t2)
+    f = relay.Function([x], sb.get())
+    fchecked = infer_expr(f)
+    assert fchecked.checked_type == relay.FuncType([tp], tp)
+
+
+def test_custom_add_broadcast_op():
+    """ Tests infer type for broadcast custom_op """
+    op_name = "custom_broadcast_add"
+    _op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
+    _op.get(op_name).set_num_inputs(2)
+    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
+    _op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
+    # call default relation functions
+    _op.get(op_name).add_type_rel("Broadcast")
+    _op.get(op_name).set_support_level(1)
+    _op.register_stateful(op_name, False)
+
+    def broadcast_add(x, y):
+        return relay.Call(_op.get(op_name), [x, y])
+
+    x = relay.var("x", shape=(10, 4))
+    y = relay.var("y", shape=(5, 10, 1))
+    z = broadcast_add(x, y)
+    func = relay.Function([x, y], z)
+    t1 = relay.TensorType((10, 4), "float32")
+    t2 = relay.TensorType((5, 10, 1), "float32")
+    t3 = relay.TensorType((5, 10, 4), "float32")
+    expected_ty = relay.FuncType([t1, t2], t3)
+    assert_has_type(func, expected_ty)
+
+
+def test_custom_op_rel_infer():
+    """" Tests infer type for custom_op """
+
+    def custom_log1_rel(arg_types, attrs):
+        assert len(arg_types) == 1, "type relation arg number mismatch!"
+        if attrs:
+            assert isinstance(attrs, DictAttrs)
+        inputa_type = arg_types[0]
+        return relay.TensorType(inputa_type.shape, inputa_type.dtype)
+
+    op_name = "custom_log1"
+    _op.register(op_name, r"code(cal log of a tensor.)code")
+    _op.get(op_name).set_num_inputs(1)
+    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
+    _op.get(op_name).set_attrs_type_key("DictAttrs")
+    # call customized relation functions
+    _op.get(op_name).add_type_rel("custom_log1", custom_log1_rel)
+    _op.get(op_name).set_support_level(1)
+    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
+    _op.register_stateful(op_name, False)
+
+    def clog(x):
+        return relay.Call(_op.get(op_name), [x])
+
+    tp = relay.TensorType((10, 10), "float32")
+    x = relay.var("x", tp)
+    sb = relay.ScopeBuilder()
+    t1 = sb.let("t1", clog(x))
+    t2 = sb.let("t2", relay.add(t1, x))
+    sb.ret(t2)
+    f = relay.Function([x], sb.get())
+    fchecked = infer_expr(f)
+    assert fchecked.checked_type == relay.FuncType([tp], tp)
+
+
+def test_custom_op_rel_infer_exception():
+    """" Tests infer type for custom_op """
+
+    def custom_log1_rel(arg_types, attrs):
+        assert len(arg_types) == 2, "type relation arg number mismatch!"
+        return None
+
+    op_name = "custom_log2"
+    _op.register(op_name, r"code(cal log of a tensor.)code")
+    _op.get(op_name).set_num_inputs(1)
+    _op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
+    _op.get(op_name).set_attrs_type_key("DictAttrs")
+    # call customized relation functions
+    _op.get(op_name).add_type_rel("custom_log2", custom_log1_rel)
+    _op.get(op_name).set_support_level(1)
+    _op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
+    _op.register_stateful(op_name, False)
+
+    def clog(x):
+        return relay.Call(_op.get(op_name), [x])
+
+    tp = relay.TensorType((10, 10), "float32")
+    x = relay.var("x", tp)
+    sb = relay.ScopeBuilder()
+    t1 = sb.let("t1", clog(x))
+    t2 = sb.let("t2", relay.add(t1, x))
+    sb.ret(t2)
+    f = relay.Function([x], sb.get())
+    try:

Review comment:
       Check error message, eg:
   ``` python
   with pytest.raises(tvm.error.TVMError) as cm:        
       fchecked = infer_type(f)
       assert "type relation arg number mismatch!" in str(cm.execption)
   ```

##########
File path: python/tvm/relay/op/op.py
##########
@@ -40,6 +40,40 @@ def get(op_name):
     return tvm.ir.Op.get(op_name)

Review comment:
       Be consistent with others (remove return) or document why it is 
specially returned.

##########
File path: python/tvm/relay/op/op.py
##########
@@ -40,6 +40,40 @@ def get(op_name):
     return tvm.ir.Op.get(op_name)
 
 
+def register(op_name, describe=""):
+    """Get the Op for a given name.
+    when the op_name is not registered, create a new empty op with the given 
name.
+    when the op_name has been registered, abort with an error message.

Review comment:
       Could you add a test case for this ? (registering with existing name)

##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,66 @@ 
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
   reg.reset_attr(attr_name);
 });
 
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String 
descr) {
   const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
   ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is 
registered before";
-  OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+  auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+  op.describe(descr);
+});
+
+TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
+    .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
+      auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
+      if (value.type_code() == kTVMPackedFuncHandle) {
+        // do an eager copy of the PackedFunc to avoid deleting function from 
frontend.
+        PackedFunc* fcopy = new PackedFunc(value.operator 
tvm::runtime::PackedFunc());
+        auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& 
attrs,
+                     const TypeReporter& reporter) -> bool {
+          Array<Type> input_types(args.begin(), args.end() - 1);
+          // call customized relation functions
+          // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> 
Type
+          Type ret_type = (*fcopy)(input_types, attrs);
+          // when defined ret_type, inference of output type is ok, do type 
assign
+          // otherwise, inference failure happens
+          if (ret_type.defined()) {
+            // the last argument is output
+            reporter->Assign(args[args.size() - 1], ret_type);

Review comment:
       NIT: Will be more clear to use args.back() (vs.  args[args.zie() - 1]), 
I think.

##########
File path: python/tvm/relay/op/op.py
##########
@@ -40,6 +40,40 @@ def get(op_name):
     return tvm.ir.Op.get(op_name)
 
 
+def register(op_name, describe=""):
+    """Get the Op for a given name.
+    when the op_name is not registered, create a new empty op with the given 
name.
+    when the op_name has been registered, abort with an error message.
+
+    Parameters
+    ----------
+    op_name : str
+        The operator name
+
+    describe : str
+        The operator description
+    """
+
+    tvm.ir.register_op(op_name, describe)
+
+
+def register_stateful(op_name, stateful, level=10):
+    """Register operator pattern for an op.
+
+    Parameters
+    ----------
+    op_name : str
+        The name of the op.
+
+    stateful : bool
+        The stateful flag.
+
+    level : int
+        The priority level
+    """
+    return tvm.ir.register_op_attr(op_name, "TOpIsStateful", stateful, level)

Review comment:
       Be consistent with others (remove return) or document that why it is 
specially returned.

##########
File path: src/ir/op.cc
##########
@@ -102,10 +102,66 @@ 
TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
   reg.reset_attr(attr_name);
 });
 
-TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
+TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String 
descr) {
   const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
   ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is 
registered before";
-  OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+  auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
+  op.describe(descr);
+});
+
+TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
+    .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
+      auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
+      if (value.type_code() == kTVMPackedFuncHandle) {
+        // do an eager copy of the PackedFunc to avoid deleting function from 
frontend.
+        PackedFunc* fcopy = new PackedFunc(value.operator 
tvm::runtime::PackedFunc());

Review comment:
       Will be great to add TODO, and describe design choice here.
   Like @altanh said:
   "should emphasize this python API strictly for user prototyping and not for 
any checked in code, since the type function exposed through FFI is not the C++ 
type relation (it's strictly weaker since we don't have access to the type 
reporter, and cannot propagate constraints to the inputs, only to the output). 
We should probably not use "type relation" to describe the python API but I'm 
not sure what other name to use."

##########
File path: python/tvm/relay/op/op.py
##########
@@ -40,6 +40,40 @@ def get(op_name):
     return tvm.ir.Op.get(op_name)
 
 
+def register(op_name, describe=""):

Review comment:
       I think this is duplicated with ```register_op```, right?  Remove one 
and keep one you like.

##########
File path: python/tvm/ir/op.py
##########
@@ -85,17 +85,87 @@ def reset_attr(self, attr_name):
         """
         _ffi_api.OpResetAttr(self, attr_name)
 
+    def add_type_rel(self, rel_name, type_rel_func=None):
+        """Attach the type function corresponding to the return type.
 
-def register_op(op_name):
-    """Register an operator by name
+        Parameters
+        ----------
+        rel_name : str
+            The type relation name to register.
+
+        type_rel_func: function (args: List[Type], attrs: Attrs) -> Type
+            The backing relation function which can solve an arbitrary 
relation on variables.
+            Differences with type_rel_func in C++:
+            1, when type_rel_func is not None:
+               1) OpAddTypeRel on C++ side will adjust type_rel_func with 
TypeReporter to
+                  calling convention of relay type system.
+               2) type_rel_func returns output argument's type, return None 
means can't
+                  infer output's type.
+               3) only support single output operators for now, the last 
argument is output tensor.
+            2, when type_rel_func is None, will call predefined type_rel_funcs 
in relay
+               accorrding to `tvm.relay.type_relation.` + rel_name.
+        """
+        _ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)
+
+    def add_argument(self, name, type, description):  # pylint: 
disable=redefined-builtin
+        """Add arguments information to the function.
+
+        Parameters
+        ----------
+        name : str
+            The argument name.
+        type : str
+            The argument type.
+        description : str
+            The argument description.
+        """
+        _ffi_api.OpAddArgument(self, name, type, description)
+
+    def set_support_level(self, level):
+        """Set the support level of op.
+
+        Parameters
+        ----------
+        level : int
+            The support level.
+        """
+        _ffi_api.OpSetSupportLevel(self, level)
+
+    def set_num_inputs(self, n):
+        """Set the support level of op.
+
+        Parameters
+        ----------
+        n : int
+            The input number.
+        """
+        _ffi_api.OpSetNumInputs(self, n)
+
+    def set_attrs_type_key(self, key):
+        """Set the attribute type key of op.
+
+        Parameters
+        ----------
+        key : str
+            The type key.
+        """
+        _ffi_api.OpSetAttrsTypeKey(self, key)
+
+
+def register_op(op_name, describe=""):
+    """Register an operator by name.
+    when the op_name is not registered, create a new empty op with the given 
name.
+    when the op_name has been registered, abort with an error message.
 
     Parameters
     ----------
     op_name : str
         The name of new operator
+    describe : str

Review comment:
       describe: Optional[str]




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to