This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c7970ddd79 [TensorIR] New schedule primitive `set_dtype` (#14316)
c7970ddd79 is described below
commit c7970ddd79a1e6bbf5e07ab2c515ec4991242ef7
Author: Zihao Ye <[email protected]>
AuthorDate: Wed Mar 22 00:10:54 2023 -0700
[TensorIR] New schedule primitive `set_dtype` (#14316)
# Motivation
Currently, we miss a schedule primitive to change the data type of
allocated buffer (e.g. via `cache_read`/`cache_write`), and thus we cannot
perform type conversion while loading data from global to shared memory.
This PR adds a new schedule primitive `set_dtype` that follows the
interface of `set_scope` and allows users to customize the allocated buffers'
data type.
# Example
Before running `set_dtype`:
```python
@T.prim_func
def before_set_dtype(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j]
C[vi, vj] = B[vi, vj] + 1.0
```
then we perform the `set_dtype` schedule:
```python
sch = tir.Schedule(before_set_dtype)
sch.set_dtype("B", buffer_index=0, dtype="float16")
print(sch.mod["main"].script())
```
we get transformed code:
```python
@T.prim_func
def after_set_dtype(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), dtype="float16")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j]
C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
```
where data type conversions are inserted automatically.
# Other Usage
Using the combination of `cache_read` + `set_dtype` can help us load data
from the memory hierarchy while converting data to the desired type.
---
include/tvm/tir/schedule/schedule.h | 12 +-
python/tvm/tir/schedule/schedule.py | 79 ++++++++++++-
src/tir/schedule/concrete_schedule.cc | 8 ++
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/primitive.h | 12 ++
src/tir/schedule/primitive/block_annotate.cc | 117 +++++++++++++++++++
src/tir/schedule/schedule.cc | 2 +
src/tir/schedule/traced_schedule.cc | 11 ++
src/tir/schedule/traced_schedule.h | 1 +
src/tir/schedule/transform.cc | 10 ++
src/tir/schedule/transform.h | 12 +-
.../python/unittest/test_tir_schedule_set_dtype.py | 125 +++++++++++++++++++++
12 files changed, 385 insertions(+), 5 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 01255e6e3f..570560c62d 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -589,13 +589,23 @@ class ScheduleNode : public runtime::Object {
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int
axis, int factor,
int offset) = 0;
/*!
- * \brief Set the storage scope of a buffer, where the buffer is specified
by the a block and a
+ * \brief Set the storage scope of a buffer, where the buffer is specified
by a block and a
* write-index
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const
String& storage_scope) = 0;
+ /*!
+ * \brief Set the data type of a buffer, where the buffer is specified by a
block and a
+ * write-index
+ * \note This schedule primitive is unsafe and may change correctness of
program because of
+ * type conversion, please use with caution.
+ * \param block_rv The producer block of the buffer
+ * \param buffer_index the index of the buffer in block's write region
+ * \param dtype The data type to be set
+ */
+ virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const
String& dtype) = 0;
/******** Schedule: Blockize & Tensorize ********/
/*!
* \brief Convert the subtree rooted at a specific loop into a block.
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index c27007682a..68f0b9454c 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2369,7 +2369,7 @@ class Schedule(Object):
self, block: Union[BlockRV, str], buffer_index: Union[int, str,
Buffer], storage_scope: str
) -> None:
"""Set the storage scope of a buffer, where the buffer is
- specified by the a block and a write-index
+ specified by the a block and a write-index.
Parameters
----------
@@ -2431,7 +2431,7 @@ class Schedule(Object):
Note
----
- Set_scope requires the buffer to be an intermediate buffer defined via
`alloc_buffer`.
+ `set_scope` requires the buffer to be an intermediate buffer defined
via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
if not isinstance(buffer_index, int):
@@ -2442,6 +2442,81 @@ class Schedule(Object):
self, block, buffer_index, storage_scope
)
+ @type_checked
+ def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int,
dtype: str) -> None:
+ """Set the data type of a buffer, where the buffer is
+ specified by the a block and write-index.
+
+ This schedule primitive is unsafe and may change the correctness of
program because of
+ type conversion, please use with caution.
+
+ Parameters
+ ----------
+ block : Union[BlockRV, str]
+ The producer block of the buffer
+ buffer_index : int
+ The index of the buffer in block's write region
+ dtype : str
+ The data type to be set
+
+ Examples
+ --------
+
+ Before set_dtype, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_set_dtype(
+ A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128),
"float32")
+ ) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j]
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ Create the schedule and do set_dtype:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_set_dtype)
+ sch.set_dtype("B", buffer_index=0, dtype="float16")
+ print(sch.mod["main"].script())
+
+ After applying set_dtype, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_set_dtype(
+ A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128),
"float32")
+ ) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float16")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j]
+ C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
+
+ Note
+ ----
+ `set_dtype` requires the buffer to be an intermediate buffer defined
via `alloc_buffer`.
+ """
+ block = self._normalize_block_arg(block)
+ _ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint:
disable=no-member
+ self, block, buffer_index, dtype
+ )
+
########## Schedule: Blockize & Tensorize ##########
@type_checked
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 6593579725..93ea38169d 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -727,6 +727,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV&
block_rv, int buffer_index,
this->state_->DebugVerify();
}
+void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int
buffer_index,
+ const String& dtype) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype);
+ TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_);
+ this->state_->DebugVerify();
+}
+
/******** Schedule: Reduction ********/
BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv,
const LoopRV& loop_rv) {
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 290b6a4456..227288b232 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -146,6 +146,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int
factor,
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String&
storage_scope) override;
+ void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String&
dtype) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool
preserve_unit_iters) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 89cdf68a45..09185498e1 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -479,6 +479,18 @@ TVM_DLL void StorageAlign(ScheduleState self, const
StmtSRef& block_sref, int bu
*/
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
const String& storage_scope);
+/*!
+ * \brief Set the data type of a buffer, where the buffer is specified by a
block and a
+ * write-index
+ * \note This schedule primitive is unsafe and may change correctness of
program because of
+ * type conversion, please use with caution.
+ * \param self The state of the schedule
+ * \param block_sref The sref of the producer block of the buffer
+ * \param buffer_index The index of the buffer in block's write region
+ * \param dtype The data type to be set
+ */
+TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref,
int buffer_index,
+ const String& dtype);
/*!
* \brief Set the axis separator of a buffer, where the buffer is specified by
a block and a read
* or write index
diff --git a/src/tir/schedule/primitive/block_annotate.cc
b/src/tir/schedule/primitive/block_annotate.cc
index 0912e36836..3f1789b3d6 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/tir/expr.h>
+
#include "../utils.h"
namespace tvm {
@@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef&
block_sref, int buffer_index,
self->Replace(alloc_site_sref, new_block, block_reuse_map);
}
+/*!
+ * \brief A helper mutator which recursively mutates the old buffer's data
type, inserts data type
+ * conversions, and collecte the block sref reuse information for the
following replacement.
+ */
+class DTypeMutator : private ReplaceBufferMutator {
+ public:
+ /*!
+ * \param allocate_site The block where `old_buffer` was allocated.
+ * \param old_buffer The old buffer
+ * \param target_dtype The data type to be set
+ * \param block_sref_reuse The block sref reuse map to be updated
+ * \return The new block after the mutation
+ */
+ static Block Mutate(const Block& allocate_site, const Buffer& old_buffer,
const DataType& dtype,
+ Map<Block, Block>* block_sref_reuse) {
+ Buffer new_buffer = WithDType(old_buffer, dtype);
+ DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse);
+ Stmt new_block = mutator.VisitStmt(allocate_site);
+ return Downcast<Block>(new_block);
+ }
+
+ private:
+ DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType&
dtype,
+ Map<Block, Block>* block_sref_reuse)
+ : ReplaceBufferMutator(old_buffer, std::move(new_buffer),
block_sref_reuse),
+ src_dtype_(old_buffer->dtype),
+ tgt_dtype_(dtype) {}
+
+ MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion&
match_buffer) final {
+ auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ Buffer new_target_buffer = WithDType(match_buffer->buffer,
it->second->dtype);
+ buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
+ return MatchBufferRegion(new_target_buffer,
+ BufferRegion(it->second,
match_buffer->source->region));
+ } else {
+ return match_buffer;
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ BufferStore node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ auto it = buffer_var_map_.find(node->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ node.CopyOnWrite()->buffer = it->second;
+ node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value);
+ }
+ return node;
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ BufferLoad node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+ auto it = buffer_var_map_.find(node->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ return Cast(src_dtype_, BufferLoad(it->second, node->indices));
+ }
+ return node;
+ }
+
+ DataType src_dtype_, tgt_dtype_;
+};
+
+void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
+ const String& dtype) {
+ const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+ Buffer buffer =
+ GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index,
BufferIndexType::kWrite);
+ DataType target_dtype(runtime::String2DLDataType(dtype));
+
+ // Step 1. If `dtype` equals the original data type, just return.
+ if (buffer->dtype == target_dtype) {
+ return;
+ }
+
+ // Step 2. Get the allocation site of the target buffer.
+ StmtSRef alloc_site_sref =
+ NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod,
block_sref, buffer);
+ const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);
+
+ // Step 3. Recursively replace old buffer to a new buffer, where the new
buffer has the given
+ // dtype, and insert data type conversions.
+ Map<Block, Block> block_reuse_map;
+ Block new_block =
+ DTypeMutator::Mutate(GetRef<Block>(alloc_site), buffer, target_dtype,
&block_reuse_map);
+ self->Replace(alloc_site_sref, new_block, block_reuse_map);
+}
+
/******** InstructionKind Registration ********/
struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
@@ -356,8 +445,36 @@ struct SetScopeTraits : public
UnpackedInstTraits<SetScopeTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct UnsafeSetDTypeTraits : public UnpackedInstTraits<UnsafeSetDTypeTraits> {
+ static constexpr const char* kName = "UnsafeSetDType";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 2;
+ static constexpr size_t kNumDecisions = 0;
+
+ static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer
buffer_index,
+ String dtype) {
+ return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv,
Integer buffer_index,
+ String dtype) {
+ PythonAPICall py("unsafe_set_dtype");
+ py.Input("block", block_rv);
+ py.Input("buffer_index", buffer_index);
+ py.Input("dtype", dtype);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits);
+TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index ad512a5fcb..a0e39b74d3 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -217,6 +217,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
.set_body_method<Schedule>(&ScheduleNode::StorageAlign);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
.set_body_method<Schedule>(&ScheduleNode::SetScope);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
+ .set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
/******** (FFI) Blockize & Tensorize ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
.set_body_method<Schedule>(&ScheduleNode::Blockize);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index 9b9420302c..2b6a7f71d4 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -504,6 +504,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv,
int buffer_index,
/*outputs=*/{}));
}
+void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int
buffer_index,
+ const String& dtype) {
+ ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype);
+ static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{Integer(buffer_index), dtype},
+ /*outputs=*/{}));
+}
+
/******** Schedule: Blockize & Tensorize ********/
BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool
preserve_unit_iters) {
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index 7854adad39..8b9621c749 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -105,6 +105,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int
factor,
int offset) final;
void SetScope(const BlockRV& block_rv, int buffer_index, const String&
storage_scope) final;
+ void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String&
dtype) final;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool
preserve_unit_iters) final;
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index e91c5d142c..baa7f44bbc 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) {
return Buffer(new_buffer);
}
+Buffer WithDType(const Buffer& buffer, const DataType& dtype) {
+ ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
+ new_buffer->dtype = dtype;
+ const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation,
PointerTypeNode);
+ new_buffer->data =
+ Var(buffer->data->name_hint, PointerType(PrimType(dtype),
ptr_type->storage_scope));
+ new_buffer->name = buffer->name;
+ return Buffer(new_buffer);
+}
+
Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer&
source,
const Buffer& target) {
regions.MutateByApply([&source, &target](BufferRegion region) ->
BufferRegion {
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 3593d6b9a4..d2412436c7 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -53,6 +53,14 @@ Block WithAnnotation(const BlockNode* block, const String&
attr_key, const Objec
*/
Buffer WithScope(const Buffer& buffer, const String& scope);
+/*!
+ * \brief Create a new buffer by changint the data type.
+ * \param buffer The given buffer.
+ * \param scope The target data type.
+ * \return The new buffer with target data type.
+ */
+Buffer WithDType(const Buffer& buffer, const DataType& dtype);
+
/*!
* \brief Replaces the buffer within the specific sequence of regions
* \param regions The regions whose buffers are to be replaced
@@ -131,9 +139,9 @@ class ReplaceBufferMutator : public StmtExprMutator {
return node;
}
- Stmt VisitStmt_(const BufferStoreNode* op) final;
+ Stmt VisitStmt_(const BufferStoreNode* op) override;
- PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+ PrimExpr VisitExpr_(const BufferLoadNode* op) override;
virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion&
match_buffer);
diff --git a/tests/python/unittest/test_tir_schedule_set_dtype.py
b/tests/python/unittest/test_tir_schedule_set_dtype.py
new file mode 100644
index 0000000000..7f0900619b
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_set_dtype.py
@@ -0,0 +1,125 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+
[email protected]_func
+def element_wise(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128),
"float32")) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 1.0
+
[email protected]_func
+def element_wise_set_dtype(A: T.Buffer((128, 128), "float32"), C:
T.Buffer((128, 128), "float32")):
+ B = T.alloc_buffer((128, 128), "float16")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
+
[email protected]_func
+def element_wise_subregion_match(A: T.Buffer((128, 128), "float32"), C:
T.Buffer((128, 128), "float32")) -> None:
+ B = T.alloc_buffer((128, 128), dtype="float32")
+
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B_subregion0 = T.match_buffer(B[vi, vj], [], offset_factor=1)
+ B_subregion0[()] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B_subregion1 = T.match_buffer(B[vi, vj], [], offset_factor=1)
+ C[vi, vj] = B_subregion1[()] + 1.0
+
+
[email protected]_func
+def element_wise_subregion_match_set_dtype(A: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")) -> None:
+ B = T.alloc_buffer((128, 128), "float16")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B_subregion0 = T.match_buffer(B[vi, vj], (), "float16",
offset_factor=1)
+ B_subregion0[()] = T.cast(A[vi, vj] * 2.0, "float16")
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ B_subregion1 = T.match_buffer(B[vi, vj], (), "float16",
offset_factor=1)
+ C[vi, vj] = T.cast(B_subregion1[()], "float32") + 1.0
+
+
+use_block_name = tvm.testing.parameter(by_dict={"block_obj": False,
"block_name": True})
+
+def test_set_dtype(use_block_name):
+ func = element_wise
+ sch = tir.Schedule(func, debug_mask="all")
+ sch.unsafe_set_dtype("B" if use_block_name else sch.get_block("B"), 0,
"float16")
+ tvm.ir.assert_structural_equal(element_wise_set_dtype, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=func)
+
+def test_set_dtype_fail_on_output_buffer(use_block_name):
+ func = element_wise
+ sch = tir.Schedule(func, debug_mask='all')
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.unsafe_set_dtype('C' if use_block_name else sch.get_block("C"), 0,
"float16")
+
+def test_set_dtype_fail_on_index_out_of_bound():
+ func = element_wise
+ sch = tir.Schedule(func, debug_mask='all')
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.unsafe_set_dtype(sch.get_block("B"), 1, "float64")
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.unsafe_set_dtype(sch.get_block("B"), -1, "float64")
+
+def test_set_dtype_subregion():
+ func = element_wise_subregion_match
+ sch = tir.Schedule(func, debug_mask='all')
+ sch.unsafe_set_dtype(sch.get_block("B"), 0, "float16")
+ tvm.ir.assert_structural_equal(element_wise_subregion_match_set_dtype,
sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=func)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()