sahooora opened a new issue #7811:
URL: https://github.com/apache/tvm/issues/7811


   I‌ have an onnx model that I want to run with my own data type.
   
   So I got the idea from bring_your_own_datatypes.py and developed the 
following code:
   
   ```
   import numpy as np
   import tvm
   from tvm.contrib import graph_runtime
   import os
   import onnx
   from tvm import relay
   import ctypes
   from tvm.relay.frontend.change_datatype import ChangeDatatype
   
   
   #register the custom type with TVM
   ctypes.CDLL('./float_st.so', ctypes.RTLD_GLOBAL)
   tvm.target.datatype.register("float_st", 150)
   
   ctx = tvm.cpu()
   
   # load example onnx model 
   onnx_model = onnx.load('./model.onnx')
   
   # convert to relay, needs the onnx model and input layer name and shape
   module, params = relay.frontend.from_onnx( onnx_model, {"input_1": 
(100,128,3)} )
   
   
   ex = tvm.relay.create_executor("graph", mod=module)
   
   
   def convert_ndarray(dst_dtype, array):
       """Converts an NDArray into the specified datatype"""
       x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
       cast = relay.Function([x], x.astype(dst_dtype))
       with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
           a = relay.create_executor("graph").evaluate(cast)(array)
           return a
   
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func(
           {
               (32, 128): "FloatToFloatst", 
           }
       ),
       "Cast",
       "llvm",
       "float",
       "float_st",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({(128, 32): "FloatstToFloat"}),
       "Cast",
       "llvm",
       "float_st",
       "float",
   )
   
   
   
   src_dtype = "float32"
   dst_dtype = "custom[float_st]128"
   
   
   module = relay.transform.InferType()(module)
   
   # Currently, custom datatypes only work if you run simplify_inference 
beforehand
   module = tvm.relay.transform.SimplifyInference()(module)
   
   # Run type inference before changing datatype
   module = tvm.relay.transform.InferType()(module)
   
   # Change datatype from float to float_st and re-infer types
   cdtype = ChangeDatatype(src_dtype, dst_dtype)
   expr = cdtype.visit(module["main"])
   module = tvm.relay.transform.InferType()(module)
   
   # We need to convert our input:
   data_shape = [100,128,3]
   input = np.random.uniform(size=data_shape).astype('float32')
   input_st = convert_ndarray(dst_dtype, input)
   
   # We also convert the parameters:
   params_st = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}
   
   
   #register all the needed functions:
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatToFloatst"}),
       "FloatImm",
       "llvm",
       "float_st",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.lower_ite, "Call", "llvm", "float_st", 
intrinsic_name="tir.if_then_else"
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.lower_call_pure_extern,
       "Call",
       "llvm",
       "float_st",
       intrinsic_name="tir.call_pure_extern",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatstMul"}),
       "Mul",
       "llvm",
       "float_st",
   )
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatstDiv"}),
       "Div",
       "llvm",
       "float_st",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatstSqrt"}),
       "Call",
       "llvm",
       "float_st",
       intrinsic_name="tir.sqrt",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatstSub"}),
       "Sub",
       "llvm",
       "float_st",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatstExp"}),
       "Call",
       "llvm",
       "float_st",
       intrinsic_name="tir.exp",
   )
   
   tvm.target.datatype.register_op(
       tvm.target.datatype.create_lower_func({128: "FloatstMax"}),
       "Max",
       "llvm",
       "float_st",
   )
   
   tvm.target.datatype.register_min_func(
       tvm.target.datatype.create_min_lower_func({128: "MinFloatst"}, 
"float_st"),
       "float_st",
   )
   
   
   
   # Vectorization is not implemented with custom datatypes.
   with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
       result_myfloat = ex.evaluate(expr)(input_st, **params_st)
       result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
       print(result_myfloat)
   ```
   
   but when I try to run it I get the following error:
   
   ```
   data types custom[float_st]128 and float32do not match in BroadcastRel
   data types custom[float_st]128 and float32do not match in BroadcastRel
   ```
   
   After changing the model's data type form float to custom data type using 
line `expr = cdtype.visit(module["main"])` with printing `expr` I noticed that 
there are still two instructions in the model's tree which have `float32` data 
type:
   ```
   %387 = zeros(shape=[100, 128], dtype="float32");
   %402 = zeros(shape=[100, 128], dtype="float32");
   ```
   I guess the error is related to these `zeros` but I don't know how can 
change their data type to my `custom` data type. Any idea?
   
   I uploaded the onnx model as well as float_st.so 
[here](https://gofile.io/d/wlQEi9) for reproducing the error
   
   Thanka in advance


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to