comaniac commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r594049237



##########
File path: python/tvm/script/scope_handler.py
##########
@@ -16,32 +16,59 @@
 # under the License.
 """TVM Script Parser Scope Handler Classes"""
 # pylint: disable=redefined-builtin, unused-argument, invalid-name, 
relative-beyond-top-level
+from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
 
+import synr
 from synr import ast
 import tvm.tir
-from .utils import get_param_list, from_synr_span
+from tvm.runtime import Object
+from tvm.ir import Span, Range
+from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
+
+from .context_maintainer import ContextMaintainer
+from .utils import (
+    get_param_list,
+    tvm_span_from_synr,
+    buffer_slice_to_region,
+    call_with_error_reporting,
+)
 from .registry import register
+from .node import BufferSlice
 
 
 class ScopeHandler:
     """Base class for all scope handlers"""
 
-    def __init__(self, func):
-        self.func = func
-        self.body = None
-        self.node = None
-        self.context = None
+    def __init__(self, func: Callable):

Review comment:
       I found that some classes added/changed by this PR have docstring but 
others don't. It would be good to make up the missing ones for the classes 
changed by this PR.

##########
File path: tests/python/unittest/test_tvmscript_error_report.py
##########
@@ -144,6 +144,197 @@ def test_no_body():
     check_error(no_body, 3)
 
 
+def allocate_with_buffers() -> None:

Review comment:
       nit: why some tests have the return annotations but others haven't? 
IMHO, it's fine to ignore the type annotation of returning None.

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"

Review comment:
       Seems like Stmts need this checker? Even a Stmt doesn't need the info 
