tkonolige commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r597853815



##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,179 @@
 # under the License.
 """TVM Script Context Maintainer for TIR"""
 
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+    """Information for block and block_realize signature"""
+
+    alloc_buffers: List[Buffer] = []
+    """List[Buffer]: list of tir.alloc_buffer statements in the block 
signature"""
+    match_buffers: List[MatchBufferRegion] = []
+    """List[MatchBufferRegion]: list of tir.match_buffer_region statements in 
the block signature"""
+    iter_bindings: Mapping[Var, PrimExpr] = {}
+    """Mapping[Var, PrimExpr]: map of block iter var to its values"""
+    reads: Optional[List[BufferSlice]] = None
+    """Optional[List[BufferSlice]]:
+    list of tir.reads statements in the block signature, None for 
not-visited"""
+    writes: Optional[List[BufferSlice]] = None
+    """Optional[List[BufferSlice]]:
+    list of tir.writes statements in the block signature, None for 
not-visited"""
+    annotations: Optional[Mapping[str, Object]] = None
+    """Optional[Mapping[str, Object]]:
+    list of tir.block_attr statements in the block signature, None for 
not-visited"""
+    predicate: Optional[PrimExpr] = None
+    """Optional[PrimExpr]: block realize predicate, None for not-visited"""
+    init: Optional[Stmt] = None
+    """Optional[Stmt]: init part of the block, None for not-visited"""
+
+    def __init__(self):
+        self.alloc_buffers = []
+        self.match_buffers = []
+        self.iter_bindings = {}
+        self.reads = None
+        self.writes = None
+        self.annotations = None
+        self.predicate = None
+        self.init = None
 
 
 class ContextMaintainer:
-    """Maintain all the necessary context info"""
+    """Maintain all the necessary context info
+    Parameters
+    ----------
+    _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
+        The report error function handle
+    """
 
-    def __init__(self, parser):
+    # scope context
+    node_stack: List[List[synr.ast.Node]] = []
+    """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
+    block_info_stack: List[BlockInfo] = []
+    """List[BlockInfo]: The block info for the current block scope"""
+    loop_stack: List[List[Var]] = []
+    """List[List[Var]]: List of loop vars inside the current block scope"""
+    symbols: List[Dict[str, Union[Var, Buffer]]] = []
+    """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for 
the current scope"""
+
+    # function context
+    func_params: List[Var] = []
+    """List[Var]: The function parameters"""
+    func_buffer_map: Mapping[Var, Buffer] = {}
+    """Mapping[Var, Buffer]: The function buffer map"""
+    func_dict_attr: Mapping[str, Object] = {}
+    """Mapping[str, Object]: The function attrs"""
+    func_var_env_dict: Mapping[Var, str] = {}
+    """Mapping[Var, str]: The map from var to env thread"""
+
+    # parser and analyzer
+    analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
+    """tvm.arith.Analyzer: The analyzer for simplifying"""
+    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+    """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error 
function handle"""
+
+    def __init__(self, _report_error: Callable[[str, Union[Span, 
synr.ast.Span]], None]):
         # scope context
-        self.node_stack = []  # AST nodes of scopes
-        self.symbols = []  # symbols of scopes
+        self.node_stack = []
+        self.block_info_stack = []
+        self.loop_stack = []
+        self.symbols = []
         # function context
-        self.func_params = []  # parameter list of function
-        self.func_buffer_map = {}  # buffer_map of function
-        self.func_dict_attr = {}  # func_attr of function
-        self.func_var_env_dict = {}  # map from var to env_name
-        # parser
-        self.parser = parser
-
-    def pop_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
+        self.func_params = []
+        self.func_buffer_map = {}
+        self.func_dict_attr = {}
+        self.func_var_env_dict = {}
+        # parser and analyzer
+        self._report_error = _report_error
+        self.analyzer = tvm.arith.Analyzer()
 
-    def new_scope(self, nodes=None):
-        """Creating a new scope"""
+    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creates a new scope
+
+        Note
+        ----
+        This function is used for normal scopes that do not involve 
+        a `with block` scope. Use `enter_block_scope`
+        for block scope cases.
+
+        Parameters
+        ----------
+        nodes : Optional[List[synr.ast.Node]]
+            The synr AST nodes in new scope
+        """
         if nodes is None:
             nodes = []
         self.node_stack.append(list(reversed(nodes)))
         self.symbols.append(dict())
 
