This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c7970ddd79 [TensorIR] New schedule primitive `set_dtype` (#14316)
c7970ddd79 is described below

commit c7970ddd79a1e6bbf5e07ab2c515ec4991242ef7
Author: Zihao Ye <[email protected]>
AuthorDate: Wed Mar 22 00:10:54 2023 -0700

    [TensorIR] New schedule primitive `set_dtype` (#14316)
    
    # Motivation
    Currently, we miss a schedule primitive to change the data type of 
allocated buffer (e.g. via `cache_read`/`cache_write`), and thus we cannot 
perform type conversion while loading data from global to shared memory.
    
    This PR adds a new schedule primitive `set_dtype` that follows the 
interface of `set_scope` and allows users to customize the allocated buffers' 
data type.
    
    # Example
    Before running `set_dtype`:
    ```python
    @T.prim_func
    def before_set_dtype(
        A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
    ) -> None:
        B = T.alloc_buffer((128, 128), dtype="float32")
    
        for i, j in T.grid(128, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j]
                C[vi, vj] = B[vi, vj] + 1.0
    ```
    then we perform the `set_dtype` schedule:
    ```python
    sch = tir.Schedule(before_set_dtype)
    sch.set_dtype("B", buffer_index=0, dtype="float16")
    print(sch.mod["main"].script())
    ```
    we get transformed code:
    ```python
    @T.prim_func
    def after_set_dtype(
        A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
    ) -> None:
        B = T.alloc_buffer((128, 128), dtype="float16")
    
        for i, j in T.grid(128, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
        for i, j in T.grid(128, 128):
            with T.block("C"):
                vi, vj = T.axis.remap("SS", [i, j]
                C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
    ```
    where data type conversions are inserted automatically.
    
    # Other Usage
    Using the combination of `cache_read` + `set_dtype` can help us load data 
from the memory hierarchy while converting data to the desired type.
---
 include/tvm/tir/schedule/schedule.h                |  12 +-
 python/tvm/tir/schedule/schedule.py                |  79 ++++++++++++-
 src/tir/schedule/concrete_schedule.cc              |   8 ++
 src/tir/schedule/concrete_schedule.h               |   1 +
 src/tir/schedule/primitive.h                       |  12 ++
 src/tir/schedule/primitive/block_annotate.cc       | 117 +++++++++++++++++++
 src/tir/schedule/schedule.cc                       |   2 +
 src/tir/schedule/traced_schedule.cc                |  11 ++
 src/tir/schedule/traced_schedule.h                 |   1 +
 src/tir/schedule/transform.cc                      |  10 ++
 src/tir/schedule/transform.h                       |  12 +-
 .../python/unittest/test_tir_schedule_set_dtype.py | 125 +++++++++++++++++++++
 12 files changed, 385 insertions(+), 5 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 01255e6e3f..570560c62d 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -589,13 +589,23 @@ class ScheduleNode : public runtime::Object {
   virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int 
axis, int factor,
                             int offset) = 0;
   /*!
-   * \brief Set the storage scope of a buffer, where the buffer is specified 
by the a block and a
+   * \brief Set the storage scope of a buffer, where the buffer is specified 
by a block and a
    * write-index
    * \param block_rv The producer block of the buffer
    * \param buffer_index The index of the buffer in block's write region
    * \param storage_scope The storage scope to be set
    */
   virtual void SetScope(const BlockRV& block_rv, int buffer_index, const 
String& storage_scope) = 0;
+  /*!
+   * \brief Set the data type of a buffer, where the buffer is specified by a 
block and a
+   * write-index
+   * \note This schedule primitive is unsafe and may change correctness of 
program because of
+   *   type conversion, please use with caution.
+   * \param block_rv The producer block of the buffer
+   * \param buffer_index the index of the buffer in block's write region
+   * \param dtype The data type to be set
+   */
+  virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const 
String& dtype) = 0;
   /******** Schedule: Blockize & Tensorize ********/
   /*!
    * \brief Convert the subtree rooted at a specific loop into a block.
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index c27007682a..68f0b9454c 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2369,7 +2369,7 @@ class Schedule(Object):
         self, block: Union[BlockRV, str], buffer_index: Union[int, str, 
Buffer], storage_scope: str
     ) -> None:
         """Set the storage scope of a buffer, where the buffer is
-        specified by the a block and a write-index
+        specified by the a block and a write-index.
 
         Parameters
         ----------
@@ -2431,7 +2431,7 @@ class Schedule(Object):
 
         Note
         ----
-        Set_scope requires the buffer to be an intermediate buffer defined via 
`alloc_buffer`.
+        `set_scope` requires the buffer to be an intermediate buffer defined 
via `alloc_buffer`.
         """
         block = self._normalize_block_arg(block)
         if not isinstance(buffer_index, int):
@@ -2442,6 +2442,81 @@ class Schedule(Object):
             self, block, buffer_index, storage_scope
         )
 
+    @type_checked
+    def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, 
dtype: str) -> None:
+        """Set the data type of a buffer, where the buffer is
+        specified by the a block and write-index.
+
+        This schedule primitive is unsafe and may change the correctness of 
program because of
+        type conversion, please use with caution.
+
+        Parameters
+        ----------
+        block : Union[BlockRV, str]
+            The producer block of the buffer
+        buffer_index : int
+            The index of the buffer in block's write region
+        dtype : str
+            The data type to be set
+
+        Examples
+        --------
+
+        Before set_dtype, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def before_set_dtype(
+                A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), 
"float32")
+            ) -> None:
+                B = T.alloc_buffer((128, 128), dtype="float32")
+
+                for i, j in T.grid(128, 128):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = A[vi, vj] * 2.0
+                for i, j in T.grid(128, 128):
+                    with T.block("C"):
+                        vi, vj = T.axis.remap("SS", [i, j]
+                        C[vi, vj] = B[vi, vj] + 1.0
+
+        Create the schedule and do set_dtype:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_set_dtype)
+            sch.set_dtype("B", buffer_index=0, dtype="float16")
+            print(sch.mod["main"].script())
+
+        After applying set_dtype, the IR becomes:
+
+        .. code-block:: python
+
+            @T.prim_func
+            def after_set_dtype(
+                A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), 
"float32")
+            ) -> None:
+                B = T.alloc_buffer((128, 128), dtype="float16")
+
+                for i, j in T.grid(128, 128):
+                    with T.block("B"):
+                        vi, vj = T.axis.remap("SS", [i, j])
+                        B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
+                for i, j in T.grid(128, 128):
+                    with T.block("C"):
+                        vi, vj = T.axis.remap("SS", [i, j]
+                        C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
+
+        Note
+        ----
+        `set_dtype` requires the buffer to be an intermediate buffer defined 
via `alloc_buffer`.
+        """
+        block = self._normalize_block_arg(block)
+        _ffi_api.ScheduleUnsafeSetDType(  # type: ignore # pylint: 
disable=no-member
+            self, block, buffer_index, dtype
+        )
+
     ########## Schedule: Blockize & Tensorize ##########
 
     @type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 6593579725..93ea38169d 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -727,6 +727,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV& 
block_rv, int buffer_index,
   this->state_->DebugVerify();
 }
 
+void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int 
buffer_index,
+                                          const String& dtype) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype);
+  TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_);
+  this->state_->DebugVerify();
+}
+
 /******** Schedule: Reduction ********/
 
 BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, 
const LoopRV& loop_rv) {
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 290b6a4456..227288b232 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -146,6 +146,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int 
factor,
                     int offset) override;
   void SetScope(const BlockRV& block_rv, int buffer_index, const String& 
storage_scope) override;
+  void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& 
dtype) override;
   /******** Schedule: Blockize & Tensorize ********/
   BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
   void Tensorize(const BlockRV& block_rv, const String& intrin, bool 
preserve_unit_iters) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 89cdf68a45..09185498e1 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -479,6 +479,18 @@ TVM_DLL void StorageAlign(ScheduleState self, const 
StmtSRef& block_sref, int bu
  */
 TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
                       const String& storage_scope);
