This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f3473e2d5 [TIR] Enhance Python Type Annotations for TIR Expr (#16083)
3f3473e2d5 is described below

commit 3f3473e2d57ba5933fe0a24d39c2d6e67f2b45c0
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Nov 8 10:54:59 2023 +0800

    [TIR] Enhance Python Type Annotations for TIR Expr (#16083)
    
    This PR enhances the Python annotations for the TIR expr,
    adding class member variables annotations.
---
 include/tvm/ir/expr.h  |   4 +-
 include/tvm/tir/var.h  |   2 +-
 python/tvm/ir/expr.py  |  40 ++++--
 python/tvm/tir/expr.py | 379 +++++++++++++++++++++++++++++++------------------
 4 files changed, 274 insertions(+), 151 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index c8531c8846..b8857e598d 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -711,7 +711,7 @@ class RangeNode : public Object {
   TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
 };
 
-/*! \brief Range constainer  */
+/*! \brief Range container  */
 class Range : public ObjectRef {
  public:
   /*!
@@ -736,7 +736,7 @@ class Range : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
 };
 
-// implementataions
+// implementations
 inline const Type& RelayExprNode::checked_type() const {
   ICHECK(checked_type_.defined()) << "internal error: the type checker has "
                                   << "not populated the checked_type "
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 9cd2bed657..6c2c6dd5fc 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -270,7 +270,7 @@ class IterVarNode : public Object {
   IterVarType iter_type;
   /*!
    * \brief additional tag on the iteration variable,
-   *  set this if this is binded already to a known thread tag.
+   *  set this if this is bound already to a known thread tag.
    */
   String thread_tag;
   /*!
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index f0f2245e7f..36c425cb85 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -16,17 +16,21 @@
 # under the License.
 """Common expressions data structures in the IR."""
 from numbers import Number
+from typing import Callable, Optional
 
 import tvm._ffi
 
-from ..runtime import Scriptable, const, convert
+from ..runtime import Object, Scriptable, const, convert
 from . import _ffi_api
-from .base import Node
+from .base import Node, Span
+from .type import Type
 
 
 class BaseExpr(Node):
     """Base class of all the expressions."""
 
+    span: Optional[Span]
+
 
 class PrimExpr(BaseExpr):
     """Base class of all primitive expressions.
@@ -35,6 +39,8 @@ class PrimExpr(BaseExpr):
     optimizations and integer analysis.
     """
 
+    dtype: str
+
 
 class RelayExpr(BaseExpr):
     """Base class of all non-primitive expressions."""
@@ -67,10 +73,12 @@ class GlobalVar(RelayExpr):
         The name of the variable.
     """
 
-    def __init__(self, name_hint, type_annot=None):
+    name_hint: str
+
+    def __init__(self, name_hint: str, type_annot: Optional[Type] = None):
         self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, 
type_annot)
 
-    def __call__(self, *args):
+    def __call__(self, *args: RelayExpr) -> BaseExpr:
         """Call the global variable.
 
         Parameters
@@ -94,7 +102,9 @@ class GlobalVar(RelayExpr):
         arg_types = [type(x) for x in args]
         raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for 
types {arg_types}")
 
-    def astext(self, show_meta_data=True, annotate=None):
+    def astext(
+        self, show_meta_data: bool = True, annotate: 
Optional[Callable[[Object], str]] = None
+    ) -> str:
         """Get the text format of the expression.
 
         Parameters
@@ -140,7 +150,7 @@ class Range(Node, Scriptable):
         The end value of the range.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this node in the source code.
 
     Note
     ----
@@ -148,14 +158,22 @@ class Range(Node, Scriptable):
     if the end argument is not None. Otherwise, it creates `[0, begin)`.
     """
 
-    def __init__(self, begin, end=None, span=None):
+    min: PrimExpr
+    extent: PrimExpr
+    span: Optional[Span]
+
+    def __init__(
+        self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: 
Optional[Span] = None
+    ) -> None:
         if end is None:
             end = convert(begin)
             begin = const(0, dtype=end.dtype, span=span)
         self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span)
 
     @staticmethod
-    def from_min_extent(min_value, extent, span=None):
+    def from_min_extent(
+        min_value: PrimExpr, extent: PrimExpr, span: Optional[Span] = None
+    ) -> "Range":
         """Construct a Range by min and extent.
 
         This constructs a range in [min_value, min_value + extent)
@@ -169,7 +187,7 @@ class Range(Node, Scriptable):
             The extent of the range.
 
         span : Optional[Span]
-            The location of this itervar in the source code.
+            The location of this node in the source code.
 
         Returns
         -------
@@ -178,8 +196,8 @@ class Range(Node, Scriptable):
         """
         return _ffi_api.Range_from_min_extent(min_value, extent, span)
 
-    def __eq__(self, other):
+    def __eq__(self, other: Object) -> bool:
         return tvm.ir.structural_equal(self, other)
 
-    def __ne__(self, other):
+    def __ne__(self, other: Object) -> bool:
         return not self.__eq__(other)
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index f93e39ee0f..fad9fca083 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -27,7 +27,7 @@ For example, you can use addexp.a to get the left operand of 
an Add node.
   assert(isinstance(y, tvm.tir.Add))
   assert(y.a == x)
 """
-from typing import Optional, Union
+from typing import List, Optional, Union
 
 import tvm._ffi
 import tvm.ir._ffi_api
@@ -38,9 +38,10 @@ from tvm.runtime import DataType, DataTypeCode, Object, 
ObjectGeneric, Scriptabl
 
 from . import _ffi_api
 from . import generic as _generic
+from .buffer import Buffer, DataProducer
 
 
-def div_ambiguity_error():
+def div_ambiguity_error() -> RuntimeError:
     return RuntimeError(
         "TVM supports multiple types of integer divisions, "
         + "please call div, indexdiv/indexmod, floordiv/floormod "
@@ -69,111 +70,111 @@ class ExprOp(object):
 
     # TODO(tkonolige): use inspect to add source information to these objects
 
-    def __add__(self, other):
+    def __add__(self, other: PrimExpr) -> PrimExpr:
         return _generic.add(self, other)
 
-    def __radd__(self, other):
+    def __radd__(self, other: PrimExpr) -> PrimExpr:
         return _generic.add(other, self)
 
-    def __sub__(self, other):
+    def __sub__(self, other: PrimExpr) -> PrimExpr:
         return _generic.subtract(self, other)
 
-    def __rsub__(self, other):
+    def __rsub__(self, other: PrimExpr) -> PrimExpr:
         return _generic.subtract(other, self)
 
-    def __mul__(self, other):
+    def __mul__(self, other: PrimExpr) -> PrimExpr:
         return _generic.multiply(self, other)
 
-    def __rmul__(self, other):
+    def __rmul__(self, other: PrimExpr) -> PrimExpr:
         return _generic.multiply(other, self)
 
-    def __div__(self, other):
+    def __div__(self, other: PrimExpr) -> PrimExpr:
         if _dtype_is_int(self) and _dtype_is_int(other):
             raise div_ambiguity_error()
         return _generic.divide(self, other)
 
-    def __rdiv__(self, other):
+    def __rdiv__(self, other: PrimExpr) -> PrimExpr:
         if _dtype_is_int(self) and _dtype_is_int(other):
             raise div_ambiguity_error()
         return _generic.divide(other, self)
 
-    def __truediv__(self, other):
+    def __truediv__(self, other: PrimExpr) -> PrimExpr:
         if _dtype_is_int(self) and _dtype_is_int(other):
             raise div_ambiguity_error()
         return _generic.divide(self, other)
 
-    def __rtruediv__(self, other):
+    def __rtruediv__(self, other: PrimExpr) -> PrimExpr:
         if _dtype_is_int(self) and _dtype_is_int(other):
             raise div_ambiguity_error()
         return _generic.divide(other, self)
 
-    def __floordiv__(self, other):
+    def __floordiv__(self, other: PrimExpr) -> PrimExpr:
         return _generic.floordiv(self, other)
 
-    def __rfloordiv__(self, other):
+    def __rfloordiv__(self, other: PrimExpr) -> PrimExpr:
         return _generic.floordiv(other, self, None)
 
-    def __mod__(self, other):
+    def __mod__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpFloorMod(self, other, None)  # type: ignore
 
-    def __rmod__(self, other):
+    def __rmod__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpFloorMod(other, self, None)  # type: ignore
 
-    def __neg__(self):
+    def __neg__(self) -> PrimExpr:
         neg_one = const(-1, self.dtype)  # type: ignore
         return self.__mul__(neg_one)
 
-    def __lshift__(self, other):
+    def __lshift__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.left_shift(self, other, None)  # type: ignore
 
-    def __rlshift__(self, other):
+    def __rlshift__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.left_shift(other, self, None)  # type: ignore
 
-    def __rshift__(self, other):
+    def __rshift__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.right_shift(self, other, None)  # type: ignore
 
-    def __rrshift__(self, other):
+    def __rrshift__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.right_shift(other, self, None)  # type: ignore
 
-    def __and__(self, other):
+    def __and__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.bitwise_and(self, other, None)  # type: ignore
 
-    def __rand__(self, other):
+    def __rand__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.bitwise_and(other, self, None)  # type: ignore
 
-    def __or__(self, other):
+    def __or__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.bitwise_or(self, other, None)  # type: ignore
 
-    def __ror__(self, other):
+    def __ror__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.bitwise_or(other, self, None)  # type: ignore
 
-    def __xor__(self, other):
+    def __xor__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.bitwise_xor(self, other, None)  # type: ignore
 
-    def __rxor__(self, other):
+    def __rxor__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api.bitwise_xor(other, self, None)  # type: ignore
 
-    def __invert__(self):
+    def __invert__(self) -> PrimExpr:
         if _dtype_is_float(self):
             raise RuntimeError("Cannot use ~ operator on float type Expr.")
         return _ffi_api.bitwise_not(self, None)  # type: ignore
 
-    def __lt__(self, other):
+    def __lt__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpLT(self, other, None)  # type: ignore
 
-    def __le__(self, other):
+    def __le__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpLE(self, other, None)  # type: ignore
 
-    def __eq__(self, other):
+    def __eq__(self, other: PrimExpr) -> PrimExpr:
         return EqualOp(self, other)
 
-    def __ne__(self, other):
+    def __ne__(self, other: PrimExpr) -> PrimExpr:
         return NotEqualOp(self, other)
 
-    def __gt__(self, other):
+    def __gt__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpGT(self, other, None)  # type: ignore
 
-    def __ge__(self, other):
+    def __ge__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpGE(self, other, None)  # type: ignore
 
     def __nonzero__(self):
@@ -182,10 +183,10 @@ class ExprOp(object):
             + "use tvm.tir.all / tvm.tir.any instead"
         )
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return self.__nonzero__()
 
-    def equal(self, other, span=None):
+    def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool:
         """Build an equal check expression with other expr.
 
         Parameters
@@ -203,7 +204,7 @@ class ExprOp(object):
         """
         return _ffi_api._OpEQ(self, other, span)  # type: ignore
 
-    def astype(self, dtype: str, span: Optional[Span] = None):
+    def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
         """Cast the expression to other type.
 
         Parameters
@@ -243,18 +244,18 @@ class EqualOp(ObjectGeneric, ExprOp):
     # This class is not manipulated by C++. So use python's identity check 
function is sufficient
     same_as = object.__eq__
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None):
         self.a = a
         self.b = b
         self.span = span
 
-    def __nonzero__(self):
+    def __nonzero__(self) -> bool:
         return self.a.same_as(self.b)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return self.__nonzero__()
 
-    def asobject(self):
+    def asobject(self) -> PrimExpr:
         """Convert object."""
         return _ffi_api._OpEQ(self.a, self.b, self.span)  # type: ignore
 
@@ -280,18 +281,18 @@ class NotEqualOp(ObjectGeneric, ExprOp):
     # This class is not manipulated by C++. So use python's identity check 
function is sufficient
     same_as = object.__eq__
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.a = a
         self.b = b
         self.span = span
 
-    def __nonzero__(self):
+    def __nonzero__(self) -> bool:
         return not self.a.same_as(self.b)
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return self.__nonzero__()
 
-    def asobject(self):
+    def asobject(self) -> PrimExpr:
         """Convert object."""
         return _ffi_api._OpNE(self.a, self.b, self.span)  # type: ignore
 
@@ -309,11 +310,11 @@ class IntImmEnum(ObjectGeneric):
         The location of the cast in the source.
     """
 
-    def __init__(self, value, span=None):
+    def __init__(self, value: int, span: Optional[Span] = None) -> None:
         self.value = value
         self.span = span
 
-    def asobject(self):
+    def asobject(self) -> "IntImm":
         """Convert object."""
         return IntImm("int32", self.value, self.span)  # type: ignore
 
@@ -331,11 +332,13 @@ class ConstExpr(PrimExprWithOp):
 
 
 class BinaryOpExpr(PrimExprWithOp):
-    pass
+    a: PrimExpr
+    b: PrimExpr
 
 
 class CmpExpr(PrimExprWithOp):
-    pass
+    a: PrimExpr
+    b: PrimExpr
 
 
 class LogicalExpr(PrimExprWithOp):
@@ -351,14 +354,17 @@ class Var(PrimExprWithOp):
     name : str
         The name
 
-    dtype : Union[str, tvm.irType]
+    dtype : Union[str, ir.Type]
         The data type
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, name: str, dtype: Union[str, ir.Type], span: 
Optional[Span] = None):
+    name_hint: str
+    type_annotation: ir.Type
+
+    def __init__(self, name: str, dtype: Union[str, ir.Type], span: 
Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span)  
# type: ignore
 
 
@@ -372,15 +378,15 @@ class SizeVar(Var):
     name : str
         The name
 
-    dtype : int
+    dtype : Union[str, ir.Type]
         The data type
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
     # pylint: disable=super-init-not-called
-    def __init__(self, name, dtype, span=None):
+    def __init__(self, name: str, dtype: Union[str, ir.Type], span: 
Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, 
span)  # type: ignore
 
 
@@ -405,7 +411,7 @@ class IterVar(Object, ExprOp, Scriptable):
         The thread type tag.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
 
     See Also
     --------
@@ -423,7 +429,19 @@ class IterVar(Object, ExprOp, Scriptable):
     Parallelized = 7
     Tensorized = 8
 
-    def __init__(self, dom, var, iter_type, thread_tag="", span=None):
+    dom: ir.Range
+    var: Var
+    iter_type: int
+    thread_tag: str
+
+    def __init__(
+        self,
+        dom: ir.Range,
+        var: Union[Var, str],
+        iter_type: int,
+        thread_tag: str = "",
+        span: Optional[Span] = None,
+    ) -> None:
         if dom is not None:
             if isinstance(dom, (list, tuple)):
                 if len(dom) != 2:
@@ -464,10 +482,22 @@ class CommReducer(Object, Scriptable):
        The identity elements.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, lhs, rhs, result, identity_element, span=None):
+    lhs: List[Var]
+    rhs: List[Var]
+    result: List[PrimExpr]
+    identity_element: List[PrimExpr]
+
+    def __init__(
+        self,
+        lhs: List[Var],
+        rhs: List[Var],
+        result: List[PrimExpr],
+        identity_element: List[PrimExpr],
+        span: Optional[Span] = None,
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.CommReducer, lhs, rhs, result, identity_element, span  # 
type: ignore
         )
@@ -498,10 +528,27 @@ class Reduce(PrimExprWithOp):
         The initial value for output. This can be an int, float or ProducerLoad
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, combiner, src, rdom, condition, value_index, init=None, 
span=None):
+    combiner: CommReducer
+    source: List[PrimExpr]
+    init: List[PrimExpr]
+    axis: List[IterVar]
+    condition: PrimExpr
+    value_index: int
+
+    def __init__(
+        self,
+        combiner: CommReducer,
+        src: List[PrimExpr],
+        rdom: List[IterVar],
+        condition: PrimExpr,
+        value_index: int,
+        init: Optional[List[PrimExpr]] = None,
+        span: Optional[Span] = None,
+    ) -> None:
+        init = [] if init is None else init
         self.__init_handle_by_constructor__(
             _ffi_api.Reduce, combiner, src, rdom, condition, value_index, 
init, span  # type: ignore
         )
@@ -520,15 +567,17 @@ class FloatImm(ConstExpr):
         The constant value.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, dtype, value, span=None):
+    value: float
+
+    def __init__(self, dtype: str, value: float, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(
             tvm.ir._ffi_api.FloatImm, dtype, value, span  # type: ignore
         )
 
-    def __float__(self):
+    def __float__(self) -> float:
         return self.value
 
 
@@ -545,30 +594,32 @@ class IntImm(ConstExpr):
         The constant value.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, dtype, value, span=None):
+    value: int
+
+    def __init__(self, dtype: str, value: int, span: Optional[Span] = None) -> 
None:
         self.__init_handle_by_constructor__(
             tvm.ir._ffi_api.IntImm, dtype, value, span  # type: ignore
         )
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return self.value
 
-    def __int__(self):
+    def __int__(self) -> int:
         return self.value
 
-    def __nonzero__(self):
+    def __nonzero__(self) -> bool:
         return self.value != 0
 
-    def __eq__(self, other):
+    def __eq__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpEQ(self, other, None)  # type: ignore
 
-    def __ne__(self, other):
+    def __ne__(self, other: PrimExpr) -> PrimExpr:
         return _ffi_api._OpNE(self, other, None)  # type: ignore
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return self.__nonzero__()
 
 
@@ -582,23 +633,25 @@ class StringImm(ConstExpr):
         The value of the function.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, value, span=None):
+    value: str
+
+    def __init__(self, value: str, span: Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span)  
# type: ignore
 
-    def __eq__(self, other):
+    def __eq__(self, other: PrimExpr) -> bool:
         if isinstance(other, ConstExpr):
             return self.value == other.value
         return self.value == other
 
-    def __ne__(self, other):
+    def __ne__(self, other: PrimExpr) -> bool:
         if isinstance(other, ConstExpr):
             return self.value != other.value
         return self.value != other
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return PrimExpr.__hash__(self)
 
 
@@ -615,10 +668,12 @@ class Cast(PrimExprWithOp):
         The value of the function.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, dtype, value, span=None):
+    value: PrimExpr
+
+    def __init__(self, dtype, value, span: Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) 
 # type: ignore
 
 
@@ -635,10 +690,10 @@ class Add(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span)  # type: 
ignore
 
 
@@ -655,10 +710,10 @@ class Sub(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span)  # type: 
ignore
 
 
@@ -675,10 +730,10 @@ class Mul(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span)  # type: 
ignore
 
 
@@ -695,10 +750,10 @@ class Div(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span)  # type: 
ignore
 
 
@@ -715,10 +770,10 @@ class Mod(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span)  # type: 
ignore
 
 
@@ -735,10 +790,10 @@ class FloorDiv(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span)  # 
type: ignore
 
 
@@ -755,10 +810,10 @@ class FloorMod(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span)  # 
type: ignore
 
 
@@ -775,10 +830,10 @@ class Min(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span)  # type: 
ignore
 
 
@@ -795,10 +850,10 @@ class Max(BinaryOpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span)  # type: 
ignore
 
 
@@ -815,10 +870,10 @@ class EQ(CmpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span)  # type: 
ignore
 
 
@@ -835,10 +890,10 @@ class NE(CmpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span)  # type: 
ignore
 
 
@@ -855,10 +910,10 @@ class LT(CmpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span)  # type: 
ignore
 
 
@@ -875,10 +930,10 @@ class LE(CmpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span)  # type: 
ignore
 
 
@@ -895,10 +950,10 @@ class GT(CmpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span)  # type: 
ignore
 
 
@@ -915,10 +970,10 @@ class GE(CmpExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span)  # type: 
ignore
 
 
@@ -935,10 +990,10 @@ class And(LogicalExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.And, a, b, span)  # type: 
ignore
 
 
@@ -955,10 +1010,13 @@ class Or(LogicalExpr):
         The right hand operand.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, b, span=None):
+    a: PrimExpr
+    b: PrimExpr
+
+    def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) 
-> None:
         self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span)  # type: 
ignore
 
 
@@ -972,10 +1030,12 @@ class Not(LogicalExpr):
         The input value
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, a, span=None):
+    a: PrimExpr
+
+    def __init__(self, a: PrimExpr, span: Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Not, a, span)  # type: 
ignore
 
 
@@ -1002,10 +1062,20 @@ class Select(PrimExprWithOp):
         The value to take when condition is false.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, condition, true_value, false_value, span=None):
+    condition: PrimExpr
+    true_value: PrimExpr
+    false_value: PrimExpr
+
+    def __init__(
+        self,
+        condition: PrimExpr,
+        true_value: PrimExpr,
+        false_value: PrimExpr,
+        span: Optional[Span] = None,
+    ) -> None:
         if isinstance(condition, bool):
             condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
@@ -1026,10 +1096,15 @@ class BufferLoad(PrimExprWithOp):
         The buffer indices.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, buffer, indices, span=None):
+    buffer: Buffer
+    indices: List[PrimExpr]
+
+    def __init__(
+        self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] = 
None
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.BufferLoad, buffer, indices, span  # type: ignore
         )
@@ -1048,10 +1123,15 @@ class ProducerLoad(PrimExprWithOp):
         The buffer indices.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, producer, indices, span=None):
+    producer: DataProducer
+    indices: List[PrimExpr]
+
+    def __init__(
+        self, producer: DataProducer, indices: List[PrimExpr], span: 
Optional[Span] = None
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.ProducerLoad, producer, indices, span  # type: ignore
         )
@@ -1073,10 +1153,16 @@ class Ramp(PrimExprWithOp):
         The lanes of the expression.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, base, stride, lanes, span=None):
+    base: PrimExpr
+    stride: PrimExpr
+    lanes: int
+
+    def __init__(
+        self, base: PrimExpr, stride: PrimExpr, lanes: int, span: 
Optional[Span] = None
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.Ramp, base, stride, lanes, span  # type: ignore
         )
@@ -1095,10 +1181,13 @@ class Broadcast(PrimExprWithOp):
         The lanes of the expression.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, value, lanes, span=None):
+    value: PrimExpr
+    lanes: int
+
+    def __init__(self, value: PrimExpr, lanes: int, span: Optional[Span] = 
None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, 
span)  # type: ignore
 
 
@@ -1108,17 +1197,22 @@ class Shuffle(PrimExprWithOp):
 
     Parameters
     ----------
-    vectors : Array of Expr
+    vectors : List[PrimExpr]
         The vectors
 
-    indices : Array of indices
+    indices : List[PrimExpr]
         The indices
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, vectors, indices, span=None):
+    vectors: List[PrimExpr]
+    indices: List[PrimExpr]
+
+    def __init__(
+        self, vectors: List[PrimExpr], indices: List[PrimExpr], span: 
Optional[Span] = None
+    ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.Shuffle, vectors, indices, span  # type: ignore
         )
@@ -1144,7 +1238,7 @@ class Call(PrimExprWithOp):
     dtype : str
         The return data type
 
-    op : Union[RelayExpr, str]
+    op : Union[Op, str]
         The function to be called, or the name
         to the global tvm.Op
 
@@ -1152,10 +1246,15 @@ class Call(PrimExprWithOp):
         The input arguments to the call
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, dtype, op, args, span=None):
+    op: Op
+    args: List[PrimExpr]
+
+    def __init__(
+        self, dtype: str, op: Union[Op, str], args: List[PrimExpr], span: 
Optional[Span] = None
+    ) -> None:
         if isinstance(op, str):
             if not op.startswith("tir."):
                 raise ValueError(
@@ -1180,16 +1279,22 @@ class Let(PrimExprWithOp):
         The variable in the binding.
 
     value : PrimExpr
-        The value in to be binded.
+        The value in to be bound.
 
     body : PrimExpr
         The body expression.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, var, value, body, span=None):
+    var: Var
+    value: PrimExpr
+    body: PrimExpr
+
+    def __init__(
+        self, var: Var, value: PrimExpr, body: PrimExpr, span: Optional[Span] 
= None
+    ) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body, 
span)  # type: ignore
 
 
@@ -1198,8 +1303,8 @@ class Any(PrimExprWithOp):
     """Any node.
 
     span : Optional[Span]
-        The location of this itervar in the source code.
+        The location of this expression in the source code.
     """
 
-    def __init__(self, span=None):
+    def __init__(self, span: Optional[Span] = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.Any, span)  # type: ignore


Reply via email to