shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745084911
##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -33,21 +33,51 @@ from numbers import Number
import builtins
from tvm.tir.function import PrimFunc
-from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.tir import Range
from tvm.runtime import Object
from .node import BufferSlice
"""
redefine types
"""
+class PrimExpr:
+ def __init__(self: PrimExpr) -> None: ...
+ @overload
+ def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ @overload
+ def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ @overload
+ def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ @overload
+ def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
class Buffer:
@overload
- def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) ->
PrimExpr: ...
+ def __getitem__(
+ self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]]
Review comment:
done
##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -33,21 +33,51 @@ from numbers import Number
import builtins
from tvm.tir.function import PrimFunc
-from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.tir import Range
from tvm.runtime import Object
from .node import BufferSlice
"""
redefine types
"""
+class PrimExpr:
+ def __init__(self: PrimExpr) -> None: ...
+ @overload
+ def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ @overload
+ def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ @overload
+ def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ @overload
+ def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+ @overload
+ def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+ def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
class Buffer:
@overload
- def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) ->
PrimExpr: ...
+ def __getitem__(
+ self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]]
+ ) -> PrimExpr: ...
@overload
def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
@overload
- def __setitem__(self: Buffer, pos: List[Union[PrimExpr, int]], value:
PrimExpr) -> None: ...
+ def __setitem__(
Review comment:
done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]