This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f3473e2d5 [TIR] Enhance Python Type Annotations for TIR Expr (#16083)
3f3473e2d5 is described below
commit 3f3473e2d57ba5933fe0a24d39c2d6e67f2b45c0
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Nov 8 10:54:59 2023 +0800
[TIR] Enhance Python Type Annotations for TIR Expr (#16083)
This PR enhances the Python annotations for the TIR expr,
adding class member variables annotations.
---
include/tvm/ir/expr.h | 4 +-
include/tvm/tir/var.h | 2 +-
python/tvm/ir/expr.py | 40 ++++--
python/tvm/tir/expr.py | 379 +++++++++++++++++++++++++++++++------------------
4 files changed, 274 insertions(+), 151 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index c8531c8846..b8857e598d 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -711,7 +711,7 @@ class RangeNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
-/*! \brief Range constainer */
+/*! \brief Range container */
class Range : public ObjectRef {
public:
/*!
@@ -736,7 +736,7 @@ class Range : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
};
-// implementataions
+// implementations
inline const Type& RelayExprNode::checked_type() const {
ICHECK(checked_type_.defined()) << "internal error: the type checker has "
<< "not populated the checked_type "
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 9cd2bed657..6c2c6dd5fc 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -270,7 +270,7 @@ class IterVarNode : public Object {
IterVarType iter_type;
/*!
* \brief additional tag on the iteration variable,
- * set this if this is binded already to a known thread tag.
+ * set this if this is bound already to a known thread tag.
*/
String thread_tag;
/*!
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index f0f2245e7f..36c425cb85 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -16,17 +16,21 @@
# under the License.
"""Common expressions data structures in the IR."""
from numbers import Number
+from typing import Callable, Optional
import tvm._ffi
-from ..runtime import Scriptable, const, convert
+from ..runtime import Object, Scriptable, const, convert
from . import _ffi_api
-from .base import Node
+from .base import Node, Span
+from .type import Type
class BaseExpr(Node):
"""Base class of all the expressions."""
+ span: Optional[Span]
+
class PrimExpr(BaseExpr):
"""Base class of all primitive expressions.
@@ -35,6 +39,8 @@ class PrimExpr(BaseExpr):
optimizations and integer analysis.
"""
+ dtype: str
+
class RelayExpr(BaseExpr):
"""Base class of all non-primitive expressions."""
@@ -67,10 +73,12 @@ class GlobalVar(RelayExpr):
The name of the variable.
"""
- def __init__(self, name_hint, type_annot=None):
+ name_hint: str
+
+ def __init__(self, name_hint: str, type_annot: Optional[Type] = None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint,
type_annot)
- def __call__(self, *args):
+ def __call__(self, *args: RelayExpr) -> BaseExpr:
"""Call the global variable.
Parameters
@@ -94,7 +102,9 @@ class GlobalVar(RelayExpr):
arg_types = [type(x) for x in args]
raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for
types {arg_types}")
- def astext(self, show_meta_data=True, annotate=None):
+ def astext(
+ self, show_meta_data: bool = True, annotate:
Optional[Callable[[Object], str]] = None
+ ) -> str:
"""Get the text format of the expression.
Parameters
@@ -140,7 +150,7 @@ class Range(Node, Scriptable):
The end value of the range.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this node in the source code.
Note
----
@@ -148,14 +158,22 @@ class Range(Node, Scriptable):
if the end argument is not None. Otherwise, it creates `[0, begin)`.
"""
- def __init__(self, begin, end=None, span=None):
+ min: PrimExpr
+ extent: PrimExpr
+ span: Optional[Span]
+
+ def __init__(
+ self, begin: PrimExpr, end: Optional[PrimExpr] = None, span:
Optional[Span] = None
+ ) -> None:
if end is None:
end = convert(begin)
begin = const(0, dtype=end.dtype, span=span)
self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span)
@staticmethod
- def from_min_extent(min_value, extent, span=None):
+ def from_min_extent(
+ min_value: PrimExpr, extent: PrimExpr, span: Optional[Span] = None
+ ) -> "Range":
"""Construct a Range by min and extent.
This constructs a range in [min_value, min_value + extent)
@@ -169,7 +187,7 @@ class Range(Node, Scriptable):
The extent of the range.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this node in the source code.
Returns
-------
@@ -178,8 +196,8 @@ class Range(Node, Scriptable):
"""
return _ffi_api.Range_from_min_extent(min_value, extent, span)
- def __eq__(self, other):
+ def __eq__(self, other: Object) -> bool:
return tvm.ir.structural_equal(self, other)
- def __ne__(self, other):
+ def __ne__(self, other: Object) -> bool:
return not self.__eq__(other)
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index f93e39ee0f..fad9fca083 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -27,7 +27,7 @@ For example, you can use addexp.a to get the left operand of
an Add node.
assert(isinstance(y, tvm.tir.Add))
assert(y.a == x)
"""
-from typing import Optional, Union
+from typing import List, Optional, Union
import tvm._ffi
import tvm.ir._ffi_api
@@ -38,9 +38,10 @@ from tvm.runtime import DataType, DataTypeCode, Object,
ObjectGeneric, Scriptabl
from . import _ffi_api
from . import generic as _generic
+from .buffer import Buffer, DataProducer
-def div_ambiguity_error():
+def div_ambiguity_error() -> RuntimeError:
return RuntimeError(
"TVM supports multiple types of integer divisions, "
+ "please call div, indexdiv/indexmod, floordiv/floormod "
@@ -69,111 +70,111 @@ class ExprOp(object):
# TODO(tkonolige): use inspect to add source information to these objects
- def __add__(self, other):
+ def __add__(self, other: PrimExpr) -> PrimExpr:
return _generic.add(self, other)
- def __radd__(self, other):
+ def __radd__(self, other: PrimExpr) -> PrimExpr:
return _generic.add(other, self)
- def __sub__(self, other):
+ def __sub__(self, other: PrimExpr) -> PrimExpr:
return _generic.subtract(self, other)
- def __rsub__(self, other):
+ def __rsub__(self, other: PrimExpr) -> PrimExpr:
return _generic.subtract(other, self)
- def __mul__(self, other):
+ def __mul__(self, other: PrimExpr) -> PrimExpr:
return _generic.multiply(self, other)
- def __rmul__(self, other):
+ def __rmul__(self, other: PrimExpr) -> PrimExpr:
return _generic.multiply(other, self)
- def __div__(self, other):
+ def __div__(self, other: PrimExpr) -> PrimExpr:
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)
- def __rdiv__(self, other):
+ def __rdiv__(self, other: PrimExpr) -> PrimExpr:
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)
- def __truediv__(self, other):
+ def __truediv__(self, other: PrimExpr) -> PrimExpr:
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)
- def __rtruediv__(self, other):
+ def __rtruediv__(self, other: PrimExpr) -> PrimExpr:
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)
- def __floordiv__(self, other):
+ def __floordiv__(self, other: PrimExpr) -> PrimExpr:
return _generic.floordiv(self, other)
- def __rfloordiv__(self, other):
+ def __rfloordiv__(self, other: PrimExpr) -> PrimExpr:
return _generic.floordiv(other, self, None)
- def __mod__(self, other):
+ def __mod__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpFloorMod(self, other, None) # type: ignore
- def __rmod__(self, other):
+ def __rmod__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpFloorMod(other, self, None) # type: ignore
- def __neg__(self):
+ def __neg__(self) -> PrimExpr:
neg_one = const(-1, self.dtype) # type: ignore
return self.__mul__(neg_one)
- def __lshift__(self, other):
+ def __lshift__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.left_shift(self, other, None) # type: ignore
- def __rlshift__(self, other):
+ def __rlshift__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.left_shift(other, self, None) # type: ignore
- def __rshift__(self, other):
+ def __rshift__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.right_shift(self, other, None) # type: ignore
- def __rrshift__(self, other):
+ def __rrshift__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.right_shift(other, self, None) # type: ignore
- def __and__(self, other):
+ def __and__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.bitwise_and(self, other, None) # type: ignore
- def __rand__(self, other):
+ def __rand__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.bitwise_and(other, self, None) # type: ignore
- def __or__(self, other):
+ def __or__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.bitwise_or(self, other, None) # type: ignore
- def __ror__(self, other):
+ def __ror__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.bitwise_or(other, self, None) # type: ignore
- def __xor__(self, other):
+ def __xor__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.bitwise_xor(self, other, None) # type: ignore
- def __rxor__(self, other):
+ def __rxor__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api.bitwise_xor(other, self, None) # type: ignore
- def __invert__(self):
+ def __invert__(self) -> PrimExpr:
if _dtype_is_float(self):
raise RuntimeError("Cannot use ~ operator on float type Expr.")
return _ffi_api.bitwise_not(self, None) # type: ignore
- def __lt__(self, other):
+ def __lt__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpLT(self, other, None) # type: ignore
- def __le__(self, other):
+ def __le__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpLE(self, other, None) # type: ignore
- def __eq__(self, other):
+ def __eq__(self, other: PrimExpr) -> PrimExpr:
return EqualOp(self, other)
- def __ne__(self, other):
+ def __ne__(self, other: PrimExpr) -> PrimExpr:
return NotEqualOp(self, other)
- def __gt__(self, other):
+ def __gt__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpGT(self, other, None) # type: ignore
- def __ge__(self, other):
+ def __ge__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpGE(self, other, None) # type: ignore
def __nonzero__(self):
@@ -182,10 +183,10 @@ class ExprOp(object):
+ "use tvm.tir.all / tvm.tir.any instead"
)
- def __bool__(self):
+ def __bool__(self) -> bool:
return self.__nonzero__()
- def equal(self, other, span=None):
+ def equal(self, other: PrimExpr, span: Optional[Span] = None) -> bool:
"""Build an equal check expression with other expr.
Parameters
@@ -203,7 +204,7 @@ class ExprOp(object):
"""
return _ffi_api._OpEQ(self, other, span) # type: ignore
- def astype(self, dtype: str, span: Optional[Span] = None):
+ def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
"""Cast the expression to other type.
Parameters
@@ -243,18 +244,18 @@ class EqualOp(ObjectGeneric, ExprOp):
# This class is not manipulated by C++. So use python's identity check
function is sufficient
same_as = object.__eq__
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None):
self.a = a
self.b = b
self.span = span
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
return self.a.same_as(self.b)
- def __bool__(self):
+ def __bool__(self) -> bool:
return self.__nonzero__()
- def asobject(self):
+ def asobject(self) -> PrimExpr:
"""Convert object."""
return _ffi_api._OpEQ(self.a, self.b, self.span) # type: ignore
@@ -280,18 +281,18 @@ class NotEqualOp(ObjectGeneric, ExprOp):
# This class is not manipulated by C++. So use python's identity check
function is sufficient
same_as = object.__eq__
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.a = a
self.b = b
self.span = span
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
return not self.a.same_as(self.b)
- def __bool__(self):
+ def __bool__(self) -> bool:
return self.__nonzero__()
- def asobject(self):
+ def asobject(self) -> PrimExpr:
"""Convert object."""
return _ffi_api._OpNE(self.a, self.b, self.span) # type: ignore
@@ -309,11 +310,11 @@ class IntImmEnum(ObjectGeneric):
The location of the cast in the source.
"""
- def __init__(self, value, span=None):
+ def __init__(self, value: int, span: Optional[Span] = None) -> None:
self.value = value
self.span = span
- def asobject(self):
+ def asobject(self) -> "IntImm":
"""Convert object."""
return IntImm("int32", self.value, self.span) # type: ignore
@@ -331,11 +332,13 @@ class ConstExpr(PrimExprWithOp):
class BinaryOpExpr(PrimExprWithOp):
- pass
+ a: PrimExpr
+ b: PrimExpr
class CmpExpr(PrimExprWithOp):
- pass
+ a: PrimExpr
+ b: PrimExpr
class LogicalExpr(PrimExprWithOp):
@@ -351,14 +354,17 @@ class Var(PrimExprWithOp):
name : str
The name
- dtype : Union[str, tvm.irType]
+ dtype : Union[str, ir.Type]
The data type
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, name: str, dtype: Union[str, ir.Type], span:
Optional[Span] = None):
+ name_hint: str
+ type_annotation: ir.Type
+
+ def __init__(self, name: str, dtype: Union[str, ir.Type], span:
Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span)
# type: ignore
@@ -372,15 +378,15 @@ class SizeVar(Var):
name : str
The name
- dtype : int
+ dtype : Union[str, ir.Type]
The data type
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
# pylint: disable=super-init-not-called
- def __init__(self, name, dtype, span=None):
+ def __init__(self, name: str, dtype: Union[str, ir.Type], span:
Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype,
span) # type: ignore
@@ -405,7 +411,7 @@ class IterVar(Object, ExprOp, Scriptable):
The thread type tag.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
See Also
--------
@@ -423,7 +429,19 @@ class IterVar(Object, ExprOp, Scriptable):
Parallelized = 7
Tensorized = 8
- def __init__(self, dom, var, iter_type, thread_tag="", span=None):
+ dom: ir.Range
+ var: Var
+ iter_type: int
+ thread_tag: str
+
+ def __init__(
+ self,
+ dom: ir.Range,
+ var: Union[Var, str],
+ iter_type: int,
+ thread_tag: str = "",
+ span: Optional[Span] = None,
+ ) -> None:
if dom is not None:
if isinstance(dom, (list, tuple)):
if len(dom) != 2:
@@ -464,10 +482,22 @@ class CommReducer(Object, Scriptable):
The identity elements.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, lhs, rhs, result, identity_element, span=None):
+ lhs: List[Var]
+ rhs: List[Var]
+ result: List[PrimExpr]
+ identity_element: List[PrimExpr]
+
+ def __init__(
+ self,
+ lhs: List[Var],
+ rhs: List[Var],
+ result: List[PrimExpr],
+ identity_element: List[PrimExpr],
+ span: Optional[Span] = None,
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.CommReducer, lhs, rhs, result, identity_element, span #
type: ignore
)
@@ -498,10 +528,27 @@ class Reduce(PrimExprWithOp):
The initial value for output. This can be an int, float or ProducerLoad
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, combiner, src, rdom, condition, value_index, init=None,
span=None):
+ combiner: CommReducer
+ source: List[PrimExpr]
+ init: List[PrimExpr]
+ axis: List[IterVar]
+ condition: PrimExpr
+ value_index: int
+
+ def __init__(
+ self,
+ combiner: CommReducer,
+ src: List[PrimExpr],
+ rdom: List[IterVar],
+ condition: PrimExpr,
+ value_index: int,
+ init: Optional[List[PrimExpr]] = None,
+ span: Optional[Span] = None,
+ ) -> None:
+ init = [] if init is None else init
self.__init_handle_by_constructor__(
_ffi_api.Reduce, combiner, src, rdom, condition, value_index,
init, span # type: ignore
)
@@ -520,15 +567,17 @@ class FloatImm(ConstExpr):
The constant value.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, dtype, value, span=None):
+ value: float
+
+ def __init__(self, dtype: str, value: float, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(
tvm.ir._ffi_api.FloatImm, dtype, value, span # type: ignore
)
- def __float__(self):
+ def __float__(self) -> float:
return self.value
@@ -545,30 +594,32 @@ class IntImm(ConstExpr):
The constant value.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, dtype, value, span=None):
+ value: int
+
+ def __init__(self, dtype: str, value: int, span: Optional[Span] = None) ->
None:
self.__init_handle_by_constructor__(
tvm.ir._ffi_api.IntImm, dtype, value, span # type: ignore
)
- def __hash__(self):
+ def __hash__(self) -> int:
return self.value
- def __int__(self):
+ def __int__(self) -> int:
return self.value
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
return self.value != 0
- def __eq__(self, other):
+ def __eq__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpEQ(self, other, None) # type: ignore
- def __ne__(self, other):
+ def __ne__(self, other: PrimExpr) -> PrimExpr:
return _ffi_api._OpNE(self, other, None) # type: ignore
- def __bool__(self):
+ def __bool__(self) -> bool:
return self.__nonzero__()
@@ -582,23 +633,25 @@ class StringImm(ConstExpr):
The value of the function.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, value, span=None):
+ value: str
+
+ def __init__(self, value: str, span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.StringImm, value, span)
# type: ignore
- def __eq__(self, other):
+ def __eq__(self, other: PrimExpr) -> bool:
if isinstance(other, ConstExpr):
return self.value == other.value
return self.value == other
- def __ne__(self, other):
+ def __ne__(self, other: PrimExpr) -> bool:
if isinstance(other, ConstExpr):
return self.value != other.value
return self.value != other
- def __hash__(self):
+ def __hash__(self) -> int:
return PrimExpr.__hash__(self)
@@ -615,10 +668,12 @@ class Cast(PrimExprWithOp):
The value of the function.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, dtype, value, span=None):
+ value: PrimExpr
+
+ def __init__(self, dtype, value, span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span)
# type: ignore
@@ -635,10 +690,10 @@ class Add(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type:
ignore
@@ -655,10 +710,10 @@ class Sub(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type:
ignore
@@ -675,10 +730,10 @@ class Mul(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type:
ignore
@@ -695,10 +750,10 @@ class Div(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type:
ignore
@@ -715,10 +770,10 @@ class Mod(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type:
ignore
@@ -735,10 +790,10 @@ class FloorDiv(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) #
type: ignore
@@ -755,10 +810,10 @@ class FloorMod(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) #
type: ignore
@@ -775,10 +830,10 @@ class Min(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type:
ignore
@@ -795,10 +850,10 @@ class Max(BinaryOpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type:
ignore
@@ -815,10 +870,10 @@ class EQ(CmpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type:
ignore
@@ -835,10 +890,10 @@ class NE(CmpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type:
ignore
@@ -855,10 +910,10 @@ class LT(CmpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type:
ignore
@@ -875,10 +930,10 @@ class LE(CmpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type:
ignore
@@ -895,10 +950,10 @@ class GT(CmpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type:
ignore
@@ -915,10 +970,10 @@ class GE(CmpExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type:
ignore
@@ -935,10 +990,10 @@ class And(LogicalExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type:
ignore
@@ -955,10 +1010,13 @@ class Or(LogicalExpr):
The right hand operand.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, b, span=None):
+ a: PrimExpr
+ b: PrimExpr
+
+ def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None)
-> None:
self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type:
ignore
@@ -972,10 +1030,12 @@ class Not(LogicalExpr):
The input value
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, a, span=None):
+ a: PrimExpr
+
+ def __init__(self, a: PrimExpr, span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type:
ignore
@@ -1002,10 +1062,20 @@ class Select(PrimExprWithOp):
The value to take when condition is false.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, condition, true_value, false_value, span=None):
+ condition: PrimExpr
+ true_value: PrimExpr
+ false_value: PrimExpr
+
+ def __init__(
+ self,
+ condition: PrimExpr,
+ true_value: PrimExpr,
+ false_value: PrimExpr,
+ span: Optional[Span] = None,
+ ) -> None:
if isinstance(condition, bool):
condition = IntImm("bool", condition)
self.__init_handle_by_constructor__(
@@ -1026,10 +1096,15 @@ class BufferLoad(PrimExprWithOp):
The buffer indices.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, buffer, indices, span=None):
+ buffer: Buffer
+ indices: List[PrimExpr]
+
+ def __init__(
+ self, buffer: Buffer, indices: List[PrimExpr], span: Optional[Span] =
None
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.BufferLoad, buffer, indices, span # type: ignore
)
@@ -1048,10 +1123,15 @@ class ProducerLoad(PrimExprWithOp):
The buffer indices.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, producer, indices, span=None):
+ producer: DataProducer
+ indices: List[PrimExpr]
+
+ def __init__(
+ self, producer: DataProducer, indices: List[PrimExpr], span:
Optional[Span] = None
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ProducerLoad, producer, indices, span # type: ignore
)
@@ -1073,10 +1153,16 @@ class Ramp(PrimExprWithOp):
The lanes of the expression.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, base, stride, lanes, span=None):
+ base: PrimExpr
+ stride: PrimExpr
+ lanes: int
+
+ def __init__(
+ self, base: PrimExpr, stride: PrimExpr, lanes: int, span:
Optional[Span] = None
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Ramp, base, stride, lanes, span # type: ignore
)
@@ -1095,10 +1181,13 @@ class Broadcast(PrimExprWithOp):
The lanes of the expression.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, value, lanes, span=None):
+ value: PrimExpr
+ lanes: int
+
+ def __init__(self, value: PrimExpr, lanes: int, span: Optional[Span] =
None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes,
span) # type: ignore
@@ -1108,17 +1197,22 @@ class Shuffle(PrimExprWithOp):
Parameters
----------
- vectors : Array of Expr
+ vectors : List[PrimExpr]
The vectors
- indices : Array of indices
+ indices : List[PrimExpr]
The indices
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, vectors, indices, span=None):
+ vectors: List[PrimExpr]
+ indices: List[PrimExpr]
+
+ def __init__(
+ self, vectors: List[PrimExpr], indices: List[PrimExpr], span:
Optional[Span] = None
+ ) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Shuffle, vectors, indices, span # type: ignore
)
@@ -1144,7 +1238,7 @@ class Call(PrimExprWithOp):
dtype : str
The return data type
- op : Union[RelayExpr, str]
+ op : Union[Op, str]
The function to be called, or the name
to the global tvm.Op
@@ -1152,10 +1246,15 @@ class Call(PrimExprWithOp):
The input arguments to the call
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, dtype, op, args, span=None):
+ op: Op
+ args: List[PrimExpr]
+
+ def __init__(
+ self, dtype: str, op: Union[Op, str], args: List[PrimExpr], span:
Optional[Span] = None
+ ) -> None:
if isinstance(op, str):
if not op.startswith("tir."):
raise ValueError(
@@ -1180,16 +1279,22 @@ class Let(PrimExprWithOp):
The variable in the binding.
value : PrimExpr
- The value in to be binded.
+ The value in to be bound.
body : PrimExpr
The body expression.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, var, value, body, span=None):
+ var: Var
+ value: PrimExpr
+ body: PrimExpr
+
+ def __init__(
+ self, var: Var, value: PrimExpr, body: PrimExpr, span: Optional[Span]
= None
+ ) -> None:
self.__init_handle_by_constructor__(_ffi_api.Let, var, value, body,
span) # type: ignore
@@ -1198,8 +1303,8 @@ class Any(PrimExprWithOp):
"""Any node.
span : Optional[Span]
- The location of this itervar in the source code.
+ The location of this expression in the source code.
"""
- def __init__(self, span=None):
+ def __init__(self, span: Optional[Span] = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Any, span) # type: ignore