This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new caf6b0339c [TVMScript][Parser] Add more warp-level builtins and 
`Range` (#14279)
caf6b0339c is described below

commit caf6b0339c48763177657c011cb4ad5c55be6bbe
Author: Zihao Ye <[email protected]>
AuthorDate: Mon Mar 13 03:57:08 2023 +0800

    [TVMScript][Parser] Add more warp-level builtins and `Range` (#14279)
    
    # Motivation
    Several builtins "tvm_storage_sync", "tvm_warp_shuffle", 
"tvm_warp_shuffle_up", "tvm_warp_shuffle_down", "tvm_warp_activemask" and 
`Range` will appear in TVMScript printer but are missing in TVMScript parser. 
This PR fix the behavior.
---
 python/tvm/script/ir_builder/tir/ir.py            |  51 +++++++---
 python/tvm/tir/op.py                              | 108 +++++++++++++++++++++-
 tests/python/unittest/test_tvmscript_roundtrip.py |  55 +++++++++++
 3 files changed, 201 insertions(+), 13 deletions(-)

diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index d65f9adea8..45350c5a65 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -29,7 +29,8 @@ from typing_extensions import Literal
 import numpy as np  # type: ignore
 
 from tvm import tir
-from tvm.ir import Range, Type
+from tvm import ir
+from tvm.ir import Type
 from tvm.ir.base import deprecated
 from tvm.runtime import String, convert, ndarray
 from tvm.target import Target
@@ -496,7 +497,7 @@ def alloc_buffer(
     )
 
 
-def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
+def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range:
     """The range constructor.
 
     Parameters
@@ -509,13 +510,13 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
     res : Range
         The Range.
     """
-    if isinstance(dom, Range):
+    if isinstance(dom, ir.Range):
         return dom
     if isinstance(dom, (list, tuple)):
-        return Range(dom[0], dom[1])
+        return ir.Range(dom[0], dom[1])
     if hasattr(dom, "dtype"):
-        return Range(IntImm(dom.dtype, 0), dom)
-    return Range(0, dom)
+        return ir.Range(IntImm(dom.dtype, 0), dom)
+    return ir.Range(0, dom)
 
 
 class axis:  # pylint: disable=invalid-name
@@ -523,7 +524,7 @@ class axis:  # pylint: disable=invalid-name
 
     @staticmethod
     def spatial(
-        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+        dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
         binding: PrimExpr,
         dtype: str = "int32",
     ) -> Var:
@@ -551,7 +552,7 @@ class axis:  # pylint: disable=invalid-name
 
     @staticmethod
     def reduce(
-        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+        dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
         binding: PrimExpr,
         dtype: str = "int32",
     ) -> Var:
@@ -579,7 +580,7 @@ class axis:  # pylint: disable=invalid-name
 
     @staticmethod
     def scan(
-        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+        dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
         binding: PrimExpr,
         dtype: str = "int32",
     ) -> Var:
@@ -607,7 +608,7 @@ class axis:  # pylint: disable=invalid-name
 
     @staticmethod
     def opaque(
-        dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+        dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]],
         binding: PrimExpr,
         dtype: str = "int32",
     ) -> Var:
