This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a06863a  [TensorIR][M2a] Storage Align (#8693)
a06863a is described below

commit a06863ac9406c027b29f346fb6177268f612912d
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Aug 13 02:54:34 2021 -0400

    [TensorIR][M2a] Storage Align (#8693)
    
    This PR is part of the TensorIR upstreaming effort (#7527), which adds the 
one
    schedule primitive storage_align.
    
    Co-authored-by: Siyuan Feng <[email protected]>
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Ruihang Lai <[email protected]>
    Co-authored-by: Hongyi Jin <[email protected]>
    Co-authored-by: Junru Shao <[email protected]>
---
 include/tvm/tir/schedule/schedule.h                |  15 +
 python/tvm/tir/schedule/schedule.py                |  73 +++++
 src/tir/schedule/analysis.h                        |  13 +
 src/tir/schedule/analysis/analysis.cc              |  39 +++
 src/tir/schedule/concrete_schedule.cc              |  10 +
 src/tir/schedule/concrete_schedule.h               |   3 +
 src/tir/schedule/primitive.h                       |  20 ++
 src/tir/schedule/primitive/block_annotate.cc       | 308 +++++++++++++++++++++
 src/tir/schedule/schedule.cc                       |   3 +
 src/tir/schedule/traced_schedule.cc                |  13 +
 src/tir/schedule/traced_schedule.h                 |   3 +
 src/tir/schedule/transform.cc                      |  35 +++
 src/tir/schedule/transform.h                       |  41 +++
 src/tir/transforms/compact_buffer_region.cc        |  78 +++++-
 .../unittest/test_tir_schedule_storage_align.py    | 182 ++++++++++++
 .../test_tir_transform_compact_buffer_region.py    |  49 ++++
 16 files changed, 882 insertions(+), 3 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index e208377..e5d2c44 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -264,6 +264,21 @@ class ScheduleNode : public runtime::Object {
    * \return The rfactor block
    */
   virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
+  /******** Schedule: Block annotation ********/
+  /*!
+   * \brief Set alignment requirement for specific dimension such that
+   *        stride[axis] == k * factor + offset for some k. This is useful to 
set memory layout for
+   *        more friendly memory access pattern. For example, we can set 
alignment to be factor=2,
+   *        offset=1 to avoid bank conflict for thread access on higher 
dimension in GPU shared
+   *        memory.
+   * \param block_rv The producer block of the buffer
+   * \param buffer_index The index of the buffer in block's write region
+   * \param axis The dimension to be specified for alignment
+   * \param factor The factor multiple of alignment
+   * \param offset The required offset factor
+   */
+  virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int 
axis, int factor,
+                            int offset) = 0;
   /******** Schedule: Blockize & Tensorize ********/
   /******** Schedule: Annotation ********/
   /******** Schedule: Misc ********/
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 4bbb5b9..e8415d2 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -710,6 +710,79 @@ class Schedule(Object):
         """
         return _ffi_api.ScheduleRFactor(self, loop, factor_axis)  # type: 
ignore # pylint: disable=no-member
 
+    ######## Schedule: Block annotatoin ########
+
+    def storage_align(  # pylint: disable=too-many-arguments
+        self, block: BlockRV, buffer_index: int, axis: int, factor: int, 
offset: int
+    ) -> None:
+        """Set alignment requirement for specific dimension such that
+        stride[axis] == k * factor + offset for some k. This is useful to set 
memory layout for more
+        friendly memory access pattern. For example, we can set alignment to 
be factor=2, offset=1
+        to avoid bank conflict for thread access on higher dimension in GPU 
shared memory.
+
+        Parameters
+        ----------
+        block : BlockRV
+            The producer block of the buffer.
+        buffer_index : int
+            The index of the buffer in block's write region.
+        axis : int
+            The dimension to be specified for alignment.
+        factor : int
+            The factor multiple of alignment.
+        offset : int
+            The required offset factor.
+
+        Examples
+        --------
+
+        Before storage_align, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_storage_align(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.alloc_buffer((128, 128))
+                C = tir.match_buffer(c, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+        Create the schedule and do storage_align:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_storage_align)
+            sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, 
factor=128, offset=1)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying rfactor, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_storage_align(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.alloc_buffer((128, 128))
+                C = tir.match_buffer(c, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
+                    B[vi, vj] = A[vi, vj] * 2.0
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+        After lowering passes, buffer B will have strides as [129, 1].
+
+        Note
+        ----
+        Storage_align requires the buffer to be an intermediate buffer defined 
via `alloc_buffer`.
+        """
+        _ffi_api.ScheduleStorageAlign(  # type: ignore # pylint: 
disable=no-member
+            self, block, buffer_index, axis, factor, offset
+        )
+
     ########## Schedule: Blockize & Tensorize ##########
 
     ########## Schedule: Annotation ##########
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 9baf4b5..370aa01 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -202,6 +202,19 @@ BlockRealize 
CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
  */
 BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& 
block_sref);
 
+/******** Block-buffer relation ********/
+
+/*!
+ * \brief Get the BlockRealize of the single child block of the block or loop 
specified by
+ * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple 
child blocks
+ * \param self The schedule state
+ * \param block The queried block
+ * \param n The index of the queried buffer
+ * \return The buffer of the n-th write region of the block.
+ * \throw ScheduleError If the buffer index is out of bound.
+ */
+Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
+
 /******** Commutative Reducer ********/
 
 /*!
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 3ee98ec..8d1913f 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -527,6 +527,45 @@ BlockRealize GetBlockRealize(const ScheduleState& self, 
const StmtSRef& block_sr
   }
 }
 
+/******** Block-buffer relation ********/
+
+Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) 
{
+  class WriteBufferIndexOutOfRangeError : public ScheduleError {
+   public:
+    explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int 
buffer_index)
+        : mod_(std::move(mod)), block_(std::move(block)), 
buffer_index_(buffer_index) {}
+
+    String FastErrorString() const final {
+      return "ScheduleError: The input `buffer_index` is out of range. It is 
required to be in "
+             "range [0, num_write_regions) where `num_write_regions` is the 
number of buffer "
+             "regions written by the block.";
+    }
+
+    String DetailRenderTemplate() const final {
+      std::ostringstream os;
+      size_t num_writes = block_->writes.size();
+      os << "The block {0} has " << num_writes
+         << " write regions, so `buffer_index` is required to be in [0, " << 
num_writes
+         << "). However, the input `buffer_index` is " << buffer_index_
+         << ", which is out of the expected range";
+      return os.str();
+    }
+
+    IRModule mod() const final { return mod_; }
+    Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+   private:
+    IRModule mod_;
+    Block block_;
+    int buffer_index_;
+  };
+
+  if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
+    throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
+  }
+  return block->writes[n]->buffer;
+}
+
 /******** Pattern Matcher ********/
 
 /*!
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 610628c..688ea80 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const 
BlockRV& block_rv) {
 }
 
 /******** Schedule: loop binding/annotation ********/
+/******** Schedule: block annotation ********/
+
+void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int 
buffer_index, int axis,
+                                        int factor, int offset) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, 
factor, offset);
+  TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_);
+  this->state_->DebugVerify();
+}
+
 /******** Schedule: cache read/write ********/
 /******** Schedule: reduction ********/
 
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index ec0dd07..cfdd9c8 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode {
   void ReverseComputeInline(const BlockRV& block) override;
   /******** Schedule: Reduction ********/
   BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
+  /******** Schedule: Block annotation ********/
+  void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int 
factor,
+                    int offset) override;
   /******** Schedule: Blockize & Tensorize ********/
   /******** Schedule: Annotation ********/
   /******** Schedule: Misc ********/
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 22e25f1..4b9c769 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -104,6 +104,26 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, 
const StmtSRef& block_sref
  * \return The sref of the rfactor block
  */
 TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int 
factor_axis);
+/******** Schedule: Block annotation ********/
+/*!
+ * \brief Set alignment requirement for specific dimension such that
+ *        stride[axis] == k * factor + offset for some k. This is useful to 
set memory layout for
+ *        more friendly memory access pattern. For example, we can set 
alignment to be factor=2,
+ *        offset=1 to avoid bank conflict for thread access on higher 
dimension in GPU shared
+ *        memory.
+ * \param block_sref The producer block of the buffer
+ * \param buffer_index The index of the buffer in block's write region
+ * \param axis The dimension to be specified for alignment
+ * \param factor The factor multiple of alignment
+ * \param offset The required offset factor
+ */
+TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index,
+                          int axis, int factor, int offset);
+
+/******** Annotation types for StorageAlign ********/
+using StorageAlignTuple = Array<Integer>;                 // (buffer_idx, 
axis, factor, offset)
+using StorageAlignAnnotation = Array<StorageAlignTuple>;  // unordered array 
of StorageAlignTuple
+
 /******** Schedule: Blockize & Tensorize ********/
 /******** Schedule: Annotation ********/
 /******** Schedule: Misc ********/
diff --git a/src/tir/schedule/primitive/block_annotate.cc 
b/src/tir/schedule/primitive/block_annotate.cc
new file mode 100644
index 0000000..937bc7c
--- /dev/null
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../transform.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+class StorageAlignAxisOutOfRangeError : public ScheduleError {
+ public:
+  explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int 
axis)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The input `axis` is out of range. It is required to 
be in range "
+           "[-ndim, ndim) where `ndim` is the number of dimensions of the 
buffer to set "
+           "storage alignment.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    int ndim = static_cast<int>(buffer_->shape.size());
+    os << "The buffer to set storage alignment of, " << buffer_->name << ", 
has " << ndim
+       << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " 
<< ndim
+       << ") for storage_align. However, the input `axis` is " << axis_
+       << ", which is out of the expected range.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+  static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int 
axis) {
+    int ndim = static_cast<int>(buffer->shape.size());
+    if (axis < -ndim || axis >= ndim) {
+      throw StorageAlignAxisOutOfRangeError(mod, buffer, axis);
+    }
+    // If axis is negative, convert it to a non-negative one.
+    if (axis < 0) {
+      axis += ndim;
+    }
+    return axis;
+  }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  int axis_;
+};
+
+/*!
+ * \brief Find the defining site of the buffer in the given block and its 
ancestors
+ * \param block_sref The block sref
+ * \param buffer The buffer
+ * \return The defining site of the buffer and whether the buffer is allocated 
(otherwise the
+ *         buffer is from match_buffer).
+ */
+std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& 
block_sref,
+                                                          const Buffer& 
buffer) {
+  // Climb up along the sref tree, and find the block where `buffer` is in 
alloc_buffers or
+  // match_buffers.
+  const StmtSRefNode* defining_site_sref = block_sref.get();
+  while (defining_site_sref != nullptr) {
+    const auto* block = defining_site_sref->StmtAs<BlockNode>();
+    // If this sref is not a block sref, skip it.
+    if (block == nullptr) {
+      defining_site_sref = defining_site_sref->parent;
+      continue;
+    }
+    // Try to find the buffer in `allloc_buffers`
+    for (const Buffer& alloc_buffer : block->alloc_buffers) {
+      if (buffer.same_as(alloc_buffer)) {
+        return {GetRef<StmtSRef>(defining_site_sref), true};
+      }
+    }
+    // We do not allow the buffer being defined in `match_buffer`.
+    for (const MatchBufferRegion match_buffer : block->match_buffers) {
+      if (buffer.same_as(match_buffer)) {
+        return {GetRef<StmtSRef>(defining_site_sref), false};
+      }
+    }
+    defining_site_sref = defining_site_sref->parent;
+  }
+  // If we cannot find the defining site block, it means that the buffer must 
be in the function's
+  // buffer_map, which isn't an intermediate buffer.
+  return {NullOpt, false};
+}
+
+class NonAllocatedBufferError : public ScheduleError {
+ public:
+  explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), 
buffer_(buffer) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The input buffer is not allocated by a block. This 
means the buffer is "
+           " either a function parameter or defined in `match_buffer` of a 
block.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The input buffer " << buffer_->name
+       << " is not allocated by a block. This means the buffer is either a 
function parameter or "
+          "defined in `match_buffer` of a block.";
+    return os.str();
+  }
+
+  static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& 
block_sref,
+                                   const Buffer& buffer) {
+    Optional<StmtSRef> defining_site_sref;
+    bool is_alloc;
+    std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, 
buffer);
+    if (!defining_site_sref || !is_alloc) {
+      throw NonAllocatedBufferError(mod, buffer);
+    }
+  }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+  IRModule mod() const final { return mod_; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class StorageAlignInvalidFactorError : public ScheduleError {
+ public:
+  explicit StorageAlignInvalidFactorError(IRModule mod, int factor)
+      : mod_(std::move(mod)), factor_(factor) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The input `factor` of storage_align is expected to 
be a positive "
+           "number.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The input `factor` of storage_align is expected to be a positive 
number. However, the "
+          "input `factor` is "
+       << factor_ << ", which is out of the expected range.";
+    return os.str();
+  }
+
+  static void Check(const IRModule& mod, int factor) {
+    if (factor <= 0) {
+      throw StorageAlignInvalidFactorError(mod, factor);
+    }
+  }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+  IRModule mod() const final { return mod_; }
+
+ private:
+  IRModule mod_;
+  int factor_;
+};
+
+class StorageAlignInvalidAnnotationError : public ScheduleError {
+ public:
+  explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The block annotation for storage align is expected 
to be an array of "
+           "4-integer-tuples (buffer_index, axis, factor, offset).";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The block annotation for storage align is expected to be an array 
of 4-integer-tuples "
+          "(buffer_index, axis, factor, offset). However, the block annotation 
with key "
+       << attr::buffer_dim_align << " of the block {0} is "
+       << block_->annotations.at(attr::buffer_dim_align) << ", which is 
unexpected.";
+    return os.str();
+  }
+
+  static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, 
const Block& block) {
+    // Get existing annotation value.
+    auto it = block->annotations.find(attr::buffer_dim_align);
+    if (it != block->annotations.end()) {
+      if (!IsValidAnnotation(block, (*it).second)) {
+        throw StorageAlignInvalidAnnotationError(mod, block);
+      }
+      return Downcast<StorageAlignAnnotation>((*it).second);
+    }
+
+    // Create new annotation value
+    StorageAlignAnnotation storage_align_annotation;
+    return storage_align_annotation;
+  }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+  IRModule mod() const final { return mod_; }
+
+ private:
+  static bool IsValidAnnotation(const Block& block, const ObjectRef& 
anno_value) {
+    if (!anno_value->IsInstance<ArrayNode>()) {
+      return false;
+    }
+    auto storage_align_annotations = Downcast<Array<ObjectRef>>(anno_value);
+    for (const ObjectRef& storage_align_annotation : 
storage_align_annotations) {
+      if (!storage_align_annotation->IsInstance<ArrayNode>()) {
+        return false;
+      }
+      auto storage_align_tuple = 
Downcast<Array<ObjectRef>>(storage_align_annotation);
+      // Check if the annotation is a 4-tuple.
+      if (storage_align_tuple.size() != 4) {
+        return false;
+      }
+      for (const ObjectRef& tuple_element : storage_align_tuple) {
+        if (!tuple_element->IsInstance<IntImmNode>()) {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  IRModule mod_;
+  Block block_;
+};
+
+void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int 
buffer_index, int axis,
+                  int factor, int offset) {
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr), 
buffer_index);
+  StorageAlignInvalidFactorError::Check(self->mod, factor);
+  axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, 
axis);
+  NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);
+
+  // Step 1: Get existing or create new annotation value.
+  StorageAlignAnnotation storage_align_annotation =
+      StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod,
+                                                                
GetRef<Block>(block_ptr));
+
+  // Step 2: Update the annotation value
+  // Array<Array<Integer>> buffer_storage_align = 
storage_align_annotation[buffer_index];
+  bool found = false;
+  StorageAlignTuple new_storage_align_tuple{Integer(buffer_index), 
Integer(axis), Integer(factor),
+                                            Integer(offset)};
+  for (size_t j = 0; j < storage_align_annotation.size(); ++j) {
+    const auto& storage_align_tuple = storage_align_annotation[j];
+    ICHECK(storage_align_tuple.size() == 4);
+    if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] == 
axis) {
+      storage_align_annotation.Set(j, std::move(new_storage_align_tuple));
+      found = true;
+      break;
+    }
+  }
+  if (!found) {
+    storage_align_annotation.push_back(std::move(new_storage_align_tuple));
+  }
+
+  // Step 3: Replace the block with the new annotation
+  Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, 
storage_align_annotation);
+  self->Replace(block_sref, new_block, {{GetRef<Block>(block_ptr), 
new_block}});
+}
+
+/******** Instruction Registration ********/
+
+struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
+  static constexpr const char* kName = "StorageAlign";
+  static constexpr bool kIsPure = false;
+
+ private:
+  static constexpr size_t kNumInputs = 1;
+  static constexpr size_t kNumAttrs = 4;
+  static constexpr size_t kNumDecisions = 0;
+
+  static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer 
buffer_index,
+                                      Integer axis, Integer factor, Integer 
offset) {
+    return sch->StorageAlign(block_rv, buffer_index->value, axis->value, 
factor->value,
+                             offset->value);
+  }
+
+  static String UnpackedAsPython(Array<String> outputs, String block_rv, 
Integer buffer_index,
+                                 Integer axis, Integer factor, Integer offset) 
{
+    PythonAPICall py("storage_align");
+    py.Input("block", block_rv);
+    py.Input("buffer_index", buffer_index);
+    py.Input("axis", axis);
+    py.Input("factor", factor);
+    py.Input("offset", offset);
+    return py.Str();
+  }
+
+  template <typename>
+  friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 3232a33..d6dc0b4 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -135,6 +135,9 @@ 
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline")
 /******** (FFI) Reduction ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor")
     .set_body_method<Schedule>(&ScheduleNode::RFactor);
+/******** (FFI) Block annotation ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
+    .set_body_method<Schedule>(&ScheduleNode::StorageAlign);
 /******** (FFI) Blockize & Tensorize ********/
 /******** (FFI) Annotation ********/
 /******** (FFI) Misc ********/
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index d664d7f..e0ffdc7 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -137,6 +137,19 @@ BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, 
int factor_axis) {
   return result;
 }
 
+/******** Schedule: Block annotation ********/
+
+void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int 
buffer_index, int axis,
+                                      int factor, int offset) {
+  ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, 
offset);
+  static const InstructionKind& kind = InstructionKind::Get("StorageAlign");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{block_rv},
+      /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), 
Integer(offset)},
+      /*outputs=*/{}));
+}
+
 /******** Schedule: Blockize & Tensorize ********/
 
 /******** Schedule: Annotation ********/
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index b4518cb..4650c44 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -61,6 +61,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void ReverseComputeInline(const BlockRV& block_rv) final;
   /******** Schedule: Reduction ********/
   BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final;
+  /******** Schedule: Block annotation ********/
+  void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int 
factor,
+                    int offset) final;
   /******** Schedule: Blockize & Tensorize ********/
   /******** Schedule: Annotation ********/
   /******** Schedule: Misc ********/
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
new file mode 100644
index 0000000..f27e0f6
--- /dev/null
+++ b/src/tir/schedule/transform.cc
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./transform.h"
+
+namespace tvm {
+namespace tir {
+
+/******** Annotation ********/
+Block WithAnnotation(const BlockNode* block, const String& attr_key, const 
ObjectRef& attr_value) {
+  Map<String, ObjectRef> annotations = block->annotations;
+  annotations.Set(attr_key, attr_value);
+  ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
+  new_block->annotations = std::move(annotations);
+  return Block(new_block);
+}
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
new file mode 100644
index 0000000..5348382
--- /dev/null
+++ b/src/tir/schedule/transform.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
+#define TVM_TIR_SCHEDULE_TRANSFORM_H_
+
+#include <tvm/tir/schedule/state.h>
+
+namespace tvm {
+namespace tir {
+
+/******** Annotation ********/
+
+/*!
+ * \brief Create a new block with the given annotation added
+ * \param block The block with original annotation
+ * \param attr_key The annotation key to be added
+ * \param attr_value The annotation value to be added
+ * \return A new block with the given annotation as its last annotation
+ */
+Block WithAnnotation(const BlockNode* block, const String& attr_key, const 
ObjectRef& attr_value);
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_SCHEDULE_TRANSFORM_H_
diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
index b1a4fd4..961ea17 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -303,18 +303,61 @@ class BufferAccessRegionCollector : public 
StmtExprVisitor {
   support::Arena arena_;
 };
 
+/*! \brief Collect storage alignment information from block annotations. */
+class StorageAlignCollector : public StmtVisitor {
+ public:
+  static std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual> Collect(
+      const PrimFunc& f) {
+    StorageAlignCollector collector;
+    collector(f->body);
+    return std::move(collector.storage_align_);
+  }
+
+ private:
+  void VisitStmt_(const BlockNode* op) final {
+    auto it = op->annotations.find(attr::buffer_dim_align);
+    if (it != op->annotations.end()) {
+      auto storage_align_annotation = 
Downcast<StorageAlignAnnotation>((*it).second);
+      for (const auto& storage_align_tuple : storage_align_annotation) {
+        int buffer_index = storage_align_tuple[0]->value;
+        const Buffer& buffer = op->writes[buffer_index]->buffer;
+        storage_align_[buffer].push_back(storage_align_tuple);
+      }
+    }
+    StmtVisitor::VisitStmt_(op);
+  }
+
+  /*! \brief The map from Buffer to its storage alignment information. */
+  std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual> storage_align_;
+};
+
 /*! \brief Reallocate the buffers with minimal region. */
 class BufferCompactor : public StmtExprMutator {
  public:
   static Stmt Compact(
       const PrimFunc& f,
-      const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& 
regions) {
+      const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& 
regions,
+      const std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual>&
+          storage_align) {
     std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info;
 
     for (const auto& kv : regions) {
       const Buffer& buffer = kv.first;
       Region region = kv.second;
-      buffer_info.emplace(buffer, BufferAllocInfo(std::move(region)));
+      BufferAllocInfo buffer_alloc_info(std::move(region));
+      auto it = storage_align.find(buffer);
+      if (it != storage_align.end()) {
+        std::vector<DimAlignInfo> dim_aligns(buffer->shape.size());
+        for (const StorageAlignTuple& dim_align : (*it).second) {
+          ICHECK(dim_align.size() == 4);
+          int dim = dim_align[1]->value;
+          int factor = dim_align[2]->value;
+          int offset = dim_align[3]->value;
+          dim_aligns.at(dim) = {factor, offset};
+        }
+        buffer_alloc_info.dim_aligns = std::move(dim_aligns);
+      }
+      buffer_info.emplace(buffer, std::move(buffer_alloc_info));
     }
     BufferCompactor compactor(std::move(buffer_info));
     Stmt stmt = compactor(f->body);
@@ -322,9 +365,19 @@ class BufferCompactor : public StmtExprMutator {
   }
 
  private:
+  /*! \brief The storage alignment for a dimension */
+  struct DimAlignInfo {
+    /*! \brief The factor of the alignment */
+    int align_factor{0};
+    /*! \brief The offset of the alignment */
+    int align_offset{0};
+  };
+
   struct BufferAllocInfo {
     /*! \brief The buffer access region. */
     Region region;
+    /*! \brief The storage alignment information. */
+    std::vector<DimAlignInfo> dim_aligns;
     /*!
      * \brief The reallocated buffer with minimal size.
      * \note The value if NullOpt if the buffer do not need reallocate (e.g 
parameter buffer).
@@ -380,8 +433,25 @@ class BufferCompactor : public StmtExprMutator {
       for (const Range& range : info.region) {
         shape.push_back(range->extent);
       }
+      Array<PrimExpr> strides;
+      if (info.dim_aligns.size()) {
+        ICHECK(info.dim_aligns.size() == shape.size());
+        strides.resize(shape.size());
+        PrimExpr stride = make_const(shape[0].dtype(), 1);
+        for (size_t i = shape.size(); i != 0; --i) {
+          size_t dim = i - 1;
+          if (info.dim_aligns[dim].align_factor != 0) {
+            PrimExpr factor = make_const(stride.dtype(), 
info.dim_aligns[dim].align_factor);
+            PrimExpr offset = make_const(stride.dtype(), 
info.dim_aligns[dim].align_offset);
+            stride = stride + indexmod(factor + offset - indexmod(stride, 
factor), factor);
+          }
+          strides.Set(dim, stride);
+          stride = stride * shape[dim];
+        }
+      }
       ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
       n->shape = std::move(shape);
+      n->strides = std::move(strides);
       info.new_buffer = Buffer(std::move(n));
       result.push_back(info.new_buffer);
     }
@@ -458,7 +528,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) {
     PrimFuncNode* fptr = f.CopyOnWrite();
     std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
         BufferAccessRegionCollector::Collect(f);
-    fptr->body = BufferCompactor::Compact(f, region);
+    std::unordered_map<Buffer, StorageAlignAnnotation, ObjectPtrHash, 
ObjectPtrEqual>
+        storage_align = StorageAlignCollector::Collect(f);
+    fptr->body = BufferCompactor::Compact(f, region, storage_align);
     return f;
   } else {
     return f;
diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py 
b/tests/python/unittest/test_tir_schedule_storage_align.py
new file mode 100644
index 0000000..a0a0693
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_storage_align.py
@@ -0,0 +1,182 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import ty
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name
+
[email protected]
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+    C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    # body
+    with tir.block([], "root"):
+        tir.reads([])
+        tir.writes([])
+        B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, 
offset_factor=1)
+        for i0 in tir.serial(0, 128):
+            for ax1 in tir.serial(0, 128):
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    tir.bind(vi, i0)
+                    tir.bind(vj, ax1)
+                    tir.reads([A[vi, vj]])
+                    tir.writes([B[vi, vj]])
+                    B[vi, vj] = (A[vi, vj]*tir.float32(2))
+            for i1 in tir.serial(0, 128):
+                with tir.block([128, 128], "C") as [vi_1, vj_1]:
+                    tir.bind(vi_1, i0)
+                    tir.bind(vj_1, i1)
+                    tir.reads([B[vi_1, vj_1]])
+                    tir.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
+
+
[email protected]
+def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None:
+    C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    # body
+    with tir.block([], "root"):
+        tir.reads([])
+        tir.writes([])
+        B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, 
offset_factor=1)
+        for i0 in tir.serial(0, 128):
+            for ax1 in tir.serial(0, 128):
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    tir.bind(vi, i0)
+                    tir.bind(vj, ax1)
+                    tir.reads([A[vi, vj]])
+                    tir.writes([B[vi, vj]])
+                    tir.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]})
+                    B[vi, vj] = (A[vi, vj]*tir.float32(2))
+            for i1 in tir.serial(0, 128):
+                with tir.block([128, 128], "C") as [vi_1, vj_1]:
+                    tir.bind(vi_1, i0)
+                    tir.bind(vj_1, i1)
+                    tir.reads([B[vi_1, vj_1]])
+                    tir.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
+
+
[email protected]
+def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None:
+    C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, 
offset_factor=1)
+    # body
+    with tir.block([], "root"):
+        tir.reads([])
+        tir.writes([])
+        B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, 
offset_factor=1)
+        for i0 in tir.serial(0, 128):
+            for ax1 in tir.serial(0, 128):
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    tir.block_attr({"buffer_dim_align": [0]})
+                    tir.bind(vi, i0)
+                    tir.bind(vj, ax1)
+                    tir.reads([A[vi, vj]])
+                    tir.writes([B[vi, vj]])
+                    B[vi, vj] = (A[vi, vj]*tir.float32(2))
+            for i1 in tir.serial(0, 128):
+                with tir.block([128, 128], "C") as [vi_1, vj_1]:
+                    tir.bind(vi_1, i0)
+                    tir.bind(vj_1, i1)
+                    tir.reads([B[vi_1, vj_1]])
+                    tir.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
+
+
+def test_storage_align():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    s.storage_align(B, 0, axis=0, factor=128, offset=127)
+    tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_storage_align_update():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    s.storage_align(B, 0, axis=0, factor=128, offset=0)
+    s.storage_align(B, 0, axis=0, factor=128, offset=127)
+    tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"])
+    verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_storage_align_invalid_factor1():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    with pytest.raises(tir.ScheduleError):
+        s.storage_align(B, 0, axis=0, factor=0, offset=127)
+
+
+def test_storage_align_invalid_factor2():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    with pytest.raises(tir.ScheduleError):
+        s.storage_align(B, 0, axis=0, factor=-1, offset=127)
+
+
+def test_storage_align_invalid_buffer():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    C = s.get_block("C")
+    with pytest.raises(tir.ScheduleError):
+        s.storage_align(C, 0, axis=0, factor=128, offset=127)
+
+
+def test_storage_align_invalid_buffer_index():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    with pytest.raises(tir.ScheduleError):
+        s.storage_align(B, 2, axis=0, factor=128, offset=127)
+
+
+def test_storage_align_invalid_axis():
+    func = element_wise
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    with pytest.raises(tir.ScheduleError):
+        s.storage_align(B, 0, axis=2, factor=128, offset=127)
+
+
+def test_storage_align_invalid_annotation():
+    func = element_wise_invalid_annotation
+    s = tir.Schedule(func, debug_mask='all')
+    B = s.get_block("B")
+    with pytest.raises(tir.ScheduleError):
+        s.storage_align(B, 0, axis=2, factor=128, offset=127)
+
+
+if __name__ == "__main__":
+    test_storage_align()
+    test_storage_align_update()
+    test_storage_align_invalid_factor1()
+    test_storage_align_invalid_factor2()
+    test_storage_align_invalid_buffer()
+    test_storage_align_invalid_buffer_index()
+    test_storage_align_invalid_axis()
+    test_storage_align_invalid_annotation()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index fb53b42..15da022 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -339,6 +339,50 @@ def compacted_match_buffer_func(a: ty.handle, c: 
ty.handle) -> None:
                     C1[()] = B2[()] * 2.0
 
 
[email protected]
+def storage_align_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer((16, 16), "float32")
+            for j in range(0, 16):
+                with tir.block([]) as []:
+                    tir.reads(A[i, j])
+                    tir.writes(B[i, j])
+                    tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+                    B[i, j] = A[i, j] + 1.0
+            for j in range(0, 16):
+                with tir.block([]) as []:
+                    tir.reads(B[i, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[i, j] * 2.0
+
+
[email protected]
+def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32")
+            for j in range(0, 16):
+                with tir.block() as []:
+                    tir.reads(A[i, j])
+                    tir.writes(B[0, j])
+                    tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
+                    B[0, j] = A[i, j] + 1.0
+            for j in range(0, 16):
+                with tir.block() as []:
+                    tir.reads(B[0, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[0, j] * 2.0
+
+
 def test_elementwise():
     _check(elementwise_func, compacted_elementwise_func)
 
@@ -380,6 +424,10 @@ def test_lower_te():
     tvm.ir.assert_structural_equal(mod, orig_mod)  # CompactBufferAllocation 
should do nothing on TE
 
 
+def test_storage_align():
+    _check(storage_align_func, compacted_storage_align_func)
+
+
 if __name__ == "__main__":
     test_elementwise()
     test_unschedulable_block()
@@ -389,3 +437,4 @@ if __name__ == "__main__":
     test_symbolic()
     test_complex()
     test_match_buffer()
+    test_storage_align()

Reply via email to