sahooora opened a new issue #7811:
URL: https://github.com/apache/tvm/issues/7811
I have an onnx model that I want to run with my own data type.
So I got the idea from bring_your_own_datatypes.py and developed the
following code:
```
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import os
import onnx
from tvm import relay
import ctypes
from tvm.relay.frontend.change_datatype import ChangeDatatype
#register the custom type with TVM
ctypes.CDLL('./float_st.so', ctypes.RTLD_GLOBAL)
tvm.target.datatype.register("float_st", 150)
ctx = tvm.cpu()
# load example onnx model
onnx_model = onnx.load('./model.onnx')
# convert to relay, needs the onnx model and input layer name and shape
module, params = relay.frontend.from_onnx( onnx_model, {"input_1":
(100,128,3)} )
ex = tvm.relay.create_executor("graph", mod=module)
def convert_ndarray(dst_dtype, array):
"""Converts an NDArray into the specified datatype"""
x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
cast = relay.Function([x], x.astype(dst_dtype))
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
a = relay.create_executor("graph").evaluate(cast)(array)
return a
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func(
{
(32, 128): "FloatToFloatst",
}
),
"Cast",
"llvm",
"float",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({(128, 32): "FloatstToFloat"}),
"Cast",
"llvm",
"float_st",
"float",
)
src_dtype = "float32"
dst_dtype = "custom[float_st]128"
module = relay.transform.InferType()(module)
# Currently, custom datatypes only work if you run simplify_inference
beforehand
module = tvm.relay.transform.SimplifyInference()(module)
# Run type inference before changing datatype
module = tvm.relay.transform.InferType()(module)
# Change datatype from float to float_st and re-infer types
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module["main"])
module = tvm.relay.transform.InferType()(module)
# We need to convert our input:
data_shape = [100,128,3]
input = np.random.uniform(size=data_shape).astype('float32')
input_st = convert_ndarray(dst_dtype, input)
# We also convert the parameters:
params_st = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}
#register all the needed functions:
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatToFloatst"}),
"FloatImm",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.lower_ite, "Call", "llvm", "float_st",
intrinsic_name="tir.if_then_else"
)
tvm.target.datatype.register_op(
tvm.target.datatype.lower_call_pure_extern,
"Call",
"llvm",
"float_st",
intrinsic_name="tir.call_pure_extern",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstMul"}),
"Mul",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstDiv"}),
"Div",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstSqrt"}),
"Call",
"llvm",
"float_st",
intrinsic_name="tir.sqrt",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstSub"}),
"Sub",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstExp"}),
"Call",
"llvm",
"float_st",
intrinsic_name="tir.exp",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstMax"}),
"Max",
"llvm",
"float_st",
)
tvm.target.datatype.register_min_func(
tvm.target.datatype.create_min_lower_func({128: "MinFloatst"},
"float_st"),
"float_st",
)
# Vectorization is not implemented with custom datatypes.
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
result_myfloat = ex.evaluate(expr)(input_st, **params_st)
result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
print(result_myfloat)
```
but when I try to run it I get the following error:
```
data types custom[float_st]128 and float32do not match in BroadcastRel
data types custom[float_st]128 and float32do not match in BroadcastRel
```
After changing the model's data type form float to custom data type using
line `expr = cdtype.visit(module["main"])` with printing `expr` I noticed that
there are still two instructions in the model's tree which have `float32` data
type:
```
%387 = zeros(shape=[100, 128], dtype="float32");
%402 = zeros(shape=[100, 128], dtype="float32");
```
I guess the error is related to these `zeros` but I don't know how can
change their data type to my `custom` data type. Any idea?
I uploaded the onnx model as well as float_st.so
[here](https://gofile.io/d/wlQEi9) for reproducing the error
Thanka in advance
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]