tkonolige commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r594480289



##########
File path: python/tvm/script/parser.py
##########
@@ -401,25 +413,52 @@ def my_function(x: ty.handle):  # 1. Argument types
         """
 
         self.init_function_parsing_env()
-        self.context.new_scope(nodes=node.body.stmts)
+        self.context.enter_scope(nodes=node.body.stmts)
 
         # add parameters of function
         for arg in node.params:
             arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-            self.context.update_symbol(arg.name, arg_var)
+            self.context.update_symbol(arg.name, arg_var, node)
             self.context.func_params.append(arg_var)
 
-        # fetch the body and return a tir.PrimFunc
+        # New Scope : Implicit root block
+        # Each function contains an implicit root block in TensorIR,
+        # so here we need a block scope for it. Please note that 
`enter_block_scope`
+        # will not create a block directly but just store some information.

Review comment:
       ```suggestion
           # will not create a block directly but just stores some information.
   ```

##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,145 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer]
+    match_buffers: List[MatchBufferRegion]
+    iter_bindings: Mapping[Var, PrimExpr]
+    reads: Optional[List[BufferSlice]]
+    writes: Optional[List[BufferSlice]]
+    annotations: Optional[Mapping[str, Object]]
+    predicate: Optional[PrimExpr]
+    init: Optional[Stmt]
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
     """Maintain all the necessary context info"""
 
-    def __init__(self, parser):
+    # scope context
+    # ast_node inside a scope
+    node_stack: List[List[synr.ast.Node]]
+    # loop stacks inside a block
+    block_info_stack: List[BlockInfo]
+    # loop stacks inside a block
+    loop_stack: List[List[Var]]
+    symbols: List[Dict[str, Union[Var, Buffer]]]
+
+    # function context
+    func_params: List[Var]
+    func_buffer_map: Mapping[Var, Buffer]
+    func_dict_attr: Mapping[str, Object]
+    func_var_env_dict: Mapping[Var, str]
+
+    # parser and analyzer
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    analyzer: tvm.arith.Analyzer
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
         self.node_stack = []  # AST nodes of scopes
+        self.block_info_stack = []  # Block info of scopes
+        self.loop_stack = []  # stack of loop vars
         self.symbols = []  # symbols of scopes
         # function context
         self.func_params = []  # parameter list of function
         self.func_buffer_map = {}  # buffer_map of function
         self.func_dict_attr = {}  # func_attr of function
         self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
 
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creating a new scope
 
-    def new_scope(self, nodes=None):
-        """Creating a new scope"""
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
         if nodes is None:
             nodes = []
         self.node_stack.append(list(reversed(nodes)))
         self.symbols.append(dict())
 
-    def update_symbol(self, name, symbol):
+    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creating a new block scope, the function will call `enter_scope` 
implicitly

Review comment:
       ```suggestion
           """Creates a new block scope, the function will call `enter_scope` 
implicitly
   ```

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -55,24 +82,27 @@ def __init__(self, func, concise_scope, def_symbol):
     @staticmethod
     def get_optional_var_names(node, context):
         """Get list of names from ast.With's optional_vars"""
-        assert isinstance(node, ast.With)
-
-        var_names = None
-        if isinstance(node.items[0].optional_vars, ast.Name):
-            var_names = [node.items[0].optional_vars.id]
-        elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
-            for var in node.items[0].optional_vars.elts:
-                if not isinstance(var, ast.Name):
-                    context.report_error("Invalid optional var definition")
-            var_names = [var.id for var in node.items[0].optional_vars.elts]
+        assert isinstance(
+            node, ast.With
+        ), f"WithScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        if isinstance(node.lhs, list):
+            for var in node.lhs:
+                if not isinstance(var, ast.Var):
+                    context.report_error(
+                        "Invalid optional var definition, only list of Var is 
valid", node.span

Review comment:
       Report what was actually provided.

##########
File path: python/tvm/script/parser.py
##########
@@ -716,47 +772,50 @@ def transform_Slice(self, node):
         end = self.transform(node.end)
         if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
             self.report_error("Only step size 1 is supported for slices.", 
node.step.span)
-        extent = end - start
-        if isinstance(extent, tvm.tir.PrimExpr):
-            ana = tvm.arith.Analyzer()
-            extent = ana.simplify(extent)
-        return tvm.ir.Range.from_min_extent(start, extent, 
span=from_synr_span(node.span))
+        return Slice(start, end)
 
     def transform_Subscript(self, node):
         """Array access visitor.
 
         By now only 2 types of Subscript are supported:
             1. Buffer[index, index, ...], Buffer element access(BufferLoad & 
BufferStore)
                Var[index] Buffer element access()
-            2. meta[type_key][index], Meta info access
+            2. Buffer[start: stop, start: stop, ...], 
BufferRealize(realize(buffer[...]))
         """
 
         symbol = self.transform(node.params[0])
         if symbol is None:
             self.report_error(f"Variable {node.value.id} is not defined.", 
node.params[0].span)
 
         indexes = [self.transform(x) for x in node.params[1].values]
-        if isinstance(indexes[0], tvm.ir.Range):
-            return symbol, indexes
-
         if isinstance(symbol, tvm.tir.expr.Var):
-            return tvm.tir.Load("float32", symbol, indexes, True, 
span=from_synr_span(node.span))
-        if isinstance(symbol, tvm.tir.Buffer):
-            return tvm.tir.BufferLoad(symbol, indexes, 
span=from_synr_span(node.span))
-
-        self.report_error(
-            f"Cannot subscript from a {type(symbol).__name__}. Only variables 
and "
-            "buffers are supported.",
-            node.params[0].span,
-        )
+            for index in indexes:
+                if not isinstance(index, (tvm.tir.PrimExpr, int)):
+                    self.report_error(
+                        "Buffer load indexes expected int or PrimExpr, but got 
" + type(index),

Review comment:
       ```suggestion
                           "Buffer load indexes should be int or PrimExpr, but 
