This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a302e0fd63 [TIR] Enhance Python Type Annotations for TIR stmt (#16076)
a302e0fd63 is described below
commit a302e0fd63c7f11bd7ce0fe709b7cc526548ae1c
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Nov 7 02:15:42 2023 +0800
[TIR] Enhance Python Type Annotations for TIR stmt (#16076)
This PR enhances the Python annotations for the TIR stmt,
adding class member variables annotations.
---
include/tvm/tir/stmt.h | 8 +-
python/tvm/tir/stmt.py | 298 +++++++++++++++------
.../python/unittest/test_tir_schedule_analysis.py | 2 +-
tests/python/unittest/test_tir_schedule_state.py | 2 +-
.../unittest/test_tvmscript_ir_builder_tir.py | 12 +-
5 files changed, 233 insertions(+), 89 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index bddf87101f..07cc9b5ad0 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -68,7 +68,7 @@ class LetStmtNode : public StmtNode {
public:
/*! \brief The variable. */
Var var;
- /*! \brief The value to be binded. */
+ /*! \brief The value to be bound. */
PrimExpr value;
/*! \brief The body block. */
Stmt body;
@@ -876,7 +876,7 @@ class SeqStmt : public Stmt {
};
/*!
- * \brief IfThenElse statment.
+ * \brief IfThenElse statement.
*/
class IfThenElseNode : public StmtNode {
public:
@@ -951,7 +951,7 @@ enum class ForKind : int {
};
/*!
- * \brief A for loop, with poissible type annotations.
+ * \brief A for loop, with possible type annotations.
*
* \code
*
@@ -1388,7 +1388,7 @@ namespace attr {
constexpr const char* thread_extent = "thread_extent";
/*! \brief Mark launching of a virtual thread. */
constexpr const char* virtual_thread = "virtual_thread";
-/*! \brief Mark region is processed by a co-proccesor */
+/*! \brief Mark region is processed by a co-processor */
constexpr const char* coproc_scope = "coproc_scope";
/*!
* \brief Mark region creates coprocessor micro ops,
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 484a6f04b6..992c388e27 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name
"""Statement AST Node in TVM.
Each statement node have subfields that can be visited from python side.
@@ -32,11 +31,11 @@ from typing import List, Mapping, Optional, Union
import tvm._ffi
from tvm.ir import PrimExpr, Range, Span
-from tvm.runtime import Object, Scriptable, const
+from tvm.runtime import Object, Scriptable, const, NDArray
from . import _ffi_api
-from .buffer import Buffer
-from .expr import IterVar
+from .buffer import Buffer, DataProducer
+from .expr import Var, IterVar
class Stmt(Object, Scriptable):
@@ -53,16 +52,21 @@ class LetStmt(Stmt):
The variable in the binding.
value : PrimExpr
- The value in to be binded.
+ The value in to be bound.
body : Stmt
The body statement.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, var, value, body, span=None):
+ var: Var
+ value: PrimExpr
+ body: Stmt
+ span: Optional[Span]
+
+ def __init__(self, var: Var, value: PrimExpr, body: Stmt, span:
Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(
_ffi_api.LetStmt, var, value, body, span # type: ignore
)
@@ -84,10 +88,17 @@ class AssertStmt(Stmt):
The body statement.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, condition, message, body, span=None):
+ condition: PrimExpr
+ message: PrimExpr
+ body: Stmt
+ span: Optional[Span]
+
+ def __init__(
+ self, condition: PrimExpr, message: PrimExpr, body: Stmt, span:
Optional[Span] = None
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.AssertStmt, condition, message, body, span # type: ignore
)
@@ -106,7 +117,7 @@ class ForKind(IntEnum):
PARALLEL = 1
VECTORIZED = 2
UNROLLED = 3
- THREAD_BINDING = 4
+ THREAD_BINDING = 4 # pylint: disable=invalid-name
@tvm._ffi.register_object("tir.For")
@@ -118,7 +129,7 @@ class For(Stmt):
loop_var : Var
The loop variable.
- min_val : PrimExpr
+ min : PrimExpr
The beginning value.
extent : PrimExpr
@@ -134,28 +145,37 @@ class For(Stmt):
The thread this loop binds to. Only valid
if kind is ThreadBinding
- annotations: tvm.ir.Map
+ annotations: Optional[Mapping[str, Object]]
Additional annotation hints.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
+ loop_var: Var
+ min: PrimExpr
+ extent: PrimExpr
+ kind: ForKind
+ body: Stmt
+ thread_binding: Optional[IterVar]
+ annotations: Mapping[str, Object]
+ span: Optional[Span]
+
def __init__(
self,
- loop_var,
- min_val,
- extent,
- kind,
- body,
- thread_binding=None,
- annotations=None,
- span=None,
- ):
+ loop_var: Var,
+ min: PrimExpr, # pylint: disable=redefined-builtin
+ extent: PrimExpr,
+ kind: ForKind,
+ body: Stmt,
+ thread_binding: Optional[IterVar] = None,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.For, # type: ignore
loop_var,
- min_val,
+ min,
extent,
kind,
body,
@@ -178,16 +198,15 @@ class While(Stmt):
The body statement.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, condition, body, span=None):
- self.__init_handle_by_constructor__(
- _ffi_api.While, # type: ignore
- condition,
- body,
- span,
- )
+ condition: PrimExpr
+ body: Stmt
+ span: Optional[Span]
+
+ def __init__(self, condition: PrimExpr, body: Stmt, span: Optional[Span] =
None) -> None:
+ self.__init_handle_by_constructor__(_ffi_api.While, condition, body,
span) # type: ignore
@tvm._ffi.register_object("tir.BufferStore")
@@ -206,10 +225,21 @@ class BufferStore(Stmt):
The indices location to be stored.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, buffer, value, indices, span=None):
+ buffer: Buffer
+ value: PrimExpr
+ indices: List[PrimExpr]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ buffer: Buffer,
+ value: PrimExpr,
+ indices: List[PrimExpr],
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferStore, buffer, value, indices, span # type: ignore
)
@@ -234,10 +264,23 @@ class BufferRealize(Stmt):
The body of the statement.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, buffer, bounds, condition, body, span=None):
+ buffer: Buffer
+ bounds: List[Range]
+ condition: PrimExpr
+ body: Stmt
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ buffer: Buffer,
+ bounds: List[Range],
+ condition: PrimExpr,
+ body: Stmt,
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferRealize, buffer, bounds, condition, body, span #
type: ignore
)
@@ -259,10 +302,21 @@ class ProducerStore(Stmt):
The index arguments of the store.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, producer, value, indices, span=None):
+ producer: DataProducer
+ value: PrimExpr
+ indices: List[PrimExpr]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ producer: DataProducer,
+ value: PrimExpr,
+ indices: List[PrimExpr],
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ProducerStore, producer, value, indices, span # type:
ignore
)
@@ -293,10 +347,27 @@ class Allocate(Stmt):
Additional annotation hints
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, buffer_var, dtype, extents, condition, body,
annotations=None, span=None):
+ buffer_var: Var
+ dtype: str
+ extents: List[PrimExpr]
+ condition: PrimExpr
+ body: Stmt
+ annotations: Mapping[str, Object]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ buffer_var: Var,
+ dtype: str,
+ extents: List[PrimExpr],
+ condition: PrimExpr,
+ body: Stmt,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> None:
if annotations is None:
annotations = dict()
self.__init_handle_by_constructor__(
@@ -335,16 +406,41 @@ class AllocateConst(Stmt):
body : Stmt
The body statement.
- annotations : Optional[Map]
+ annotations : Optional[Mapping[str, Object]]
Additional annotations about the allocation.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, buffer_var, dtype, extents, data_or_idx, body,
annotations=None, span=None):
+ buffer_var: Var
+ dtype: str
+ extents: List[PrimExpr]
+ data: Optional[NDArray]
+ irmod_storage_idx: Optional[int]
+ body: Stmt
+ annotations: Mapping[str, Object]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ buffer_var: Var,
+ dtype: str,
+ extents: List[PrimExpr],
+ data_or_idx: Union[NDArray, int],
+ body: Stmt,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
- _ffi_api.AllocateConst, buffer_var, dtype, extents, data_or_idx,
body, annotations, span
+ _ffi_api.AllocateConst, # type: ignore
+ buffer_var,
+ dtype,
+ extents,
+ data_or_idx,
+ body,
+ annotations,
+ span,
)
@@ -364,7 +460,11 @@ class DeclBuffer(Stmt):
The location of this DeclBuffer in the source code.
"""
- def __init__(self, buffer, body, span=None):
+ buffer: Buffer
+ body: Stmt
+ span: Optional[Span]
+
+ def __init__(self, buffer: Buffer, body: Stmt, span: Optional[Span] =
None) -> None:
self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, body,
span)
@@ -374,7 +474,7 @@ class AttrStmt(Stmt):
Parameters
----------
- node : Node
+ node : Object
The node to annotate the attribute
attr_key : str
@@ -387,10 +487,18 @@ class AttrStmt(Stmt):
The body statement.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, node, attr_key, value, body, span=None):
+ node: Object
+ attr_key: str
+ value: PrimExpr
+ body: Stmt
+ span: Optional[Span]
+
+ def __init__(
+ self, node: Object, attr_key: str, value: PrimExpr, body: Stmt, span:
Optional[Span] = None
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.AttrStmt, node, attr_key, value, body, span # type:
ignore
)
@@ -405,7 +513,7 @@ class ProducerRealize(Stmt):
producer : DataProducer
The data producer.
- bounds : list of range
+ bounds : List[Range]
The bound of realize
condition : PrimExpr
@@ -418,18 +526,33 @@ class ProducerRealize(Stmt):
The storage scope associated with this realization
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, producer, bounds, condition, body, storage_scope="",
span=None):
+ producer: DataProducer
+ bounds: List[Range]
+ condition: PrimExpr
+ body: Stmt
+ storage_scope: str
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ producer: DataProducer,
+ bounds: List[Range],
+ condition: PrimExpr,
+ body: Stmt,
+ storage_scope: str = "",
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
- _ffi_api.ProducerRealize,
+ _ffi_api.ProducerRealize, # type: ignore
producer,
bounds,
condition,
body,
storage_scope,
- span, # type: ignore
+ span,
)
@@ -443,13 +566,16 @@ class SeqStmt(Stmt):
The statements
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, seq, span=None):
+ seq: List[Stmt]
+ span: Optional[Span]
+
+ def __init__(self, seq: List[Stmt], span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.SeqStmt, seq, span) #
type: ignore
- def __getitem__(self, i):
+ def __getitem__(self, i: int):
return self.seq[i]
def __len__(self):
@@ -468,14 +594,24 @@ class IfThenElse(Stmt):
then_case : Stmt
The statement to execute if condition is true.
- else_case : Stmt
+ else_case : Optional[Stmt]
The statement to execute if condition is false.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, condition, then_case, else_case, span=None):
+ condition: PrimExpr
+ then_case: Stmt
+ else_case: Optional[Stmt]
+
+ def __init__(
+ self,
+ condition: PrimExpr,
+ then_case: Stmt,
+ else_case: Optional[Stmt],
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.IfThenElse, condition, then_case, else_case, span #
type: ignore
)
@@ -488,13 +624,16 @@ class Evaluate(Stmt):
Parameters
----------
value : PrimExpr
- The expression to be evalued.
+ The expression to be evaluated.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, value, span=None):
+ value: PrimExpr
+ span: Optional[Span]
+
+ def __init__(self, value: PrimExpr, span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) #
type: ignore
@@ -507,14 +646,18 @@ class Prefetch(Stmt):
buffer : Buffer
The buffer to be prefetched.
- bounds : list of Range
+ bounds : List[Range]
The bounds to be prefetched.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of the stmt in the source code.
"""
- def __init__(self, buffer, bounds, span=None):
+ buffer: Buffer
+ bounds: List[Range]
+ span: Optional[Span]
+
+ def __init__(self, buffer: Buffer, bounds: List[Range], span:
Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds,
span) # type: ignore
@@ -534,7 +677,7 @@ class BufferRegion(Object, Scriptable):
buffer: Buffer
region: List[Range]
- def __init__(self, buffer: Buffer, region: List[Range]):
+ def __init__(self, buffer: Buffer, region: List[Range]) -> None:
self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer,
region) # type: ignore
@@ -554,7 +697,7 @@ class MatchBufferRegion(Object, Scriptable):
buffer: Buffer
source: BufferRegion
- def __init__(self, buffer: Buffer, source: BufferRegion):
+ def __init__(self, buffer: Buffer, source: BufferRegion) -> None:
self.__init_handle_by_constructor__(
_ffi_api.MatchBufferRegion, buffer, source # type: ignore
)
@@ -603,9 +746,9 @@ class Block(Stmt):
name_hint: str
body: Stmt
init: Optional[Stmt]
- alloc_buffers: Optional[List[Buffer]]
- match_buffers: Optional[List[MatchBufferRegion]]
- annotations: Optional[Mapping[str, Object]]
+ alloc_buffers: List[Buffer]
+ match_buffers: List[MatchBufferRegion]
+ annotations: Mapping[str, Object]
span: Optional[Span]
def __init__(
@@ -620,7 +763,7 @@ class Block(Stmt):
match_buffers: Optional[List[MatchBufferRegion]] = None,
annotations: Optional[Mapping[str, Object]] = None,
span: Optional[Span] = None,
- ):
+ ) -> None:
if alloc_buffers is None:
alloc_buffers = []
if match_buffers is None:
@@ -672,7 +815,7 @@ class BlockRealize(Stmt):
predicate: Union[PrimExpr, bool],
block: Block,
span: Optional[Span] = None,
- ):
+ ) -> None:
if isinstance(predicate, bool):
predicate = const(predicate, "bool")
self.__init_handle_by_constructor__(
@@ -684,12 +827,12 @@ class BlockRealize(Stmt):
) # type: ignore
-def stmt_seq(*args):
+def stmt_seq(*args: Union[PrimExpr, Stmt]) -> SeqStmt:
"""Make sequence of statements
Parameters
----------
- args : list of Expr or Var
+ *args : Union[PrimExpr, Stmt]
List of statements to be combined as sequence.
Returns
@@ -707,17 +850,18 @@ def stmt_seq(*args):
return SeqStmt(ret)
-def stmt_list(stmt):
+def stmt_list(stmt: Stmt) -> List[Stmt]:
"""Make list of stmt from blocks.
Parameters
----------
- stmt : A block statement
+ stmt : Stmt
+ The input statement.
Returns
-------
- stmt_list : list of Stmt
- The unpacked list of statements
+ stmt_list : List[Stmt]
+ The unpacked list of statements
"""
if isinstance(stmt, SeqStmt):
res = []
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py
b/tests/python/unittest/test_tir_schedule_analysis.py
index c4fc49da9f..cc87818db4 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -60,7 +60,7 @@ def _make_loops(loop_vars: List[Var], extents: List[int]) ->
List[For]:
return [
For(
loop_var=loop_var,
- min_val=0,
+ min=0,
extent=extent,
kind=ForKind.SERIAL,
body=Evaluate(0),
diff --git a/tests/python/unittest/test_tir_schedule_state.py
b/tests/python/unittest/test_tir_schedule_state.py
index db6909a048..74880e5a42 100644
--- a/tests/python/unittest/test_tir_schedule_state.py
+++ b/tests/python/unittest/test_tir_schedule_state.py
@@ -323,7 +323,7 @@ def test_replace_block_in_opaque_block():
sref = s.get_sref(for_loop)
new_for_loop = tir.For(
loop_var=for_loop.loop_var,
- min_val=0,
+ min=0,
extent=128,
kind=tir.ForKind.SERIAL,
body=tir.Evaluate(0),
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index e13b609d86..5362dae303 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -226,7 +226,7 @@ def test_ir_builder_tir_for():
# the expected for
thread_binding_expected = tir.For(
loop_var=tir.Var("", "int32"),
- min_val=0,
+ min=0,
extent=8,
kind=tir.ForKind.THREAD_BINDING,
body=tir.Evaluate(0),
@@ -236,28 +236,28 @@ def test_ir_builder_tir_for():
)
unroll_expected = tir.For(
loop_var=tir.Var("", "int32"),
- min_val=0,
+ min=0,
extent=16,
kind=tir.ForKind.UNROLLED,
body=thread_binding_expected,
)
vectorized_expected = tir.For(
loop_var=tir.Var("", "int32"),
- min_val=0,
+ min=0,
extent=32,
kind=tir.ForKind.VECTORIZED,
body=unroll_expected,
)
parallel_expected = tir.For(
loop_var=tir.Var("", "int32"),
- min_val=0,
+ min=0,
extent=64,
kind=tir.ForKind.PARALLEL,
body=vectorized_expected,
)
for_expected = tir.For(
loop_var=tir.Var("", "int32"),
- min_val=0,
+ min=0,
extent=128,
kind=tir.ForKind.SERIAL,
body=parallel_expected,
@@ -277,7 +277,7 @@ def test_ir_builder_tir_for_uint():
for_expected = tir.For(
loop_var=tir.Var("", "uint32"),
- min_val=tir.const(0, "uint32"),
+ min=tir.const(0, "uint32"),
extent=tir.const(128, "uint32"),
kind=tir.ForKind.SERIAL,
body=tir.Evaluate(0),