from context, it may also need to report errors through it. Could we use a 
wrapper in SpecialStmt to help them guarantee this so that we don't need to 
repeat this checker in every statements? In case some Stmts don't need this 
check, we could also add an optional flag to the `__init__` to turn this 
checker off.

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +232,384 @@ def let(var, value, span):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert (
+                self.node and self.context and self.body
+            ), "call 'exit_scope' before 'enter_scope'"
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"there are {len(axes)} axes but {len(self.block_vars)} 
block vars. "
+                    + "The number of block vars should match the number of 
axes.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expected PrimExpr, Range or IterVar, but got 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()
+                if len(values) == 0:
+                    values = [tvm.tir.const(float("nan"), dtype="float32")] * 
len(block_iters)
+                elif len(values) != len(block_iters):
+                    self.context.report_error(
+                        "Number of block iter var and outer loop nesting 
mismatch, "
+                        + f"{len(block_iters)} block iter vars but 
{len(values)} loops",
+                        self.node.span,
+                    )
+            else:
+                for block_var in self.block_vars:
+                    if block_var not in block_info.iter_bindings:
+                        self.context.report_error(
+                            "Missing block iter var binding for " + 
block_var.name,
+                            self.node.span,
+                        )
+                values = [block_info.iter_bindings[block_var] for block_var in 
self.block_vars]
+            predicate = (
+                tvm.tir.const(True, "bool")
+                if block_info.predicate is None
+                else block_info.predicate
+            )
+            body = tvm.tir.BlockRealize(values, predicate, inner, span)
+            return body
+
+        super().__init__(func=block, concise_scope=False, def_symbol=True)
+        self.block_vars = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define block vars
+        assert isinstance(
+            node, ast.With
+        ), f"BlockScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        var_names = WithScopeHandler.get_optional_var_names(node, context)
+        self.block_vars = [tvm.te.var(name) for name in var_names]
+        for block_var in self.block_vars:
+            context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+    """ With scope handler tir.init()"""
+
+    def __init__(self):
+        def init(span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if self.context.block_info_stack[-2].init is not None:
+                self.context.report_error("Duplicate init block declaration", 
span)
+            self.context.block_info_stack[-2].init = self.body
+
+        super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
 class ForScopeHandler(ScopeHandler):
     """Base class for all for scope handlers"""
 
     def __init__(self, func):
         super().__init__(func)
-        self.loop_vars = None
-
-    def enter_scope(self, node, context, arg_list, span):
-        assert isinstance(node, ast.For)
+        self.loop_vars: Optional[List[Var]] = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert isinstance(
+            node, ast.For
+        ), f"ForScopeHandler expected to work on ast.For but got {type(node)}"
 
         loop_var_names = list()
         spans = list()
         if isinstance(node.lhs, ast.Var):
             loop_var_names.append(node.lhs.id.name)
-            spans.append(from_synr_span(node.lhs.id.span))
-        elif isinstance(node.lhs, ast.Tuple):
-            for elt in node.lhs.values:
+            spans.append(tvm_span_from_synr(node.lhs.id.span))
+        elif isinstance(node.lhs, list):
+            for elt in node.lhs:
                 if not isinstance(elt, ast.Var):
-                    context.report_error("Invalid loop var", elt.span)
+                    context.report_error(
+                        f"Invalid loop var. Expected a var, but got 
{type(elt)}", elt.span
+                    )
                 loop_var_names.append(elt.id.name)
-                spans.append(from_synr_span(elt.id.span))
+                spans.append(tvm_span_from_synr(elt.id.span))
         else:
-            context.report_error("Invalid loop var", node.lhs.span)
+            context.report_error(
+                f"Invalid loop var. Expected var or list of vars as lhs, but 
got {type(node.lhs)}",
+                span,
+            )
 
         self.loop_vars = [
             tvm.te.var(name, dtype="int32", span=span) for name, span in 
zip(loop_var_names, spans)
         ]
         for loop_var in self.loop_vars:
-            context.update_symbol(loop_var.name, loop_var)
+            context.update_symbol(loop_var.name, loop_var, node)
+            context.loop_stack[-1].append(loop_var)
+
+    def exit_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+        for _ in self.loop_vars:
+            context.loop_stack[-1].pop()
+        return super().exit_scope(node, context, arg_list, span)
+
+    def create_loop(
+        self,
+        begin: PrimExpr,
+        end: PrimExpr,
+        kind: ForKind,
+        thread_binding: Optional[str] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> tvm.tir.For:
+        """
+        Helper function for creating For in TVM Script parser.
+
+        Parameters
+        ----------
+        begin : PrimExpr
+            The beginning value.
+
+        end : PrimExpr
+            The endding value.
+
+        kind : ForKind
+            The type of the for.
+
+        thread_binding: Optional[str]
+            The thread this loop binds to.
+
+        annotations : Optional[Mapping[str, Object]]
+            Additional annotation hints.
+
+        span : Optional[Span]
+            The location of this for in the source code.
+
+        Returns
+        -------
+        for : For
+            The constructed For.
+        """
+        assert (
+            self.loop_vars and self.context and self.node
+        ), "call 'exit_scope' before 'enter_scope'"
+        if len(self.loop_vars) != 1:
+            self.context.report_error(
+                f"Expect exact only one loop var, but get {self.loop_vars}", 
self.node.span
+            )
+        extent = end if begin == 0 else self.context.analyzer.simplify(end - 
begin)
+        annos: Mapping[str, Object]
+        if annotations is None:
+            annos = {}
+        else:

Review comment:
       ```suggestion
           annos: Mapping[str, Object] = {}
           if annotations is not None:
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(write_region)}",
+                    span,
+                )
+            block_scope.writes = write_region
+
+        super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+    """Special function block_attr({attr_key: attr_value})
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.block_attr({"double_buffer_scope": 1})
+
+    """
+
+    def __init__(self):
+        def block_attr(attrs: Mapping[str, Object], span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.annotations is not None:
+                self.context.report_error(
+                    "Duplicate block annotations declaration, "
+                    + "previous one is "
+                    + str(block_scope.annotations),
+                    span,
+                )
+            attrs = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in attrs.items()
+            }
+            block_scope.annotations = attrs
+
+        super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+    """Special function where(predicate)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.where(i < 4)
+
+    """
+
+    def __init__(self):
+        def where(predicate, span=None):
+            block_scope = self.context.current_block_scope()

Review comment:
       Need to check `self.context`?

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+

Review comment:
       remove this line (ditto to the rest)

##########
File path: python/tvm/script/parser.py
##########
@@ -191,9 +194,11 @@ def report_error(self, message, span):
         ----------
         message : str
             Error message
-        span : synr.ast.Span
+        span : synr.ast.Span or tvm.ir.Span

Review comment:
       It would be better to make the type annotation consistent. i.e., 
`Union[...]`.

##########
File path: python/tvm/script/parser.py
##########
@@ -401,25 +413,52 @@ def my_function(x: ty.handle):  # 1. Argument types
         """
 
         self.init_function_parsing_env()
-        self.context.new_scope(nodes=node.body.stmts)
+        self.context.enter_scope(nodes=node.body.stmts)
 
         # add parameters of function
         for arg in node.params:
             arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-            self.context.update_symbol(arg.name, arg_var)
+            self.context.update_symbol(arg.name, arg_var, node)
             self.context.func_params.append(arg_var)
 
-        # fetch the body and return a tir.PrimFunc
+        # New Scope : Implicit root block
+        # Each function contains an implicit root block in TensorIR,
+        # so here we need a block scope for it. Please note that 
`enter_block_scope`
+        # will not create a block directly but just store some information.
+        # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or 
low-level func),
+        # the roo block will not be added. The logic to add root block is in 
`_ffi_api.Complete`

Review comment:
       ```suggestion
           # the root block will not be added. The logic to add root block is 
in `_ffi_api.Complete`
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+

Review comment:
       parameters? (ditto to the rest)

##########
File path: python/tvm/script/node.py
##########
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""TVM Script nodes."""
+
+from typing import Optional, Union, List, Callable
+import synr
+
+from tvm.runtime import ObjectGeneric
+from tvm.tir import PrimExpr, Buffer, BufferLoad
+from tvm.ir import Span
+
+
+class Slice:
+    """A helper class to present slice information for BufferSlice
+
+    Parameters
+    ----------
+    start : Union[PrimExpr, int]
+        The start index.
+
+    stop : Optional[Union[PrimExpr, int]]
+        The stop index, None means the Slice is an element-wise index
+
+    span : Optional[Span]
+        The location of the slice in the source.
+    """
+
+    start: Union[PrimExpr, int]
+    stop: Optional[Union[PrimExpr, int]]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        start: Union[PrimExpr, int],
+        stop: Optional[Union[PrimExpr, int]] = None,
+        span: Optional[Span] = None,
+    ):
+        self.start = start
+        self.stop = stop
+        self.span = span
+
+
+class BufferSlice(ObjectGeneric):
+    """A generic object for representing general buffer access. Following 
cases are supported:
+        - element wise access buffer[i, j], which can be convert to BufferLoad 
if necessary

Review comment:
       ```suggestion
           - element wise access buffer[i, j], which can be converted to 
BufferLoad if necessary
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(write_region)}",
+                    span,
+                )
+            block_scope.writes = write_region
+
+        super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+    """Special function block_attr({attr_key: attr_value})
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.block_attr({"double_buffer_scope": 1})
+
+    """
+
+    def __init__(self):
+        def block_attr(attrs: Mapping[str, Object], span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.annotations is not None:
+                self.context.report_error(
+                    "Duplicate block annotations declaration, "
+                    + "previous one is "
+                    + str(block_scope.annotations),
+                    span,
+                )
+            attrs = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in attrs.items()
+            }
+            block_scope.annotations = attrs
+
+        super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+    """Special function where(predicate)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.where(i < 4)
+
+    """
+
+    def __init__(self):
+        def where(predicate, span=None):
+            block_scope = self.context.current_block_scope()
+            if block_scope.predicate is not None:
+                self.context.report_error(
+                    "Duplicate block predicate declaration, "
+                    + "previous one is "
+                    + str(block_scope.predicate),
+                    span,
+                )
+
+            block_scope.predicate = predicate
+
+        super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+    """Special function match_buffer_region(source, strides, elem_offset, 
align, offset_factor)
+
+    Example
+    -------
+    .. code-block:: python
+
+        B = tir.match_buffer_region(A[0: 4])
+
+    """
+
+    def __init__(self):
+        def match_buffer_region(
+            source,
+            strides=None,
+            elem_offset=None,
+            align=-1,
+            offset_factor=0,
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(

Review comment:
       Need to check `self.context`?

##########
File path: python/tvm/tir/analysis/analysis.py
##########
@@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints):
         The result of verification.
     """
     return _ffi_api.verify_gpu_code(func, constraints)
+
+
+def get_block_access_region(block, buffer_var_map):
+    """Auto detect the block read/write region according to body stmt
+        It will detect the read/write region as an array in order of 
appearance in AST
+
+    Parameters
+    ----------
+    block: tvm.tir.Block
+        The block to be detected.
+
+    buffer_var_map : Dict[Var, Buffer]
+        The outside buffers which may be accessed the block. Mapping from 
buffer var to the buffer

Review comment:
       ```suggestion
           The outside buffers which may access the block. Mapping from buffer 
var to the buffer
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:

Review comment:
       ```suggestion
               if isinstance(read_regions, BufferSlice):
                   read_regions = [read_regions]
               if not isinstance(read_regions, list):
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to