they are " + type(index),
   ```

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +232,384 @@ def let(var, value, span):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert (
+                self.node and self.context and self.body
+            ), "call 'exit_scope' before 'enter_scope'"
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"there are {len(axes)} axes but {len(self.block_vars)} 
block vars. "
+                    + "The number of block vars should match the number of 
axes.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expected PrimExpr, Range or IterVar, but got 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()
+                if len(values) == 0:
+                    values = [tvm.tir.const(float("nan"), dtype="float32")] * 
len(block_iters)
+                elif len(values) != len(block_iters):
+                    self.context.report_error(
+                        "Number of block iter var and outer loop nesting 
mismatch, "
+                        + f"{len(block_iters)} block iter vars but 
{len(values)} loops",
+                        self.node.span,
+                    )
+            else:
+                for block_var in self.block_vars:
+                    if block_var not in block_info.iter_bindings:
+                        self.context.report_error(
+                            "Missing block iter var binding for " + 
block_var.name,
+                            self.node.span,
+                        )
+                values = [block_info.iter_bindings[block_var] for block_var in 
self.block_vars]
+            predicate = (
+                tvm.tir.const(True, "bool")
+                if block_info.predicate is None
+                else block_info.predicate
+            )
+            body = tvm.tir.BlockRealize(values, predicate, inner, span)
+            return body
+
+        super().__init__(func=block, concise_scope=False, def_symbol=True)
+        self.block_vars = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define block vars
+        assert isinstance(
+            node, ast.With
+        ), f"BlockScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        var_names = WithScopeHandler.get_optional_var_names(node, context)
+        self.block_vars = [tvm.te.var(name) for name in var_names]
+        for block_var in self.block_vars:
+            context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+    """ With scope handler tir.init()"""
+
+    def __init__(self):
+        def init(span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if self.context.block_info_stack[-2].init is not None:
+                self.context.report_error("Duplicate init block declaration", 
span)
+            self.context.block_info_stack[-2].init = self.body
+
+        super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
 class ForScopeHandler(ScopeHandler):
     """Base class for all for scope handlers"""
 
     def __init__(self, func):
         super().__init__(func)
-        self.loop_vars = None
-
-    def enter_scope(self, node, context, arg_list, span):
-        assert isinstance(node, ast.For)
+        self.loop_vars: Optional[List[Var]] = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert isinstance(
+            node, ast.For
+        ), f"ForScopeHandler expected to work on ast.For but got {type(node)}"
 
         loop_var_names = list()
         spans = list()
         if isinstance(node.lhs, ast.Var):
             loop_var_names.append(node.lhs.id.name)
-            spans.append(from_synr_span(node.lhs.id.span))
-        elif isinstance(node.lhs, ast.Tuple):
-            for elt in node.lhs.values:
+            spans.append(tvm_span_from_synr(node.lhs.id.span))
+        elif isinstance(node.lhs, list):
+            for elt in node.lhs:
                 if not isinstance(elt, ast.Var):
-                    context.report_error("Invalid loop var", elt.span)
+                    context.report_error(
+                        f"Invalid loop var. Expected a var, but got 
{type(elt)}", elt.span
+                    )
                 loop_var_names.append(elt.id.name)
-                spans.append(from_synr_span(elt.id.span))
+                spans.append(tvm_span_from_synr(elt.id.span))
         else:
-            context.report_error("Invalid loop var", node.lhs.span)
+            context.report_error(
+                f"Invalid loop var. Expected var or list of vars as lhs, but 
got {type(node.lhs)}",
+                span,
+            )
 
         self.loop_vars = [
             tvm.te.var(name, dtype="int32", span=span) for name, span in 
zip(loop_var_names, spans)
         ]
         for loop_var in self.loop_vars:
-            context.update_symbol(loop_var.name, loop_var)
+            context.update_symbol(loop_var.name, loop_var, node)
+            context.loop_stack[-1].append(loop_var)
+
+    def exit_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+        for _ in self.loop_vars:
+            context.loop_stack[-1].pop()
+        return super().exit_scope(node, context, arg_list, span)
+
+    def create_loop(
+        self,
+        begin: PrimExpr,
+        end: PrimExpr,
+        kind: ForKind,
+        thread_binding: Optional[str] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> tvm.tir.For:
+        """
+        Helper function for creating For in TVM Script parser.
+
+        Parameters
+        ----------
+        begin : PrimExpr
+            The beginning value.
+
+        end : PrimExpr
+            The endding value.
+
+        kind : ForKind
+            The type of the for.
+
+        thread_binding: Optional[str]
+            The thread this loop binds to.
+
+        annotations : Optional[Mapping[str, Object]]
+            Additional annotation hints.
+
+        span : Optional[Span]
+            The location of this for in the source code.
+
+        Returns
+        -------
+        for : For
+            The constructed For.
+        """
+        assert (
+            self.loop_vars and self.context and self.node
+        ), "call 'exit_scope' before 'enter_scope'"
+        if len(self.loop_vars) != 1:
+            self.context.report_error(
+                f"Expect exact only one loop var, but get {self.loop_vars}", 
self.node.span

Review comment:
       ```suggestion
                   f"Expected exactly one loop var, but got {self.loop_vars}", 
