This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a06863a [TensorIR][M2a] Storage Align (#8693)
a06863a is described below
commit a06863ac9406c027b29f346fb6177268f612912d
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 13 02:54:34 2021 -0400
[TensorIR][M2a] Storage Align (#8693)
This PR is part of the TensorIR upstreaming effort (#7527), which adds the
one
schedule primitive storage_align.
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
---
include/tvm/tir/schedule/schedule.h | 15 +
python/tvm/tir/schedule/schedule.py | 73 +++++
src/tir/schedule/analysis.h | 13 +
src/tir/schedule/analysis/analysis.cc | 39 +++
src/tir/schedule/concrete_schedule.cc | 10 +
src/tir/schedule/concrete_schedule.h | 3 +
src/tir/schedule/primitive.h | 20 ++
src/tir/schedule/primitive/block_annotate.cc | 308 +++++++++++++++++++++
src/tir/schedule/schedule.cc | 3 +
src/tir/schedule/traced_schedule.cc | 13 +
src/tir/schedule/traced_schedule.h | 3 +
src/tir/schedule/transform.cc | 35 +++
src/tir/schedule/transform.h | 41 +++
src/tir/transforms/compact_buffer_region.cc | 78 +++++-
.../unittest/test_tir_schedule_storage_align.py | 182 ++++++++++++
.../test_tir_transform_compact_buffer_region.py | 49 ++++
16 files changed, 882 insertions(+), 3 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index e208377..e5d2c44 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -264,6 +264,21 @@ class ScheduleNode : public runtime::Object {
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
+ /******** Schedule: Block annotation ********/
+ /*!
+ * \brief Set alignment requirement for specific dimension such that
+ * stride[axis] == k * factor + offset for some k. This is useful to
set memory layout for
+ * more friendly memory access pattern. For example, we can set
alignment to be factor=2,
+ * offset=1 to avoid bank conflict for thread access on higher
dimension in GPU shared
+ * memory.
+ * \param block_rv The producer block of the buffer
+ * \param buffer_index The index of the buffer in block's write region
+ * \param axis The dimension to be specified for alignment
+ * \param factor The factor multiple of alignment
+ * \param offset The required offset factor
+ */
+ virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int
axis, int factor,
+ int offset) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 4bbb5b9..e8415d2 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -710,6 +710,79 @@ class Schedule(Object):
"""
return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type:
ignore # pylint: disable=no-member
+ ######## Schedule: Block annotatoin ########
+
+ def storage_align( # pylint: disable=too-many-arguments
+ self, block: BlockRV, buffer_index: int, axis: int, factor: int,
offset: int
+ ) -> None:
+ """Set alignment requirement for specific dimension such that
+ stride[axis] == k * factor + offset for some k. This is useful to set
memory layout for more
+ friendly memory access pattern. For example, we can set alignment to
be factor=2, offset=1
+ to avoid bank conflict for thread access on higher dimension in GPU
shared memory.
+
+ Parameters
+ ----------
+ block : BlockRV
+ The producer block of the buffer.
+ buffer_index : int
+ The index of the buffer in block's write region.
+ axis : int
+ The dimension to be specified for alignment.
+ factor : int
+ The factor multiple of alignment.
+ offset : int
+ The required offset factor.
+
+ Examples
+ --------
+
+ Before storage_align, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def before_storage_align(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ Create the schedule and do storage_align:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_storage_align)
+ sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0,
factor=128, offset=1)
+ print(tvm.script.asscript(sch.mod["main"]))
+
+ After applying rfactor, the IR becomes:
+
+ .. code-block:: python
+
+ @tvm.script.tir
+ def after_storage_align(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128))
+ B = tir.alloc_buffer((128, 128))
+ C = tir.match_buffer(c, (128, 128))
+ with tir.block([128, 128], "B") as [vi, vj]:
+ tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
+ B[vi, vj] = A[vi, vj] * 2.0
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ After lowering passes, buffer B will have strides as [129, 1].
+
+ Note
+ ----
+ Storage_align requires the buffer to be an intermediate buffer defined
via `alloc_buffer`.
+ """
+ _ffi_api.ScheduleStorageAlign( # type: ignore # pylint:
disable=no-member
+ self, block, buffer_index, axis, factor, offset
+ )
+
########## Schedule: Blockize & Tensorize ##########
########## Schedule: Annotation ##########
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 9baf4b5..370aa01 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -202,6 +202,19 @@ BlockRealize
CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef&
block_sref);
+/******** Block-buffer relation ********/
+
+/*!
+ * \brief Get the BlockRealize of the single child block of the block or loop
specified by
+ * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple
child blocks
+ * \param self The schedule state
+ * \param block The queried block
+ * \param n The index of the queried buffer
+ * \return The buffer of the n-th write region of the block.
+ * \throw ScheduleError If the buffer index is out of bound.
+ */
+Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
+
/******** Commutative Reducer ********/
/*!
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 3ee98ec..8d1913f 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -527,6 +527,45 @@ BlockRealize GetBlockRealize(const ScheduleState& self,
const StmtSRef& block_sr
}
}
+/******** Block-buffer relation ********/
+
+Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n)
{
+ class WriteBufferIndexOutOfRangeError : public ScheduleError {
+ public:
+ explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int
buffer_index)
+ : mod_(std::move(mod)), block_(std::move(block)),
buffer_index_(buffer_index) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The input `buffer_index` is out of range. It is
required to be in "
+ "range [0, num_write_regions) where `num_write_regions` is the
number of buffer "
+ "regions written by the block.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ size_t num_writes = block_->writes.size();
+ os << "The block {0} has " << num_writes
+ << " write regions, so `buffer_index` is required to be in [0, " <<
num_writes
+ << "). However, the input `buffer_index` is " << buffer_index_
+ << ", which is out of the expected range";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+ IRModule mod_;
+ Block block_;
+ int buffer_index_;
+ };
+
+ if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
+ throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
+ }
+ return block->writes[n]->buffer;
+}
+
/******** Pattern Matcher ********/
/*!
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 610628c..688ea80 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const
BlockRV& block_rv) {
}
/******** Schedule: loop binding/annotation ********/
+/******** Schedule: block annotation ********/
+
+void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int
buffer_index, int axis,
+ int factor, int offset) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis,
factor, offset);
+ TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_);
+ this->state_->DebugVerify();
+}
+
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index ec0dd07..cfdd9c8 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
+ /******** Schedule: Block annotation ********/
+ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int
factor,
+ int offset) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 22e25f1..4b9c769 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -104,6 +104,26 @@ TVM_DLL void ReverseComputeInline(ScheduleState self,
const StmtSRef& block_sref
* \return The sref of the rfactor block
*/
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int
factor_axis);
+/******** Schedule: Block annotation ********/
+/*!
+ * \brief Set alignment requirement for specific dimension such that
+ * stride[axis] == k * factor + offset for some k. This is useful to
set memory layout for
+ * more friendly memory access pattern. For example, we can set
alignment to be factor=2,
+ * offset=1 to avoid bank conflict for thread access on higher
dimension in GPU shared
+ * memory.
+ * \param block_sref The producer block of the buffer
+ * \param buffer_index The index of the buffer in block's write region
+ * \param axis The dimension to be specified for alignment
+ * \param factor The factor multiple of alignment
+ * \param offset The required offset factor
+ */
+TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
+ int axis, int factor, int offset);
+
+/******** Annotation types for StorageAlign ********/
+using StorageAlignTuple = Array<Integer>; // (buffer_idx,
axis, factor, offset)
+using StorageAlignAnnotation = Array<StorageAlignTuple>; // unordered array
of StorageAlignTuple
+
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
diff --git a/src/tir/schedule/primitive/block_annotate.cc
b/src/tir/schedule/primitive/block_annotate.cc
new file mode 100644
index 0000000..937bc7c
--- /dev/null
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../transform.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+class StorageAlignAxisOutOfRangeError : public ScheduleError {
+ public:
+ explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int
axis)
+ : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The input `axis` is out of range. It is required to
be in range "
+ "[-ndim, ndim) where `ndim` is the number of dimensions of the
buffer to set "
+ "storage alignment.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ int ndim = static_cast<int>(buffer_->shape.size());
+ os << "The buffer to set storage alignment of, " << buffer_->name << ",
has " << ndim
+ << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", "
<< ndim
+ << ") for storage_align. However, the input `axis` is " << axis_
+ << ", which is out of the expected range.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int
axis) {
+ int ndim = static_cast<int>(buffer->shape.size());
+ if (axis < -ndim || axis >= ndim) {
+ throw StorageAlignAxisOutOfRangeError(mod, buffer, axis);
+ }
+ // If axis is negative, convert it to a non-negative one.
+ if (axis < 0) {
+ axis += ndim;
+ }
+ return axis;
+ }
+
+ private:
+ IRModule mod_;
+ Buffer buffer_;
+ int axis_;
+};
+
+/*!
+ * \brief Find the defining site of the buffer in the given block and its
ancestors
+ * \param block_sref The block sref
+ * \param buffer The buffer
+ * \return The defining site of the buffer and whether the buffer is allocated
(otherwise the
+ * buffer is from match_buffer).
+ */
+std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef&
block_sref,
+ const Buffer&
buffer) {
+ // Climb up along the sref tree, and find the block where `buffer` is in
alloc_buffers or
+ // match_buffers.
+ const StmtSRefNode* defining_site_sref = block_sref.get();
+ while (defining_site_sref != nullptr) {
+ const auto* block = defining_site_sref->StmtAs<BlockNode>();
+ // If this sref is not a block sref, skip it.
+ if (block == nullptr) {
+ defining_site_sref = defining_site_sref->parent;
+ continue;
+ }
+ // Try to find the buffer in `allloc_buffers`
+ for (const Buffer& alloc_buffer : block->alloc_buffers) {
+ if (buffer.same_as(alloc_buffer)) {
+ return {GetRef<StmtSRef>(defining_site_sref), true};
+ }
+ }
+ // We do not allow the buffer being defined in `match_buffer`.
+ for (const MatchBufferRegion match_buffer : block->match_buffers) {
+ if (buffer.same_as(match_buffer)) {
+ return {GetRef<StmtSRef>(defining_site_sref), false};
+ }
+ }
+ defining_site_sref = defining_site_sref->parent;
+ }
+ // If we cannot find the defining site block, it means that the buffer must
be in the function's
+ // buffer_map, which isn't an intermediate buffer.
+ return {NullOpt, false};
+}
+
+class NonAllocatedBufferError : public ScheduleError {
+ public:
+ explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod),
buffer_(buffer) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The input buffer is not allocated by a block. This
means the buffer is "
+ " either a function parameter or defined in `match_buffer` of a
block.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The input buffer " << buffer_->name
+ << " is not allocated by a block. This means the buffer is either a
function parameter or "
+ "defined in `match_buffer` of a block.";
+ return os.str();
+ }
+
+ static void CheckBufferAllocated(const IRModule& mod, const StmtSRef&
block_sref,
+ const Buffer& buffer) {
+ Optional<StmtSRef> defining_site_sref;
+ bool is_alloc;
+ std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref,
buffer);
+ if (!defining_site_sref || !is_alloc) {
+ throw NonAllocatedBufferError(mod, buffer);
+ }
+ }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+ IRModule mod() const final { return mod_; }
+
+ private:
+ IRModule mod_;
+ Buffer buffer_;
+};
+
+class StorageAlignInvalidFactorError : public ScheduleError {
+ public:
+ explicit StorageAlignInvalidFactorError(IRModule mod, int factor)
+ : mod_(std::move(mod)), factor_(factor) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The input `factor` of storage_align is expected to
be a positive "
+ "number.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The input `factor` of storage_align is expected to be a positive
number. However, the "
+ "input `factor` is "
+ << factor_ << ", which is out of the expected range.";
+ return os.str();
+ }
+
+ static void Check(const IRModule& mod, int factor) {
+ if (factor <= 0) {
+ throw StorageAlignInvalidFactorError(mod, factor);
+ }
+ }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+ IRModule mod() const final { return mod_; }
+
+ private:
+ IRModule mod_;
+ int factor_;
+};
+
+class StorageAlignInvalidAnnotationError : public ScheduleError {
+ public:
+ explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block)
+ : mod_(std::move(mod)), block_(std::move(block)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The block annotation for storage align is expected
to be an array of "
+ "4-integer-tuples (buffer_index, axis, factor, offset).";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The block annotation for storage align is expected to be an array
of 4-integer-tuples "
+ "(buffer_index, axis, factor, offset). However, the block annotation
with key "
+ << attr::buffer_dim_align << " of the block {0} is "
+ << block_->annotations.at(attr::buffer_dim_align) << ", which is
unexpected.";
+ return os.str();
+ }
+
+ static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod,
const Block& block) {
+ // Get existing annotation value.
+ auto it = block->annotations.find(attr::buffer_dim_align);
+ if (it != block->annotations.end()) {
+ if (!IsValidAnnotation(block, (*it).second)) {
+ throw StorageAlignInvalidAnnotationError(mod, block);
+ }
+ return Downcast<StorageAlignAnnotation>((*it).second);
+ }
+
+ // Create new annotation value
+ StorageAlignAnnotation storage_align_annotation;
+ return storage_align_annotation;
+ }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+ IRModule mod() const final { return mod_; }
+
+ private:
+ static bool IsValidAnnotation(const Block& block, const ObjectRef&
anno_value) {
+ if (!anno_value->IsInstance<ArrayNode>()) {
+ return false;
+ }
+ auto storage_align_annotations = Downcast<Array<ObjectRef>>(anno_value);
+ for (const ObjectRef& storage_align_annotation :
storage_align_annotations) {
+ if (!storage_align_annotation->IsInstance<ArrayNode>()) {
+ return false;
+ }
+ auto storage_align_tuple =
Downcast<Array<ObjectRef>>(storage_align_annotation);
+ // Check if the annotation is a 4-tuple.
+ if (storage_align_tuple.size() != 4) {
+ return false;
+ }
+ for (const ObjectRef& tuple_element : storage_align_tuple) {
+ if (!tuple_element->IsInstance<IntImmNode>()) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ IRModule mod_;
+ Block block_;
+};
+
+void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int
buffer_index, int axis,
+ int factor, int offset) {
+ const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+ Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr),
buffer_index);
+ StorageAlignInvalidFactorError::Check(self->mod, factor);
+ axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer,
axis);
+ NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);
+
+ // Step 1: Get existing or create new annotation value.
+ StorageAlignAnnotation storage_align_annotation =
+ StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod,
+
GetRef<Block>(block_ptr));
+
+ // Step 2: Update the annotation value
+ // Array<Array<Integer>> buffer_storage_align =
storage_align_annotation[buffer_index];
+ bool found = false;
+ StorageAlignTuple new_storage_align_tuple{Integer(buffer_index),
Integer(axis), Integer(factor),
+ Integer(offset)};
+ for (size_t j = 0; j < storage_align_annotation.size(); ++j) {
+ const auto& storage_align_tuple = storage_align_annotation[j];
+ ICHECK(storage_align_tuple.size() == 4);
+ if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] ==
axis) {
+ storage_align_annotation.Set(j, std::move(new_storage_align_tuple));
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ storage_align_annotation.push_back(std::move(new_storage_align_tuple));
+ }
+
+ // Step 3: Replace the block with the new annotation
+ Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align,
storage_align_annotation);
+ self->Replace(block_sref, new_block, {{GetRef<Block>(block_ptr),
new_block}});
+}
+
+/******** Instruction Registration ********/
+
+struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
+ static constexpr const char* kName = "StorageAlign";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 4;
+ static constexpr size_t kNumDecisions = 0;
+
+ static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer
buffer_index,
+ Integer axis, Integer factor, Integer
offset) {
+ return sch->StorageAlign(block_rv, buffer_index->value, axis->value,
factor->value,
+ offset->value);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv,
Integer buffer_index,
+ Integer axis, Integer factor, Integer offset)
{
+ PythonAPICall py("storage_align");
+ py.Input("block", block_rv);
+ py.Input("buffer_index", buffer_index);
+ py.Input("axis", axis);
+ py.Input("factor", factor);
+ py.Input("offset", offset);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 3232a33..d6dc0b4 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -135,6 +135,9 @@
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline")
/******** (FFI) Reduction ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor")
.set_body_method<Schedule>(&ScheduleNode::RFactor);
+/******** (FFI) Block annotation ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
+ .set_body_method<Schedule>(&ScheduleNode::StorageAlign);
/******** (FFI) Blockize & Tensorize ********/
/******** (FFI) Annotation ********/
/******** (FFI) Misc ********/
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index d664d7f..e0ffdc7 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -137,6 +137,19 @@ BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv,
int factor_axis) {
return result;
}
+/******** Schedule: Block annotation ********/
+
+void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int
buffer_index, int axis,
+ int factor, int offset) {
+ ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor,
offset);
+ static const InstructionKind& kind = InstructionKind::Get("StorageAlign");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor),
Integer(offset)},
+ /*outputs=*/{}));
+}
+
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index b4518cb..4650c44 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -61,6 +61,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void ReverseComputeInline(const BlockRV& block_rv) final;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final;
+ /******** Schedule: Block annotation ********/
+ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int
factor,
+ int offset) final;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
new file mode 100644
index 0000000..f27e0f6
--- /dev/null
+++ b/src/tir/schedule/transform.cc
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./transform.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Annotation ********/
+Block WithAnnotation(const BlockNode* block, const String& attr_key, const
ObjectRef& attr_value) {
+ Map<String, ObjectRef> annotations = block->annotations;
+ annotations.Set(attr_key, attr_value);
+ ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
+ new_block->annotations = std::move(annotations);
+ return Block(new_block);
+}
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
new file mode 100644
index 0000000..5348382
--- /dev/null
+++ b/src/tir/schedule/transform.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
+#define TVM_TIR_SCHEDULE_TRANSFORM_H_
+
+#include <tvm/tir/schedule/state.h>
+
+namespace tvm {
+namespace tir {
+
+/******** Annotation ********/
+
+/*!
+ * \brief Create a new block with the given annotation added
+ * \param block The block with original annotation
+ * \param attr_key The annotation key to be added
+ * \param attr_value The annotation value to be added
+ * \return A new block with the given annotation as its last annotation
+ */
+Block WithAnnotation(const BlockNode* block, const String& attr_key, const
ObjectRef& attr_value);
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_
diff --git a/src/tir/transforms/compact_buffer_region.cc
b/src/tir/transforms/compact_buffer_region.cc
index b1a4fd4..961ea17 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -303,18 +303,61 @@ class BufferAccessRegionCollector : public
StmtExprVisitor {
support::Arena arena_;
};
+/*! \brief Collect storage alignment information from block annotations. */
+class StorageAlignCollector : public StmtVisitor {
+ public:
+ static std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual> Collect(
+ const PrimFunc& f) {
+ StorageAlignCollector collector;
+ collector(f->body);
+ return std::move(collector.storage_align_);
+ }
+
+ private:
+ void VisitStmt_(const BlockNode* op) final {
+ auto it = op->annotations.find(attr::buffer_dim_align);
+ if (it != op->annotations.end()) {
+ auto storage_align_annotation =
Downcast<StorageAlignAnnotation>((*it).second);
+ for (const auto& storage_align_tuple : storage_align_annotation) {
+ int buffer_index = storage_align_tuple[0]->value;
+ const Buffer& buffer = op->writes[buffer_index]->buffer;
+ storage_align_[buffer].push_back(storage_align_tuple);
+ }
+ }
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ /*! \brief The map from Buffer to its storage alignment information. */
+ std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual> storage_align_;
+};
+
/*! \brief Reallocate the buffers with minimal region. */
class BufferCompactor : public StmtExprMutator {
public:
static Stmt Compact(
const PrimFunc& f,
- const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>&
regions) {
+ const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>&
regions,
+ const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual>&
+ storage_align) {
std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_info;
for (const auto& kv : regions) {
const Buffer& buffer = kv.first;
Region region = kv.second;
- buffer_info.emplace(buffer, BufferAllocInfo(std::move(region)));
+ BufferAllocInfo buffer_alloc_info(std::move(region));
+ auto it = storage_align.find(buffer);
+ if (it != storage_align.end()) {
+ std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
+ for (const StorageAlignTuple& dim_align : (*it).second) {
+ ICHECK(dim_align.size() == 4);
+ int dim = dim_align[1]->value;
+ int factor = dim_align[2]->value;
+ int offset = dim_align[3]->value;
+ dim_aligns.at(dim) = {factor, offset};
+ }
+ buffer_alloc_info.dim_aligns = std::move(dim_aligns);
+ }
+ buffer_info.emplace(buffer, std::move(buffer_alloc_info));
}
BufferCompactor compactor(std::move(buffer_info));
Stmt stmt = compactor(f->body);
@@ -322,9 +365,19 @@ class BufferCompactor : public StmtExprMutator {
}
private:
+ /*! \brief The storage alignment for a dimension */
+ struct DimAlignInfo {
+ /*! \brief The factor of the alignment */
+ int align_factor{0};
+ /*! \brief The offset of the alignment */
+ int align_offset{0};
+ };
+
struct BufferAllocInfo {
/*! \brief The buffer access region. */
Region region;
+ /*! \brief The storage alignment information. */
+ std::vector<DimAlignInfo> dim_aligns;
/*!
* \brief The reallocated buffer with minimal size.
* \note The value if NullOpt if the buffer do not need reallocate (e.g
parameter buffer).
@@ -380,8 +433,25 @@ class BufferCompactor : public StmtExprMutator {
for (const Range& range : info.region) {
shape.push_back(range->extent);
}
+ Array<PrimExpr> strides;
+ if (info.dim_aligns.size()) {
+ ICHECK(info.dim_aligns.size() == shape.size());
+ strides.resize(shape.size());
+ PrimExpr stride = make_const(shape[0].dtype(), 1);
+ for (size_t i = shape.size(); i != 0; --i) {
+ size_t dim = i - 1;
+ if (info.dim_aligns[dim].align_factor != 0) {
+ PrimExpr factor = make_const(stride.dtype(),
info.dim_aligns[dim].align_factor);
+ PrimExpr offset = make_const(stride.dtype(),
info.dim_aligns[dim].align_offset);
+ stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
+ }
+ strides.Set(dim, stride);
+ stride = stride * shape[dim];
+ }
+ }
ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
n->shape = std::move(shape);
+ n->strides = std::move(strides);
info.new_buffer = Buffer(std::move(n));
result.push_back(info.new_buffer);
}
@@ -458,7 +528,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) {
PrimFuncNode* fptr = f.CopyOnWrite();
std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
BufferAccessRegionCollector::Collect(f);
- fptr->body = BufferCompactor::Compact(f, region);
+ std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash,
ObjectPtrEqual>
+ storage_align = StorageAlignCollector::Collect(f);
+ fptr->body = BufferCompactor::Compact(f, region, storage_align);
return f;
} else {
return f;
diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py
b/tests/python/unittest/test_tir_schedule_storage_align.py
new file mode 100644
index 0000000..a0a0693
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_storage_align.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name
+
[email protected]
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+ C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ # body
+ with tir.block([], "root"):
+ tir.reads([])
+ tir.writes([])
+ B = tir.alloc_buffer([128, 128], elem_offset=0, align=128,
offset_factor=1)
+ for i0 in tir.serial(0, 128):
+ for ax1 in tir.serial(0, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ tir.bind(vi, i0)
+ tir.bind(vj, ax1)
+ tir.reads([A[vi, vj]])
+ tir.writes([B[vi, vj]])
+ B[vi, vj] = (A[vi, vj]*tir.float32(2))
+ for i1 in tir.serial(0, 128):
+ with tir.block([128, 128], "C") as [vi_1, vj_1]:
+ tir.bind(vi_1, i0)
+ tir.bind(vj_1, i1)
+ tir.reads([B[vi_1, vj_1]])
+ tir.writes([C[vi_1, vj_1]])
+ C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
+
+
[email protected]
+def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None:
+ C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ # body
+ with tir.block([], "root"):
+ tir.reads([])
+ tir.writes([])
+ B = tir.alloc_buffer([128, 128], elem_offset=0, align=128,
offset_factor=1)
+ for i0 in tir.serial(0, 128):
+ for ax1 in tir.serial(0, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ tir.bind(vi, i0)
+ tir.bind(vj, ax1)
+ tir.reads([A[vi, vj]])
+ tir.writes([B[vi, vj]])
+ tir.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]})
+ B[vi, vj] = (A[vi, vj]*tir.float32(2))
+ for i1 in tir.serial(0, 128):
+ with tir.block([128, 128], "C") as [vi_1, vj_1]:
+ tir.bind(vi_1, i0)
+ tir.bind(vj_1, i1)
+ tir.reads([B[vi_1, vj_1]])
+ tir.writes([C[vi_1, vj_1]])
+ C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
+
+
[email protected]
+def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None:
+ C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ # body
+ with tir.block([], "root"):
+ tir.reads([])
+ tir.writes([])
+ B = tir.alloc_buffer([128, 128], elem_offset=0, align=128,
offset_factor=1)
+ for i0 in tir.serial(0, 128):
+ for ax1 in tir.serial(0, 128):
+ with tir.block([128, 128], "B") as [vi, vj]:
+ tir.block_attr({"buffer_dim_align": [0]})
+ tir.bind(vi, i0)
+ tir.bind(vj, ax1)
+ tir.reads([A[vi, vj]])
+ tir.writes([B[vi, vj]])
+ B[vi, vj] = (A[vi, vj]*tir.float32(2))
+ for i1 in tir.serial(0, 128):
+ with tir.block([128, 128], "C") as [vi_1, vj_1]:
+ tir.bind(vi_1, i0)
+ tir.bind(vj_1, i1)
+ tir.reads([B[vi_1, vj_1]])
+ tir.writes([C[vi_1, vj_1]])
+ C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
+
+
+def test_storage_align():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ s.storage_align(B, 0, axis=0, factor=128, offset=127)
+ tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_storage_align_update():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ s.storage_align(B, 0, axis=0, factor=128, offset=0)
+ s.storage_align(B, 0, axis=0, factor=128, offset=127)
+ tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_storage_align_invalid_factor1():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ with pytest.raises(tir.ScheduleError):
+ s.storage_align(B, 0, axis=0, factor=0, offset=127)
+
+
+def test_storage_align_invalid_factor2():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ with pytest.raises(tir.ScheduleError):
+ s.storage_align(B, 0, axis=0, factor=-1, offset=127)
+
+
+def test_storage_align_invalid_buffer():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ C = s.get_block("C")
+ with pytest.raises(tir.ScheduleError):
+ s.storage_align(C, 0, axis=0, factor=128, offset=127)
+
+
+def test_storage_align_invalid_buffer_index():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ with pytest.raises(tir.ScheduleError):
+ s.storage_align(B, 2, axis=0, factor=128, offset=127)
+
+
+def test_storage_align_invalid_axis():
+ func = element_wise
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ with pytest.raises(tir.ScheduleError):
+ s.storage_align(B, 0, axis=2, factor=128, offset=127)
+
+
+def test_storage_align_invalid_annotation():
+ func = element_wise_invalid_annotation
+ s = tir.Schedule(func, debug_mask='all')
+ B = s.get_block("B")
+ with pytest.raises(tir.ScheduleError):
+ s.storage_align(B, 0, axis=2, factor=128, offset=127)
+
+
+if __name__ == "__main__":
+ test_storage_align()
+ test_storage_align_update()
+ test_storage_align_invalid_factor1()
+ test_storage_align_invalid_factor2()
+ test_storage_align_invalid_buffer()
+ test_storage_align_invalid_buffer_index()
+ test_storage_align_invalid_axis()
+ test_storage_align_invalid_annotation()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index fb53b42..15da022 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -339,6 +339,50 @@ def compacted_match_buffer_func(a: ty.handle, c:
ty.handle) -> None:
C1[()] = B2[()] * 2.0
[email protected]
+def storage_align_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer((16, 16), "float32")
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(A[i, j])
+ tir.writes(B[i, j])
+ tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+ B[i, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with tir.block([]) as []:
+ tir.reads(B[i, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[i, j] * 2.0
+
+
[email protected]
+def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+ for i in range(0, 16):
+ with tir.block([]):
+ tir.reads(A[i, 0:16])
+ tir.writes(C[i, 0:16])
+ B = tir.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32")
+ for j in range(0, 16):
+ with tir.block() as []:
+ tir.reads(A[i, j])
+ tir.writes(B[0, j])
+ tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+ B[0, j] = A[i, j] + 1.0
+ for j in range(0, 16):
+ with tir.block() as []:
+ tir.reads(B[0, j])
+ tir.writes(C[i, j])
+ C[i, j] = B[0, j] * 2.0
+
+
def test_elementwise():
_check(elementwise_func, compacted_elementwise_func)
@@ -380,6 +424,10 @@ def test_lower_te():
tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation
should do nothing on TE
+def test_storage_align():
+ _check(storage_align_func, compacted_storage_align_func)
+
+
if __name__ == "__main__":
test_elementwise()
test_unschedulable_block()
@@ -389,3 +437,4 @@ if __name__ == "__main__":
test_symbolic()
test_complex()
test_match_buffer()
+ test_storage_align()