Civitasv opened a new pull request, #15212:
URL: https://github.com/apache/tvm/pull/15212

   I cannot find a suitable way to register custom operator of ONNX in TVM. So 
I try to implement it.
   
   Actually, it really simple, we just need to update the convert_map which is 
used to convert ONNX to Relax IR.
   
   An example:
   
   ```py
   import onnx
   from onnx import helper, TensorProto
   from tvm.relax.frontend.onnx import from_onnx
   from tvm import relax
   from tvm.relax.frontend.onnx import OnnxOpConverter
   from tvm.relax.frontend.onnx import register_custom_op
   
   op_name = "my_custom_op"
   op_domain = "ai.onnx.contrib"
   
   
   class MyCustomOpConverter(OnnxOpConverter):
       @classmethod
       def _impl_v13(cls, bb, inputs, attr, params):
           return relax.op.abs(inputs[0])
   
   
   def create_custom_opset():
       opset = helper.make_opsetid(op_domain, 2)
   
       return opset
   
   
   def create_custom_model():
       shape = [32, 32]
       node = helper.make_node(op_name, ["x"], ["y"], domain=op_domain)
       graph = helper.make_graph(
           [node],
           "custom_op_graph",
           inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
shape)],
           outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
shape)],
       )
   
       model = helper.make_model(graph, producer_name="custom_op_model")
   
       model.opset_import.append(create_custom_opset())
       return model
   
   
   if __name__ == "__main__":
       model = create_custom_model()
   
       # try to check it using onnx.check, it should success
       try:
           onnx.checker.check_model(model)
           print("Model is valid.")
       except Exception as exception:
           print("Model is not valid: ", str(exception))
   
       # register new operator
       register_custom_op(op_name, MyCustomOpConverter)
   
       ir_mod = from_onnx(model, keep_params_in_input=True)
       ir_mod, params = relax.frontend.detach_params(ir_mod)
   
       ir_mod.show()
   ```
   
   It will show:
   
   ```txt
   Model is valid.
   # from tvm.script import ir as I
   # from tvm.script import relax as R
   
   
   @I.ir_module
   class Module:
       @R.function
       def main(
           x: R.Tensor((32, 32), dtype="float32")
       ) -> R.Tensor((32, 32), dtype="float32"):
           R.func_attr({"global_symbol": "main", "num_input": 1})
           with R.dataflow():
               gv: R.Tensor((32, 32), dtype="float32") = R.abs(x)
               R.output(gv)
           return gv
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to