This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 41b65a3144 [TVMScript] IRBuilder methods for `Block` (#12815)
41b65a3144 is described below

commit 41b65a3144595afb04228be1334dc77c08d11ba7
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Sep 16 18:11:06 2022 -0700

    [TVMScript] IRBuilder methods for `Block` (#12815)
    
    This PR introduces remaining IRBuilder methods for `Block`.
    
    Co-authored-by: yongwww <[email protected]>
---
 include/tvm/script/ir_builder/tir/frame.h          |  35 +++
 include/tvm/script/ir_builder/tir/ir.h             |  49 +++++
 python/tvm/script/ir_builder/base.py               |  18 +-
 python/tvm/script/ir_builder/ir/ir.py              |   2 +-
 python/tvm/script/ir_builder/tir/frame.py          |   7 +-
 python/tvm/script/ir_builder/tir/ir.py             | 235 +++++++++++++++++----
 src/script/ir_builder/tir/frame.cc                 |  15 ++
 src/script/ir_builder/tir/ir.cc                    |  80 +++++++
 .../unittest/test_tvmscript_ir_builder_tir.py      |  50 ++++-
 tests/scripts/task_mypy.sh                         |   3 +
 10 files changed, 442 insertions(+), 52 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/frame.h 
b/include/tvm/script/ir_builder/tir/frame.h
index 2902b982d5..c76b400d96 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -187,6 +187,41 @@ class BlockFrame : public TIRFrame {
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, 
BlockFrameNode);
 };
 
+/*!
+ * \brief A frame that represents the block initialization statment.
+ *
+ * \sa BlockInitFrame
+ */
+class BlockInitFrameNode : public TIRFrameNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); }
+
+  static constexpr const char* _type_key = 
"script.ir_builder.tir.BlockInitFrame";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode);
+
+ public:
+  /*!
+   * \brief The method called when entering RAII scope.
+   * \sa tvm::support::With
+   */
+  void EnterWithScope() final;
+  /*!
+   * \brief The method called when exiting RAII scope.
+   * \sa tvm::support::With
+   */
+  void ExitWithScope() final;
+};
+
+/*!
+ * \brief Managed reference to BlockInitFrameNode.
+ *
+ * \sa BlockInitFrameNode
+ */
+class BlockInitFrame : public TIRFrame {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame, 
BlockInitFrameNode);
+};
+
 /*!
  * \brief A frame that represents the for loop.
  *
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 037606253a..191887648d 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer, 
Array<PrimExpr> shape,
  */
 BlockFrame Block(String name, bool no_realize = false);
 
