ConvolutedDog commented on code in PR #19535:
URL: https://github.com/apache/tvm/pull/19535#discussion_r3218686954
##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -52,12 +52,26 @@
from tvm import TVMError, relax, tirx, topi
from tvm.ir import IRModule
from tvm.ir.supply import NameSupply
+from tvm.runtime import DataType, DataTypeCode
from tvm.tirx.generic import cast
from tvm.topi.utils import get_const_tuple
from ..common import autopad
+def _relax_dtype_is_floating_point(dtype: str) -> bool:
+ """Whether a Relax dtype string is a floating point type."""
+ try:
+ code = DataType(dtype).type_code
+ except (ValueError, TypeError, TVMError):
+ return False
+ return (
+ code == DataTypeCode.FLOAT
+ or code == DataTypeCode.BFLOAT
+ or (code >= DataTypeCode.Float8E3M4 and code <=
DataTypeCode.Float4E2M1FN)
Review Comment:
DataTypeCode in tvm_ffi is defined with CamelCase for FP8/FP4 (Float8E3M4 to
Float4E2M1FN):
```py
class DataTypeCode(IntEnum):
"""DLDataTypeCode code in DLTensor."""
INT = 0
UINT = 1
FLOAT = 2
HANDLE = 3
BFLOAT = 4
BOOL = 6
Float8E3M4 = 7
Float8E4M3 = 8
Float8E4M3B11FNUZ = 9
Float8E4M3FN = 10
Float8E4M3FNUZ = 11
Float8E5M2 = 12
Float8E5M2FNUZ = 13
Float8E8M0FNU = 14
Float6E2M3FN = 15
Float6E3M2FN = 16
Float4E2M1FN = 17
```
there are no ALL_CAPS aliases, so we match the actual enum or we'd get
AttributeError.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]