@@ -1288,7 +1289,7 @@ def buffer_store(
 
 def prefetch(
     buffer: Buffer,  # pylint: disable=redefined-outer-name
-    bounds: List[Range],
+    bounds: List[ir.Range],
 ) -> None:
     """The prefetch hint for a buffer.
 
@@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr:  # pylint: 
disable=redefined-buil
     return _ffi_api.max(a, b)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
-def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) 
-> IterVar:
+def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: 
str) -> IterVar:
     """The iteration variable.
 
     Parameters
@@ -1666,6 +1667,21 @@ def target(target_config: Union[Dict, str]) -> Target:
     return Target(target_config)
 
 
+def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range:  # pylint: 
disable=invalid-name
+    """
+    Create a Range object.
+
+    Parameters
+    ----------
+    begin : PrimExpr
+        The begin value of the range.
+
+    end : Optional[PrimExpr]
+        The end value of the range.
+    """
+    return ir.Range(begin, end)
+
+
 class meta_var:  # pylint: disable=invalid-name
     """A meta variable used in TVMScript metaprogramming. It means that the 
value of the variable
     does not appear in the final TIR, but only stays in the parser.
@@ -1782,6 +1798,11 @@ tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
 tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
 tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
 tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+tvm_storage_sync = _tir_op.tvm_storage_sync
+tvm_warp_shuffle = _tir_op.tvm_warp_shuffle
+tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up
+tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down
+tvm_warp_activemask = _tir_op.tvm_warp_activemask
 ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
 ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
 assume = _op_wrapper(_tir_op.assume)
@@ -2042,6 +2063,11 @@ __all__ = [
     "tvm_bmma_sync",
     "tvm_fill_fragment",
     "tvm_store_matrix_sync",
+    "tvm_storage_sync",
+    "tvm_warp_shuffle",
+    "tvm_warp_shuffle_up",
+    "tvm_warp_shuffle_down",
+    "tvm_warp_activemask",
     "ptx_mma",
     "ptx_mma_sp",
     "ptx_ldmatrix",
@@ -2109,4 +2135,5 @@ __all__ = [
     "Let",
     "IterVar",
     "CommReducer",
+    "Range",
 ]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 0a9c4fdfaa..0fe460c085 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -569,7 +569,8 @@ def lookup_param(param_name, span=None):
 
 
 def tvm_thread_allreduce(*freduce_args):
-    """
+    """Perform allreduce inside threadblock.
+
     Parameters
     ----------
     freduce_args : Expr
@@ -583,6 +584,111 @@ def tvm_thread_allreduce(*freduce_args):
     return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
 
 
+def tvm_storage_sync(storage_scope):
+    """Perform synchronization in specified scope.
+
+    Parameters
+    ----------
+    storage_scope : str
+        The storage scope to perform synchronization.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
+
+
+def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
+    """Exchange value between threads inside a warp.
+
+    Parameters
+    ----------
+    mask : PrimExpr
+        The warp mask indicates active threads inside warp.
+    value : PrimExpr
+        The value to exchange.
+    warp_id : PrimExpr
+        The source lane index to fetch value.
+    width : PrimExpr
+        The width of sub-sections to perform warp shuffle.
+    warp_size : PrimExpr
+        The warp size.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, 
warp_id, width, warp_size)
+
+
+def tvm_warp_shuffle_up(mask, value, offset, width, warp_size):
+    """Copy value from a lane with lower (by offset) index relative to caller.
+
+    Parameters
+    ----------
+    mask : PrimExpr
+        The warp mask indicates active threads inside warp.
+    value : PrimExpr
+        The value to exchange.
+    offset : PrimExpr
+        The difference between source lane index and destination lane index:
+        `offset = dst_lane_idx - src_lane_idx`
+    width : PrimExpr
+        The width of sub-sections to perform warp shuffle.
+    warp_size : PrimExpr
+        The warp size.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, 
warp_size
+    )
+
+
+def tvm_warp_shuffle_down(mask, value, offset, width, warp_size):
+    """Copy value from a lane with higher (by offset) index relative to caller.
+
+    Parameters
+    ----------
+    mask : PrimExpr
+        The warp mask indicates active threads inside warp.
+    value : PrimExpr
+        The value to exchange.
+    offset : PrimExpr
+        The difference between source lane index and destination lane index:
+        `offset = src_lane_idx - dst_lane_idx`
+    width : PrimExpr
+        The width of sub-sections to perform warp shuffle.
+    warp_size : PrimExpr
+        The warp size.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, 
warp_size
+    )
+
+
+def tvm_warp_activemask():
+    """Return a 32-bit mask indicates currently active threads in a calling 
warp.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("uint32", "tir.tvm_warp_activemask")
+
+
 def type_annotation(dtype):
     """Create a type annotation expression
 
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index c956f3bb02..6f07b6a75a 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3623,6 +3623,60 @@ def merge_shape_var_def():
     return main
 
 
+def tvm_shfl_builtins():
+    @T.prim_func
+    def func(
+        A: T.handle("float32"),
+        B: T.handle("float32"),
+        C: T.handle("float32"),
+    ):
+        blockIdx_x = T.launch_thread("blockIdx.x", 1)
+        threadIdx_x = T.launch_thread("threadIdx.x", 32)
+        A_warp = T.allocate([1], "float32", "local")
+        B_warp = T.allocate([1], "float32", "local")
+        red_buf0 = T.allocate([1], "float32", "local")
+        A_warp_1 = T.Buffer((32,), data=A_warp, scope="local")
+        A_1 = T.Buffer((32,), data=A)
+        A_warp_1[0] = A_1[threadIdx_x]
+        B_warp_1 = T.Buffer((32,), data=B_warp, scope="local")
+        T.tvm_storage_sync("warp")
+        B_warp_1[0] = T.tvm_warp_shuffle(
+            T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + 
threadIdx_x // 4, 32, 32
+        ) + T.float32(1)
+        red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local")
+        with T.attr(
+            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
+            "reduce_scope",
+            T.reinterpret("handle", T.uint64(0)),
+        ):
+            mask = T.allocate([1], "uint32", "local")
+            t0 = T.allocate([1], "float32", "local")
+            red_buf0_1[0] = A_warp_1[0]
+            mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local")
+            mask_1[0] = T.tvm_warp_activemask()
+            t0_1 = T.Buffer((1,), data=t0, scope="local")
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 
32, 32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 
32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 
32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 
32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 
32)
+            red_buf0_1[0] = red_buf0_1[0] + t0_1[0]
+            red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 
32, 32)
+            # NOTE(Zihao): test tvm_warp_shuffle_up
+            red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0, 
32, 32)
+        if threadIdx_x == 0:
+            C_1 = T.Buffer((1,), data=C)
+            C_1[0] = red_buf0_1[0]
+        B_1 = T.Buffer((32,), data=B)
+        B_1[threadIdx_x] = B_warp_1[0]
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -3686,6 +3740,7 @@ ir_generator = tvm.testing.parameter(
     let_stmt_value,
     string_stride,
     merge_shape_var_def,
+    tvm_shfl_builtins,
 )
 
 

Reply via email to