This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 875296c762 [TVMScript] Linter-friendly function definitions (#13713)
875296c762 is described below
commit 875296c762f4654da7cd560674485dabdadcfdb6
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Jan 6 21:52:53 2023 -0800
[TVMScript] Linter-friendly function definitions (#13713)
Initially, the functions like `T.int8`, `T.uint32` and `T.float64x64` were
generated by loops and use `globals()` to add symbols globally to reduce code
complexity.
But for linters like Pylint may not be able to look into `globals()`, where
the functions defined implicitly.
This pr refactors the definitions of these functions explicitly for better
linter experience.
---
python/tvm/script/ir_builder/tir/ir.py | 205 ++++++++++++++++++++++++++++-----
1 file changed, 174 insertions(+), 31 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index ac1e990a96..48b2834479 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1229,36 +1229,107 @@ def evaluate(value: PrimExpr) -> None:
return _ffi_api.Evaluate(value) # type: ignore[attr-defined] # pylint:
disable=no-member
-__all__ = []
-for _dtype in ["Float", "UInt", "Int"]:
- for _size in ["8", "16", "32", "64"]:
- for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]:
- _name = _dtype + _size + _lanes # pylint: disable=invalid-name
-
- def func_gen(name: str):
- """Generate a function for each PrimExpr dtype.
-
- Parameters
- ----------
- name: str
- The ffi function name to call.
- """
-
- def func(
- expr: Union[
- None,
- PrimExpr,
- Literal["inf", "-inf", "nan"],
- ] = None
- ) -> PrimExpr:
- if isinstance(expr, str):
- expr = float(expr)
- return getattr(_ffi_api, name)(expr)
-
- return func
-
- globals()[_name.lower()] = func_gen(_name)
- __all__.append(_name.lower())
+def func_gen(name: str):
+ """Generate a function for each PrimExpr dtype.
+
+ Parameters
+ ----------
+ name: str
+ The ffi function name to call.
+ """
+
+ def func(
+ expr: Union[
+ None,
+ PrimExpr,
+ Literal["inf", "-inf", "nan"],
+ int,
+ float,
+ ] = None
+ ) -> PrimExpr:
+ if isinstance(expr, str):
+ expr = float(expr)
+ return getattr(_ffi_api, name)(expr)
+
+ return func
+
+
+# pylint: disable=invalid-name
+int8 = func_gen(("Int8"))
+int16 = func_gen(("Int16"))
+int32 = func_gen(("Int32"))
+int64 = func_gen(("Int64"))
+int8x4 = func_gen(("Int8x4"))
+int16x4 = func_gen(("Int16x4"))
+int32x4 = func_gen(("Int32x4"))
+int64x4 = func_gen(("Int64x4"))
+int8x8 = func_gen(("Int8x8"))
+int16x8 = func_gen(("Int16x8"))
+int32x8 = func_gen(("Int32x8"))
+int64x8 = func_gen(("Int64x8"))
+int8x16 = func_gen(("Int8x16"))
+int16x16 = func_gen(("Int16x16"))
+int32x16 = func_gen(("Int32x16"))
+int64x16 = func_gen(("Int64x16"))
+int8x32 = func_gen(("Int8x32"))
+int16x32 = func_gen(("Int16x32"))
+int32x32 = func_gen(("Int32x32"))
+int64x32 = func_gen(("Int64x32"))
+int8x64 = func_gen(("Int8x64"))
+int16x64 = func_gen(("Int16x64"))
+int32x64 = func_gen(("Int32x64"))
+int64x64 = func_gen(("Int64x64"))
+
+uint8 = func_gen(("UInt8"))
+uint16 = func_gen(("UInt16"))
+uint32 = func_gen(("UInt32"))
+uint64 = func_gen(("UInt64"))
+uint8x4 = func_gen(("UInt8x4"))
+uint16x4 = func_gen(("UInt16x4"))
+uint32x4 = func_gen(("UInt32x4"))
+uint64x4 = func_gen(("UInt64x4"))
+uint8x8 = func_gen(("UInt8x8"))
+uint16x8 = func_gen(("UInt16x8"))
+uint32x8 = func_gen(("UInt32x8"))
+uint64x8 = func_gen(("UInt64x8"))
+uint8x16 = func_gen(("UInt8x16"))
+uint16x16 = func_gen(("UInt16x16"))
+uint32x16 = func_gen(("UInt32x16"))
+uint64x16 = func_gen(("UInt64x16"))
+uint8x32 = func_gen(("UInt8x32"))
+uint16x32 = func_gen(("UInt16x32"))
+uint32x32 = func_gen(("UInt32x32"))
+uint64x32 = func_gen(("UInt64x32"))
+uint8x64 = func_gen(("UInt8x64"))
+uint16x64 = func_gen(("UInt16x64"))
+uint32x64 = func_gen(("UInt32x64"))
+uint64x64 = func_gen(("UInt64x64"))
+
+float8 = func_gen(("Float8"))
+float16 = func_gen(("Float16"))
+float32 = func_gen(("Float32"))
+float64 = func_gen(("Float64"))
+float8x4 = func_gen(("Float8x4"))
+float16x4 = func_gen(("Float16x4"))
+float32x4 = func_gen(("Float32x4"))
+float64x4 = func_gen(("Float64x4"))
+float8x8 = func_gen(("Float8x8"))
+float16x8 = func_gen(("Float16x8"))
+float32x8 = func_gen(("Float32x8"))
+float64x8 = func_gen(("Float64x8"))
+float8x16 = func_gen(("Float8x16"))
+float16x16 = func_gen(("Float16x16"))
+float32x16 = func_gen(("Float32x16"))
+float64x16 = func_gen(("Float64x16"))
+float8x32 = func_gen(("Float8x32"))
+float16x32 = func_gen(("Float16x32"))
+float32x32 = func_gen(("Float32x32"))
+float64x32 = func_gen(("Float64x32"))
+float8x64 = func_gen(("Float8x64"))
+float16x64 = func_gen(("Float16x64"))
+float32x64 = func_gen(("Float32x64"))
+float64x64 = func_gen(("Float64x64"))
+# pylint: enable=invalid-name
def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -1621,7 +1692,79 @@ class meta_var:
# pylint: enable=invalid-name
-__all__ += [
+__all__ = [
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "int8x4",
+ "int16x4",
+ "int32x4",
+ "int64x4",
+ "int8x8",
+ "int16x8",
+ "int32x8",
+ "int64x8",
+ "int8x16",
+ "int16x16",
+ "int32x16",
+ "int64x16",
+ "int8x32",
+ "int16x32",
+ "int32x32",
+ "int64x32",
+ "int8x64",
+ "int16x64",
+ "int32x64",
+ "int64x64",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "uint8x4",
+ "uint16x4",
+ "uint32x4",
+ "uint64x4",
+ "uint8x8",
+ "uint16x8",
+ "uint32x8",
+ "uint64x8",
+ "uint8x16",
+ "uint16x16",
+ "uint32x16",
+ "uint64x16",
+ "uint8x32",
+ "uint16x32",
+ "uint32x32",
+ "uint64x32",
+ "uint8x64",
+ "uint16x64",
+ "uint32x64",
+ "uint64x64",
+ "float8",
+ "float16",
+ "float32",
+ "float64",
+ "float8x4",
+ "float16x4",
+ "float32x4",
+ "float64x4",
+ "float8x8",
+ "float16x8",
+ "float32x8",
+ "float64x8",
+ "float8x16",
+ "float16x16",
+ "float32x16",
+ "float64x16",
+ "float8x32",
+ "float16x32",
+ "float32x32",
+ "float64x32",
+ "float8x64",
+ "float16x64",
+ "float32x64",
+ "float64x64",
"buffer_decl",
"prim_func",
"arg",