+/*!
+ * \brief Set the data type of a buffer, where the buffer is specified by a 
block and a
+ * write-index
+ * \note This schedule primitive is unsafe and may change correctness of 
program because of
+ *   type conversion, please use with caution.
+ * \param self The state of the schedule
+ * \param block_sref The sref of the producer block of the buffer
+ * \param buffer_index The index of the buffer in block's write region
+ * \param dtype The data type to be set
+ */
+TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, 
int buffer_index,
+                            const String& dtype);
 /*!
  * \brief Set the axis separator of a buffer, where the buffer is specified by 
a block and a read
  * or write index
diff --git a/src/tir/schedule/primitive/block_annotate.cc 
b/src/tir/schedule/primitive/block_annotate.cc
index 0912e36836..3f1789b3d6 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -16,6 +16,8 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/tir/expr.h>
+
 #include "../utils.h"
 
 namespace tvm {
@@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef& 
block_sref, int buffer_index,
   self->Replace(alloc_site_sref, new_block, block_reuse_map);
 }
 
+/*!
+ * \brief A helper mutator which recursively mutates the old buffer's data 
type, inserts data type
+ * conversions, and collecte the block sref reuse information for the 
following replacement.
+ */
+class DTypeMutator : private ReplaceBufferMutator {
+ public:
+  /*!
+   * \param allocate_site The block where `old_buffer` was allocated.
+   * \param old_buffer The old buffer
+   * \param target_dtype The data type to be set
+   * \param block_sref_reuse The block sref reuse map to be updated
+   * \return The new block after the mutation
+   */
+  static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, 
const DataType& dtype,
+                      Map<Block, Block>* block_sref_reuse) {
+    Buffer new_buffer = WithDType(old_buffer, dtype);
+    DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse);
+    Stmt new_block = mutator.VisitStmt(allocate_site);
+    return Downcast<Block>(new_block);
+  }
+
+ private:
+  DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& 
dtype,
+               Map<Block, Block>* block_sref_reuse)
+      : ReplaceBufferMutator(old_buffer, std::move(new_buffer), 
block_sref_reuse),
+        src_dtype_(old_buffer->dtype),
+        tgt_dtype_(dtype) {}
+
+  MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& 
match_buffer) final {
+    auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      Buffer new_target_buffer = WithDType(match_buffer->buffer, 
it->second->dtype);
+      buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
+      return MatchBufferRegion(new_target_buffer,
+                               BufferRegion(it->second, 
match_buffer->source->region));
+    } else {
+      return match_buffer;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    auto it = buffer_var_map_.find(node->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      node.CopyOnWrite()->buffer = it->second;
+      node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value);
+    }
+    return node;
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    BufferLoad node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    auto it = buffer_var_map_.find(node->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      return Cast(src_dtype_, BufferLoad(it->second, node->indices));
+    }
+    return node;
+  }
+
+  DataType src_dtype_, tgt_dtype_;
+};
+
+void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
+                    const String& dtype) {
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, 
BufferIndexType::kWrite);
+  DataType target_dtype(runtime::String2DLDataType(dtype));
+
+  // Step 1. If `dtype` equals the original data type, just return.
+  if (buffer->dtype == target_dtype) {
+    return;
+  }
+
+  // Step 2. Get the allocation site of the target buffer.
+  StmtSRef alloc_site_sref =
+      NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, 
block_sref, buffer);
+  const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);
+
+  // Step 3. Recursively replace old buffer to a new buffer, where the new 
buffer has the given
+  // dtype, and insert data type conversions.
+  Map<Block, Block> block_reuse_map;
+  Block new_block =
+      DTypeMutator::Mutate(GetRef<Block>(alloc_site), buffer, target_dtype, 
&block_reuse_map);
+  self->Replace(alloc_site_sref, new_block, block_reuse_map);
+}
+
 /******** InstructionKind Registration ********/
 
 struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
