tkonolige commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r595390660



##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,166 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer] = []
+    """List[Buffer]: alloc_buffers list for the block"""
+    match_buffers: List[MatchBufferRegion] = []
+    """List[MatchBufferRegion]: match_buffer_region list for the block"""

Review comment:
       Please document this more instead of just restating what the variable is 
called. You've done the same thing on the other variables.
   
   Maybe it would be helpful to provide an example of a block and show what 
each variable corresponds to.

##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,166 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer] = []
+    """List[Buffer]: alloc_buffers list for the block"""
+    match_buffers: List[MatchBufferRegion] = []
+    """List[MatchBufferRegion]: match_buffer_region list for the block"""
+    iter_bindings: Mapping[Var, PrimExpr] = {}
+    """Mapping[Var, PrimExpr]: block iter var and its values"""
+    reads: Optional[List[BufferSlice]] = None
+    """Optional[List[BufferSlice]]: block read buffer regions, None for 
not-visited"""
+    writes: Optional[List[BufferSlice]] = None
+    """Optional[List[BufferSlice]]: block write buffer regions, None for 
not-visited"""
+    annotations: Optional[Mapping[str, Object]] = None
+    """Optional[Mapping[str, Object]]: block annotations, None for 
not-visited"""
+    predicate: Optional[PrimExpr] = None
+    """Optional[PrimExpr]: block realize predicate, None for not-visited"""
+    init: Optional[Stmt] = None
+    """Optional[Stmt]: init part of the block, None for not-visited"""
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
-    """Maintain all the necessary context info"""
+    """Maintain all the necessary context info
+    Parameters
+    ----------
+    _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
+        The report error function handle
+    """
 
-    def __init__(self, parser):
+    # scope context
+    node_stack: List[List[synr.ast.Node]] = []
+    """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
+    block_info_stack: List[BlockInfo] = []
+    """List[BlockInfo]: The block info for the current block scope"""
+    loop_stack: List[List[Var]] = []
+    """List[List[Var]]: List of loop vars inside the current block scope"""
+    symbols: List[Dict[str, Union[Var, Buffer]]] = []
+    """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for 
the current scope"""
+
+    # function context
+    func_params: List[Var] = []
+    """List[Var]: The function parameters"""
+    func_buffer_map: Mapping[Var, Buffer] = {}
+    """Mapping[Var, Buffer]: The function buffer map"""
+    func_dict_attr: Mapping[str, Object] = {}
+    """Mapping[str, Object]: The function attrs"""
+    func_var_env_dict: Mapping[Var, str] = {}
+    """Mapping[Var, str]: The map from var to env thread"""
+
+    # parser and analyzer
+    analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
+    """tvm.arith.Analyzer: The analyzer for simplifying"""
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error 
function handle"""
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
-        self.node_stack = []  # AST nodes of scopes
-        self.symbols = []  # symbols of scopes
+        self.node_stack = []
+        self.block_info_stack = []
+        self.loop_stack = []
+        self.symbols = []
         # function context
-        self.func_params = []  # parameter list of function
-        self.func_buffer_map = {}  # buffer_map of function
-        self.func_dict_attr = {}  # func_attr of function
-        self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
-
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+        self.func_params = []
+        self.func_buffer_map = {}
+        self.func_dict_attr = {}
+        self.func_var_env_dict = {}
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
+
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creates a new scope
 
-    def new_scope(self, nodes=None):
-        """Creating a new scope"""
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
         if nodes is None:
             nodes = []
         self.node_stack.append(list(reversed(nodes)))
         self.symbols.append(dict())
 
-    def update_symbol(self, name, symbol):
+    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creates a new block scope, the function will call `enter_scope` 
implicitly
+        Besides the behaviors of `enter_scope`, it will update loop_stack and 
block_info_stack
+        to maintain block info.
+        It should be used when the scope is a block (or likely to be a block)

Review comment:
       Is this only for things like
   ```
   with tir.block([]):
       ...
   ```
   ?

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)

Review comment:
       I think we should show a full block example here. Context is a key part 
of examples.

##########
File path: python/tvm/script/scope_handler.py
##########
@@ -55,24 +82,29 @@ def __init__(self, func, concise_scope, def_symbol):
     @staticmethod
     def get_optional_var_names(node, context):
         """Get list of names from ast.With's optional_vars"""
-        assert isinstance(node, ast.With)
-
-        var_names = None
-        if isinstance(node.items[0].optional_vars, ast.Name):
-            var_names = [node.items[0].optional_vars.id]
-        elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
-            for var in node.items[0].optional_vars.elts:
-                if not isinstance(var, ast.Name):
-                    context.report_error("Invalid optional var definition")
-            var_names = [var.id for var in node.items[0].optional_vars.elts]
+        assert isinstance(
+            node, ast.With
+        ), f"WithScopeHandler expected to work on ast.With but got 
{type(node)}"

