This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a302e0fd63 [TIR] Enhance Python Type Annotations for TIR stmt (#16076)
a302e0fd63 is described below

commit a302e0fd63c7f11bd7ce0fe709b7cc526548ae1c
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Nov 7 02:15:42 2023 +0800

    [TIR] Enhance Python Type Annotations for TIR stmt (#16076)
    
    This PR enhances the Python annotations for the TIR stmt,
    adding class member variables annotations.
---
 include/tvm/tir/stmt.h                             |   8 +-
 python/tvm/tir/stmt.py                             | 298 +++++++++++++++------
 .../python/unittest/test_tir_schedule_analysis.py  |   2 +-
 tests/python/unittest/test_tir_schedule_state.py   |   2 +-
 .../unittest/test_tvmscript_ir_builder_tir.py      |  12 +-
 5 files changed, 233 insertions(+), 89 deletions(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index bddf87101f..07cc9b5ad0 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -68,7 +68,7 @@ class LetStmtNode : public StmtNode {
  public:
   /*! \brief The variable. */
   Var var;
-  /*! \brief The value to be binded. */
+  /*! \brief The value to be bound. */
   PrimExpr value;
   /*! \brief The body block. */
   Stmt body;
@@ -876,7 +876,7 @@ class SeqStmt : public Stmt {
 };
 
 /*!
- * \brief IfThenElse statment.
+ * \brief IfThenElse statement.
  */
 class IfThenElseNode : public StmtNode {
  public:
@@ -951,7 +951,7 @@ enum class ForKind : int {
 };
 
 /*!
- * \brief A for loop, with poissible type annotations.
+ * \brief A for loop, with possible type annotations.
  *
  * \code
  *
@@ -1388,7 +1388,7 @@ namespace attr {
 constexpr const char* thread_extent = "thread_extent";
 /*! \brief Mark launching of a virtual thread. */
 constexpr const char* virtual_thread = "virtual_thread";
-/*! \brief Mark region is processed by a co-proccesor */
+/*! \brief Mark region is processed by a co-processor */
 constexpr const char* coproc_scope = "coproc_scope";
 /*!
  * \brief Mark region creates coprocessor micro ops,
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 484a6f04b6..992c388e27 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name
 """Statement AST Node in TVM.
 
 Each statement node have subfields that can be visited from python side.
@@ -32,11 +31,11 @@ from typing import List, Mapping, Optional, Union
 
 import tvm._ffi
 from tvm.ir import PrimExpr, Range, Span
-from tvm.runtime import Object, Scriptable, const
+from tvm.runtime import Object, Scriptable, const, NDArray
 
 from . import _ffi_api
-from .buffer import Buffer
-from .expr import IterVar
+from .buffer import Buffer, DataProducer
+from .expr import Var, IterVar
 
 
 class Stmt(Object, Scriptable):
@@ -53,16 +52,21 @@ class LetStmt(Stmt):
         The variable in the binding.
 
     value : PrimExpr
-        The value in to be binded.
+        The value in to be bound.
 
     body : Stmt
         The body statement.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, var, value, body, span=None):
+    var: Var
+    value: PrimExpr
+    body: Stmt
+    span: Optional[Span]
+
+    def __init__(self, var: Var, value: PrimExpr, body: Stmt, span: 
Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.LetStmt, var, value, body, span  # type: ignore
         )
@@ -84,10 +88,17 @@ class AssertStmt(Stmt):
         The body statement.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, condition, message, body, span=None):
+    condition: PrimExpr
+    message: PrimExpr
+    body: Stmt
+    span: Optional[Span]
+
+    def __init__(
+        self, condition: PrimExpr, message: PrimExpr, body: Stmt, span: 
Optional[Span] = None
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.AssertStmt, condition, message, body, span  # type: ignore
         )
@@ -106,7 +117,7 @@ class ForKind(IntEnum):
     PARALLEL = 1
     VECTORIZED = 2
     UNROLLED = 3
-    THREAD_BINDING = 4
+    THREAD_BINDING = 4  # pylint: disable=invalid-name
 
 
 @tvm._ffi.register_object("tir.For")
@@ -118,7 +129,7 @@ class For(Stmt):
     loop_var : Var
         The loop variable.
 
-    min_val : PrimExpr
+    min : PrimExpr
         The beginning value.
 
     extent : PrimExpr
@@ -134,28 +145,37 @@ class For(Stmt):
         The thread this loop binds to. Only valid
         if kind is ThreadBinding
 
-    annotations: tvm.ir.Map
+    annotations: Optional[Mapping[str, Object]]
         Additional annotation hints.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
+    loop_var: Var
+    min: PrimExpr
+    extent: PrimExpr
+    kind: ForKind
+    body: Stmt
+    thread_binding: Optional[IterVar]
+    annotations: Mapping[str, Object]
+    span: Optional[Span]
+
     def __init__(
         self,
-        loop_var,
-        min_val,
-        extent,
-        kind,
-        body,
-        thread_binding=None,
-        annotations=None,
-        span=None,
-    ):
+        loop_var: Var,
+        min: PrimExpr,  # pylint: disable=redefined-builtin
+        extent: PrimExpr,
+        kind: ForKind,
+        body: Stmt,
+        thread_binding: Optional[IterVar] = None,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.For,  # type: ignore
             loop_var,
-            min_val,
+            min,
             extent,
             kind,
             body,
@@ -178,16 +198,15 @@ class While(Stmt):
         The body statement.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, condition, body, span=None):
-        self.__init_handle_by_constructor__(
-            _ffi_api.While,  # type: ignore
-            condition,
-            body,
-            span,
-        )
+    condition: PrimExpr
+    body: Stmt
+    span: Optional[Span]
+
+    def __init__(self, condition: PrimExpr, body: Stmt, span: Optional[Span] = 
None) -> None:
+        self.__init_handle_by_constructor__(_ffi_api.While, condition, body, 
span)  # type: ignore
 
 
 @tvm._ffi.register_object("tir.BufferStore")
@@ -206,10 +225,21 @@ class BufferStore(Stmt):
         The indices location to be stored.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, buffer, value, indices, span=None):
+    buffer: Buffer
+    value: PrimExpr
+    indices: List[PrimExpr]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        buffer: Buffer,
+        value: PrimExpr,
+        indices: List[PrimExpr],
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.BufferStore, buffer, value, indices, span  # type: ignore
         )
@@ -234,10 +264,23 @@ class BufferRealize(Stmt):
         The body of the statement.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, buffer, bounds, condition, body, span=None):
+    buffer: Buffer
+    bounds: List[Range]
+    condition: PrimExpr
+    body: Stmt
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        buffer: Buffer,
+        bounds: List[Range],
+        condition: PrimExpr,
+        body: Stmt,
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.BufferRealize, buffer, bounds, condition, body, span  # 
type: ignore
         )
@@ -259,10 +302,21 @@ class ProducerStore(Stmt):
         The index arguments of the store.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, producer, value, indices, span=None):
+    producer: DataProducer
+    value: PrimExpr
+    indices: List[PrimExpr]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        producer: DataProducer,
+        value: PrimExpr,
+        indices: List[PrimExpr],
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.ProducerStore, producer, value, indices, span  # type: 
ignore
         )
@@ -293,10 +347,27 @@ class Allocate(Stmt):
         Additional annotation hints
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, buffer_var, dtype, extents, condition, body, 
annotations=None, span=None):
+    buffer_var: Var
+    dtype: str
+    extents: List[PrimExpr]
+    condition: PrimExpr
+    body: Stmt
+    annotations: Mapping[str, Object]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        buffer_var: Var,
+        dtype: str,
+        extents: List[PrimExpr],
+        condition: PrimExpr,
+        body: Stmt,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> None:
         if annotations is None:
             annotations = dict()
         self.__init_handle_by_constructor__(
@@ -335,16 +406,41 @@ class AllocateConst(Stmt):
     body : Stmt
         The body statement.
 
-    annotations : Optional[Map]
+    annotations : Optional[Mapping[str, Object]]
         Additional annotations about the allocation.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, buffer_var, dtype, extents, data_or_idx, body, 
annotations=None, span=None):
+    buffer_var: Var
+    dtype: str
+    extents: List[PrimExpr]
+    data: Optional[NDArray]
+    irmod_storage_idx: Optional[int]
+    body: Stmt
+    annotations: Mapping[str, Object]
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        buffer_var: Var,
+        dtype: str,
+        extents: List[PrimExpr],
+        data_or_idx: Union[NDArray, int],
+        body: Stmt,
+        annotations: Optional[Mapping[str, Object]] = None,
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
-            _ffi_api.AllocateConst, buffer_var, dtype, extents, data_or_idx, 
body, annotations, span
+            _ffi_api.AllocateConst,  # type: ignore
+            buffer_var,
+            dtype,
+            extents,
+            data_or_idx,
+            body,
+            annotations,
+            span,
         )
 
 
@@ -364,7 +460,11 @@ class DeclBuffer(Stmt):
         The location of this DeclBuffer in the source code.
     """
 
-    def __init__(self, buffer, body, span=None):
+    buffer: Buffer
+    body: Stmt
+    span: Optional[Span]
+
+    def __init__(self, buffer: Buffer, body: Stmt, span: Optional[Span] = 
None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body, 
span)
 
 
@@ -374,7 +474,7 @@ class AttrStmt(Stmt):
 
     Parameters
     ----------
-    node : Node
+    node : Object
         The node to annotate the attribute
 
     attr_key : str
@@ -387,10 +487,18 @@ class AttrStmt(Stmt):
         The body statement.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, node, attr_key, value, body, span=None):
+    node: Object
+    attr_key: str
+    value: PrimExpr
+    body: Stmt
+    span: Optional[Span]
+
+    def __init__(
+        self, node: Object, attr_key: str, value: PrimExpr, body: Stmt, span: 
Optional[Span] = None
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.AttrStmt, node, attr_key, value, body, span  # type: 
ignore
         )
@@ -405,7 +513,7 @@ class ProducerRealize(Stmt):
     producer : DataProducer
         The data producer.
 
-    bounds : list of range
+    bounds : List[Range]
         The bound of realize
 
     condition : PrimExpr
@@ -418,18 +526,33 @@ class ProducerRealize(Stmt):
         The storage scope associated with this realization
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, producer, bounds, condition, body, storage_scope="", 
span=None):
+    producer: DataProducer
+    bounds: List[Range]
+    condition: PrimExpr
+    body: Stmt
+    storage_scope: str
+    span: Optional[Span]
+
+    def __init__(
+        self,
+        producer: DataProducer,
+        bounds: List[Range],
+        condition: PrimExpr,
+        body: Stmt,
+        storage_scope: str = "",
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
-            _ffi_api.ProducerRealize,
+            _ffi_api.ProducerRealize,  # type: ignore
             producer,
             bounds,
             condition,
             body,
             storage_scope,
-            span,  # type: ignore
+            span,
         )
 
 
@@ -443,13 +566,16 @@ class SeqStmt(Stmt):
         The statements
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, seq, span=None):
+    seq: List[Stmt]
+    span: Optional[Span]
+
+    def __init__(self, seq: List[Stmt], span: Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span)  # 
type: ignore
 
-    def __getitem__(self, i):
+    def __getitem__(self, i: int):
         return self.seq[i]
 
     def __len__(self):
@@ -468,14 +594,24 @@ class IfThenElse(Stmt):
     then_case : Stmt
         The statement to execute if condition is true.
 
-    else_case : Stmt
+    else_case : Optional[Stmt]
         The statement to execute if condition is false.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, condition, then_case, else_case, span=None):
+    condition: PrimExpr
+    then_case: Stmt
+    else_case: Optional[Stmt]
+
+    def __init__(
+        self,
+        condition: PrimExpr,
+        then_case: Stmt,
+        else_case: Optional[Stmt],
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.IfThenElse, condition, then_case, else_case, span  # 
type: ignore
         )
@@ -488,13 +624,16 @@ class Evaluate(Stmt):
     Parameters
     ----------
     value : PrimExpr
-        The expression to be evalued.
+        The expression to be evaluated.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, value, span=None):
+    value: PrimExpr
+    span: Optional[Span]
+
+    def __init__(self, value: PrimExpr, span: Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span)  # 
type: ignore
 
 
@@ -507,14 +646,18 @@ class Prefetch(Stmt):
     buffer : Buffer
         The buffer to be prefetched.
 
-    bounds : list of Range
+    bounds : List[Range]
         The bounds to be prefetched.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of the stmt in the source code.
     """
 
-    def __init__(self, buffer, bounds, span=None):
+    buffer: Buffer
+    bounds: List[Range]
+    span: Optional[Span]
+
+    def __init__(self, buffer: Buffer, bounds: List[Range], span: 
Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, 
span)  # type: ignore
 
 
@@ -534,7 +677,7 @@ class BufferRegion(Object, Scriptable):
     buffer: Buffer
     region: List[Range]
 
-    def __init__(self, buffer: Buffer, region: List[Range]):
+    def __init__(self, buffer: Buffer, region: List[Range]) -> None:
         self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, 
region)  # type: ignore
 
 
@@ -554,7 +697,7 @@ class MatchBufferRegion(Object, Scriptable):
     buffer: Buffer
     source: BufferRegion
 
-    def __init__(self, buffer: Buffer, source: BufferRegion):
+    def __init__(self, buffer: Buffer, source: BufferRegion) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.MatchBufferRegion, buffer, source  # type: ignore
         )
@@ -603,9 +746,9 @@ class Block(Stmt):
     name_hint: str
     body: Stmt
     init: Optional[Stmt]
-    alloc_buffers: Optional[List[Buffer]]
-    match_buffers: Optional[List[MatchBufferRegion]]
-    annotations: Optional[Mapping[str, Object]]
+    alloc_buffers: List[Buffer]
+    match_buffers: List[MatchBufferRegion]
+    annotations: Mapping[str, Object]
     span: Optional[Span]
 
     def __init__(
@@ -620,7 +763,7 @@ class Block(Stmt):
         match_buffers: Optional[List[MatchBufferRegion]] = None,
         annotations: Optional[Mapping[str, Object]] = None,
         span: Optional[Span] = None,
-    ):
+    ) -> None:
         if alloc_buffers is None:
             alloc_buffers = []
         if match_buffers is None:
@@ -672,7 +815,7 @@ class BlockRealize(Stmt):
         predicate: Union[PrimExpr, bool],
         block: Block,
         span: Optional[Span] = None,
-    ):
+    ) -> None:
         if isinstance(predicate, bool):
             predicate = const(predicate, "bool")
         self.__init_handle_by_constructor__(
@@ -684,12 +827,12 @@ class BlockRealize(Stmt):
         )  # type: ignore
 
 
-def stmt_seq(*args):
+def stmt_seq(*args: Union[PrimExpr, Stmt]) -> SeqStmt:
     """Make sequence of statements
 
     Parameters
     ----------
-    args : list of Expr or Var
+    *args : Union[PrimExpr, Stmt]
         List of statements to be combined as sequence.
 
     Returns
@@ -707,17 +850,18 @@ def stmt_seq(*args):
     return SeqStmt(ret)
 
 
-def stmt_list(stmt):
+def stmt_list(stmt: Stmt) -> List[Stmt]:
     """Make list of stmt from blocks.
 
     Parameters
     ----------
-    stmt : A block statement
+    stmt : Stmt
+        The input statement.
 
     Returns
     -------
-    stmt_list : list of Stmt
-         The unpacked list of statements
+    stmt_list : List[Stmt]
+        The unpacked list of statements
     """
     if isinstance(stmt, SeqStmt):
         res = []
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py 
b/tests/python/unittest/test_tir_schedule_analysis.py
index c4fc49da9f..cc87818db4 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -60,7 +60,7 @@ def _make_loops(loop_vars: List[Var], extents: List[int]) -> 
List[For]:
     return [
         For(
             loop_var=loop_var,
-            min_val=0,
+            min=0,
             extent=extent,
             kind=ForKind.SERIAL,
             body=Evaluate(0),
diff --git a/tests/python/unittest/test_tir_schedule_state.py 
b/tests/python/unittest/test_tir_schedule_state.py
index db6909a048..74880e5a42 100644
--- a/tests/python/unittest/test_tir_schedule_state.py
+++ b/tests/python/unittest/test_tir_schedule_state.py
@@ -323,7 +323,7 @@ def test_replace_block_in_opaque_block():
     sref = s.get_sref(for_loop)
     new_for_loop = tir.For(
         loop_var=for_loop.loop_var,
-        min_val=0,
+        min=0,
         extent=128,
         kind=tir.ForKind.SERIAL,
         body=tir.Evaluate(0),
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py 
b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index e13b609d86..5362dae303 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -226,7 +226,7 @@ def test_ir_builder_tir_for():
     # the expected for
     thread_binding_expected = tir.For(
         loop_var=tir.Var("", "int32"),
-        min_val=0,
+        min=0,
         extent=8,
         kind=tir.ForKind.THREAD_BINDING,
         body=tir.Evaluate(0),
@@ -236,28 +236,28 @@ def test_ir_builder_tir_for():
     )
     unroll_expected = tir.For(
         loop_var=tir.Var("", "int32"),
-        min_val=0,
+        min=0,
         extent=16,
         kind=tir.ForKind.UNROLLED,
         body=thread_binding_expected,
     )
     vectorized_expected = tir.For(
         loop_var=tir.Var("", "int32"),
-        min_val=0,
+        min=0,
         extent=32,
         kind=tir.ForKind.VECTORIZED,
         body=unroll_expected,
     )
     parallel_expected = tir.For(
         loop_var=tir.Var("", "int32"),
-        min_val=0,
+        min=0,
         extent=64,
         kind=tir.ForKind.PARALLEL,
         body=vectorized_expected,
     )
     for_expected = tir.For(
         loop_var=tir.Var("", "int32"),
-        min_val=0,
+        min=0,
         extent=128,
         kind=tir.ForKind.SERIAL,
         body=parallel_expected,
@@ -277,7 +277,7 @@ def test_ir_builder_tir_for_uint():
 
     for_expected = tir.For(
         loop_var=tir.Var("", "uint32"),
-        min_val=tir.const(0, "uint32"),
+        min=tir.const(0, "uint32"),
         extent=tir.const(128, "uint32"),
         kind=tir.ForKind.SERIAL,
         body=tir.Evaluate(0),

Reply via email to