Civitasv opened a new pull request, #15212:
URL: https://github.com/apache/tvm/pull/15212
I cannot find a suitable way to register custom operator of ONNX in TVM. So
I try to implement it.
Actually, it really simple, we just need to update the convert_map which is
used to convert ONNX to Relax IR.
An example:
```py
import onnx
from onnx import helper, TensorProto
from tvm.relax.frontend.onnx import from_onnx
from tvm import relax
from tvm.relax.frontend.onnx import OnnxOpConverter
from tvm.relax.frontend.onnx import register_custom_op
op_name = "my_custom_op"
op_domain = "ai.onnx.contrib"
class MyCustomOpConverter(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
return relax.op.abs(inputs[0])
def create_custom_opset():
opset = helper.make_opsetid(op_domain, 2)
return opset
def create_custom_model():
shape = [32, 32]
node = helper.make_node(op_name, ["x"], ["y"], domain=op_domain)
graph = helper.make_graph(
[node],
"custom_op_graph",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
shape)],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
shape)],
)
model = helper.make_model(graph, producer_name="custom_op_model")
model.opset_import.append(create_custom_opset())
return model
if __name__ == "__main__":
model = create_custom_model()
# try to check it using onnx.check, it should success
try:
onnx.checker.check_model(model)
print("Model is valid.")
except Exception as exception:
print("Model is not valid: ", str(exception))
# register new operator
register_custom_op(op_name, MyCustomOpConverter)
ir_mod = from_onnx(model, keep_params_in_input=True)
ir_mod, params = relax.frontend.detach_params(ir_mod)
ir_mod.show()
```
It will show:
```txt
Model is valid.
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(
x: R.Tensor((32, 32), dtype="float32")
) -> R.Tensor((32, 32), dtype="float32"):
R.func_attr({"global_symbol": "main", "num_input": 1})
with R.dataflow():
gv: R.Tensor((32, 32), dtype="float32") = R.abs(x)
R.output(gv)
return gv
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]