Review comment:
       ```suggestion
           ), f"WithScopeHandler expected ast.With but got {type(node)}"
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +200,293 @@ def buffer_decl(
                 buffer_type,
                 span=span,
             )
-            self.context.update_symbol(self.node.lhs.id.name, buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
             return buffer
 
         super().__init__(buffer_decl, def_symbol=True)
 
 
+@register
+class AllocBuffer(SpecialStmt):
+    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, 
scope, align,
+                                     offset_factor, buffer_type)
+
+    Example
+    -------
+    .. code-block:: python
+
+        A = tir.alloc_buffer((128, 128), dtype="float32")
+    """
+
+    def __init__(self):
+        def alloc_buffer(
+            shape,
+            dtype="float32",
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "alloc_buffer must be assigned to a buffer, e.g. A = 
alloc_buffer(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                self.node.lhs.id.name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            self.context.current_block_scope().alloc_buffers.append(buffer)
+            self.context.update_symbol(self.node.lhs.id.name, buffer, 
self.node)
+
+        super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+    """Special function bind(block_iter, binding_value)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.bind(vx, i)
+    """
+
+    def __init__(self):
+        def bind(iter_var, values, span=None):
+            block_scope = self.context.current_block_scope()
+            if iter_var in block_scope.iter_bindings:
+                self.context.report_error("Duplicate iter_var bindings of " + 
str(iter_var), span)
+            block_scope.iter_bindings[iter_var] = values
+
+        super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+    """Special function reads([read_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+    """
+
+    def __init__(self):
+        def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.reads is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.reads)),
+                    span,
+                )
+            if isinstance(read_regions, BufferSlice):
+                read_regions = [read_regions]
+            if not isinstance(read_regions, list):
+                self.context.report_error(
+                    "Incorrect input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(read_regions)}",
+                    span,
+                )
+            block_scope.reads = read_regions
+
+        super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+    """Special function writes([write_buffer_regions])
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.writes([C[vi: vi + 4, vj])
+    """
+
+    def __init__(self):
+        def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: 
Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.writes is not None:
+                self.context.report_error(
+                    "Duplicate write region declaration, "
+                    + "previous one is "
+                    + str(", ".join(str(x) for x in block_scope.writes)),
+                    span,
+                )
+            if isinstance(write_region, list):
+                pass
+            elif isinstance(write_region, BufferSlice):
+                write_region = [write_region]
+            else:
+                self.context.report_error(
+                    "Incorrect input type. "
+                    + f"Expected BufferSlice or List[BufferSlice], but got 
{type(write_region)}",
+                    span,
+                )
+            block_scope.writes = write_region
+
+        super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+    """Special function block_attr({attr_key: attr_value})
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.block_attr({"double_buffer_scope": 1})
+    """
+
+    def __init__(self):
+        def block_attr(attrs: Mapping[str, Object], span: Span = None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.annotations is not None:
+                self.context.report_error(
+                    "Duplicate block annotations declaration, "
+                    + "previous one is "
+                    + str(block_scope.annotations),
+                    span,
+                )
+            attrs = {
+                key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+                for key, val in attrs.items()
+            }
+            block_scope.annotations = attrs
+
+        super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+    """Special function where(predicate)
+
+    Example
+    -------
+    .. code-block:: python
+
+        tir.where(i < 4)
+    """
+
+    def __init__(self):
+        def where(predicate, span=None):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            block_scope = self.context.current_block_scope()
+            if block_scope.predicate is not None:
+                self.context.report_error(
+                    "Duplicate block predicate declaration, "
+                    + "previous one is "
+                    + str(block_scope.predicate),
+                    span,
+                )
+
+            block_scope.predicate = predicate
+
+        super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+    """Special function match_buffer_region(source, strides, elem_offset, 
align, offset_factor)
+
+    Example
+    -------
+    .. code-block:: python
+
+        B = tir.match_buffer_region(A[0: 4])
+    """
+
+    def __init__(self):
+        def match_buffer_region(
+            source,
+            strides=None,
+            elem_offset=None,
+            align=-1,
+            offset_factor=0,
+            span=None,
+        ):
+            assert self.context, "call 'exit_scope' before 'enter_scope'"
+            if not isinstance(self.node, ast.Assign):
+                self.context.report_error(
+                    "match_buffer_region must be assigned to a buffer, "
+                    + "e.g. A = match_buffer_region(...)",
+                    self.node.span,
+                )
+
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, 
self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, 
self.node.span
+            )
+
+            if not isinstance(source, BufferSlice):
+                self.context.report_error(
+                    "match_buffer_region needs a buffer region as source",

Review comment:
       Show an example of using a buffer region as a source in this error 
message.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to