-    def update_symbol(self, name, symbol):
+    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+        """Creates a new block scope, the function will call `enter_scope` 
implicitly
+        Besides the behaviors of `enter_scope`, it will update loop_stack and 
block_info_stack
+        to maintain block info.
+
+        Note
+        ----
+        This function should be used to handle a block scope,
+        aka the blocks that involve a `with block` scope.

Review comment:
       This is still confusing. What is a block? How does it differ from a 
regular scope?

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -17,30 +17,81 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, 
inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
+from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+
+import synr
 from synr import ast
 
 import tvm.tir
+from tvm.runtime import Object
 from tvm import te
-from .utils import get_param_list, from_synr_span
+from tvm.ir import Span
+from tvm.tir import IntImm
+from .utils import (
+    get_param_list,
+    tvm_span_from_synr,
+    buffer_slice_to_region,
+    call_with_error_reporting,
+)
 from .registry import register
+from .context_maintainer import ContextMaintainer
+from .node import BufferSlice
+
+
+def convert_to_int(
+    value: Union[IntImm, int],
+    arg_name: str,
+    report_error: Callable,
+    span: Union[Span, synr.ast.Span],
+) -> int:
+    """convert a const int or TVM IntImm to Python int.
+    Report an error when input cannot be converted to int.
+
+    Parameters
+    ----------
+    value : Union[tvm.tir.IntImm, int]
+        The input value to be converted.
+    arg_name : str
+        function arg name for error reporting
+    report_error: Callable
+        The report error function handle
+    span : Union[synr.ast.Span, tvm.ir.Span】

Review comment:
       ```suggestion
       span : Union[synr.ast.Span, tvm.ir.Span】
   ```
   ```suggestion
       span : Union[synr.ast.Span, tvm.ir.Span]
   ```

##########
File path: python/tvm/tir/analysis/analysis.py
##########
@@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints):
         The result of verification.
     """
     return _ffi_api.verify_gpu_code(func, constraints)
+
+
+def get_block_access_region(block, buffer_var_map):
+    """Auto detect the block read/write region according to body stmt
+        It will detect the read/write region as an array in order of 
appearance in AST
+
+    Parameters
+    ----------
+    block: tvm.tir.Block
+        The block to be detected.

Review comment:
       ```suggestion
           The block in which we are detecting read/write regions.
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -17,30 +17,81 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, 
inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
+from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+
+import synr
 from synr import ast
 
 import tvm.tir
+from tvm.runtime import Object
 from tvm import te
-from .utils import get_param_list, from_synr_span
+from tvm.ir import Span
+from tvm.tir import IntImm
+from .utils import (
+    get_param_list,
+    tvm_span_from_synr,
+    buffer_slice_to_region,
+    call_with_error_reporting,
+)
 from .registry import register
+from .context_maintainer import ContextMaintainer
+from .node import BufferSlice
+
+
+def convert_to_int(
+    value: Union[IntImm, int],
+    arg_name: str,
+    report_error: Callable,
+    span: Union[Span, synr.ast.Span],
+) -> int:
+    """convert a const int or TVM IntImm to Python int.
+    Report an error when input cannot be converted to int.

Review comment:
       ```suggestion
       Reports an error when input cannot be converted to int.
   ```
   ```suggestion
       Report an error when input cannot be converted to int.
   ```

##########
File path: src/tir/analysis/block_access_region_detector.cc
##########
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Auto detect the block read write region
+ *        It will detect the read/write region as an array in order of 
appearance in AST

Review comment:
       ```suggestion
    * \brief Detect which regions of tensors in this block are read or written 
to. Regions are sorted by order of appearance in the AST.
   ```

##########
File path: python/tvm/script/special_stmt.py
##########
@@ -17,30 +17,81 @@
 """TVM Script Parser Special Stmt Classes"""
 # pylint: disable=unused-argument, no-self-argument, 
inconsistent-return-statements
 # pylint: disable=relative-beyond-top-level
+from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+
+import synr
 from synr import ast
 
 import tvm.tir
+from tvm.runtime import Object
 from tvm import te
-from .utils import get_param_list, from_synr_span
+from tvm.ir import Span
+from tvm.tir import IntImm
+from .utils import (
+    get_param_list,
+    tvm_span_from_synr,
+    buffer_slice_to_region,
+    call_with_error_reporting,
+)
 from .registry import register
+from .context_maintainer import ContextMaintainer
+from .node import BufferSlice
+
+
+def convert_to_int(
+    value: Union[IntImm, int],
+    arg_name: str,
+    report_error: Callable,
+    span: Union[Span, synr.ast.Span],
+) -> int:
+    """convert a const int or TVM IntImm to Python int.
+    Report an error when input cannot be converted to int.
+
+    Parameters
+    ----------
+    value : Union[tvm.tir.IntImm, int]
+        The input value to be converted.
+    arg_name : str
+        function arg name for error reporting

Review comment:
       ```suggestion
           Function argument name for error reporting.
   ```
   ```suggestion
           function arg name for error reporting
   ```

##########
File path: python/tvm/tir/analysis/analysis.py
##########
@@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints):
         The result of verification.
     """
     return _ffi_api.verify_gpu_code(func, constraints)
+
+
+def get_block_access_region(block, buffer_var_map):
+    """Auto detect the block read/write region according to body stmt
+        It will detect the read/write region as an array in order of 
appearance in AST

Review comment:
       ```suggestion
       """Detect which regions of tensors in this block are read or written to.
          Regions are sorted by order of appearance in the AST.
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to