@@ -356,8 +445,36 @@ struct SetScopeTraits : public 
UnpackedInstTraits<SetScopeTraits> {
   friend struct ::tvm::tir::UnpackedInstTraits;
 };
 
+struct UnsafeSetDTypeTraits : public UnpackedInstTraits<UnsafeSetDTypeTraits> {
+  static constexpr const char* kName = "UnsafeSetDType";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 2;
+  static constexpr size_t kNumDecisions = 0;
+
+  static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer 
buffer_index,
+                                      String dtype) {
+    return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv, 
Integer buffer_index,
+                                 String dtype) {
+    PythonAPICall py("unsafe_set_dtype");
+    py.Input("block", block_rv);
+    py.Input("buffer_index", buffer_index);
+    py.Input("dtype", dtype);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
 TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
 TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits);
+TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index ad512a5fcb..a0e39b74d3 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -217,6 +217,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
     .set_body_method<Schedule>(&ScheduleNode::StorageAlign);
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
     .set_body_method<Schedule>(&ScheduleNode::SetScope);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
+    .set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
 /******** (FFI) Blockize & Tensorize ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
     .set_body_method<Schedule>(&ScheduleNode::Blockize);
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index 9b9420302c..2b6a7f71d4 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -504,6 +504,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, 
int buffer_index,
       /*outputs=*/{}));
 }
 
+void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int 
buffer_index,
+                                        const String& dtype) {
+  ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype);
+  static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{block_rv},
+      /*attrs=*/{Integer(buffer_index), dtype},
+      /*outputs=*/{}));
+}
+
 /******** Schedule: Blockize & Tensorize ********/
 
 BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool 
preserve_unit_iters) {
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index 7854adad39..8b9621c749 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -105,6 +105,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int 
factor,
                     int offset) final;
   void SetScope(const BlockRV& block_rv, int buffer_index, const String& 
storage_scope) final;
+  void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& 
dtype) final;
   /******** Schedule: Blockize & Tensorize ********/
   BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
   void Tensorize(const BlockRV& block_rv, const String& intrin, bool 
preserve_unit_iters) final;
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index e91c5d142c..baa7f44bbc 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) {
   return Buffer(new_buffer);
 }
 
