This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new caf6b0339c [TVMScript][Parser] Add more warp-level builtins and
`Range` (#14279)
caf6b0339c is described below
commit caf6b0339c48763177657c011cb4ad5c55be6bbe
Author: Zihao Ye <[email protected]>
AuthorDate: Mon Mar 13 03:57:08 2023 +0800
[TVMScript][Parser] Add more warp-level builtins and `Range` (#14279)
# Motivation
Several builtins "tvm_storage_sync", "tvm_warp_shuffle",
"tvm_warp_shuffle_up", "tvm_warp_shuffle_down", "tvm_warp_activemask" and
`Range` will appear in TVMScript printer but are missing in TVMScript parser.
This PR fix the behavior.
---
python/tvm/script/ir_builder/tir/ir.py | 51 +++++++---
python/tvm/tir/op.py | 108 +++++++++++++++++++++-
tests/python/unittest/test_tvmscript_roundtrip.py | 55 +++++++++++
3 files changed, 201 insertions(+), 13 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index d65f9adea8..45350c5a65 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -29,7 +29,8 @@ from typing_extensions import Literal
import numpy as np # type: ignore
from tvm import tir
-from tvm.ir import Range, Type
+from tvm import ir
+from tvm.ir import Type
from tvm.ir.base import deprecated
from tvm.runtime import String, convert, ndarray
from tvm.target import Target
@@ -496,7 +497,7 @@ def alloc_buffer(
)
-def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
+def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range:
"""The range constructor.
Parameters
@@ -509,13 +510,13 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
res : Range
The Range.
"""
- if isinstance(dom, Range):
+ if isinstance(dom, ir.Range):
return dom
if isinstance(dom, (list, tuple)):
- return Range(dom[0], dom[1])
+ return ir.Range(dom[0], dom[1])
if hasattr(dom, "dtype"):
- return Range(IntImm(dom.dtype, 0), dom)
- return Range(0, dom)
+ return ir.Range(IntImm(dom.dtype, 0), dom)
+ return ir.Range(0, dom)
class axis: # pylint: disable=invalid-name
@@ -523,7 +524,7 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def spatial(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
@@ -551,7 +552,7 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def reduce(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
@@ -579,7 +580,7 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def scan(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
@@ -607,7 +608,7 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def opaque(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
binding: PrimExpr,
dtype: str = "int32",
) -> Var:
@@ -1288,7 +1289,7 @@ def buffer_store(
def prefetch(
buffer: Buffer, # pylint: disable=redefined-outer-name
- bounds: List[Range],
+ bounds: List[ir.Range],
) -> None:
"""The prefetch hint for a buffer.
@@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint:
disable=redefined-buil
return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint:
disable=no-member
-def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str)
-> IterVar:
+def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag:
str) -> IterVar:
"""The iteration variable.
Parameters
@@ -1666,6 +1667,21 @@ def target(target_config: Union[Dict, str]) -> Target:
return Target(target_config)
+def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint:
disable=invalid-name
+ """
+ Create a Range object.
+
+ Parameters
+ ----------
+ begin : PrimExpr
+ The begin value of the range.
+
+ end : Optional[PrimExpr]
+ The end value of the range.
+ """
+ return ir.Range(begin, end)
+
+
class meta_var: # pylint: disable=invalid-name
"""A meta variable used in TVMScript metaprogramming. It means that the
value of the variable
does not appear in the final TIR, but only stays in the parser.
@@ -1782,6 +1798,11 @@ tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+tvm_storage_sync = _tir_op.tvm_storage_sync
+tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
+tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
+tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
+tvm_warp_activemask = _tir_op.tvm_warp_activemask
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
assume = _op_wrapper(_tir_op.assume)
@@ -2042,6 +2063,11 @@ __all__ = [
"tvm_bmma_sync",
"tvm_fill_fragment",
"tvm_store_matrix_sync",
+ "tvm_storage_sync",
+ "tvm_warp_shuffle",
+ "tvm_warp_shuffle_up",
+ "tvm_warp_shuffle_down",
+ "tvm_warp_activemask",
"ptx_mma",
"ptx_mma_sp",
"ptx_ldmatrix",
@@ -2109,4 +2135,5 @@ __all__ = [
"Let",
"IterVar",
"CommReducer",
+ "Range",
]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 0a9c4fdfaa..0fe460c085 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -569,7 +569,8 @@ def lookup_param(param_name, span=None):
def tvm_thread_allreduce(*freduce_args):
- """
+ """Perform allreduce inside threadblock.
+
Parameters
----------
freduce_args : Expr
@@ -583,6 +584,111 @@ def tvm_thread_allreduce(*freduce_args):
return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+def tvm_storage_sync(storage_scope):
+ """Perform synchronization in specified scope.
+
+ Parameters
+ ----------
+ storage_scope : str
+ The storage scope to perform synchronization.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
+
+
+def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
+ """Exchange value between threads inside a warp.
+
+ Parameters
+ ----------
+ mask : PrimExpr
+ The warp mask indicates active threads inside warp.
+ value : PrimExpr
+ The value to exchange.
+ warp_id : PrimExpr
+ The source lane index to fetch value.
+ width : PrimExpr
+ The width of sub-sections to perform warp shuffle.
+ warp_size : PrimExpr
+ The warp size.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value,
warp_id, width, warp_size)
+
+
+def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
+ """Copy value from a lane with lower (by offset) index relative to caller.
+
+ Parameters
+ ----------
+ mask : PrimExpr
+ The warp mask indicates active threads inside warp.
+ value : PrimExpr
+ The value to exchange.
+ offset : PrimExpr
+ The difference between source lane index and destination lane index:
+ `offset = dst_lane_idx - src_lane_idx`
+ width : PrimExpr
+ The width of sub-sections to perform warp shuffle.
+ warp_size : PrimExpr
+ The warp size.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width,
warp_size
+ )
+
+
+def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
+ """Copy value from a lane with higher (by offset) index relative to caller.
+
+ Parameters
+ ----------
+ mask : PrimExpr
+ The warp mask indicates active threads inside warp.
+ value : PrimExpr
+ The value to exchange.
+ offset : PrimExpr
+ The difference between source lane index and destination lane index:
+ `offset = src_lane_idx - dst_lane_idx`
+ width : PrimExpr
+ The width of sub-sections to perform warp shuffle.
+ warp_size : PrimExpr
+ The warp size.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width,
warp_size
+ )
+
+
+def tvm_warp_activemask():
+ """Return a 32-bit mask indicates currently active threads in a calling
warp.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("uint32", "tir.tvm_warp_activemask")
+
+
def type_annotation(dtype):
"""Create a type annotation expression
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index c956f3bb02..6f07b6a75a 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3623,6 +3623,60 @@ def merge_shape_var_def():
return main
+def tvm_shfl_builtins():
+ @T.prim_func
+ def func(
+ A: T.handle("float32"),
+ B: T.handle("float32"),
+ C: T.handle("float32"),
+ ):
+ blockIdx_x = T.launch_thread("blockIdx.x", 1)
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+ A_warp = T.allocate([1], "float32", "local")
+ B_warp = T.allocate([1], "float32", "local")
+ red_buf0 = T.allocate([1], "float32", "local")
+ A_warp_1 = T.Buffer((32,), data=A_warp, scope="local")
+ A_1 = T.Buffer((32,), data=A)
+ A_warp_1[0] = A_1[threadIdx_x]
+ B_warp_1 = T.Buffer((32,), data=B_warp, scope="local")
+ T.tvm_storage_sync("warp")
+ B_warp_1[0] = T.tvm_warp_shuffle(
+ T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 +
threadIdx_x // 4, 32, 32
+ ) + T.float32(1)
+ red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
+ with T.attr(
+ T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ mask = T.allocate([1], "uint32", "local")
+ t0 = T.allocate([1], "float32", "local")
+ red_buf0_1[0] = A_warp_1[0]
+ mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
+ mask_1[0] = T.tvm_warp_activemask()
+ t0_1 = T.Buffer((1,), data=t0, scope="local")
+ t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16,
32, 32)
+ red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+ t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32,
32)
+ red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+ t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32,
32)
+ red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+ t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32,
32)
+ red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+ t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32,
32)
+ red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+ red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0,
32, 32)
+ # NOTE(Zihao): test tvm_warp_shuffle_up
+ red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0,
32, 32)
+ if threadIdx_x == 0:
+ C_1 = T.Buffer((1,), data=C)
+ C_1[0] = red_buf0_1[0]
+ B_1 = T.Buffer((32,), data=B)
+ B_1[threadIdx_x] = B_warp_1[0]
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
@@ -3686,6 +3740,7 @@ ir_generator = tvm.testing.parameter(
let_stmt_value,
string_stride,
merge_shape_var_def,
+ tvm_shfl_builtins,
)