self.node.span
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "

Review comment:
       ```suggestion
                       "Incorrect input type. "
   ```

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +232,384 @@ def let(var, value, span):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert (
+                self.node and self.context and self.body
+            ), "call 'exit_scope' before 'enter_scope'"
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"there are {len(axes)} axes but {len(self.block_vars)} 
block vars. "
+                    + "The number of block vars should match the number of 
axes.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expected PrimExpr, Range or IterVar, but got 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()
+                if len(values) == 0:
+                    values = [tvm.tir.const(float("nan"), dtype="float32")] * 
len(block_iters)
+                elif len(values) != len(block_iters):
+                    self.context.report_error(
+                        "Number of block iter var and outer loop nesting 
mismatch, "
+                        + f"{len(block_iters)} block iter vars but 
{len(values)} loops",
+                        self.node.span,
+                    )
+            else:
+                for block_var in self.block_vars:
+                    if block_var not in block_info.iter_bindings:
+                        self.context.report_error(
+                            "Missing block iter var binding for " + 
block_var.name,
+                            self.node.span,
+                        )
+                values = [block_info.iter_bindings[block_var] for block_var in 
self.block_vars]
+            predicate = (
+                tvm.tir.const(True, "bool")
+                if block_info.predicate is None
+                else block_info.predicate
+            )
+            body = tvm.tir.BlockRealize(values, predicate, inner, span)
+            return body
+
+        super().__init__(func=block, concise_scope=False, def_symbol=True)
+        self.block_vars = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define block vars
+        assert isinstance(
+            node, ast.With
+        ), f"BlockScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        var_names = WithScopeHandler.get_optional_var_names(node, context)
+        self.block_vars = [tvm.te.var(name) for name in var_names]
+        for block_var in self.block_vars:
+            context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+    """ With scope handler tir.init()"""
+
+    def __init__(self):
+        def init(span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if self.context.block_info_stack[-2].init is not None:
+                self.context.report_error("Duplicate init block declaration", 
span)
+            self.context.block_info_stack[-2].init = self.body
+
+        super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
 class ForScopeHandler(ScopeHandler):
     """Base class for all for scope handlers"""
 
     def __init__(self, func):
         super().__init__(func)
-        self.loop_vars = None
-
-    def enter_scope(self, node, context, arg_list, span):
-        assert isinstance(node, ast.For)
+        self.loop_vars: Optional[List[Var]] = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert isinstance(
+            node, ast.For
+        ), f"ForScopeHandler expected to work on ast.For but got {type(node)}"
 
         loop_var_names = list()
         spans = list()
         if isinstance(node.lhs, ast.Var):
             loop_var_names.append(node.lhs.id.name)
-            spans.append(from_synr_span(node.lhs.id.span))
-        elif isinstance(node.lhs, ast.Tuple):
-            for elt in node.lhs.values:
+            spans.append(tvm_span_from_synr(node.lhs.id.span))
+        elif isinstance(node.lhs, list):
+            for elt in node.lhs:
                 if not isinstance(elt, ast.Var):
-                    context.report_error("Invalid loop var", elt.span)
+                    context.report_error(
+                        f"Invalid loop var. Expected a var, but got 
{type(elt)}", elt.span
+                    )
                 loop_var_names.append(elt.id.name)
-                spans.append(from_synr_span(elt.id.span))
+                spans.append(tvm_span_from_synr(elt.id.span))
         else:
-            context.report_error("Invalid loop var", node.lhs.span)
+            context.report_error(
+                f"Invalid loop var. Expected var or list of vars as lhs, but 
got {type(node.lhs)}",
+                span,
+            )
 
         self.loop_vars = [
             tvm.te.var(name, dtype="int32", span=span) for name, span in 
zip(loop_var_names, spans)
         ]
         for loop_var in self.loop_vars:
-            context.update_symbol(loop_var.name, loop_var)
+            context.update_symbol(loop_var.name, loop_var, node)
+            context.loop_stack[-1].append(loop_var)
+
+    def exit_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+        for _ in self.loop_vars:
+            context.loop_stack[-1].pop()
+        return super().exit_scope(node, context, arg_list, span)
+
+    def create_loop(
+        self,
+        begin: PrimExpr,
+        end: PrimExpr,
+        kind: ForKind,
+        thread_binding: Optional[str] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> tvm.tir.For:
+        """
+        Helper function for creating For in TVM Script parser.
+
+        Parameters
+        ----------
+        begin : PrimExpr
+            The beginning value.
+
+        end : PrimExpr
+            The endding value.
+
+        kind : ForKind
+            The type of the for.
+
+        thread_binding: Optional[str]
+            The thread this loop binds to.
+
+        annotations : Optional[Mapping[str, Object]]
+            Additional annotation hints.
+
+        span : Optional[Span]
+            The location of this for in the source code.
+
+        Returns
+        -------
+        for : For
+            The constructed For.
+        """
+        assert (
+            self.loop_vars and self.context and self.node
+        ), "call 'exit_scope' before 'enter_scope'"
+        if len(self.loop_vars) != 1:
+            self.context.report_error(
+                f"Expect exact only one loop var, but get {self.loop_vars}", 
self.node.span
+            )
+        extent = end if begin == 0 else self.context.analyzer.simplify(end - 
begin)
+        annos: Mapping[str, Object]
+        if annotations is None:
+            annos = {}
+        else:
+            annos = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in annotations.items()
+            }
+        return tvm.tir.For(
+            self.loop_vars[0],
+            begin,
+            extent,
+            kind,
+            self.body,
+            thread_binding=thread_binding,
+            annotations=annos,
+            span=span,
+        )
 
 
 @register
 class Serial(ForScopeHandler):
-    """ For scope handler tir.serial(begin, end)"""
+    """ For scope handler tir.serial(begin, end, annotations)"""
 
     def __init__(self):
-        def serial(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var", span)
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, 
span=span)
+        def serial(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(begin, end, ForKind.SERIAL, 
annotations=annotations, span=span)
 
         super().__init__(serial)
 
 
 @register
 class Parallel(ForScopeHandler):
-    """ For scope handler tir.parallel(begin, end)"""
+    """ For scope handler tir.parallel(begin, end, annotations)"""
 
     def __init__(self):
-        def parallel(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, 
span=span)
+        def parallel(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.PARALLEL, annotations=annotations, 
span=span
+            )
 
         super().__init__(parallel)
 
 
 @register
 class Vectorized(ForScopeHandler):
-    """ For scope handler tir.vectorized(begin, end)"""
+    """ For scope handler tir.vectorized(begin, end, annotations)"""
 
     def __init__(self):
-        def vectorized(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, 
span=span)
+        def vectorized(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.VECTORIZED, annotations=annotations, 
span=span
+            )
 
         super().__init__(vectorized)
 
 
 @register
 class Unroll(ForScopeHandler):
-    """ For scope handler tir.unroll(begin, end)"""
+    """ For scope handler tir.unroll(begin, end, annotations)"""
 
     def __init__(self):
-        def unroll(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, 
span=span)
+        def unroll(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.UNROLLED, annotations=annotations, 
span=span
+            )
 
         super().__init__(unroll)
+
+
+@register
+class ThreadBinding(ForScopeHandler):
+    """ For scope handler tir.thread_binding(begin, end, thread, 
annotations)"""
+
+    def __init__(self):
+        def thread_binding(
+            begin: PrimExpr,
+            end: PrimExpr,
+            thread: str,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            thread_iter_var = IterVar(None, None, 1, thread, span=span)
+            return self.create_loop(
+                begin,
+                end,
+                ForKind.THREAD_BINDING,
+                thread_binding=thread_iter_var,
+                annotations=annotations,
+                span=span,
+            )
+
+        super().__init__(thread_binding)
+
+
+@register
+class RangeHandler(ForScopeHandler):
+    """For scope handler range(begin, end, annotations)
+    Note that tir.range is totally the same as tir.serial
+    """
+
+    def __init__(self):
+        def for_range(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(begin, end, ForKind.SERIAL, 
annotations=annotations, span=span)
+
+        super().__init__(for_range)
+
+    def signature(self):
+        return "range", get_param_list(self.func)
+
+
+@register
+class Grid(ForScopeHandler):
+    """ For scope handler tir.grid(extents)"""
+
+    def __init__(self):
+        def grid(*extents: List[PrimExpr], span: Span):
+            assert (
+                self.node and self.context and self.loop_vars
+            ), "call 'exit_scope' before 'enter_scope'"
+            if len(self.loop_vars) != len(extents):
+                self.context.report_error(
+                    "Inconsistent number of loop vars and extents, "
+                    + f"got {len(self.loop_vars)} vs {len(extents)}",
+                    self.node.span,
+                )
+            body = self.body
+            for loop_var, extent in zip(reversed(self.loop_vars), 
reversed(extents)):
+                body = tvm.tir.For(loop_var, 0, extent, 0, body, span=span)

Review comment:
       Use the enum element instead of `0` here (the second `0`).

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(write_region)}",
+                    span,
+                )
+            block_scope.writes = write_region
+
+        super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+    """Special function block_attr({attr_key: attr_value})
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.block_attr({"double_buffer_scope": 1})
+
+    """
+
+    def __init__(self):
+        def block_attr(attrs: Mapping[str, Object], span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.annotations is not None:
+                self.context.report_error(
+                    "Duplicate block annotations declaration, "
+                    + "previous one is "
+                    + str(block_scope.annotations),
+                    span,
+                )
+            attrs = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in attrs.items()
+            }
+            block_scope.annotations = attrs
+
+        super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+    """Special function where(predicate)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.where(i < 4)
+
+    """
+
+    def __init__(self):
+        def where(predicate, span=None):
+            block_scope = self.context.current_block_scope()
+            if block_scope.predicate is not None:
+                self.context.report_error(
+                    "Duplicate block predicate declaration, "
+                    + "previous one is "
+                    + str(block_scope.predicate),
+                    span,
+                )
+
+            block_scope.predicate = predicate
+
+        super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+    """Special function match_buffer_region(source, strides, elem_offset, 
align, offset_factor)
+
+    Example
+    -------
+    .. code-block:: python
+
+        B = tir.match_buffer_region(A[0: 4])
+
+    """
+
+    def __init__(self):
+        def match_buffer_region(
+            source,
+            strides=None,
+            elem_offset=None,
+            align=-1,
+            offset_factor=0,
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "match_buffer_region must be assigned to a buffer, "
+                    + "e.g. A = match_buffer_region(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+
+            if not isinstance(source, BufferSlice):
+                self.context.report_error(
+                    "match_buffer_region needs a buffer region as source",
+                    span=span,
+                )
+            buffer_region = buffer_slice_to_region(source)
+            shape = [r.extent for r in buffer_region.region]
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                buffer_region.buffer.dtype,
+                self.node.lhs.id.name,
+                data=None,
+                strides=strides,
+                elem_offset=elem_offset,
+                scope=buffer_region.buffer.scope,
+                data_alignment=align,
+                offset_factor=offset_factor,
+                span=span,
+            )
+            self.context.current_block_scope().match_buffers.append(
+                tvm.tir.MatchBufferRegion(buffer, buffer_region)
+            )
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(match_buffer_region, def_symbol=True)
+
+
 @register
 class VarDef(SpecialStmt):
     """ Special function for defining a Var"""
 
     def __init__(self):
         def var(dtype, span):
-            assert isinstance(self.node, ast.Assign)
+            assert isinstance(
+                self.node, ast.Assign
+            ), f"VarDef expected to work on ast.Assign but got 
{type(self.node)}"

Review comment:
       ```suggestion
               ), f"VarDef expected ast.Assign but got {type(self.node)}"
   ```

##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,145 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer]
+    match_buffers: List[MatchBufferRegion]
+    iter_bindings: Mapping[Var, PrimExpr]
+    reads: Optional[List[BufferSlice]]
+    writes: Optional[List[BufferSlice]]
+    annotations: Optional[Mapping[str, Object]]
+    predicate: Optional[PrimExpr]
+    init: Optional[Stmt]
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
     """Maintain all the necessary context info"""
 
-    def __init__(self, parser):
+    # scope context
+    # ast_node inside a scope
+    node_stack: List[List[synr.ast.Node]]
+    # loop stacks inside a block
+    block_info_stack: List[BlockInfo]
+    # loop stacks inside a block
+    loop_stack: List[List[Var]]
+    symbols: List[Dict[str, Union[Var, Buffer]]]
+
+    # function context
+    func_params: List[Var]
+    func_buffer_map: Mapping[Var, Buffer]
+    func_dict_attr: Mapping[str, Object]
+    func_var_env_dict: Mapping[Var, str]
+
+    # parser and analyzer
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    analyzer: tvm.arith.Analyzer
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
         self.node_stack = []  # AST nodes of scopes
+        self.block_info_stack = []  # Block info of scopes
+        self.loop_stack = []  # stack of loop vars
         self.symbols = []  # symbols of scopes
         # function context
         self.func_params = []  # parameter list of function
         self.func_buffer_map = {}  # buffer_map of function
         self.func_dict_attr = {}  # func_attr of function
         self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
 
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creating a new scope
 
-    def new_scope(self, nodes=None):
-        """Creating a new scope"""
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
         if nodes is None:
             nodes = []
         self.node_stack.append(list(reversed(nodes)))
         self.symbols.append(dict())
 
-    def update_symbol(self, name, symbol):
+    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):

Review comment:
       Document the difference between a regular and block scope from the user 
perspective. When should a should `enter_scope` be used and when should 
`enter_block_scope` be used.

##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,145 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer]
+    match_buffers: List[MatchBufferRegion]
+    iter_bindings: Mapping[Var, PrimExpr]
+    reads: Optional[List[BufferSlice]]
+    writes: Optional[List[BufferSlice]]
+    annotations: Optional[Mapping[str, Object]]
+    predicate: Optional[PrimExpr]
+    init: Optional[Stmt]

Review comment:
       Document these.

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,12 +181,289 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "Need assign the alloc_buffer to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expects BufferSlice or List[BufferSlice], but gets 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expects BufferSlice or List[BufferSlice], but gets 
{type(write_region)}",
+                    span,
+                )
+            block_scope.writes = write_region
+
+        super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+    """Special function block_attr({attr_key: attr_value})
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.block_attr({"double_buffer_scope": 1})
+
+    """
+
+    def __init__(self):
+        def block_attr(attrs: Mapping[str, Object], span: Span = None):
+            assert self.context
+            block_scope = self.context.current_block_scope()
+            if block_scope.annotations is not None:
+                self.context.report_error(
+                    "Duplicate block annotations declaration, "
+                    + "previous one is "
+                    + str(block_scope.annotations),
+                    span,
+                )
+            attrs = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in attrs.items()
+            }
+            block_scope.annotations = attrs
+
+        super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+    """Special function where(predicate)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.where(i < 4)
+
+    """
+
+    def __init__(self):
+        def where(predicate, span=None):
+            block_scope = self.context.current_block_scope()
+            if block_scope.predicate is not None:
+                self.context.report_error(
+                    "Duplicate block predicate declaration, "
+                    + "previous one is "
+                    + str(block_scope.predicate),
+                    span,
+                )
+
+            block_scope.predicate = predicate
+
+        super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+    """Special function match_buffer_region(source, strides, elem_offset, 
align, offset_factor)
+
+    Example
+    -------
+    .. code-block:: python
+
+        B = tir.match_buffer_region(A[0: 4])
+
+    """
+
+    def __init__(self):
+        def match_buffer_region(
+            source,
+            strides=None,
+            elem_offset=None,
+            align=-1,
+            offset_factor=0,
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "Need assign the match_buffer_region to a buffer, "
+                    + "e.g. A = match_buffer_region(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+
+            if not isinstance(source, BufferSlice):
+                self.context.report_error(
+                    "match_buffer_region needs a buffer region as source",

Review comment:
       Add it to the docstring so that users can find it.

##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,145 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer]
+    match_buffers: List[MatchBufferRegion]
+    iter_bindings: Mapping[Var, PrimExpr]
+    reads: Optional[List[BufferSlice]]
+    writes: Optional[List[BufferSlice]]
+    annotations: Optional[Mapping[str, Object]]
+    predicate: Optional[PrimExpr]
+    init: Optional[Stmt]
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
     """Maintain all the necessary context info"""
 
-    def __init__(self, parser):
+    # scope context
+    # ast_node inside a scope
+    node_stack: List[List[synr.ast.Node]]
+    # loop stacks inside a block
+    block_info_stack: List[BlockInfo]
+    # loop stacks inside a block
+    loop_stack: List[List[Var]]
+    symbols: List[Dict[str, Union[Var, Buffer]]]
+
+    # function context
+    func_params: List[Var]
+    func_buffer_map: Mapping[Var, Buffer]
+    func_dict_attr: Mapping[str, Object]
+    func_var_env_dict: Mapping[Var, str]
+
+    # parser and analyzer
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    analyzer: tvm.arith.Analyzer
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
         self.node_stack = []  # AST nodes of scopes
+        self.block_info_stack = []  # Block info of scopes
+        self.loop_stack = []  # stack of loop vars
         self.symbols = []  # symbols of scopes
         # function context
         self.func_params = []  # parameter list of function
         self.func_buffer_map = {}  # buffer_map of function
         self.func_dict_attr = {}  # func_attr of function
         self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
 
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creating a new scope
 
-    def new_scope(self, nodes=None):
-        """Creating a new scope"""
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
         if nodes is None:
             nodes = []
         self.node_stack.append(list(reversed(nodes)))
         self.symbols.append(dict())
 
-    def update_symbol(self, name, symbol):
+    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creating a new block scope, the function will call `enter_scope` 
implicitly
+        Besides behaviors of normal `enter_scope`, it will update loop_stack 
and block_info_stack
+        for block info maintaining.

Review comment:
       ```suggestion
           Besides the behaviors of `enter_scope`, it will update loop_stack 
and block_info_stack
           to maintain block info.
   ```

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -55,24 +82,27 @@ def __init__(self, func, concise_scope, def_symbol):
     @staticmethod
     def get_optional_var_names(node, context):
         """Get list of names from ast.With's optional_vars"""
-        assert isinstance(node, ast.With)
-
-        var_names = None
-        if isinstance(node.items[0].optional_vars, ast.Name):
-            var_names = [node.items[0].optional_vars.id]
-        elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
-            for var in node.items[0].optional_vars.elts:
-                if not isinstance(var, ast.Name):
-                    context.report_error("Invalid optional var definition")
-            var_names = [var.id for var in node.items[0].optional_vars.elts]
+        assert isinstance(
+            node, ast.With
+        ), f"WithScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        if isinstance(node.lhs, list):
+            for var in node.lhs:
+                if not isinstance(var, ast.Var):
+                    context.report_error(
+                        "Invalid optional var definition, only list of Var is 
valid", node.span
+                    )
+            var_names = [var.id.name for var in node.lhs]
         else:
-            context.report_error("Invalid optional var definition")
+            context.report_error(
+                "Invalid optional var definition, only list of Var is valid", 
node.span

Review comment:
       Report what was actually provided.

##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,145 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer]
+    match_buffers: List[MatchBufferRegion]
+    iter_bindings: Mapping[Var, PrimExpr]
+    reads: Optional[List[BufferSlice]]
+    writes: Optional[List[BufferSlice]]
+    annotations: Optional[Mapping[str, Object]]
+    predicate: Optional[PrimExpr]
+    init: Optional[Stmt]
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
     """Maintain all the necessary context info"""
 
-    def __init__(self, parser):
+    # scope context
+    # ast_node inside a scope
+    node_stack: List[List[synr.ast.Node]]
+    # loop stacks inside a block
+    block_info_stack: List[BlockInfo]
+    # loop stacks inside a block
+    loop_stack: List[List[Var]]
+    symbols: List[Dict[str, Union[Var, Buffer]]]
+
+    # function context
+    func_params: List[Var]
+    func_buffer_map: Mapping[Var, Buffer]
+    func_dict_attr: Mapping[str, Object]
+    func_var_env_dict: Mapping[Var, str]
+
+    # parser and analyzer
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    analyzer: tvm.arith.Analyzer
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
         self.node_stack = []  # AST nodes of scopes
+        self.block_info_stack = []  # Block info of scopes
+        self.loop_stack = []  # stack of loop vars
         self.symbols = []  # symbols of scopes
         # function context
         self.func_params = []  # parameter list of function
         self.func_buffer_map = {}  # buffer_map of function
         self.func_dict_attr = {}  # func_attr of function
         self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
 
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creating a new scope

Review comment:
       ```suggestion
           """Creates a new scope
   ```

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +232,384 @@ def let(var, value, span):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert (
+                self.node and self.context and self.body
+            ), "call 'exit_scope' before 'enter_scope'"
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"there are {len(axes)} axes but {len(self.block_vars)} 
block vars. "
+                    + "The number of block vars should match the number of 
axes.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expected PrimExpr, Range or IterVar, but got 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()
+                if len(values) == 0:
+                    values = [tvm.tir.const(float("nan"), dtype="float32")] * 
len(block_iters)
+                elif len(values) != len(block_iters):
+                    self.context.report_error(
+                        "Number of block iter var and outer loop nesting 
mismatch, "
+                        + f"{len(block_iters)} block iter vars but 
{len(values)} loops",
+                        self.node.span,
+                    )
+            else:
+                for block_var in self.block_vars:
+                    if block_var not in block_info.iter_bindings:
+                        self.context.report_error(
+                            "Missing block iter var binding for " + 
block_var.name,
+                            self.node.span,
+                        )
+                values = [block_info.iter_bindings[block_var] for block_var in 
self.block_vars]
+            predicate = (
+                tvm.tir.const(True, "bool")
+                if block_info.predicate is None
+                else block_info.predicate
+            )
+            body = tvm.tir.BlockRealize(values, predicate, inner, span)
+            return body
+
+        super().__init__(func=block, concise_scope=False, def_symbol=True)
+        self.block_vars = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define block vars
+        assert isinstance(
+            node, ast.With
+        ), f"BlockScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        var_names = WithScopeHandler.get_optional_var_names(node, context)
+        self.block_vars = [tvm.te.var(name) for name in var_names]
+        for block_var in self.block_vars:
+            context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+    """ With scope handler tir.init()"""
+
+    def __init__(self):
+        def init(span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if self.context.block_info_stack[-2].init is not None:
+                self.context.report_error("Duplicate init block declaration", 
span)
+            self.context.block_info_stack[-2].init = self.body
+
+        super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
 class ForScopeHandler(ScopeHandler):
     """Base class for all for scope handlers"""
 
     def __init__(self, func):
         super().__init__(func)
-        self.loop_vars = None
-
-    def enter_scope(self, node, context, arg_list, span):
-        assert isinstance(node, ast.For)
+        self.loop_vars: Optional[List[Var]] = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert isinstance(
+            node, ast.For
+        ), f"ForScopeHandler expected to work on ast.For but got {type(node)}"

Review comment:
       ```suggestion
           ), f"ForScopeHandler expected ast.For but got {type(node)}"
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)

Review comment:
       Can you give a little more detail in the examples.

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -167,10 +485,12 @@ class EnvThread(SpecialStmt):
 
     def __init__(self):
         def env_thread(env_name, span):
-            assert isinstance(self.node, ast.Assign)
+            assert isinstance(
+                self.node, ast.Assign
+            ), f"EnvThread expected to work on ast.Assign but got 
{type(self.node)}"

Review comment:
       ```suggestion
               ), f"EnvThread expected ast.Assign but got {type(self.node)}"
   ```

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +232,384 @@ def let(var, value, span):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert (
+                self.node and self.context and self.body
+            ), "call 'exit_scope' before 'enter_scope'"
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"there are {len(axes)} axes but {len(self.block_vars)} 
block vars. "
+                    + "The number of block vars should match the number of 
axes.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expected PrimExpr, Range or IterVar, but got 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()
+                if len(values) == 0:
+                    values = [tvm.tir.const(float("nan"), dtype="float32")] * 
len(block_iters)
+                elif len(values) != len(block_iters):
+                    self.context.report_error(
+                        "Number of block iter var and outer loop nesting 
mismatch, "
+                        + f"{len(block_iters)} block iter vars but 
{len(values)} loops",
+                        self.node.span,
+                    )
+            else:
+                for block_var in self.block_vars:
+                    if block_var not in block_info.iter_bindings:
+                        self.context.report_error(
+                            "Missing block iter var binding for " + 
block_var.name,
+                            self.node.span,
+                        )
+                values = [block_info.iter_bindings[block_var] for block_var in 
self.block_vars]
+            predicate = (
+                tvm.tir.const(True, "bool")
+                if block_info.predicate is None
+                else block_info.predicate
+            )
+            body = tvm.tir.BlockRealize(values, predicate, inner, span)
+            return body
+
+        super().__init__(func=block, concise_scope=False, def_symbol=True)
+        self.block_vars = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        # define block vars
+        assert isinstance(
+            node, ast.With
+        ), f"BlockScopeHandler expected to work on ast.With but got 
{type(node)}"
+
+        var_names = WithScopeHandler.get_optional_var_names(node, context)
+        self.block_vars = [tvm.te.var(name) for name in var_names]
+        for block_var in self.block_vars:
+            context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+    """ With scope handler tir.init()"""
+
+    def __init__(self):
+        def init(span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if self.context.block_info_stack[-2].init is not None:
+                self.context.report_error("Duplicate init block declaration", 
span)
+            self.context.block_info_stack[-2].init = self.body
+
+        super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
 class ForScopeHandler(ScopeHandler):
     """Base class for all for scope handlers"""
 
     def __init__(self, func):
         super().__init__(func)
-        self.loop_vars = None
-
-    def enter_scope(self, node, context, arg_list, span):
-        assert isinstance(node, ast.For)
+        self.loop_vars: Optional[List[Var]] = None
+
+    def enter_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert isinstance(
+            node, ast.For
+        ), f"ForScopeHandler expected to work on ast.For but got {type(node)}"
 
         loop_var_names = list()
         spans = list()
         if isinstance(node.lhs, ast.Var):
             loop_var_names.append(node.lhs.id.name)
-            spans.append(from_synr_span(node.lhs.id.span))
-        elif isinstance(node.lhs, ast.Tuple):
-            for elt in node.lhs.values:
+            spans.append(tvm_span_from_synr(node.lhs.id.span))
+        elif isinstance(node.lhs, list):
+            for elt in node.lhs:
                 if not isinstance(elt, ast.Var):
-                    context.report_error("Invalid loop var", elt.span)
+                    context.report_error(
+                        f"Invalid loop var. Expected a var, but got 
{type(elt)}", elt.span
+                    )
                 loop_var_names.append(elt.id.name)
-                spans.append(from_synr_span(elt.id.span))
+                spans.append(tvm_span_from_synr(elt.id.span))
         else:
-            context.report_error("Invalid loop var", node.lhs.span)
+            context.report_error(
+                f"Invalid loop var. Expected var or list of vars as lhs, but 
got {type(node.lhs)}",
+                span,
+            )
 
         self.loop_vars = [
             tvm.te.var(name, dtype="int32", span=span) for name, span in 
zip(loop_var_names, spans)
         ]
         for loop_var in self.loop_vars:
-            context.update_symbol(loop_var.name, loop_var)
+            context.update_symbol(loop_var.name, loop_var, node)
+            context.loop_stack[-1].append(loop_var)
+
+    def exit_scope(
+        self,
+        node: synr.ast.Node,
+        context: ContextMaintainer,
+        arg_list: List[Any],
+        span: synr.ast.Span,
+    ):
+        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+        for _ in self.loop_vars:
+            context.loop_stack[-1].pop()
+        return super().exit_scope(node, context, arg_list, span)
+
+    def create_loop(
+        self,
+        begin: PrimExpr,
+        end: PrimExpr,
+        kind: ForKind,
+        thread_binding: Optional[str] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> tvm.tir.For:
+        """
+        Helper function for creating For in TVM Script parser.
+
+        Parameters
+        ----------
+        begin : PrimExpr
+            The beginning value.
+
+        end : PrimExpr
+            The endding value.
+
+        kind : ForKind
+            The type of the for.
+
+        thread_binding: Optional[str]
+            The thread this loop binds to.
+
+        annotations : Optional[Mapping[str, Object]]
+            Additional annotation hints.
+
+        span : Optional[Span]
+            The location of this for in the source code.
+
+        Returns
+        -------
+        for : For
+            The constructed For.
+        """
+        assert (
+            self.loop_vars and self.context and self.node
+        ), "call 'exit_scope' before 'enter_scope'"
+        if len(self.loop_vars) != 1:
+            self.context.report_error(
+                f"Expect exact only one loop var, but get {self.loop_vars}", 
self.node.span
+            )
+        extent = end if begin == 0 else self.context.analyzer.simplify(end - 
begin)
+        annos: Mapping[str, Object]
+        if annotations is None:
+            annos = {}
+        else:
+            annos = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in annotations.items()
+            }
+        return tvm.tir.For(
+            self.loop_vars[0],
+            begin,
+            extent,
+            kind,
+            self.body,
+            thread_binding=thread_binding,
+            annotations=annos,
+            span=span,
+        )
 
 
 @register
 class Serial(ForScopeHandler):
-    """ For scope handler tir.serial(begin, end)"""
+    """ For scope handler tir.serial(begin, end, annotations)"""
 
     def __init__(self):
-        def serial(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var", span)
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, 
span=span)
+        def serial(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(begin, end, ForKind.SERIAL, 
annotations=annotations, span=span)
 
         super().__init__(serial)
 
 
 @register
 class Parallel(ForScopeHandler):
-    """ For scope handler tir.parallel(begin, end)"""
+    """ For scope handler tir.parallel(begin, end, annotations)"""
 
     def __init__(self):
-        def parallel(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, 
span=span)
+        def parallel(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.PARALLEL, annotations=annotations, 
span=span
+            )
 
         super().__init__(parallel)
 
 
 @register
 class Vectorized(ForScopeHandler):
-    """ For scope handler tir.vectorized(begin, end)"""
+    """ For scope handler tir.vectorized(begin, end, annotations)"""
 
     def __init__(self):
-        def vectorized(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, 
span=span)
+        def vectorized(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.VECTORIZED, annotations=annotations, 
span=span
+            )
 
         super().__init__(vectorized)
 
 
 @register
 class Unroll(ForScopeHandler):
-    """ For scope handler tir.unroll(begin, end)"""
+    """ For scope handler tir.unroll(begin, end, annotations)"""
 
     def __init__(self):
-        def unroll(begin, end, span):
-            if len(self.loop_vars) != 1:
-                self.context.report_error("Expect exact 1 loop var")
-            ana = tvm.arith.Analyzer()
-            extent = end if begin == 0 else ana.simplify(end - begin)
-            return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, 
span=span)
+        def unroll(
+            begin: PrimExpr,
+            end: PrimExpr,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            return self.create_loop(
+                begin, end, ForKind.UNROLLED, annotations=annotations, 
span=span
+            )
 
         super().__init__(unroll)
+
+
+@register
+class ThreadBinding(ForScopeHandler):
+    """ For scope handler tir.thread_binding(begin, end, thread, 
annotations)"""
+
+    def __init__(self):
+        def thread_binding(
+            begin: PrimExpr,
+            end: PrimExpr,
+            thread: str,
+            annotations: Optional[Mapping[str, Object]] = None,
+            span: Optional[Span] = None,
+        ):
+            thread_iter_var = IterVar(None, None, 1, thread, span=span)

Review comment:
       Use the enum entry instead of `1` here.

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, list):
+                pass
+            elif isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            else:
+                self.context.report_error(
+                    "Error input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Error input type. "

Review comment:
       ```suggestion
                       "Incorrect input type. "
   ```

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -713,6 +782,88 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* 
op) {
   return doc;
 }
 
+Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
+  const auto* block_op = op->block.as<BlockNode>();
+  // print block name and block vars
+  Doc doc;
+  doc << "with tir.block([";
+  std::vector<Doc> block_var_docs;
+  for (const auto& iter_var : block_op->iter_vars) {
+    Doc block_var_doc;
+    if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) {
+      block_var_doc << Print(iter_var->dom->extent);
+    } else {
+      block_var_doc << "tir.";
+      switch (iter_var->iter_type) {
+        case kDataPar:
+          block_var_doc << "range";
+          break;
+        case kCommReduce:
+          block_var_doc << "reduce_axis";
+          break;
+        case kOrdered:
+          block_var_doc << "scan_axis";
+          break;
+        case kOpaque:
+          block_var_doc << "opaque_axis";
+          break;
+        default:
+          LOG(FATAL) << "Unknown block var iter type";

Review comment:
       Might as well print the number.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to