+/*!
+ * \brief The block initialization statement.
+ * \return The BlockInitFrame.
+ */
+BlockInitFrame Init();
+
+/*!
+ * \brief The block predicate statement.
+ * \param predicate The predicate condition.
+ */
+void Where(PrimExpr predicate);
+
+/*!
+ * \brief The block buffer region reading statement.
+ * \param buffer_slices The array of buffer regions to read.
+ */
+void Reads(Array<ObjectRef> buffer_slices);
+
+/*!
+ * \brief The block buffer region writing statement.
+ * \param buffer_slices The array of buffer regions to write.
+ */
+void Writes(Array<ObjectRef> buffer_slices);
+
+/*!
+ * \brief The block annotation statement.
+ * \param attrs The annotation of the block.
+ */
+void BlockAttrs(Map<String, ObjectRef> attrs);
+
+/*!
+ * \brief The buffer allocation function.
+ * \param shape The type of the buffer prior to flattening.
+ * \param dtype The data type in the content of the buffer.
+ * \param data The pointer to the head of the data.
+ * \param strides The strides of each dimension.
+ * \param elem_offset The offset in terms of number of dtype elements 
(including lanes).
+ * \param storage_scope The optional storage scope of buffer data pointer.
+ * \param align The alignment requirement of data pointer in bytes.
+ * \param offset_factor The factor of elem_offset field.
+ * \param buffer_type The buffer type.
+ * \param axis_separators The separators between input axes when generating 
flattened output axes.
+ * \return The allocated buffer.
+ */
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
+                   Optional<Var> data = NullOpt, Array<PrimExpr> strides = {},
+                   PrimExpr elem_offset = PrimExpr(), String storage_scope = 
"", int align = -1,
+                   int offset_factor = 0, String buffer_type = "default",
+                   Array<IntImm> axis_separators = {});
 namespace axis {
 
 /*!
diff --git a/python/tvm/script/ir_builder/base.py 
b/python/tvm/script/ir_builder/base.py
index 767fa8bf25..7aa33ee49c 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -61,11 +61,11 @@ class IRBuilderFrame(_Object):
     """
 
     def __enter__(self) -> "IRBuilderFrame":
-        _ffi_api.IRBuilderFrameEnter(self)  # pylint: disable=no-member # 
type: ignore
+        _ffi_api.IRBuilderFrameEnter(self)  # type: ignore[attr-defined] # 
pylint: disable=no-member
         return self
 
     def __exit__(self, ptype, value, trace) -> None:  # pylint: 
disable=unused-argument
-        _ffi_api.IRBuilderFrameExit(self)  # pylint: disable=no-member # type: 
ignore
+        _ffi_api.IRBuilderFrameExit(self)  # type: ignore[attr-defined] # 
pylint: disable=no-member
 
     def add_callback(self, callback: Callable[[], None]) -> None:
         """Add a callback method invoked when exiting the with-scope.
@@ -75,7 +75,7 @@ class IRBuilderFrame(_Object):
         callback : Callable[[], None]
             The callback method to be invoked.
         """
-        _ffi_api.IRBuilderFrameAddCallback(  # pylint: disable=no-member # 
type: ignore
+        _ffi_api.IRBuilderFrameAddCallback(  # type: ignore[attr-defined] # 
pylint: disable=no-member
             self, callback
         )
 
@@ -104,7 +104,7 @@ class IRBuilder(_Object):
     def __init__(self) -> None:
         """Construct an IRBuilder."""
         self.__init_handle_by_constructor__(
-            _ffi_api.IRBuilder  # pylint: disable=no-member # type: ignore
+            _ffi_api.IRBuilder  # type: ignore[attr-defined] # pylint: 
disable=no-member
         )
 
     def __enter__(self) -> "IRBuilder":
@@ -119,11 +119,11 @@ class IRBuilder(_Object):
         with IRBuilder() as builder:
             assert IRBuilder.current() == builder
         """
-        _ffi_api.IRBuilderEnter(self)  # pylint: disable=no-member # type: 
ignore
+        _ffi_api.IRBuilderEnter(self)  # type: ignore[attr-defined] # pylint: 
disable=no-member
         return self
 
     def __exit__(self, ptype, value, trace) -> None:  # pylint: 
disable=unused-argument
-        _ffi_api.IRBuilderExit(self)  # pylint: disable=no-member # type: 
ignore
+        _ffi_api.IRBuilderExit(self)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
     @staticmethod
     def current() -> "IRBuilder":
@@ -134,11 +134,11 @@ class IRBuilder(_Object):
         builder : IRBuilder
             The current IRBuilder.
         """
-        return _ffi_api.IRBuilderCurrent()  # pylint: disable=no-member # 
type: ignore
+        return _ffi_api.IRBuilderCurrent()  # type: ignore[attr-defined] # 
pylint: disable=no-member
 
     def get(self) -> _Object:
         """Get the constructed IR."""
-        return _ffi_api.IRBuilderGet(self)  # pylint: disable=no-member # 
type: ignore
+        return _ffi_api.IRBuilderGet(self)  # type: ignore[attr-defined] # 
pylint: disable=no-member
 
     @staticmethod
     def name(s: str, v: Any) -> Any:
@@ -156,7 +156,7 @@ class IRBuilder(_Object):
         v : Any
             The same object with the name set.
         """
-        return _ffi_api.IRBuilderName(s, v)  # pylint: disable=no-member # 
type: ignore
+        return _ffi_api.IRBuilderName(s, v)  # type: ignore[attr-defined] # 
pylint: disable=no-member
 
     @staticmethod
     def name_many(  # pylint: disable=invalid-name
diff --git a/python/tvm/script/ir_builder/ir/ir.py 
b/python/tvm/script/ir_builder/ir/ir.py
index df92036435..213180463c 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -21,4 +21,4 @@ from .frame import IRModuleFrame
 
 
 def ir_module() -> IRModuleFrame:
-    return _ffi_api.IRModule()  # pylint: disable=no-member # type: ignore
+    return _ffi_api.IRModule()  # type: ignore[attr-defined] # pylint: 
disable=no-member
diff --git a/python/tvm/script/ir_builder/tir/frame.py 
b/python/tvm/script/ir_builder/tir/frame.py
index 75bb0231ae..2ad08f3516 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -38,8 +38,13 @@ class BlockFrame(TIRFrame):
     ...
 
 
+@_register_object("script.ir_builder.tir.BlockInitFrame")
+class BlockInitFrame(TIRFrame):
+    ...
+
+
 @_register_object("script.ir_builder.tir.ForFrame")
 class ForFrame(TIRFrame):
-    def __enter__(self) -> Union[Var, List[Var]]:
+    def __enter__(self) -> Union[Var, List[Var]]:  # type: ignore[override]
         super().__enter__()
         return self.vars if len(self.vars) > 1 else self.vars[0]
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 40cd99c744..d1dc1c8960 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -25,6 +25,7 @@ from tvm.tir import (
     Buffer,
     BufferLoad,
     BufferRegion,
+    IntImm,
     PrimExpr,
     StringImm,
     Var,
@@ -85,7 +86,7 @@ def buffer_decl(
         The declared buffer.
     """
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
-    return _ffi_api.BufferDecl(  # pylint: disable=no-member # type: ignore
+    return _ffi_api.BufferDecl(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         shape,
         dtype,
         "",
@@ -108,7 +109,7 @@ def prim_func() -> frame.PrimFuncFrame:
     res : frame.PrimFuncFrame
         The PrimFuncFrame.
     """
-    return _ffi_api.PrimFunc()  # pylint: disable=no-member # type: ignore
+    return _ffi_api.PrimFunc()  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, Buffer]:
@@ -127,7 +128,7 @@ def arg(name: str, obj: Union[Var, Buffer]) -> Union[Var, 
Buffer]:
     res : Union[Var, Buffer]
         The argument.
     """
-    return _ffi_api.Arg(name, obj)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Arg(name, obj)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def func_name(name: str) -> None:
@@ -138,7 +139,7 @@ def func_name(name: str) -> None:
     name : str
         The name of the PrimFunc.
     """
-    _ffi_api.FuncName(name)  # pylint: disable=no-member # type: ignore
+    _ffi_api.FuncName(name)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def func_attr(attrs: Dict[str, Any]) -> None:
@@ -149,7 +150,7 @@ def func_attr(attrs: Dict[str, Any]) -> None:
     attrs : Dict[str, Any]
         The annotations of the PrimFunc.
     """
-    _ffi_api.FuncAttrs(attrs)  # pylint: disable=no-member # type: ignore
+    _ffi_api.FuncAttrs(attrs)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def func_ret(ret_type: Type) -> Type:
@@ -165,7 +166,7 @@ def func_ret(ret_type: Type) -> Type:
     res : Type
         The return type.
     """
-    return _ffi_api.FuncRet(ret_type)  # pylint: disable=no-member # type: 
ignore
+    return _ffi_api.FuncRet(ret_type)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def match_buffer(
@@ -242,7 +243,7 @@ def match_buffer(
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
     if strides is None:
         strides = []
-    return _ffi_api.MatchBuffer(  # pylint: disable=no-member # type: ignore
+    return _ffi_api.MatchBuffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         param,
         shape,
         dtype,
@@ -310,7 +311,7 @@ def preflattened_buffer(
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
     if strides is None:
         strides = []
-    _ffi_api.PreflattenedBuffer(  # pylint: disable=no-member # type: ignore
+    _ffi_api.PreflattenedBuffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         postflattened,
         shape,
         dtype,
@@ -341,7 +342,155 @@ def block(name: str = "", no_realize: bool = False) -> 
frame.BlockFrame:
     res : frame.BlockFrame
         The BlockFrame.
     """
-    return _ffi_api.Block(name, no_realize)  # pylint: disable=no-member # 
type: ignore
+    return _ffi_api.Block(name, no_realize)  # type: ignore[attr-defined] # 
pylint: disable=no-member
+
+
+def init() -> frame.BlockInitFrame:
+    """The block initialization statement.
+
+    Returns
+    -------
+    res : frame.BlockInitFrame
+        The BlockInitFrame.
+    """
+    return _ffi_api.Init()  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def where(predicate: Union[PrimExpr, int]) -> None:
+    """The block predicate statement.
+
+    Parameters
+    ----------
+    predicate : Union[PrimExpr, Literal[0, 1]]
+        The predicate condition.
+    """
+    if isinstance(predicate, bool):
+        predicate = IntImm("bool", predicate)
+    if isinstance(predicate, int):
+        if predicate in [0, 1]:
+            predicate = IntImm("bool", predicate)
+        else:
+            raise ValueError(f"Invalid value for predicate: {predicate}")
+    _ffi_api.Where(predicate)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def reads(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    """The block buffer region reading statement.
+
+    Parameters
+    ----------
+    buffer_slices : List[Union[BufferRegion, BufferLoad]]
+        The array of buffer regions to read.
+    """
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore[assignment]
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore[assignment]
+    _ffi_api.Reads(buffer_slices)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def writes(*buffer_slices: List[Union[BufferRegion, BufferLoad]]) -> None:
+    """The block buffer region writing statement.
+
+    Parameters
+    ----------
+    buffer_slices : List[Union[BufferRegion, BufferLoad]]
+        The array of buffer regions to write.
+    """
+    if len(buffer_slices) == 1:
+        if isinstance(buffer_slices[0], tuple):
+            buffer_slices = list(buffer_slices[0])
+        elif isinstance(buffer_slices[0], list):
+            buffer_slices = buffer_slices[0]  # type: ignore[assignment]
+        else:
+            buffer_slices = [buffer_slices[0]]
+    else:
+        buffer_slices = list(buffer_slices)  # type: ignore[assignment]
+    _ffi_api.Writes(buffer_slices)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def block_attr(attrs: Dict[str, Any]) -> None:
+    """The block annotation statement.
+
+    Parameters
+    ----------
+    attrs : Dict[str, Any]
+        The annotation of the block.
+    """
+    return _ffi_api.BlockAttrs(attrs)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def alloc_buffer(
+    shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: List[PrimExpr] = None,
+    elem_offset: PrimExpr = None,
+    scope: str = "",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+    axis_separators: List[int] = None,
+) -> Buffer:
+    """The buffer alllocation function.
+
+    Parameters
+    ----------
+    shape : Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral]
+        The type of the buffer prior to flattening.
+
+    dtype : str
+        The data type in the content of the buffer.
+
+    data : Var
+        The pointer to the head of the data.
+
+    strides : List[PrimExpr]
+        The strides of each dimension.
+
+    elem_offset : PrimExpr
+        The offset in terms of number of dtype elements (including lanes).
+
+    scope : str
+        The optional storage scope of buffer data pointer.
+
+    align : int
+        The alignment requirement of data pointer in bytes.
+
+    offset_factor : int
+        The factor of elem_offset field.
+
+    buffer_type : str
+        The buffer type.
+
+    axis_separators : List[int]
+        The separators between input axes when generating flattened output 
axes.
+
+    Returns
+    -------
+    res : Buffer
+        The allocated buffer.
+    """
+    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is None:
+        strides = []
+    return _ffi_api.AllocBuffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
+        shape,
+        dtype,
+        data,
+        strides,
+        elem_offset,
+        scope,
+        align,
+        offset_factor,
+        buffer_type,
+        axis_separators,
+    )
 
 
 def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
@@ -387,7 +536,7 @@ class axis:  # pylint: disable=invalid-name
         res : Var
             The iteration variable.
         """
-        return _ffi_api.AxisSpatial(  # pylint: disable=no-member # type: 
ignore
+        return _ffi_api.AxisSpatial(  # type: ignore[attr-defined] # pylint: 
disable=no-member
             _as_range(dom), binding, dtype
         )
 
@@ -413,7 +562,7 @@ class axis:  # pylint: disable=invalid-name
         res : Var
             The iteration variable.
         """
-        return _ffi_api.AxisReduce(  # pylint: disable=no-member # type: ignore
+        return _ffi_api.AxisReduce(  # type: ignore[attr-defined] # pylint: 
disable=no-member
             _as_range(dom), binding, dtype
         )
 
@@ -439,7 +588,7 @@ class axis:  # pylint: disable=invalid-name
         res : Var
             The iteration variable.
         """
-        return _ffi_api.AxisScan(  # pylint: disable=no-member # type: ignore
+        return _ffi_api.AxisScan(  # type: ignore[attr-defined] # pylint: 
disable=no-member
             _as_range(dom), binding, dtype
         )
 
@@ -465,7 +614,7 @@ class axis:  # pylint: disable=invalid-name
         res : Var
             The iteration variable.
         """
-        return _ffi_api.AxisOpaque(  # pylint: disable=no-member # type: ignore
+        return _ffi_api.AxisOpaque(  # type: ignore[attr-defined] # pylint: 
disable=no-member
             _as_range(dom), binding, dtype
         )
 
@@ -489,7 +638,7 @@ class axis:  # pylint: disable=invalid-name
         res : Var
             The iteration variables.
         """
-        iter_vars = _ffi_api.AxisRemap(  # pylint: disable=no-member # type: 
ignore
+        iter_vars = _ffi_api.AxisRemap(  # type: ignore[attr-defined] # 
pylint: disable=no-member
             kinds, bindings, dtype
         )
         return iter_vars[0] if len(iter_vars) == 1 else iter_vars
@@ -522,7 +671,7 @@ def serial(
     if stop is None:
         stop = start
         start = 0
-    return _ffi_api.Serial(start, stop, annotations)  # pylint: 
disable=no-member # type: ignore
+    return _ffi_api.Serial(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def parallel(
@@ -549,7 +698,7 @@ def parallel(
     if stop is None:
         stop = start
         start = 0
-    return _ffi_api.Parallel(start, stop, annotations)  # pylint: 
disable=no-member # type: ignore
+    return _ffi_api.Parallel(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def vectorized(
@@ -576,7 +725,7 @@ def vectorized(
     if stop is None:
         stop = start
         start = 0
-    return _ffi_api.Vectorized(start, stop, annotations)  # pylint: 
disable=no-member # type: ignore
+    return _ffi_api.Vectorized(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def unroll(
@@ -603,7 +752,7 @@ def unroll(
     if stop is None:
         stop = start
         start = 0
-    return _ffi_api.Unroll(start, stop, annotations)  # pylint: 
disable=no-member # type: ignore
+    return _ffi_api.Unroll(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def thread_binding(
@@ -643,7 +792,7 @@ def thread_binding(
     elif stop is None:
         stop = start
         start = 0
-    return _ffi_api.ThreadBinding(  # pylint: disable=no-member # type: ignore
+    return _ffi_api.ThreadBinding(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         start, stop, thread, annotations
     )
 
@@ -661,7 +810,7 @@ def grid(*extents: PrimExpr) -> frame.ForFrame:
     res : frame.ForFrame
         The ForFrame.
     """
-    return _ffi_api.Grid(extents)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Grid(extents)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def evaluate(value: PrimExpr) -> None:
@@ -674,7 +823,7 @@ def evaluate(value: PrimExpr) -> None:
     """
     if isinstance(value, str):
         value = StringImm(value)
-    return _ffi_api.Evaluate(value)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Evaluate(value)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -690,7 +839,7 @@ def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int8 or casted expression with type int8.
     """
-    return _ffi_api.Int8(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int8(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -706,7 +855,7 @@ def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int16 or casted expression with type int16.
     """
-    return _ffi_api.Int16(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int16(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -722,7 +871,7 @@ def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int32 or casted expression with type int32.
     """
-    return _ffi_api.Int32(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int32(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -738,7 +887,7 @@ def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int64 or casted expression with type int64.
     """
-    return _ffi_api.Int64(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int64(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -754,7 +903,7 @@ def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type uint8 or casted expression with type uint8.
     """
-    return _ffi_api.UInt8(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.UInt8(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -770,7 +919,7 @@ def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type uint16 or casted expression with type uint16.
     """
-    return _ffi_api.UInt16(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.UInt16(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -786,7 +935,7 @@ def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type uint32 or casted expression with type uint32.
     """
-    return _ffi_api.UInt32(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.UInt32(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -802,7 +951,7 @@ def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type uint64 or casted expression with type uint64.
     """
-    return _ffi_api.UInt64(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.UInt64(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -818,7 +967,7 @@ def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type float8 or casted expression with type float8.
     """
-    return _ffi_api.Float8(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Float8(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -834,7 +983,7 @@ def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type float16 or casted expression with type 
float16.
     """
-    return _ffi_api.Float16(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Float16(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -850,7 +999,7 @@ def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type float32 or casted expression with type 
float32.
     """
-    return _ffi_api.Float32(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Float32(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -866,7 +1015,7 @@ def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type float64 or casted expression with type 
float64.
     """
-    return _ffi_api.Float64(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Float64(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -882,7 +1031,7 @@ def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int32x4 or casted expression with type 
int32x4.
     """
-    return _ffi_api.Int32x4(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int32x4(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -898,7 +1047,7 @@ def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int32x8 or casted expression with type 
int32x8.
     """
-    return _ffi_api.Int32x8(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int32x8(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -914,7 +1063,7 @@ def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type int32x16 or casted expression with type 
int32x16.
     """
-    return _ffi_api.Int32x16(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Int32x16(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -930,7 +1079,7 @@ def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type boolean or casted expression with type 
boolean.
     """
-    return _ffi_api.Boolean(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Boolean(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -946,7 +1095,7 @@ def handle(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type handle or casted expression with type handle.
     """
-    return _ffi_api.Handle(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Handle(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -962,7 +1111,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
     res : PrimExpr
         The new tir.Var with type void or casted expression with type void.
     """
-    return _ffi_api.Void(expr)  # pylint: disable=no-member # type: ignore
+    return _ffi_api.Void(expr)  # type: ignore[attr-defined] # pylint: 
disable=no-member
 
 
 def var(dtype, name="") -> Var:
@@ -981,7 +1130,7 @@ def var(dtype, name="") -> Var:
     res : Var
         The result tir.Var.
     """
-    return Var(name, dtype)  # pylint: disable=no-member # type: ignore
+    return Var(name, dtype)  # pylint: disable=no-member
 
 
 # pylint: enable=invalid-name
@@ -997,6 +1146,12 @@ __all__ = [
     "match_buffer",
     "preflattened_buffer",
     "block",
+    "init",
+    "where",
+    "reads",
+    "writes",
+    "block_attr",
+    "alloc_buffer",
     "axis",
     "serial",
     "parallel",
diff --git a/src/script/ir_builder/tir/frame.cc 
b/src/script/ir_builder/tir/frame.cc
index e54bf75eef..8b8b2a4d80 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -73,6 +73,20 @@ void BlockFrameNode::ExitWithScope() {
   }
 }
 
+void BlockInitFrameNode::EnterWithScope() {
+  BlockFrame frame = FindBlockFrame("T.init");
+  if (frame->init.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block init declaration";
+  }
+  TIRFrameNode::EnterWithScope();
+}
+
+void BlockInitFrameNode::ExitWithScope() {
+  TIRFrameNode::ExitWithScope();
+  BlockFrame frame = FindBlockFrame("T.init");
+  frame->init = AsStmt(stmts);
+}
+
 void ForFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
   AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
@@ -81,6 +95,7 @@ void ForFrameNode::ExitWithScope() {
 TVM_REGISTER_NODE_TYPE(TIRFrameNode);
 TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
 TVM_REGISTER_NODE_TYPE(BlockFrameNode);
+TVM_REGISTER_NODE_TYPE(BlockInitFrameNode);
 TVM_REGISTER_NODE_TYPE(ForFrameNode);
 
 }  // namespace tir
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 5013e32172..75e7592626 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -173,6 +173,80 @@ BlockFrame Block(String name, bool no_realize) {
   return BlockFrame(n);
 }
 
+BlockInitFrame Init() { return 
BlockInitFrame(make_object<BlockInitFrameNode>()); }
+
+void Where(PrimExpr predicate) {
+  BlockFrame frame = FindBlockFrame("T.where");
+  if (frame->predicate.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block predicate declaration, previous 
one is "
+               << frame->predicate;
+  }
+  frame->predicate = predicate;
+}
+
+void Reads(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.reads");
+  if (frame->reads.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate read region declaration, previous one 
is " << frame->reads;
+  }
+  Array<BufferRegion> reads;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      reads.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      reads.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer reads.";
+    }
+  }
+  frame->reads = reads;
+}
+
+void Writes(Array<ObjectRef> buffer_slices) {
+  using namespace tvm::tir;
+  BlockFrame frame = FindBlockFrame("T.writes");
+  if (frame->writes.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate write region declaration, previous 
one is "
+               << frame->writes;
+  }
+  Array<BufferRegion> writes;
+  for (const ObjectRef& obj : buffer_slices) {
+    if (const auto* buffer_region = obj.as<BufferRegionNode>()) {
+      writes.push_back(GetRef<BufferRegion>(buffer_region));
+    } else if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
+      writes.push_back(BufferRegionFromLoad(GetRef<BufferLoad>(buffer_load)));
+    } else {
+      LOG(FATAL) << "Invalid type for buffer writes.";
+    }
+  }
+  frame->writes = writes;
+}
+
+void BlockAttrs(Map<String, ObjectRef> attrs) {
+  BlockFrame frame = FindBlockFrame("T.block_attr");
+  if (frame->annotations.defined()) {
+    LOG(FATAL) << "ValueError: Duplicate block annotations, previous one is " 
<< frame->annotations;
+  }
+  frame->annotations = attrs;
+}
+
+Buffer AllocBuffer(Array<PrimExpr> shape, DataType dtype, Optional<Var> data,
+                   Array<PrimExpr> strides, PrimExpr elem_offset, String 
storage_scope, int align,
+                   int offset_factor, String buffer_type_str, Array<IntImm> 
axis_separators) {
+  Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, 
storage_scope, align,
+                             offset_factor, buffer_type_str, axis_separators);
+  IRBuilder builder = IRBuilder::Current();
+  if (Optional<BlockFrame> frame = builder->GetLastFrame<BlockFrame>()) {
+    frame.value()->alloc_buffers.push_back(buffer);
+  } else if (Optional<PrimFuncFrame> frame = 
builder->GetLastFrame<PrimFuncFrame>()) {
+    frame.value()->root_alloc_buffers.push_back(buffer);
+  } else {
+    LOG(FATAL) << "ValueError: Block frame or PrimFunc frame not find. Please 
ensure "
+                  "'T.alloc_buffer' is called under T.block() or 
T.prim_func()";
+  }
+  return buffer;
+}
 namespace axis {
 
 IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
@@ -383,6 +457,12 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.MatchBuffer").set_body_typed(MatchBuf
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(PreflattenedBuffer);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Init").set_body_typed(Init);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Where").set_body_typed(Where);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Reads").set_body_typed(Reads);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Writes").set_body_typed(Writes);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.BlockAttrs").set_body_typed(BlockAttrs);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocBuffer").set_body_typed(AllocBuffer);
 
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py 
b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index d893ebc545..a5d8c10680 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -87,7 +87,7 @@ def test_ir_builder_tir_primfunc_complete():
     assert_structural_equal(prim_func_actual, prim_func_expected, 
map_free_vars=True)
 
 
-def test_ir_builder_tir_block():
+def test_ir_builder_tir_block_base():
     with IRBuilder() as ib:
         with T.block("block"):
             T.evaluate(0)
@@ -114,6 +114,54 @@ def test_ir_builder_tir_block():
     assert_structural_equal(block_realize_actual, block_realize_expected, 
map_free_vars=True)
 
 
+def test_ir_builder_tir_block_complete():
+    with IRBuilder() as ib:
+        a = T.var("int64", "a")
+        b = T.buffer_decl((128, 128), "float32")
+        c = T.buffer_decl((128, 128), "float32")
+        d = T.var("int32", "d")
+        e = T.buffer_decl((128, 128), "float32")
+        f = T.var("int32", "f")
+        with T.block("block"):
+            T.where(a > 1)
+            T.reads(b[0:16, 0:16])
+            T.writes(c[d:128, d:128])
+            T.block_attr({"key": "value"})
+            T.alloc_buffer((128, 128), "float32")
+            T.match_buffer(e[0:32, 0:32], (32, 32), "float32")
+            T.axis.spatial(128, f)
+            T.evaluate(0)
+    # the block generated by IRBuilder
+    block_realize_actual = ib.get()
+
+    # the expected block
+    var_a = tir.Var("a", "int64")
+    buffer_b = tir.decl_buffer((128, 128), "float32", name="b")
+    buffer_c = tir.decl_buffer((128, 128), "float32", name="c")
+    var_d = tir.Var("d", "int32")
+    buffer_e = tir.decl_buffer((128, 128), "float32", name="c")
+    var_f = tir.Var("f", "int32")
+    block_expected = tir.Block(
+        iter_vars=[tir.IterVar((0, 128), tir.Var("", "int32"), 
iter_type=tir.IterVar.DataPar)],
+        reads=[buffer_b[0:16, 0:16]],
+        writes=[buffer_c[var_d:128, var_d:128]],
+        name_hint="block",
+        body=tir.Evaluate(0),
+        alloc_buffers=[tir.decl_buffer((128, 128), "float32")],
+        match_buffers=[
+            tir.MatchBufferRegion(tir.decl_buffer((32, 32), "float32"), 
buffer_e[0:32, 0:32])
+        ],
+        annotations={"key": "value"},
+    )
+    block_realize_expected = tir.BlockRealize(
+        iter_values=[var_f],
+        predicate=var_a > 1,
+        block=block_expected,
+    )
+    # Check if the generated ir is expected
+    assert_structural_equal(block_realize_actual, block_realize_expected, 
map_free_vars=True)
+
+
 def test_ir_builder_tir_axis():
     with IRBuilder() as ib:
         a = T.var("int32", "a")
diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh
index f165adfe1b..c3e5d50b3e 100755
--- a/tests/scripts/task_mypy.sh
+++ b/tests/scripts/task_mypy.sh
@@ -47,3 +47,6 @@ mypy --disallow-untyped-defs 
python/tvm/relay/op/contrib/tensorrt.py
 #TODO(@mikepapadim): This is failing atm
 # echo "Checking MyPy Type defs in the tvm.relay.backend.contrib.ethosu 
package."
 # mypy  --check-untyped-defs python/tvm/relay/backend/contrib/ethosu/
+
+echo "Checking MyPy Type defs in the tvmscript IRBuilder package."
+mypy  --check-untyped-defs python/tvm/script/ir_builder


Reply via email to