SigureMo opened a new issue, #176:
URL: https://github.com/apache/tvm-ffi/issues/176

   While adapting FlashInfer with TVM FFI for PaddlePaddle compatibility, I've 
found that there appears to be no general protocol for exchanging dtype objects 
across the FFI boundary when functions need to accept dtype parameters. 
Currently, `torch.dtype` handling is still implemented as a special case.
   
   
https://github.com/apache/tvm-ffi/blob/0729193f475c7ab1059524fcfa6ffc742b0addac/python/tvm_ffi/cython/function.pxi#L696
   
   I noticed that https://github.com/data-apis/array-api/issues/972 proposed an 
RFC for a DLPack-based dtype exchange protocol, but it seems there hasn't been 
recent progress on this, and I don't see a corresponding implementation in TVM 
FFI.
   
   I also noticed that TVM FFI has a `__tvm_ffi_dtype__` protocol, but during 
the argument conversion phase, this protocol only works for `tvm_ffi.dtype`:
   
   
https://github.com/apache/tvm-ffi/blob/0729193f475c7ab1059524fcfa6ffc742b0addac/python/tvm_ffi/cython/function.pxi#L655
   
   As a result, to enable dtype exchange, I'm currently performing manual 
conversion between `paddle.dtype` and `tvm_ffi.dtype`:
   
   ```python
   # https://github.com/cattidea/flashinfer/pull/2#discussion_r2397423854
   def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype):
       dtype_str = str(dtype).split(".", 1)[-1]
       return tvm_ffi.dtype(dtype_str)
   
   if instance_key not in MoERunner.runner_dict:
       MoERunner.runner_dict[instance_key] = module.init(
           paddle_dtype_to_tvm_ffi_dtype(x_dtype),
           paddle_dtype_to_tvm_ffi_dtype(weight_dtype),
           paddle_dtype_to_tvm_ffi_dtype(output_dtype),
           ...
       )
   ```
   
   My questions are:
   - Is there an existing protocol for dtype exchange that I may have 
overlooked?
   - If not, what's the preferred path forward — should we continue pursuing 
the DLPack-based dtype exchange protocol, or would it be acceptable to relax 
the `__tvm_ffi_dtype__` protocol (for example, by changing the condition to 
`hasattr(arg, "__tvm_ffi_dtype__")` instead of restricting it to 
`tvm_ffi.dtype`)?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to