This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 41b65a3144 [TVMScript] IRBuilder methods for `Block` (#12815)
41b65a3144 is described below
commit 41b65a3144595afb04228be1334dc77c08d11ba7
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Sep 16 18:11:06 2022 -0700
[TVMScript] IRBuilder methods for `Block` (#12815)
This PR introduces remaining IRBuilder methods for `Block`.
Co-authored-by: yongwww <[email protected]>
---
include/tvm/script/ir_builder/tir/frame.h | 35 +++
include/tvm/script/ir_builder/tir/ir.h | 49 +++++
python/tvm/script/ir_builder/base.py | 18 +-
python/tvm/script/ir_builder/ir/ir.py | 2 +-
python/tvm/script/ir_builder/tir/frame.py | 7 +-
python/tvm/script/ir_builder/tir/ir.py | 235 +++++++++++++++++----
src/script/ir_builder/tir/frame.cc | 15 ++
src/script/ir_builder/tir/ir.cc | 80 +++++++
.../unittest/test_tvmscript_ir_builder_tir.py | 50 ++++-
tests/scripts/task_mypy.sh | 3 +
10 files changed, 442 insertions(+), 52 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/frame.h
b/include/tvm/script/ir_builder/tir/frame.h
index 2902b982d5..c76b400d96 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -187,6 +187,41 @@ class BlockFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame,
BlockFrameNode);
};
+/*!
+ * \brief A frame that represents the block initialization statment.
+ *
+ * \sa BlockInitFrame
+ */
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+ void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+ static constexpr const char* _type_key =
"script.ir_builder.tir.BlockInitFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+ /*!
+ * \brief The method called when entering RAII scope.
+ * \sa tvm::support::With
+ */
+ void EnterWithScope() final;
+ /*!
+ * \brief The method called when exiting RAII scope.
+ * \sa tvm::support::With
+ */
+ void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to BlockInitFrameNode.
+ *
+ * \sa BlockInitFrameNode
+ */
+class BlockInitFrame : public TIRFrame {
+ public:
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame,
BlockInitFrameNode);
+};
+
/*!
* \brief A frame that represents the for loop.
*
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 037606253a..191887648d 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer,
Array<PrimExpr> shape,
*/
BlockFrame Block(String name, bool no_realize = false);
+/*!
+ * \brief The block initialization statement.
+ * \return The BlockInitFrame.
+ */
+BlockInitFrame Init();
+
+/*!
+ * \brief The block predicate statement.
+ * \param predicate The predicate condition.
+ */
+void Where(PrimExpr predicate);
+
+/*!
+ * \brief The block buffer region reading statement.
+ * \param buffer_slices The array of buffer regions to read.
+ */
+void Reads(Array<ObjectRef> buffer_slices);
+
+/*!
+ * \brief The block buffer region writing statement.
+ * \param buffer_slices The array of buffer regions to write.
+ */
+void Writes(Array<ObjectRef> buffer_slices);
+
+/*!
+ * \brief The block annotation statement.
+ * \param attrs The annotation of the block.
+ */
+void BlockAttrs(Map<String, ObjectRef> attrs);
+
+/*!
+ * \brief The buffer allocation function.
+ * \param shape The type of the buffer prior to flattening.
+ * \param dtype The data type in the content of the buffer.
+ * \param data The pointer to the head of the data.
+ * \param strides The strides of each dimension.
+ * \param elem_offset The offset in terms of number of dtype elements
(including lanes).
+ * \param storage_scope The optional storage scope of buffer data pointer.
+ * \param align The alignment requirement of data pointer in bytes.
+ * \param offset_factor The factor of elem_offset field.
+ * \param buffer_type The buffer type.
+ * \param axis_separators The separators between input axes when generating
flattened output axes.
+ * \return The allocated buffer.
+ */
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+ Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+ PrimExpr elem_offset = PrimExpr(), String storage_scope =
"", int align = -1,
+ int offset_factor = 0, String buffer_type = "default",
+ Array<IntImm> axis_separators = {});
namespace axis {
/*!
diff --git a/python/tvm/script/ir_builder/base.py
b/python/tvm/script/ir_builder/base.py
index 767fa8bf25..7aa33ee49c 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -61,11 +61,11 @@ class IRBuilderFrame(_Object):
"""
def __enter__(self) -> "IRBuilderFrame":
- _ffi_api.IRBuilderFrameEnter(self) # pylint: disable=no-member #
type: ignore
+ _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] #
pylint: disable=no-member
return self
def __exit__(self, ptype, value, trace) -> None: # pylint:
disable=unused-argument
- _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type:
ignore
+ _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] #
pylint: disable=no-member
def add_callback(self, callback: Callable[[], None]) -> None:
"""Add a callback method invoked when exiting the with-scope.
@@ -75,7 +75,7 @@ class IRBuilderFrame(_Object):
callback : Callable[[], None]
The callback method to be invoked.
"""
- _ffi_api.IRBuilderFrameAddCallback( # pylint: disable=no-member #
type: ignore
+ _ffi_api.IRBuilderFrameAddCallback( # type: ignore[attr-defined] #
pylint: disable=no-member
self, callback
)
@@ -104,7 +104,7 @@ class IRBuilder(_Object):
def __init__(self) -> None:
"""Construct an IRBuilder."""
self.__init_handle_by_constructor__(
- _ffi_api.IRBuilder # pylint: disable=no-member # type: ignore
+ _ffi_api.IRBuilder # type: ignore[attr-defined] # pylint:
disable=no-member
)
def __enter__(self) -> "IRBuilder":
@@ -119,11 +119,11 @@ class IRBuilder(_Object):
with IRBuilder() as builder:
assert IRBuilder.current() == builder
"""
- _ffi_api.IRBuilderEnter(self) # pylint: disable=no-member # type:
ignore
+ _ffi_api.IRBuilderEnter(self) # type: ignore[attr-defined] # pylint:
disable=no-member
return self
def __exit__(self, ptype, value, trace) -> None: # pylint:
disable=unused-argument
- _ffi_api.IRBuilderExit(self) # pylint: disable=no-member # type:
ignore
+ _ffi_api.IRBuilderExit(self) # type: ignore[attr-defined] # pylint:
disable=no-member
@staticmethod
def current() -> "IRBuilder":
@@ -134,11 +134,11 @@ class IRBuilder(_Object):
builder : IRBuilder
The current IRBuilder.
"""
- return _ffi_api.IRBuilderCurrent() # pylint: disable=no-member #
type: ignore
+ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] #
pylint: disable=no-member
def get(self) -> _Object:
"""Get the constructed IR."""
- return _ffi_api.IRBuilderGet(self) # pylint: disable=no-member #
type: ignore
+ return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] #
pylint: disable=no-member
@staticmethod
def name(s: str, v: Any) -> Any:
@@ -156,7 +156,7 @@ class IRBuilder(_Object):
v : Any
The same object with the name set.
"""
- return _ffi_api.IRBuilderName(s, v) # pylint: disable=no-member #
type: ignore
+ return _ffi_api.IRBuilderName(s, v) # type: ignore[attr-defined] #
pylint: disable=no-member
@staticmethod
def name_many( # pylint: disable=invalid-name
diff --git a/python/tvm/script/ir_builder/ir/ir.py
b/python/tvm/script/ir_builder/ir/ir.py
index df92036435..213180463c 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -21,4 +21,4 @@ from .frame import IRModuleFrame
def ir_module() -> IRModuleFrame:
- return _ffi_api.IRModule() # pylint: disable=no-member # type: ignore
+ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint:
disable=no-member
diff --git a/python/tvm/script/ir_builder/tir/frame.py
b/python/tvm/script/ir_builder/tir/frame.py
index 75bb0231ae..2ad08f3516 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -38,8 +38,13 @@ class BlockFrame(TIRFrame):
...
+@_register_object("script.ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+ ...
+
+
@_register_object("script.ir_builder.tir.ForFrame")
class ForFrame(TIRFrame):
- def __enter__(self) -> Union[Var, List[Var]]:
+ def __enter__(self) -> Union[Var, List[Var]]: # type: ignore[override]
super().__enter__()
return self.vars if len(self.vars) > 1 else self.vars[0]
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 40cd99c744..d1dc1c8960 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -25,6 +25,7 @@ from tvm.tir import (
Buffer,
BufferLoad,
BufferRegion,
+ IntImm,
PrimExpr,
StringImm,
Var,
@@ -85,7 +86,7 @@ def buffer_decl(
The declared buffer.
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
- return _ffi_api.BufferDecl( # pylint: disable=no-member # type: ignore
+ return _ffi_api.BufferDecl( # type: ignore[attr-defined] # pylint:
disable=no-member
shape,
dtype,
"",
@@ -108,7 +109,7 @@ def prim_func() -> frame.PrimFuncFrame:
res : frame.PrimFuncFrame
The PrimFuncFrame.
"""
- return _ffi_api.PrimFunc() # pylint: disable=no-member # type: ignore
+ return _ffi_api.PrimFunc() # type: ignore[attr-defined] # pylint:
disable=no-member
def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, Buffer]:
@@ -127,7 +128,7 @@ def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var,
Buffer]:
res : Union[Var, Buffer]
The argument.
"""
- return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Arg(name, obj) # type: ignore[attr-defined] # pylint:
disable=no-member
def func_name(name: str) -> None:
@@ -138,7 +139,7 @@ def func_name(name: str) -> None:
name : str
The name of the PrimFunc.
"""
- _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore
+ _ffi_api.FuncName(name) # type: ignore[attr-defined] # pylint:
disable=no-member
def func_attr(attrs: Dict[str, Any]) -> None:
@@ -149,7 +150,7 @@ def func_attr(attrs: Dict[str, Any]) -> None:
attrs : Dict[str, Any]
The annotations of the PrimFunc.
"""
- _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore
+ _ffi_api.FuncAttrs(attrs) # type: ignore[attr-defined] # pylint:
disable=no-member
def func_ret(ret_type: Type) -> Type:
@@ -165,7 +166,7 @@ def func_ret(ret_type: Type) -> Type:
res : Type
The return type.
"""
- return _ffi_api.FuncRet(ret_type) # pylint: disable=no-member # type:
ignore
+ return _ffi_api.FuncRet(ret_type) # type: ignore[attr-defined] # pylint:
disable=no-member
def match_buffer(
@@ -242,7 +243,7 @@ def match_buffer(
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is None:
strides = []
- return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore
+ return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint:
disable=no-member
param,
shape,
dtype,
@@ -310,7 +311,7 @@ def preflattened_buffer(
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is None:
strides = []
- _ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore
+ _ffi_api.PreflattenedBuffer( # type: ignore[attr-defined] # pylint:
disable=no-member
postflattened,
shape,
dtype,
@@ -341,7 +342,155 @@ def block(name: str = "", no_realize: bool = False) ->
frame.BlockFrame:
res : frame.BlockFrame
The BlockFrame.
"""
- return _ffi_api.Block(name, no_realize) # pylint: disable=no-member #
type: ignore
+ return _ffi_api.Block(name, no_realize) # type: ignore[attr-defined] #
pylint: disable=no-member
+
+
+def init() -> frame.BlockInitFrame:
+ """The block initialization statement.
+
+ Returns
+ -------
+ res : frame.BlockInitFrame
+ The BlockInitFrame.
+ """
+ return _ffi_api.Init() # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def where(predicate: Union[PrimExpr, int]) -> None:
+ """The block predicate statement.
+
+ Parameters
+ ----------
+ predicate : Union[PrimExpr, Literal[0, 1]]
+ The predicate condition.
+ """
+ if isinstance(predicate, bool):
+ predicate = IntImm("bool", predicate)
+ if isinstance(predicate, int):
+ if predicate in [0, 1]:
+ predicate = IntImm("bool", predicate)
+ else:
+ raise ValueError(f"Invalid value for predicate: {predicate}")
+ _ffi_api.Where(predicate) # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+ """The block buffer region reading statement.
+
+ Parameters
+ ----------
+ buffer_slices : List[Union[BufferRegion, BufferLoad]]
+ The array of buffer regions to read.
+ """
+ if len(buffer_slices) == 1:
+ if isinstance(buffer_slices[0], tuple):
+ buffer_slices = list(buffer_slices[0])
+ elif isinstance(buffer_slices[0], list):
+ buffer_slices = buffer_slices[0] # type: ignore[assignment]
+ else:
+ buffer_slices = [buffer_slices[0]]
+ else:
+ buffer_slices = list(buffer_slices) # type: ignore[assignment]
+ _ffi_api.Reads(buffer_slices) # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+ """The block buffer region writing statement.
+
+ Parameters
+ ----------
+ buffer_slices : List[Union[BufferRegion, BufferLoad]]
+ The array of buffer regions to write.
+ """
+ if len(buffer_slices) == 1:
+ if isinstance(buffer_slices[0], tuple):
+ buffer_slices = list(buffer_slices[0])
+ elif isinstance(buffer_slices[0], list):
+ buffer_slices = buffer_slices[0] # type: ignore[assignment]
+ else:
+ buffer_slices = [buffer_slices[0]]
+ else:
+ buffer_slices = list(buffer_slices) # type: ignore[assignment]
+ _ffi_api.Writes(buffer_slices) # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+ """The block annotation statement.
+
+ Parameters
+ ----------
+ attrs : Dict[str, Any]
+ The annotation of the block.
+ """
+ return _ffi_api.BlockAttrs(attrs) # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def alloc_buffer(
+ shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
+ dtype: str = "float32",
+ data: Var = None,
+ strides: List[PrimExpr] = None,
+ elem_offset: PrimExpr = None,
+ scope: str = "",
+ align: int = -1,
+ offset_factor: int = 0,
+ buffer_type: str = "default",
+ axis_separators: List[int] = None,
+) -> Buffer:
+ """The buffer alllocation function.
+
+ Parameters
+ ----------
+ shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral]
+ The type of the buffer prior to flattening.
+
+ dtype : str
+ The data type in the content of the buffer.
+
+ data : Var
+ The pointer to the head of the data.
+
+ strides : List[PrimExpr]
+ The strides of each dimension.
+
+ elem_offset : PrimExpr
+ The offset in terms of number of dtype elements (including lanes).
+
+ scope : str
+ The optional storage scope of buffer data pointer.
+
+ align : int
+ The alignment requirement of data pointer in bytes.
+
+ offset_factor : int
+ The factor of elem_offset field.
+
+ buffer_type : str
+ The buffer type.
+
+ axis_separators : List[int]
+ The separators between input axes when generating flattened output
axes.
+
+ Returns
+ -------
+ res : Buffer
+ The allocated buffer.
+ """
+ shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is None:
+ strides = []
+ return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint:
disable=no-member
+ shape,
+ dtype,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ )
def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
@@ -387,7 +536,7 @@ class axis: # pylint: disable=invalid-name
res : Var
The iteration variable.
"""
- return _ffi_api.AxisSpatial( # pylint: disable=no-member # type:
ignore
+ return _ffi_api.AxisSpatial( # type: ignore[attr-defined] # pylint:
disable=no-member
_as_range(dom), binding, dtype
)
@@ -413,7 +562,7 @@ class axis: # pylint: disable=invalid-name
res : Var
The iteration variable.
"""
- return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore
+ return _ffi_api.AxisReduce( # type: ignore[attr-defined] # pylint:
disable=no-member
_as_range(dom), binding, dtype
)
@@ -439,7 +588,7 @@ class axis: # pylint: disable=invalid-name
res : Var
The iteration variable.
"""
- return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore
+ return _ffi_api.AxisScan( # type: ignore[attr-defined] # pylint:
disable=no-member
_as_range(dom), binding, dtype
)
@@ -465,7 +614,7 @@ class axis: # pylint: disable=invalid-name
res : Var
The iteration variable.
"""
- return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore
+ return _ffi_api.AxisOpaque( # type: ignore[attr-defined] # pylint:
disable=no-member
_as_range(dom), binding, dtype
)
@@ -489,7 +638,7 @@ class axis: # pylint: disable=invalid-name
res : Var
The iteration variables.
"""
- iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type:
ignore
+ iter_vars = _ffi_api.AxisRemap( # type: ignore[attr-defined] #
pylint: disable=no-member
kinds, bindings, dtype
)
return iter_vars[0] if len(iter_vars) == 1 else iter_vars
@@ -522,7 +671,7 @@ def serial(
if stop is None:
stop = start
start = 0
- return _ffi_api.Serial(start, stop, annotations) # pylint:
disable=no-member # type: ignore
+ return _ffi_api.Serial(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
def parallel(
@@ -549,7 +698,7 @@ def parallel(
if stop is None:
stop = start
start = 0
- return _ffi_api.Parallel(start, stop, annotations) # pylint:
disable=no-member # type: ignore
+ return _ffi_api.Parallel(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
def vectorized(
@@ -576,7 +725,7 @@ def vectorized(
if stop is None:
stop = start
start = 0
- return _ffi_api.Vectorized(start, stop, annotations) # pylint:
disable=no-member # type: ignore
+ return _ffi_api.Vectorized(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
def unroll(
@@ -603,7 +752,7 @@ def unroll(
if stop is None:
stop = start
start = 0
- return _ffi_api.Unroll(start, stop, annotations) # pylint:
disable=no-member # type: ignore
+ return _ffi_api.Unroll(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
def thread_binding(
@@ -643,7 +792,7 @@ def thread_binding(
elif stop is None:
stop = start
start = 0
- return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
+ return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint:
disable=no-member
start, stop, thread, annotations
)
@@ -661,7 +810,7 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
res : frame.ForFrame
The ForFrame.
"""
- return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Grid(extents) # type: ignore[attr-defined] # pylint:
disable=no-member
def evaluate(value: PrimExpr) -> None:
@@ -674,7 +823,7 @@ def evaluate(value: PrimExpr) -> None:
"""
if isinstance(value, str):
value = StringImm(value)
- return _ffi_api.Evaluate(value) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Evaluate(value) # type: ignore[attr-defined] # pylint:
disable=no-member
def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -690,7 +839,7 @@ def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int8 or casted expression with type int8.
"""
- return _ffi_api.Int8(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int8(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -706,7 +855,7 @@ def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int16 or casted expression with type int16.
"""
- return _ffi_api.Int16(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int16(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -722,7 +871,7 @@ def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int32 or casted expression with type int32.
"""
- return _ffi_api.Int32(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int32(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -738,7 +887,7 @@ def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int64 or casted expression with type int64.
"""
- return _ffi_api.Int64(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int64(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -754,7 +903,7 @@ def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type uint8 or casted expression with type uint8.
"""
- return _ffi_api.UInt8(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.UInt8(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -770,7 +919,7 @@ def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type uint16 or casted expression with type uint16.
"""
- return _ffi_api.UInt16(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.UInt16(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -786,7 +935,7 @@ def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type uint32 or casted expression with type uint32.
"""
- return _ffi_api.UInt32(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.UInt32(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -802,7 +951,7 @@ def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type uint64 or casted expression with type uint64.
"""
- return _ffi_api.UInt64(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.UInt64(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -818,7 +967,7 @@ def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type float8 or casted expression with type float8.
"""
- return _ffi_api.Float8(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Float8(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -834,7 +983,7 @@ def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type float16 or casted expression with type
float16.
"""
- return _ffi_api.Float16(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Float16(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -850,7 +999,7 @@ def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type float32 or casted expression with type
float32.
"""
- return _ffi_api.Float32(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Float32(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -866,7 +1015,7 @@ def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type float64 or casted expression with type
float64.
"""
- return _ffi_api.Float64(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Float64(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -882,7 +1031,7 @@ def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int32x4 or casted expression with type
int32x4.
"""
- return _ffi_api.Int32x4(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int32x4(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -898,7 +1047,7 @@ def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int32x8 or casted expression with type
int32x8.
"""
- return _ffi_api.Int32x8(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int32x8(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -914,7 +1063,7 @@ def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type int32x16 or casted expression with type
int32x16.
"""
- return _ffi_api.Int32x16(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Int32x16(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -930,7 +1079,7 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type boolean or casted expression with type
boolean.
"""
- return _ffi_api.Boolean(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Boolean(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -946,7 +1095,7 @@ def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
- return _ffi_api.Handle(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Handle(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -962,7 +1111,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
res : PrimExpr
The new tir.Var with type void or casted expression with type void.
"""
- return _ffi_api.Void(expr) # pylint: disable=no-member # type: ignore
+ return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint:
disable=no-member
def var(dtype, name="") -> Var:
@@ -981,7 +1130,7 @@ def var(dtype, name="") -> Var:
res : Var
The result tir.Var.
"""
- return Var(name, dtype) # pylint: disable=no-member # type: ignore
+ return Var(name, dtype) # pylint: disable=no-member
# pylint: enable=invalid-name
@@ -997,6 +1146,12 @@ __all__ = [
"match_buffer",
"preflattened_buffer",
"block",
+ "init",
+ "where",
+ "reads",
+ "writes",
+ "block_attr",
+ "alloc_buffer",
"axis",
"serial",
"parallel",
diff --git a/src/script/ir_builder/tir/frame.cc
b/src/script/ir_builder/tir/frame.cc
index e54bf75eef..8b8b2a4d80 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -73,6 +73,20 @@ void BlockFrameNode::ExitWithScope() {
}
}
+void BlockInitFrameNode::EnterWithScope() {
+ BlockFrame frame = FindBlockFrame("T.init");
+ if (frame->init.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block init declaration";
+ }
+ TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+ TIRFrameNode::ExitWithScope();
+ BlockFrame frame = FindBlockFrame("T.init");
+ frame->init = AsStmt(stmts);
+}
+
void ForFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
@@ -81,6 +95,7 @@ void ForFrameNode::ExitWithScope() {
TVM_REGISTER_NODE_TYPE(TIRFrameNode);
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
TVM_REGISTER_NODE_TYPE(ForFrameNode);
} // namespace tir
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 5013e32172..75e7592626 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -173,6 +173,80 @@ BlockFrame Block(String name, bool no_realize) {
return BlockFrame(n);
}
+BlockInitFrame Init() { return
BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+ BlockFrame frame = FindBlockFrame("T.where");
+ if (frame->predicate.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous
one is "
+ << frame->predicate;
+ }
+ frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+ using namespace tvm::tir;
+ BlockFrame frame = FindBlockFrame("T.reads");
+ if (frame->reads.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one
is " << frame->reads;
+ }
+ Array<BufferRegion> reads;
+ for (const ObjectRef& obj : buffer_slices) {
+ if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+ reads.push_back(GetRef<BufferRegion>(buffer_region));
+ } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+ reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+ } else {
+ LOG(FATAL) << "Invalid type for buffer reads.";
+ }
+ }
+ frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+ using namespace tvm::tir;
+ BlockFrame frame = FindBlockFrame("T.writes");
+ if (frame->writes.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate write region declaration, previous
one is "
+ << frame->writes;
+ }
+ Array<BufferRegion> writes;
+ for (const ObjectRef& obj : buffer_slices) {
+ if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+ writes.push_back(GetRef<BufferRegion>(buffer_region));
+ } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+ writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+ } else {
+ LOG(FATAL) << "Invalid type for buffer writes.";
+ }
+ }
+ frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+ BlockFrame frame = FindBlockFrame("T.block_attr");
+ if (frame->annotations.defined()) {
+ LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is "
<< frame->annotations;
+ }
+ frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+ Array<PrimExpr> strides, PrimExpr elem_offset, String
storage_scope, int align,
+ int offset_factor, String buffer_type_str, Array<IntImm>
axis_separators) {
+ Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset,
storage_scope, align,
+ offset_factor, buffer_type_str, axis_separators);
+ IRBuilder builder = IRBuilder::Current();
+ if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+ frame.value()->alloc_buffers.push_back(buffer);
+ } else if (Optional<PrimFuncFrame> frame =
builder->GetLastFrame<PrimFuncFrame>()) {
+ frame.value()->root_alloc_buffers.push_back(buffer);
+ } else {
+ LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please
ensure "
+ "'T.alloc_buffer' is called under T.block() or
T.prim_func()";
+ }
+ return buffer;
+}
namespace axis {
IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
@@ -383,6 +457,12 @@
TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuf
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index d893ebc545..a5d8c10680 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -87,7 +87,7 @@ def test_ir_builder_tir_primfunc_complete():
assert_structural_equal(prim_func_actual, prim_func_expected,
map_free_vars=True)
-def test_ir_builder_tir_block():
+def test_ir_builder_tir_block_base():
with IRBuilder() as ib:
with T.block("block"):
T.evaluate(0)
@@ -114,6 +114,54 @@ def test_ir_builder_tir_block():
assert_structural_equal(block_realize_actual, block_realize_expected,
map_free_vars=True)
+def test_ir_builder_tir_block_complete():
+ with IRBuilder() as ib:
+ a = T.var("int64", "a")
+ b = T.buffer_decl((128, 128), "float32")
+ c = T.buffer_decl((128, 128), "float32")
+ d = T.var("int32", "d")
+ e = T.buffer_decl((128, 128), "float32")
+ f = T.var("int32", "f")
+ with T.block("block"):
+ T.where(a > 1)
+ T.reads(b[0:16, 0:16])
+ T.writes(c[d:128, d:128])
+ T.block_attr({"key": "value"})
+ T.alloc_buffer((128, 128), "float32")
+ T.match_buffer(e[0:32, 0:32], (32, 32), "float32")
+ T.axis.spatial(128, f)
+ T.evaluate(0)
+ # the block generated by IRBuilder
+ block_realize_actual = ib.get()
+
+ # the expected block
+ var_a = tir.Var("a", "int64")
+ buffer_b = tir.decl_buffer((128, 128), "float32", name="b")
+ buffer_c = tir.decl_buffer((128, 128), "float32", name="c")
+ var_d = tir.Var("d", "int32")
+ buffer_e = tir.decl_buffer((128, 128), "float32", name="c")
+ var_f = tir.Var("f", "int32")
+ block_expected = tir.Block(
+ iter_vars=[tir.IterVar((0, 128), tir.Var("", "int32"),
iter_type=tir.IterVar.DataPar)],
+ reads=[buffer_b[0:16, 0:16]],
+ writes=[buffer_c[var_d:128, var_d:128]],
+ name_hint="block",
+ body=tir.Evaluate(0),
+ alloc_buffers=[tir.decl_buffer((128, 128), "float32")],
+ match_buffers=[
+ tir.MatchBufferRegion(tir.decl_buffer((32, 32), "float32"),
buffer_e[0:32, 0:32])
+ ],
+ annotations={"key": "value"},
+ )
+ block_realize_expected = tir.BlockRealize(
+ iter_values=[var_f],
+ predicate=var_a > 1,
+ block=block_expected,
+ )
+ # Check if the generated ir is expected
+ assert_structural_equal(block_realize_actual, block_realize_expected,
map_free_vars=True)
+
+
def test_ir_builder_tir_axis():
with IRBuilder() as ib:
a = T.var("int32", "a")
diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh
index f165adfe1b..c3e5d50b3e 100755
--- a/tests/scripts/task_mypy.sh
+++ b/tests/scripts/task_mypy.sh
@@ -47,3 +47,6 @@ mypy --disallow-untyped-defs
python/tvm/relay/op/contrib/tensorrt.py
#TODO(@mikepapadim): This is failing atm
# echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu
package."
# mypy --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/
+
+echo "Checking MyPy Type defs in the tvmscript IRBuilder package."
+mypy --check-untyped-defs python/tvm/script/ir_builder