This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a31ce9ee88 [Unity][Layout] Add layout transformation analysis for
PrimFunc (#14066)
a31ce9ee88 is described below
commit a31ce9ee888b562db9b4a1923294eee98049122b
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Thu Feb 23 05:21:20 2023 -0500
[Unity][Layout] Add layout transformation analysis for PrimFunc (#14066)
* [Layout] Add layout transformation analysis for PrimFunc.
This change adds a PrimFunc level analysis to suggest layout
transformations to block and buffers in the PrimFunc based on the layout
transformations to PrimFunc outputs.
* Add support for multiple blocks such as split op.
* Add negative tests and increase coverage.
* fix warning message
* fix lint
* remove unused header
* Address comments.
Moved some utility functions to support/array.h
improve doc
* fix deprecation warn T.var("int64") to T.int64()
* address comments
---
include/tvm/relax/analysis.h | 13 +
python/tvm/relax/analysis/analysis.py | 32 +-
src/relax/analysis/layout_transformation.cc | 621 +++++++++++++++
src/support/array.h | 27 +-
.../test_analysis_suggest_layout_transforms.py | 831 +++++++++++++++++++++
5 files changed, 1522 insertions(+), 2 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 39ecfd9e13..2b771b9708 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -403,6 +403,19 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);
*/
TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);
+/*!
+ * \brief Using the layout transforms on the outputs, suggest layout
transformation on the blocks
+ * and buffers for the PrimFunc.
+ *
+ * \param fn The PrimFunc to be analyzed.
+ * \param write_buffer_transformations Array of IndexMap transformations on
PrimFunc outputs.
+ * \return Suggested transforms per block in `fn`. For each block the returned
value is a map
+ * from the object (block or buffer) to it's index map transformation.
+ */
+
+TVM_DLL Map<tir::Block, Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
+ const Function& fn, Array<tir::IndexMap> write_buffer_transformations);
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index ffcdaceb40..efd1b51f11 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -21,7 +21,7 @@ This file contains the set of passes for Relax, which exposes
an interface for
configuring the passes and scripting them in Python.
"""
-from typing import Dict, List
+from typing import Dict, List, Union, Callable
from enum import IntEnum
from tvm import tir
@@ -29,6 +29,7 @@ from tvm import IRModule
from tvm.relax.ty import Type
from tvm.relax.struct_info import StructInfo, FuncStructInfo
from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding
+from tvm.tir import IndexMap, PrimFunc, Block, Buffer
from . import _ffi_api
@@ -289,3 +290,32 @@ def well_formed(mod: IRModule, check_struct_info: bool =
True) -> bool:
will be well tested and will not be blocked by not having structure info.
"""
return _ffi_api.well_formed(mod, check_struct_info) # type: ignore
+
+
+def suggest_layout_transforms(
+ func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]]
+) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]:
+ """Suggest Layout transformations of blocks and buffers in a PrimFunc.
+
+ Parameters
+ ----------
+ func: PrimFunc
+ PrimFunc on which analysis will be performed and transformations
suggested.
+
+ write_buffer_transforms: List[Union[IndexMap, Callable]
+ List of layout transformations on the output buffers. The number of
layout
+ transformations must match the number of outputs of the PrimFunc.
+
+ Returns
+ -------
+ ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]]
+ Suggested transforms per block in `func`. For each block the returned
value is a map
+ from the object (block or buffer) to it's index map transformation.
+ """
+ write_buffer_index_maps = []
+ for transform in write_buffer_transforms:
+ if callable(transform):
+ transform = IndexMap.from_func(transform)
+ assert isinstance(transform, IndexMap)
+ write_buffer_index_maps.append(transform)
+ return _ffi_api.suggest_layout_transforms(func, write_buffer_index_maps)
# type: ignore
diff --git a/src/relax/analysis/layout_transformation.cc
b/src/relax/analysis/layout_transformation.cc
new file mode 100644
index 0000000000..44538fea98
--- /dev/null
+++ b/src/relax/analysis/layout_transformation.cc
@@ -0,0 +1,621 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/analysis/layout_transormation.cc
+ * \brief Analyze the PrimFunc and suggest layout transformation on it's
blocks and buffers based on
+ * the user provided layout transformations on it's outputs.
+ */
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../support/array.h"
+
+namespace tvm {
+namespace relax {
+
+using namespace tir;
+
+/********** Helper Functions **********/
+
+/*! \brief Checks if a transformation is bijective affine over the given
ranges */
+static bool IsBijectiveAffine(const IndexMap& m, const Array<Range>& ranges) {
+ Map<tir::Var, Range> input_iters;
+ ICHECK_EQ(m->initial_indices.size(), ranges.size());
+ for (size_t i = 0; i < ranges.size(); i++) {
+ input_iters.Set(m->initial_indices[i], ranges[i]);
+ }
+ arith::Analyzer analyzer;
+ auto iter_map_result = DetectIterMap(m->final_indices, input_iters, /*
predicate = */ 1,
+
/*check_level=*/arith::IterMapLevel::Bijective, &analyzer,
+ /*simplify_trivial_iterators=*/true);
+ return !iter_map_result->indices.empty();
+}
+
+/*!
+ * \brief Analyzer to collect iterators from IterSumExpr.
+ * \details Analyzes the indices from DetectIterMap analysis to collect the
spatial iterators that
+ * are used in it. This is important to get which spatial iterators are
accessed in each index
+ * of buffer access.
+ */
+class IndexAnalyzer : public ExprVisitor {
+ public:
+ Array<tir::Var> Analyze(const arith::IterSumExpr& expr) {
+ VisitExpr(expr);
+ return iterators_;
+ }
+
+ private:
+ /*! \brief Override VisitExpr for iter expr type processing */
+ void VisitExpr(const PrimExpr& expr) override {
+ if (const auto* op = expr.as<arith::IterSumExprNode>()) {
+ for (const auto& arg : op->args) VisitExpr(arg);
+ VisitExpr(op->base);
+ return;
+ }
+ if (const auto* op = expr.as<arith::IterSplitExprNode>()) {
+ VisitIterMark(op->source);
+ VisitExpr(op->lower_factor);
+ VisitExpr(op->extent);
+ VisitExpr(op->scale);
+ return;
+ }
+ return ExprVisitor::VisitExpr(expr);
+ }
+
+ void VisitIterMark(const arith::IterMark& op) {
+ if (const auto* var = op->source.as<tir::VarNode>())
+ iterators_.push_back(GetRef<tir::Var>(var));
+ else
+ VisitExpr(op->source);
+ VisitExpr(op->extent);
+ }
+
+ private:
+ Array<tir::Var> iterators_;
+};
+
+/*!
+ * \brief Analyzes IterMapResult to get the Spatial Layout of buffer access.
+ * \details We define Spatial Layout of a buffer access as an array of length
equal to the
+ * dimensions of the buffer. i-th element of Spatial Layout contains spatial
iter var used from the
+ * block iteration domain. For indices, where no spatial iter vars are used,
the spatial layout
+ * element is empty. If any of the buffer access indices use multiple spatial
iter vars, the spatial
+ * layout is undefined.
+ *
+ * Here are a few examples of inferred spatial layout from buffer access. si
denotes i-th spatial
+ * iter var, and ri denotes i-th reduction iter var.
+ *
+ * SpatialLayout(A[s0*constant, s1]) = {s0, s1}
+ * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1}
+ * SpatialLayout(A[s0 * c + s1]) = undefined
+ */
+using SpatialLayout = Array<Optional<tir::Var>>;
+static SpatialLayout GetSpatialLayout(const arith::IterMapResult&
iter_map_result) {
+ ICHECK(!iter_map_result->indices.empty());
+ SpatialLayout result;
+ for (const arith::IterSumExpr& index : iter_map_result->indices) {
+ IndexAnalyzer index_analyzer;
+ Array<tir::Var> iter_vars = index_analyzer.Analyze(index);
+ if (iter_vars.size() >= 2) {
+ LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of
access: "
+ << arith::NormalizeIterMapToExpr(index);
+ return {};
+ }
+ if (iter_vars.empty()) {
+ result.push_back({});
+ continue;
+ }
+ result.push_back(iter_vars[0]);
+ }
+ return result;
+}
+
+/*!
+ * \brief Checks if the two spatial layouts are identical. Two empty spatial
layouts are treated as
+ * unequal.
+ */
+static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const
SpatialLayout& s1) {
+ if (s0.empty() || s1.empty()) return false;
+ if (s0.size() != s1.size()) return false;
+ for (size_t i = 0; i < s0.size(); ++i) {
+ if ((!s0[i].defined() && s1[i].defined()) || (s0[i].defined() &&
!s1[i].defined()))
+ return false;
+ if (!s0[i].same_as(s1[i])) return false;
+ }
+ return true;
+}
+
+/*!
+ * \brief Checks if the block accesses a buffer sequentially in terms of
spatial dimensions
+ * (ignoring reduction dimensions). It checks that the order of spatial iter
vars in spatial layout
+ * of a buffer access is same as the order of spatial iter vars in block
domain.
+ */
+using VarToBlockIndexMap = std::unordered_map<tir::Var, int, ObjectPtrHash,
ObjectPtrEqual>;
+static bool IsSequentialAccess(const SpatialLayout& iterators,
+ const VarToBlockIndexMap& iter_to_block_index) {
+ int last_value = -1;
+ for (const auto& i : iterators) {
+ if (!i.defined()) continue;
+ auto it = iter_to_block_index.find(i.value());
+ ICHECK(it != iter_to_block_index.end());
+ int blk_index = it->second;
+ if (blk_index <= last_value) return false;
+ last_value = blk_index;
+ }
+ return true;
+}
+
+/*! \brief Checks if two IndexMaps represent identical transforms */
+static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) {
+ if (t0->initial_indices.size() != t1->initial_indices.size()) return false;
+ if (t0->final_indices.size() != t1->final_indices.size()) return false;
+
+ // Create a new shape expression.
+ Array<PrimExpr> t1_initial_indices =
+ t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; });
+ auto t0_output = t0->MapIndices(t1_initial_indices);
+ arith::Analyzer analyzer;
+ for (size_t i = 0; i < t0_output.size(); ++i) {
+ if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return
false;
+ }
+ return true;
+}
+
+/*!
+ * \brief Returns the layout transformation for a target spatial layout from
the source spatial
+ * layout and transformation.
+ * \details Given the source buffer spatial layout \p src_spatial_layout and
its transformation \p
+ * src_transformation, this function constructs the transformation for the
target buffer whose
+ * spatial layout is given as \p tgt_spatial_layout.
+ *
+ * The algorithm is explained below using an example:
+ *
+ * Let's say the source transformation is lambda N, C, H, W -> (N, H, W, C //
4, C %
+ * 4), source spatial layout is 'NCHW' and target spatial layout is 'KCHW'.
+ *
+ * Step 1: Copy over the source transformation initial & final indices to
target transformation
+ * initial and final indices.
+ * target transformation = lambda N, C, H, W -> (N, H, W, C // 4, C %4)
+ *
+ * Step 2: Drop any vars from initial indices which do not occur in target
buffer using source and
+ * target spatial layouts.
+ * target transformation = lambda C, H, W -> (N, H, W, C // 4, C %4)
+ *
+ * Step 3: Erase any expression from final indices which is dependent on a var
not present in
+ * initial indices.
+ * target transformation = lambda C, H, W -> (H, W, C // 4, C %4)
+ *
+ * Step 4: Go over the target spatial layout and add any missing dims to both
initial and final
+ * indices. This is done by checking if any iterator in target spatial layout
is not present in
+ * source spatial layout.
+ * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4)
+ */
+using VarSet = std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>;
+static Optional<IndexMap> InferLayoutTransformation(const SpatialLayout&
src_spatial_layout,
+ const IndexMap&
src_transformation,
+ const SpatialLayout&
tgt_spatial_layout) {
+ // Copy over the src transformation intial and final indices
+ auto initial_indices = support::AsList(src_transformation->initial_indices);
+ auto final_indices = support::AsList(src_transformation->final_indices);
+
+ // Get the iterator var set used in target spatial layout.
+ VarSet tgt_var_set;
+ for (const auto& i : tgt_spatial_layout) {
+ if (i.defined()) tgt_var_set.insert(i.value());
+ }
+
+ // Erase initial indices corresponding to iter vars that do not occur in
target spatial layout.
+ // Also compute the var set of initial indices.
+ auto initial_indices_it = initial_indices.begin();
+ VarSet initial_indices_var_set;
+ for (const auto& i : src_spatial_layout) {
+ ICHECK(i.defined());
+ if (tgt_var_set.count(i.value())) {
+ initial_indices_var_set.insert(*initial_indices_it);
+ initial_indices_it++;
+ continue;
+ }
+ initial_indices_it = initial_indices.erase(initial_indices_it);
+ }
+
+ // Erase any expressions in final indices that have undefined vars
+ auto final_indices_it = final_indices.begin();
+ while (final_indices_it != final_indices.end()) {
+ // Collect all the vars used in this final index.
+ Array<tir::Var> used_vars = tir::UndefinedVars(*final_indices_it);
+ ICHECK(!used_vars.empty())
+ << "IndexMap expression must always contain tir::Var nodes but found
none in: "
+ << *final_indices_it;
+
+ bool has_undefined_vars = std::any_of(used_vars.begin(), used_vars.end(),
+ [&initial_indices_var_set](const
tir::Var& v) {
+ return
initial_indices_var_set.count(v) == 0;
+ });
+
+ // If all vars are from initial indices, nothing to do for this final
index.
+ if (!has_undefined_vars) {
+ final_indices_it++;
+ continue;
+ }
+ // We are about to drop this expr from final indices since it has
undefined vars. Check if it is
+ // dependent on any of the initial indices. If it is dependent, this
cannot be dropped and we
+ // bail by returning null.
+ // This captures the scenario where the source transformation is unpacking
a dimension (e.g,
+ // "H4h" -> "H*4+h" ) and the buffer we are trying to infer the
transformation of has 'h'
+ // dimension, but not 'H'. So, it is dependent on undefined var 'H' and
defined var 'h'.
+ bool depends_on_initial_indices = std::any_of(used_vars.begin(),
used_vars.end(),
+
[&initial_indices_var_set](const tir::Var& v) {
+ return
initial_indices_var_set.count(v) != 0;
+ });
+ if (depends_on_initial_indices) {
+ LOG(WARNING)
+ << "[LayoutInference] Buffer access is dependent on both defined and
undefined vars";
+ return {};
+ }
+ // It is ok to erase this final index expression as it only depends on
undefined vars.
+ final_indices_it = final_indices.erase(final_indices_it);
+ }
+
+ // Go over the target spatial layout and add any missing dims to both
initial and final indices.
+ // This is done by checking if any iterator in target spatial layout is not
present in source
+ // spatial layout.
+ VarSet src_var_set;
+ for (const auto& i : src_spatial_layout) {
+ ICHECK(i.defined());
+ src_var_set.insert(i.value());
+ }
+
+ initial_indices_it = initial_indices.begin();
+ final_indices_it = final_indices.begin();
+ for (const auto& i : tgt_spatial_layout) {
+ if (i.defined() && src_var_set.count(i.value())) {
+ initial_indices_it++;
+ if (final_indices_it != final_indices.end()) final_indices_it++;
+ continue;
+ }
+
+ auto new_dim = tir::Var("d");
+ initial_indices.insert(initial_indices_it, new_dim);
+ final_indices.insert(final_indices_it, new_dim);
+ }
+
+ return IndexMap(support::AsArray(initial_indices),
support::AsArray(final_indices));
+}
+
+/*!
+ * \brief Analyzes the Block and given output buffer transformations to propose
+ * transformations of block and read buffers.
+ * \details It does a best effort analysis to propose transformations which
would preserve
+ * sequential access to buffers (especially output buffers). Since this is
best effort, it is
+ * possible that the Block is too complex for analysis. In such a case, no
transformations are
+ * proposed. Limitations:
+ * 1. Expects exactly one write buffer in the block whose transformation is
given by
+ * `write_transformation`.
+ * 2. Expects write buffer access to be affine and only use spatial iterators
of the block.
+ * 3. Proposes transformations to a read buffer if all access to it are affine.
+ */
+class BlockAnalyzer : public StmtExprVisitor {
+ public:
+ explicit BlockAnalyzer(const Block& block, const Map<Buffer, IndexMap>&
transformation_cache,
+ IndexMap write_transformation)
+ : can_transform_block_(true),
+ write_transformation_(write_transformation),
+ block_(block),
+ buffer_transformation_cache_(transformation_cache) {
+ ICHECK(block_->writes.size() == 1);
+ auto write_buffer = block_->writes[0]->buffer;
+
+ ComputeBlockSpatialDomain();
+
+ // Visit the block body to collect load/store access patterns of different
buffers.
+ VisitStmt(block_->body);
+
+ // While visiting the load/store accesses it is possible we see an
unexpected pattern, such as
+ // nested block or write access to multiple buffers. In such a case, we
can return early as we
+ // would not be making any layout suggesstions.
+ if (!can_transform_block_) {
+ LOG(WARNING) << "[LayoutInference] Unable to transform block " <<
block->name_hint;
+ return;
+ }
+
+ // Get iterator ordering and it's spatial layout.
+ VarToBlockIndexMap iter_var_to_block_index;
+ SpatialLayout block_spatial_layout;
+ int index = 0;
+ for (const auto& iter_var : block->iter_vars) {
+ auto var = iter_var->var;
+ iter_var_to_block_index[var] = index++;
+ block_spatial_layout.push_back(var);
+ }
+
+ // Helper to get the spatial layout of buffer from buffer access map.
+ auto get_spatial_layout = [&](Buffer b) -> SpatialLayout {
+ auto it = buffer_access_info_.find(b);
+ if (it == buffer_access_info_.end()) {
+ return {};
+ }
+ auto access_info = it->second;
+ return access_info.GetValidSpatialLayout();
+ };
+
+ // Check that write has sequential access within the block.
+ SpatialLayout write_spatial_layout = get_spatial_layout(write_buffer);
+ if (write_spatial_layout.empty()) {
+ can_transform_block_ = false;
+ return;
+ }
+ if (!IsSequentialAccess(write_spatial_layout, iter_var_to_block_index)) {
+ can_transform_block_ = false;
+ return;
+ }
+
+ // Infer Block transformation from write buffer transformation.
+ auto maybe_block_transformation = InferLayoutTransformation(
+ write_spatial_layout, write_transformation_, block_spatial_layout);
+ if (!maybe_block_transformation.defined()) {
+ can_transform_block_ = false;
+ return;
+ }
+ block_transformation_ = maybe_block_transformation.value();
+
+ Array<Range> block_ranges = block_->iter_vars.Map([](const IterVar& i) {
return i->dom; });
+ if (!IsBijectiveAffine(block_transformation_, block_ranges)) {
+ can_transform_block_ = false;
+ LOG(WARNING) << "[LayoutInference] Inferred block transformation is not
bijective affine, "
+ "transformation: ("
+ << block_transformation_ << ") over range (" <<
block_ranges << ")";
+ return;
+ }
+
+ // Infer read buffer transformations from write buffer transformation.
+ for (const auto& r : block->reads) {
+ SpatialLayout read_spatial_layout = get_spatial_layout(r->buffer);
+ if (read_spatial_layout.empty()) continue;
+ if (!IsSequentialAccess(read_spatial_layout, iter_var_to_block_index))
continue;
+
+ auto maybe_read_transformation = InferLayoutTransformation(
+ write_spatial_layout, write_transformation_, read_spatial_layout);
+ if (!maybe_read_transformation.defined()) continue;
+ IndexMap read_transformation = maybe_read_transformation.value();
+ if (buffer_transformation_cache_.count(r->buffer) != 0) {
+ if (!AreIdenticalTransforms(read_transformation,
buffer_transformation_cache_[r->buffer]))
+ LOG(WARNING) << "[LayoutInference] Buffer: " << r->buffer
+ << " has conflicting transform proposals -- (preferred)
"
+ << buffer_transformation_cache_[r->buffer] << " vs. "
<< read_transformation;
+ continue;
+ }
+ read_buffer_transformations_.Set(r->buffer, read_transformation);
+ }
+ }
+
+ private:
+ // Helper class to keep track of spatial layout of buffer as we visit
multiple accesses to this
+ // buffer within the block.
+ class BufferAccessInfo {
+ public:
+ BufferAccessInfo() : is_valid_(true) {}
+ void Update(SpatialLayout s) {
+ if (!IsValid()) return;
+ if (spatial_layout_.empty()) spatial_layout_ = s;
+ if (!AreIdenticalSpatialAccess(s, spatial_layout_)) {
+ Invalidate();
+ return;
+ }
+ }
+ bool IsValid() { return is_valid_; }
+ void Invalidate() { is_valid_ = false; }
+ SpatialLayout GetValidSpatialLayout() {
+ if (!IsValid()) return {};
+ return spatial_layout_;
+ }
+
+ private:
+ bool is_valid_;
+ SpatialLayout spatial_layout_;
+ };
+
+ // Helper to break down the indices of buffer access.
+ SpatialLayout DetectBufferAccessIterMap(Array<PrimExpr> indices) {
+ auto result = arith::DetectIterMap(
+ /*indices=*/indices, /*input_iters*/ spatial_dom_,
+ /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck,
&arith_analyzer_);
+ if (result->indices.empty()) {
+ LOG(WARNING) << "[LayoutInference] Failed to analyze indices " << indices
+ << ", error: " << result->errors;
+ return {};
+ }
+ return GetSpatialLayout(result);
+ }
+
+ // Compute the spatial domain map of block
+ void ComputeBlockSpatialDomain() {
+ for (const IterVar& v : block_->iter_vars) {
+ if (v->iter_type == kDataPar) {
+ spatial_dom_.Set(v->var, v->dom);
+ continue;
+ }
+ if (v->iter_type == kCommReduce) continue;
+ LOG(WARNING) << "[LayoutInference] Cannot compute block spatial domain
in presence of "
+ "unknown block iter_type : "
+ << v->iter_type;
+ can_transform_block_ = false;
+ return;
+ }
+ }
+
+ void VisitStmt_(const BlockNode* op) final {
+ // Blocks with nested blocks cannot be handled yet.
+ LOG(WARNING) << "[LayoutInference] Nested blocks are not supported for
layout inference yet";
+ can_transform_block_ = false;
+ }
+ void VisitStmt_(const BufferStoreNode* op) final {
+ StmtExprVisitor::VisitStmt_(op);
+
+ BufferAccessInfo& access_info = buffer_access_info_[op->buffer];
+
+ // Fast path to ignore further analysis if we know that the buffer access
is invalid.
+ if (!access_info.IsValid()) return;
+
+ // Only single write buffer is supported for each block.
+ if (!op->buffer.same_as(block_->writes[0]->buffer)) {
+ access_info.Invalidate();
+ LOG(WARNING) << "[LayoutInference] Exactly one write buffer is supported
for layout "
+ "inference, found two: "
+ << op->buffer << " and " << block_->writes[0]->buffer;
+ can_transform_block_ = false;
+ return;
+ }
+
+ // If the write buffer access cannot be analyzed, no transformation to the
block will be made.
+ auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices);
+ if (detected_spatial_layout.empty()) {
+ access_info.Invalidate();
+ return;
+ }
+
+ // Check if we have access info for this buffer, if present, the two
accesses must be
+ // identical.
+ access_info.Update(detected_spatial_layout);
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ Buffer read_buffer = op->buffer;
+ BufferAccessInfo& access_info = buffer_access_info_[op->buffer];
+
+ auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices);
+
+ if (detected_spatial_layout.empty()) {
+ access_info.Invalidate();
+ return;
+ }
+ access_info.Update(detected_spatial_layout);
+ }
+
+ public:
+ bool CanBeTransformed() { return can_transform_block_; }
+ IndexMap GetBlockTransformation() { return block_transformation_; }
+ Map<Buffer, IndexMap> GetReadBufferTransformations() { return
read_buffer_transformations_; }
+
+ private:
+ bool can_transform_block_;
+ IndexMap write_transformation_;
+ Map<tir::Var, Range> spatial_dom_;
+ arith::Analyzer arith_analyzer_;
+
+ Block block_;
+ IndexMap block_transformation_;
+
+ Map<Buffer, IndexMap> read_buffer_transformations_;
+ const Map<Buffer, IndexMap>& buffer_transformation_cache_;
+ std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
buffer_access_info_;
+};
+
+/*!
+ * \brief Analyzes the PrimFunc and user provided output buffer
transformations to propose
+ * transformations of block and buffers within the PrimFunc.
+ * \details It does a best effort analysis to propose transformations which
would preserve
+ * sequential access to buffers (especially output buffers). Since this is
best effort, it is
+ * possible that the PrimFunc is too complex for analysis. In such a case, no
transformations are
+ * proposed.
+ */
+class PrimFuncAnalyzer : public StmtExprVisitor {
+ public:
+ explicit PrimFuncAnalyzer(const PrimFunc& func, Array<IndexMap>
write_transformations) {
+ ICHECK_LE(write_transformations.size(), func->params.size())
+ << "Incompatible PrimFunc and write_transformations";
+
+ size_t first_write_index = func->params.size() -
write_transformations.size();
+ for (size_t i = 0; i < write_transformations.size(); ++i) {
+ auto param = func->params[first_write_index + i];
+ Optional<Buffer> param_buf = func->buffer_map.Get(param);
+ ICHECK(param_buf.defined());
+ ICHECK_EQ(param_buf.value()->shape.size(),
write_transformations[i]->initial_indices.size())
+ << "Mismatch between output buffer shape and index map";
+ buffer_transformation_cache_.Set(param_buf.value(),
write_transformations[i]);
+ }
+ VisitStmt(func->body);
+ }
+ Map<Block, Map<ObjectRef, IndexMap>> GetSuggestedTransforms() {
+ Map<Block, Map<ObjectRef, IndexMap>> result;
+ for (const auto& [block, index_map] : block_transformations_) {
+ Map<ObjectRef, IndexMap> block_transformations;
+ block_transformations.Set(block, index_map);
+ for (const auto& buffer : block_to_buffer_[block]) {
+ block_transformations.Set(buffer,
buffer_transformation_cache_[buffer]);
+ }
+ result.Set(block, block_transformations);
+ }
+ return result;
+ }
+
+ private:
+ void VisitStmt_(const BlockNode* op) final {
+ if (op->name_hint == "root") {
+ // Skip the root block
+ StmtVisitor::VisitStmt_(op);
+ return;
+ }
+
+ Block block = GetRef<Block>(op);
+ // Get block write buffer transformation.
+ if (block->writes.size() != 1) return;
+ auto write_buffer = block->writes[0]->buffer;
+ block_to_buffer_[block].push_back(write_buffer);
+ BlockAnalyzer block_analyzer(block, buffer_transformation_cache_,
+ buffer_transformation_cache_[write_buffer]);
+
+ if (!block_analyzer.CanBeTransformed()) return;
+ // Collect the suggested transformations
+ block_transformations_.Set(block, block_analyzer.GetBlockTransformation());
+
+ for (const auto& [buffer, index_map] :
block_analyzer.GetReadBufferTransformations()) {
+ // BlockAnalyzer makes sure that it does not propose transformation for
a buffer for which a
+ // transformation has already been proposed by other blocks or by
write_transformations which
+ // are input to this analysis.
+ ICHECK_EQ(buffer_transformation_cache_.count(buffer), 0);
+ buffer_transformation_cache_.Set(buffer, index_map);
+ block_to_buffer_[block].push_back(buffer);
+ }
+ }
+
+ private:
+ Map<Buffer, IndexMap> buffer_transformation_cache_;
+ Map<Block, IndexMap> block_transformations_;
+ std::unordered_map<Block, Array<Buffer>, ObjectPtrHash, ObjectPtrEqual>
block_to_buffer_;
+};
+
+Map<tir::Block, Map<ObjectRef, tir::IndexMap>> SuggestLayoutTransforms(
+ const PrimFunc& prim_func, Array<IndexMap> write_buffer_transformations) {
+ // No changes to the PrimFunc are required if no transformations on output
buffers.
+ if (write_buffer_transformations.empty()) return {};
+
+ PrimFuncAnalyzer analyzer(prim_func, write_buffer_transformations);
+ return analyzer.GetSuggestedTransforms();
+}
+
+TVM_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms"))
+ .set_body_typed([](PrimFunc fn, Array<tir::IndexMap>
write_buffer_transformations) {
+ return SuggestLayoutTransforms(fn, write_buffer_transformations);
+ });
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/support/array.h b/src/support/array.h
index 218150f9db..0ca57a2410 100644
--- a/src/support/array.h
+++ b/src/support/array.h
@@ -21,6 +21,7 @@
#include <tvm/ir/expr.h>
#include <tvm/runtime/container/array.h>
+#include <list>
#include <vector>
namespace tvm {
@@ -81,11 +82,35 @@ inline std::vector<TDst> AsVector(const Array<TSrc>& vec);
* \brief Convert a std::vector to tvm::runtime::Array
* \tparam TSrc The type of elements in the source vector
* \tparam TDst The type of elements in the result Array
- * \return The result vector
+ * \return The result Array
*/
template <class TSrc, class TDst>
inline Array<TDst> AsArray(const std::vector<TSrc>& vec);
+/*!
+ * \brief Convert a tvm::runtime::Array to std::list
+ * \tparam T The type of elements in the source array
+ * \return The result list
+ */
+template <class T>
+inline std::list<T> AsList(const Array<T>& array) {
+ std::list<T> list;
+ for (const auto& v : array) list.push_back(v);
+ return list;
+}
+
+/*!
+ * \brief Convert a std::list to tvm::runtime::Array
+ * \tparam T The type of elements in the source list
+ * \return The result list
+ */
+template <class T>
+inline Array<T> AsArray(const std::list<T>& list) {
+ Array<T> array;
+ for (const auto& v : list) array.push_back(v);
+ return array;
+}
+
/*!
* \brief Get the shape tuple as array
* \param shape The shape tuple
diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py
b/tests/python/relax/test_analysis_suggest_layout_transforms.py
new file mode 100644
index 0000000000..2850f0ed9f
--- /dev/null
+++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py
@@ -0,0 +1,831 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm.testing
+
+from tvm import relax, tir
+from tvm.script import tir as T
+
+
+def apply_transformations(func, suggested_transfoms,
print_transformation=False):
+ sch = tir.Schedule(func)
+ for block, per_block_transformations in suggested_transfoms.items():
+ blockrv = sch.get_block(block.name_hint)
+ for obj, index_map in per_block_transformations.items():
+ if isinstance(obj, tir.Block):
+ block_name = obj.name_hint
+ if print_transformation:
+ print("Block transformation: ", block_name, " :: ",
index_map)
+ sch.transform_block_layout(block_name, index_map)
+ else:
+ assert isinstance(obj, tir.Buffer)
+ buffer = obj
+ if print_transformation:
+ print("Buffer transformation: ", buffer, " :: ", index_map)
+ sch.transform_layout(blockrv, buffer, index_map)
+ return sch.mod["main"]
+
+
+def test_nested_blocks():
+ @T.prim_func
+ def nested_block(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ relu: T.Buffer((32, 64, 224, 224), "float32"),
+ ):
+ for i, j in T.grid(32, 64):
+ with T.block("outer"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(arg[v_i, v_j, 0:224, 0:224])
+ T.writes(relu[v_i, v_j, 0:224, 0:224])
+ for k, l in T.grid(224, 224):
+ with T.block("inner"):
+ v_k, v_l = T.axis.remap("SS", [k, l])
+ T.reads(arg[v_i, v_j, v_k, v_l])
+ T.writes(relu[v_i, v_j, v_k, v_l])
+ relu[v_i, v_j, v_k, v_l] = T.max(arg[v_i, v_j, v_k,
v_l], T.float32(0))
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=nested_block, write_buffer_transforms=[lambda n, c, h, w: (n, h,
w, c)]
+ )
+ # no suggestions for nested block.
+ assert len(suggested_transforms.items()) == 0
+
+
+def test_mismatch_transformations_and_num_params():
+ @T.prim_func
+ def elemwise(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ relu: T.Buffer((32, 64, 224, 224), "float32"),
+ ):
+ for i0, i1, i2, i3 in T.grid(32, 64, 224, 224):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(arg[v_i0, v_i1, v_i2, v_i3])
+ T.writes(relu[v_i0, v_i1, v_i2, v_i3])
+ relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
+
+ with pytest.raises(tvm.TVMError, match="Incompatible PrimFunc and
write_transformations"):
+ _ = relax.analysis.suggest_layout_transforms(
+ func=elemwise,
+ write_buffer_transforms=[
+ lambda n, c, h, w: (n, h, w, c),
+ lambda n, c, h, w: (n, h, w, c),
+ lambda n, c, h, w: (n, h, w, c),
+ ],
+ )
+
+
+def test_empty_write_transformations():
+ @T.prim_func
+ def elemwise(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ relu: T.Buffer((32, 64, 224, 224), "float32"),
+ ):
+ for i0, i1, i2, i3 in T.grid(32, 64, 224, 224):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(arg[v_i0, v_i1, v_i2, v_i3])
+ T.writes(relu[v_i0, v_i1, v_i2, v_i3])
+ relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=elemwise, write_buffer_transforms=[]
+ )
+ assert len(suggested_transforms.items()) == 0
+
+
+def test_non_bijective_block_transform():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64), "float32"),
+ output: T.Buffer((32, 64), "float32"),
+ ):
+ for ax0, ax1 in T.grid(32, 64):
+ with T.block("compute"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg[v_ax0, v_ax1]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c: (n, c // 5, c % 5)]
+ )
+ assert len(suggested_transforms.items()) == 0
+
+
+def test_non_affine_access():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64), "float32"),
+ output: T.Buffer((32 * 64, 10), "float32"),
+ ):
+ for ax0, ax1, ax2 in T.grid(32, 64, 10):
+ with T.block("compute"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(arg[v_ax0, v_ax1])
+ T.writes(output[v_ax0 * v_ax1, v_ax2])
+ output[v_ax0 * v_ax1, v_ax2] = arg[v_ax0, v_ax1]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda a, b: (b, a)]
+ )
+ assert len(suggested_transforms.items()) == 0
+
+
+def test_unsupported_write_spatial_layout():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((4, 4), "float32"),
+ output: T.Buffer((16), "float32"),
+ ):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("flatten"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg[v_ax0, v_ax1])
+ T.writes(output[v_ax0 * 4 + v_ax1])
+ output[v_ax0 * 4 + v_ax1] = arg[v_ax0, v_ax1]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda a: (a // 4, a % 4)]
+ )
+ assert len(suggested_transforms.items()) == 0
+
+
+def test_unpacked_iter_used_in_read_access():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((8, 4), "float32"),
+ output: T.Buffer((4, 8), "float32"),
+ ):
+ for ax0, ax1, ax2 in T.grid(4, 8, 4):
+ with T.block("compute"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(arg[v_ax1, v_ax2])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg[v_ax1, v_ax2]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((8, 4), "float32"),
+ output: T.Buffer((32), "float32"),
+ ):
+ for ax0, ax2 in T.grid(32, 4):
+ with T.block("compute"):
+ v_ax0, v_ax2 = T.axis.remap("SS", [ax0, ax2])
+ T.reads(arg[v_ax0 % 8, v_ax2])
+ T.writes(output[v_ax0])
+ output[v_ax0] = arg[v_ax0 % 8, v_ax2]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda a, b: (a * 8 + b)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_invalid_index_map():
+ @T.prim_func
+ def elemwise(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ relu: T.Buffer((32, 64, 224, 224), "float32"),
+ ):
+ for i0, i1, i2, i3 in T.grid(32, 64, 224, 224):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(arg[v_i0, v_i1, v_i2, v_i3])
+ T.writes(relu[v_i0, v_i1, v_i2, v_i3])
+ relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
+
+ with pytest.raises(tvm.TVMError, match="Mismatch between output buffer
shape and index map"):
+ _ = relax.analysis.suggest_layout_transforms(
+ func=elemwise, write_buffer_transforms=[lambda n, h, w: (n, w, h)]
+ )
+ with pytest.raises(AssertionError):
+ _ = relax.analysis.suggest_layout_transforms(func=elemwise,
write_buffer_transforms=[2])
+
+
+def test_SRSR_block():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 224, 64, 224), "float32"),
+ sum: T.Buffer((32, 64), "float32"),
+ ):
+ for ax0, k2, ax1, k3 in T.grid(32, 224, 64, 224):
+ with T.block("rxplaceholder_red"):
+ v_ax0, v_k2, v_ax1, v_k3 = T.axis.remap("SRSR", [ax0, k2, ax1,
k3])
+ T.reads(arg[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(sum[v_ax0, v_ax1])
+ with T.init():
+ sum[v_ax0, v_ax1] = T.float32(0)
+ sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_k2,
v_ax1, v_k3]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 16, 224, 4), "float32"),
+ sum: T.Buffer((32, 16, 4), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 16, 224, 4):
+ with T.block("rxplaceholder_red"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SRSRS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg[v0, v1, v2, v3, v4])
+ T.writes(sum[v0, v2, v4])
+ with T.init():
+ sum[v0, v2, v4] = T.float32(0)
+ sum[v0, v2, v4] = sum[v0, v2, v4] + arg[v0, v1, v2, v3, v4]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c: (n, c // 4, c % 4)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_elemwise_symbolic():
+ @T.prim_func
+ def before(arg: T.handle, relu: T.handle):
+ N = T.int64()
+ C = T.int64()
+ H = T.int64()
+ W = T.int64()
+ Arg = T.match_buffer(arg, (N, C, H, W))
+ Relu = T.match_buffer(relu, (N, C, H, W))
+ for i0, i1, i2, i3 in T.grid(N, C, H, W):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(Arg[v_i0, v_i1, v_i2, v_i3])
+ T.writes(Relu[v_i0, v_i1, v_i2, v_i3])
+ Relu[v_i0, v_i1, v_i2, v_i3] = T.max(Arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
+
+ @T.prim_func
+ def expected(arg: T.handle, relu: T.handle):
+ N = T.int64()
+ C = T.int64()
+ H = T.int64()
+ W = T.int64()
+ Arg = T.match_buffer(arg, (N, H, W, C))
+ Relu = T.match_buffer(relu, (N, H, W, C))
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C):
+ with T.block("compute"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(Arg[v0, v1, v2, v3])
+ T.writes(Relu[v0, v1, v2, v3])
+ Relu[v0, v1, v2, v3] = T.max(Arg[v0, v1, v2, v3], T.float32(0))
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_elemwise():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ relu: T.Buffer((32, 64, 224, 224), "float32"),
+ ):
+ for i0, i1, i2, i3 in T.grid(32, 64, 224, 224):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(arg[v_i0, v_i1, v_i2, v_i3])
+ T.writes(relu[v_i0, v_i1, v_i2, v_i3])
+ relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2,
v_i3], T.float32(0))
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 64), "float32"),
+ relu: T.Buffer((32, 224, 224, 64), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64):
+ with T.block("compute"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(arg[v0, v1, v2, v3])
+ T.writes(relu[v0, v1, v2, v3])
+ relu[v0, v1, v2, v3] = T.max(arg[v0, v1, v2, v3], T.float32(0))
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_pool_nchw_nhwc():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ pool_max: T.Buffer((32, 64, 111, 223), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(32, 64, 111, 223, 2, 2):
+ with T.block("pool_max"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap(
+ "SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1]
+ )
+ T.reads(
+ arg[
+ v_ax0,
+ v_ax1,
+ v_ax2 * 2 + v_rv0 * 2,
+ v_ax3 + v_rv1,
+ ]
+ )
+ T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_max"})
+ with T.init():
+ pool_max[v_ax0, v_ax1, v_ax2, v_ax3] =
T.float32(-3.4028234663852886e38)
+ pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(
+ pool_max[v_ax0, v_ax1, v_ax2, v_ax3],
+ arg[
+ v_ax0,
+ v_ax1,
+ v_ax2 * 2 + v_rv0 * 2,
+ v_ax3 + v_rv1,
+ ],
+ )
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 64), "float32"),
+ pool_max: T.Buffer((32, 111, 223, 64), "float32"),
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 111, 223, 64, 2, 2):
+ with T.block("pool_max"):
+ v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1,
ax2, ax3, ax4, ax5])
+ T.reads(arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3])
+ T.writes(pool_max[v0, v1, v2, v3])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_max"})
+ with T.init():
+ pool_max[v0, v1, v2, v3] =
T.float32(-3.4028234663852886e38)
+ pool_max[v0, v1, v2, v3] = T.max(
+ pool_max[v0, v1, v2, v3],
+ arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3],
+ )
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before,
+ write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)],
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_pool_nchw16c_nhwc():
+ @T.prim_func
+ def before(
+ arg: T.Buffer(
+ (32, 4, 224, 224, 16),
+ "float32",
+ ),
+ pool_max: T.Buffer(
+ (32, 4, 110, 220, 16),
+ "float32",
+ ),
+ ):
+ for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(32, 4, 110, 220, 16,
5, 5):
+ with T.block("pool_max"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap(
+ "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]
+ )
+ T.reads(arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1,
v_ax4])
+ T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_max"})
+ with T.init():
+ pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
T.float32(-3.4028234663852886e38)
+ pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
+ pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
+ arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4],
+ )
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 64), "float32"),
+ pool_max: T.Buffer((32, 110, 220, 64), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 110, 220, 64, 5, 5):
+ with T.block("pool_max"):
+ v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1,
ax2, ax3, ax4, ax5])
+ T.reads(arg[v0, v1 * 2 + v4, v2 + v5, v3])
+ T.writes(pool_max[v0, v1, v2, v3])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_max"})
+ with T.init():
+ pool_max[v0, v1, v2, v3] =
T.float32(-3.4028234663852886e38)
+ pool_max[v0, v1, v2, v3] = T.max(
+ pool_max[v0, v1, v2, v3],
+ arg[v0, v1 * 2 + v4, v2 + v5, v3],
+ )
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before,
+ write_buffer_transforms=[lambda n, C, h, w, c: (n, h, w, C * 16 + c)],
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_reduce():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ sum: T.Buffer((32, 64), "float32"),
+ ):
+ for ax0, ax1, k2, k3 in T.grid(32, 64, 224, 224):
+ with T.block("rxplaceholder_red"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2,
k3])
+ T.reads(arg[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(sum[v_ax0, v_ax1])
+ with T.init():
+ sum[v_ax0, v_ax1] = T.float32(0)
+ sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_ax1,
v_k2, v_k3]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 4, 224, 224, 16), "float32"),
+ sum: T.Buffer((32, 4, 16), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 4, 224, 224, 16):
+ with T.block("rxplaceholder_red"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SSRRS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg[v0, v1, v2, v3, v4])
+ T.writes(sum[v0, v1, v4])
+ with T.init():
+ sum[v0, v1, v4] = T.float32(0)
+ sum[v0, v1, v4] = sum[v0, v1, v4] + arg[v0, v1, v2, v3, v4]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c: (n, c // 16, c %
16)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_upsampling():
+ # relay materializes the layout if H, W or D dimensions are moved or tiled.
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ resize: T.Buffer((32, 64, 202, 246), "float32"),
+ ):
+ for i0, i1, i2, i3 in T.grid(32, 64, 202, 246):
+ with T.block("resize"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(arg[v_i0, v_i1, 0:224, 0:224])
+ T.writes(resize[v_i0, v_i1, v_i2, v_i3])
+ resize[v_i0, v_i1, v_i2, v_i3] = arg[
+ v_i0,
+ v_i1,
+ T.max(
+ T.min(
+ T.Cast(
+ "int64",
+ T.floor(
+ T.float32(1.1089109182357788) *
T.Cast("float32", v_i2)
+ + T.float32(1.0000000000000001e-05)
+ ),
+ ),
+ 223,
+ ),
+ 0,
+ ),
+ T.max(
+ T.min(
+ T.Cast(
+ "int64",
+ T.floor(
+ T.float32(0.91056913137435913) *
T.Cast("float32", v_i3)
+ + T.float32(1.0000000000000001e-05)
+ ),
+ ),
+ 223,
+ ),
+ 0,
+ ),
+ ]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ resize: T.Buffer((32, 202, 246, 64), "float32"),
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 202, 246, 64):
+ with T.block("resize"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(arg[v0, v3, 0:224, 0:224])
+ T.writes(resize[v0, v1, v2, v3])
+ resize[v0, v1, v2, v3] = arg[
+ v0,
+ v3,
+ T.max(
+ T.min(
+ T.Cast(
+ "int64",
+ T.floor(
+ T.float32(1.1089109182357788) *
T.Cast("float32", v1)
+ + T.float32(1.0000000000000001e-05)
+ ),
+ ),
+ T.int64(223),
+ ),
+ T.int64(0),
+ ),
+ T.max(
+ T.min(
+ T.Cast(
+ "int64",
+ T.floor(
+ T.float32(0.91056913137435913) *
T.Cast("float32", v2)
+ + T.float32(1.0000000000000001e-05)
+ ),
+ ),
+ T.int64(223),
+ ),
+ T.int64(0),
+ ),
+ ]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_strided_slice():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 64, 10, 8):
+ with T.block("T_strided_slice_with_axes"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(
+ arg[
+ v_ax0,
+ v_ax1,
+ v_ax2 * 5 + 2,
+ v_ax3 * 7 + 4,
+ ]
+ )
+ T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = arg[
+ v_ax0,
+ v_ax1,
+ v_ax2 * 5 + 2,
+ v_ax3 * 7 + 4,
+ ]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
+ T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"),
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 10, 8, 16, 4):
+ with T.block("T_strided_slice_with_axes"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg[v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4])
+ T.writes(T_strided_slice_with_axes[v0, v1, v2, v3, v4])
+ T_strided_slice_with_axes[v0, v1, v2, v3, v4] = arg[
+ v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4
+ ]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c
// 4, c % 4)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_binary_broadcast():
+ @T.prim_func
+ def before(
+ arg0: T.Buffer((32, 64, 224, 224), "float32"),
+ arg1: T.Buffer((64, 224, 224), "float32"),
+ T_add: T.Buffer((32, 64, 224, 224), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 64, 224, 224):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(
+ arg0[v_ax0, v_ax1, v_ax2, v_ax3],
+ arg1[v_ax1, v_ax2, v_ax3],
+ )
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax1, v_ax2,
v_ax3]
+ )
+
+ @T.prim_func
+ def expected(
+ arg0: T.Buffer((32, 224, 224, 16, 4), "float32"),
+ arg1: T.Buffer((224, 224, 16, 4), "float32"),
+ T_add: T.Buffer((32, 224, 224, 16, 4), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4):
+ with T.block("T_add"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg0[v0, v1, v2, v3, v4], arg1[v1, v2, v3, v4])
+ T.writes(T_add[v0, v1, v2, v3, v4])
+ T_add[v0, v1, v2, v3, v4] = arg0[v0, v1, v2, v3, v4] +
arg1[v1, v2, v3, v4]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c
// 4, c % 4)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_transpose():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ T_transpose: T.Buffer((32, 224, 224, 64), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(arg[v_ax0, v_ax3, v_ax1, v_ax2])
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax3,
v_ax1, v_ax2]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ T_transpose: T.Buffer((32, 224, 64, 224), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 224, 64, 224):
+ with T.block("T_transpose"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(arg[v0, v2, v3, v1])
+ T.writes(T_transpose[v0, v1, v2, v3])
+ T_transpose[v0, v1, v2, v3] = arg[v0, v2, v3, v1]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_pad():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ PadInput: T.Buffer((32, 64, 230, 230), "float32"),
+ ):
+ for i0, i1, i2, i3 in T.grid(32, 64, 230, 230):
+ with T.block("PadInput"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2])
+ T.writes(PadInput[v_i0, v_i1, v_i2, v_i3])
+ PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(
+ 2 <= v_i2 and v_i2 < 226 and 2 <= v_i3 and v_i3 < 226,
+ arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2],
+ T.float32(2),
+ )
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
+ PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 230, 230, 16, 4):
+ with T.block("PadInput"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg[v0, v1 - 2, v2 - 2, v3, v4])
+ T.writes(PadInput[v0, v1, v2, v3, v4])
+ PadInput[v0, v1, v2, v3, v4] = T.if_then_else(
+ 2 <= v1 and v1 < 226 and 2 <= v2 and v2 < 226,
+ arg[v0, v1 - 2, v2 - 2, v3, v4],
+ T.float32(2),
+ )
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c
// 4, c % 4)]
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_split():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ split0: T.Buffer((32, 32, 224, 224), "float32"),
+ split1: T.Buffer((32, 32, 224, 224), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224):
+ with T.block("T_split_sections"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3])
+ split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2,
v_ax3]
+ for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224):
+ with T.block("T_split_sections_1"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3])
+ T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3])
+ split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32,
v_ax2, v_ax3]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 64), "float32"),
+ split0: T.Buffer((32, 224, 224, 32), "float32"),
+ split1: T.Buffer((32, 224, 224, 32), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32):
+ with T.block("T_split_sections"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(arg[v0, v1, v2, v3])
+ T.writes(split0[v0, v1, v2, v3])
+ split0[v0, v1, v2, v3] = arg[v0, v1, v2, v3]
+ for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32):
+ with T.block("T_split_sections_1"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(arg[v0, v1, v2, v3 + 32])
+ T.writes(split1[v0, v1, v2, v3])
+ split1[v0, v1, v2, v3] = arg[v0, v1, v2, v3 + 32]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before,
+ write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c), lambda n, c,
h, w: (n, h, w, c)],
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_op_split_tiling_split_dim():
+ @T.prim_func
+ def before(
+ arg: T.Buffer((32, 64, 224, 224), "float32"),
+ split0: T.Buffer((32, 32, 224, 224), "float32"),
+ split1: T.Buffer((32, 32, 224, 224), "float32"),
+ ):
+ for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224):
+ with T.block("T_split_sections"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3])
+ split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2,
v_ax3]
+ for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224):
+ with T.block("T_split_sections_1"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1,
ax2, ax3])
+ T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3])
+ T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3])
+ split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32,
v_ax2, v_ax3]
+
+ @T.prim_func
+ def expected(
+ arg: T.Buffer((32, 224, 224, 16, 4), "float32"),
+ split0: T.Buffer((32, 224, 224, 8, 4), "float32"),
+ split1: T.Buffer((32, 224, 224, 8, 4), "float32"),
+ ):
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4):
+ with T.block("T_split_sections"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg[v0, v1, v2, v3, v4])
+ T.writes(split0[v0, v1, v2, v3, v4])
+ split0[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3, v4]
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4):
+ with T.block("T_split_sections_1"):
+ v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2,
ax3, ax4])
+ T.reads(arg[v0, v1, v2, v3 + 8, v4])
+ T.writes(split1[v0, v1, v2, v3, v4])
+ split1[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3 + 8, v4]
+
+ suggested_transforms = relax.analysis.suggest_layout_transforms(
+ func=before,
+ write_buffer_transforms=[
+ lambda n, c, h, w: (n, h, w, c // 4, c % 4),
+ lambda n, c, h, w: (n, h, w, c // 4, c % 4),
+ ],
+ )
+ after = apply_transformations(before, suggested_transforms)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()