+Buffer WithDType(const Buffer& buffer, const DataType& dtype) {
+  ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
+  new_buffer->dtype = dtype;
+  const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, 
PointerTypeNode);
+  new_buffer->data =
+      Var(buffer->data->name_hint, PointerType(PrimType(dtype), 
ptr_type->storage_scope));
+  new_buffer->name = buffer->name;
+  return Buffer(new_buffer);
+}
+
 Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& 
source,
                                   const Buffer& target) {
   regions.MutateByApply([&source, &target](BufferRegion region) -> 
BufferRegion {
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 3593d6b9a4..d2412436c7 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -53,6 +53,14 @@ Block WithAnnotation(const BlockNode* block, const String& 
attr_key, const Objec
  */
 Buffer WithScope(const Buffer& buffer, const String& scope);
 
+/*!
+ * \brief Create a new buffer by changint the data type.
+ * \param buffer The given buffer.
+ * \param scope The target data type.
+ * \return The new buffer with target data type.
+ */
+Buffer WithDType(const Buffer& buffer, const DataType& dtype);
+
 /*!
  * \brief Replaces the buffer within the specific sequence of regions
  * \param regions The regions whose buffers are to be replaced
@@ -131,9 +139,9 @@ class ReplaceBufferMutator : public StmtExprMutator {
     return node;
   }
 
-  Stmt VisitStmt_(const BufferStoreNode* op) final;
+  Stmt VisitStmt_(const BufferStoreNode* op) override;
 
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+  PrimExpr VisitExpr_(const BufferLoadNode* op) override;
 
   virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& 
match_buffer);
 
diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py 
b/tests/python/unittest/test_tir_schedule_set_dtype.py
new file mode 100644
index 0000000000..7f0900619b
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_set_dtype.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
[email protected]_func
+def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), 
"float32")) -> None:
+    B = T.alloc_buffer((128, 128), dtype="float32")
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi, vj] = B[vi, vj] + 1.0
+
[email protected]_func
+def element_wise_set_dtype(A: T.Buffer((128, 128), "float32"), C: 
T.Buffer((128, 128), "float32")):
+    B = T.alloc_buffer((128, 128), "float16")
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(A[vi, vj])
+            T.writes(B[vi, vj])
+            B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(B[vi, vj])
+            T.writes(C[vi, vj])
+            C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
+
[email protected]_func
+def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C: 
T.Buffer((128, 128), "float32")) -> None:
+    B = T.alloc_buffer((128, 128), dtype="float32")
+
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
+            B_subregion0[()] = A[vi, vj] * 2.0
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
+            C[vi, vj] = B_subregion1[()] + 1.0
+
+
[email protected]_func
+def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"), 
C: T.Buffer((128, 128), "float32")) -> None:
+    B = T.alloc_buffer((128, 128), "float16")
+    for i, j in T.grid(128, 128):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(A[vi, vj])
+            T.writes(B[vi, vj])
+            B_subregion0 = T.match_buffer(B[vi, vj], (), "float16", 
offset_factor=1)
+            B_subregion0[()] = T.cast(A[vi, vj] * 2.0, "float16")
+    for i, j in T.grid(128, 128):
+        with T.block("C"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(B[vi, vj])
+            T.writes(C[vi, vj])
+            B_subregion1 = T.match_buffer(B[vi, vj], (), "float16", 
offset_factor=1)
+            C[vi, vj] = T.cast(B_subregion1[()], "float32") + 1.0
+
+
+use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, 
"block_name": True})
+
+def test_set_dtype(use_block_name):
+    func = element_wise
+    sch = tir.Schedule(func, debug_mask="all")
+    sch.unsafe_set_dtype("B" if use_block_name else sch.get_block("B"), 0, 
"float16")
+    tvm.ir.assert_structural_equal(element_wise_set_dtype, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=func)
+
+def test_set_dtype_fail_on_output_buffer(use_block_name):
+    func = element_wise
+    sch = tir.Schedule(func, debug_mask='all')
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.unsafe_set_dtype('C' if use_block_name else sch.get_block("C"), 0, 
"float16")
+
+def test_set_dtype_fail_on_index_out_of_bound():
+    func = element_wise
+    sch = tir.Schedule(func, debug_mask='all')
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.unsafe_set_dtype(sch.get_block("B"), 1, "float64")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.unsafe_set_dtype(sch.get_block("B"), -1, "float64")
+
+def test_set_dtype_subregion():
+    func = element_wise_subregion_match
+    sch = tir.Schedule(func, debug_mask='all')
+    sch.unsafe_set_dtype(sch.get_block("B"), 0, "float16")
+    tvm.ir.assert_structural_equal(element_wise_subregion_match_set_dtype, 
sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=func)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to