This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5e7438f [TIR][Schedule] Blockize and Tensorize (#9871)
5e7438f is described below
commit 5e7438feaabf5b843fdb0fac9d93934ffb313c5a
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jan 26 03:00:07 2022 -0500
[TIR][Schedule] Blockize and Tensorize (#9871)
* WIP
* WIP
* WIP
* test cases
* add examples
* lint
* Amend co-authors information
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
* WIP
* address comments and changed tensorized comparator
* update
* nit
* fix example
* lint
* lint
* lint
* remove unused
* trigger ci
* clang-format
* fix
* rebase
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Xiyou Zhou <[email protected]>
---
include/tvm/arith/iter_affine_map.h | 7 +
include/tvm/tir/function.h | 52 ++
include/tvm/tir/schedule/schedule.h | 19 +
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/function.py | 48 ++
python/tvm/tir/schedule/schedule.py | 229 +++++++
src/arith/int_set.cc | 2 +-
src/tir/ir/function.cc | 53 ++
src/tir/schedule/concrete_schedule.cc | 23 +
src/tir/schedule/concrete_schedule.h | 3 +
src/tir/schedule/ir_comparator.cc | 363 +++++++++++
src/tir/schedule/ir_comparator.h | 116 ++++
src/tir/schedule/primitive.h | 18 +
src/tir/schedule/primitive/blockize_tensorize.cc | 698 +++++++++++++++++++++
src/tir/schedule/schedule.cc | 14 +
src/tir/schedule/state.cc | 4 +-
src/tir/schedule/traced_schedule.cc | 31 +
src/tir/schedule/traced_schedule.h | 3 +
.../python/unittest/test_tir_schedule_blockize.py | 210 +++++++
.../python/unittest/test_tir_schedule_tensorize.py | 431 +++++++++++++
20 files changed, 2321 insertions(+), 5 deletions(-)
diff --git a/include/tvm/arith/iter_affine_map.h
b/include/tvm/arith/iter_affine_map.h
index 22b4cd5..eb69c18 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -350,6 +350,13 @@ Array<Array<IterMark>> SubspaceDivide(const
Array<PrimExpr>& bindings,
bool require_bijective, arith::Analyzer*
analyzer,
DiagnosticContext diag_ctx);
+/*!
+ * \brief Given an IterMapExpr, transform it to normal PrimExpr.
+ * \param expr The input IterMapExpr.
+ * \return The corresponding normal PrimExpr.
+ */
+PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);
+
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index e482a18..1ab911b 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -188,6 +188,58 @@ class LinkedParam : public ObjectRef {
};
/*!
+ * \brief Tensor intrinsics for tensorization
+ */
+class TensorIntrinNode : public Object {
+ public:
+ /*! \brief The function to describe the computation. */
+ PrimFunc desc;
+ /*! \brief The function of the implementation for the execution. */
+ PrimFunc impl;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("desc", &desc);
+ v->Visit("impl", &impl);
+ }
+
+ static constexpr const char* _type_key = "tir.TensorIntrin";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
+};
+
+/*!
+ * \brief Managed reference to TensorIntrinNode.
+ */
+class TensorIntrin : public ObjectRef {
+ public:
+ /*!
+ * \brief Constructor
+ * \param desc The function to describe the computation.
+ * \param impl The function of the implementation for the execution.
+ */
+ TVM_DLL explicit TensorIntrin(PrimFunc desc, PrimFunc impl);
+
+ /*!
+ * \brief Create and register a TensorIntrin. After registration, the
TensorIntrin can be looked
+ * up with its name.
+ * \param name The name of the TensorIntrin to register
+ * \param intrin The TensorIntrin to register.
+ * \throws This method throws an exception if the TensorIntrin with the
specified name already
+ * exists.
+ */
+ TVM_DLL static void Register(String name, TensorIntrin intrin);
+
+ /*!
+ * \brief Look up TensorIntrin by name. Raises an exception if not found.
+ * \param name The name of the TensorIntrin.
+ * \return The TensorIntrin with the specified name.
+ * \throws This method throws an exception if the TensorIntrin does not
exist.
+ */
+ TVM_DLL static TensorIntrin Get(String name);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
+};
+
+/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
* \param param_map The mapping from function params to the instance.
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 43f2379..be06b44 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -473,6 +473,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const
String& storage_scope) = 0;
/******** Schedule: Blockize & Tensorize ********/
+ /*!
+ * \brief Convert the subtree rooted at a specific loop into a block.
+ * \param loop_rv the root of the subtree
+ * \return the new block
+ */
+ virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
+ /*!
+ * \brief Tensorize the computation enclosed by loop with the tensor intrin.
+ * \param loop_rv The loop to be tensorized
+ * \param intrin Name of the tensor intrinsic
+ */
+ virtual void Tensorize(const LoopRV& loop_rv, const String& intrin) = 0;
+ /*!
+ * \brief Tensorize the computation enclosed by loop with the tensor intrin.
+ * \param block_rv The block to be tensorized
+ * \param intrin Name of the tensor intrinsic
+ */
+ virtual void Tensorize(const BlockRV& block_rv, const String& intrin) = 0;
+
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a loop with a key value pair
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 07ceb29..5854b93 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -33,7 +33,7 @@ from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
-from .function import PrimFunc
+from .function import PrimFunc, TensorIntrin
from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any,
min_value, max_value, trace
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index ecbcd83..bcebab9 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -162,3 +162,51 @@ class PrimFunc(BaseFunc):
return tvm._ffi.get_global_func("script.AsTVMScript")(
self, tir_prefix, show_meta
) # type: ignore
+
+
+@tvm._ffi.register_object("tir.TensorIntrin")
+class TensorIntrin(Object):
+ """A tensor intrinsic.
+
+ Parameters
+ ----------
+ desc : PrimFunc
+ The function to describe the computation.
+
+ impl : PrimFunc
+ The function of the implementation for the execution.
+ """
+
+ def __init__(self, desc, impl):
+ self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl)
+
+ @staticmethod
+ def register(name: str, desc: PrimFunc, impl: PrimFunc):
+ """Register a tensor intrinsic with its name.
+
+ Parameters
+ ----------
+ name : str
+ The name of the TensorIntrin to register.
+ desc : PrimFunc
+ The function to describe the computation.
+ impl : PrimFunc
+ The function of the implementation for the execution.
+ """
+ return _ffi_api.TensorIntrinRegister(name, TensorIntrin(desc, impl))
# type: ignore
+
+ @staticmethod
+ def get(name: str):
+ """Look up a tensor intrinsic by its name.
+
+ Parameters
+ ----------
+ name : str
+ The name of the TensorIntrin to look up.
+
+ Returns
+ -------
+ result : TensorIntrin
+ The TensorIntrin with the specified name.
+ """
+ return _ffi_api.TensorIntrinGet(name) # pylint: type: ignore
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 7d352f1..96fa21f 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1759,6 +1759,235 @@ class Schedule(Object):
########## Schedule: Blockize & Tensorize ##########
+ @type_checked
+ def blockize(self, loop: LoopRV) -> BlockRV:
+ """Convert the subtree rooted at a specific loop into a block.
+
+ Parameters
+ ----------
+ loop : LoopRV
+ The root of the subtree.
+
+ Returns
+ -------
+ result : BlockRV
+ The new block.
+
+ Examples
+ --------
+
+ Before blockize, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_blockize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+
+ Create the schedule and do set_scope:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_blockize)
+ B = sch.get_block("B")
+ _, _, i1, _ = sch.get_loops(B)
+ sch.blockize(i1)
+ print(sch.mod["main"].script())
+
+ After applying blockize, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_blockize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"]
+ )-> None:
+ for i_0, j_0 in T.grid(8, 8):
+ with T.block("B_o"):
+ vio, vjo = T.axis.remap("SS", [i_0, j_0])
+ T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16])
+ T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i_1, j_1])
+ T.reads(A[vio * 16 + vi, vjo * 16 + vj])
+ T.writes(B[vio * 16 + vi, vjo * 16 + vj])
+ B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 +
vi, vjo * 16 + vj] \
+ *
T.float32(2)
+
+ Note
+ ----
+ blockize requires there is exactly one block under the given loop and
the bindings of the
+ block are divisible by the subspace represented by the loops starting
at the given loop.
+ """
+
+ return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint:
disable=no-member
+
+ @type_checked
+ def tensorize(self, block_or_loop: Union[BlockRV, LoopRV], tensor_intrin:
str) -> None:
+ """Tensorize the computation enclosed by loop with the tensor
intrinsic.
+
+ Parameters
+ ----------
+ block_or_loop : Union[BlockRV, LoopRV]
+ The loop to be tensorized.
+ tensor_intrin : str
+ The tensor intrin or the name of the tensor intrin.
+
+ Examples
+ --------
+
+ Before tensorize, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_tensorize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(8, 8, 8, 16, 16,
16):
+ with T.block("update"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ vk = T.axis.reduce(128, k_0 * 16 + k_1)
+ T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+ T.writes(C[vi, vj])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ Declare and register the tensor intrinsic:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0
: 16])
+ T.writes(C[0 : 16, 0 : 16])
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+ @T.prim_func
+ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0
: 16])
+ T.writes(C[0 : 16, 0 : 16])
+ T.evaluate(
+ T.tvm_mma_sync(
+ C.data,
+ C.elem_offset // 256,
+ A.data,
+ A.elem_offset // 256,
+ B.data,
+ B.elem_offset // 256,
+ C.data,
+ C.elem_offset // 256,
+ dtype="handle",
+ )
+ )
+
+ tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+
+ Create the schedule and do tensorize:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_tensorize)
+ update = sch.get_block("update")
+ _, _, _, i1, _, _ = sch.get_loops(update)
+ sch.tensorize(i1, "test_mma_intrin")
+ print(sch.mod["main"].script())
+
+ After applying tensorize, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_tensorize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i_0, j_0, k_0 in T.grid(8, 8, 8):
+ with T.block("update_o"):
+ vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0, k_0])
+ T.reads(
+ C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 +
16],
+ A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 +
16],
+ B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 +
16],
+ )
+ T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16])
+ A_1 = T.match_buffer(
+ A[vio * 16 : vio * 16 + 16, vko * 16 : vko * 16 +
16],
+ [16, 16],
+ dtype="float32",
+ offset_factor=1,
+ )
+ B_1 = T.match_buffer(
+ B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko * 16 +
16],
+ [16, 16],
+ dtype="float32",
+ offset_factor=1,
+ )
+ C_1 = T.match_buffer(
+ C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 +
16],
+ [16, 16],
+ dtype="float32",
+ offset_factor=1,
+ )
+ with T.init():
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("update_init"):
+ vi_init, vj_init = T.axis.remap("SS",
[i_1, j_1])
+ T.reads()
+ T.writes(C[vio * 16 + vi_init, vjo * 16 +
vj_init])
+ C[vio * 16 + vi_init, vjo * 16 + vj_init]
= T.float32(0)
+ T.evaluate(
+ T.tvm_mma_sync(
+ C_1.data,
+ C_1.elem_offset // 256,
+ A_1.data,
+ A_1.elem_offset // 256,
+ B_1.data,
+ B_1.elem_offset // 256,
+ C_1.data,
+ C_1.elem_offset // 256,
+ dtype="handle",
+ )
+ )
+ """
+ _ffi_api.ScheduleTensorize( # type: ignore # pylint: disable=no-member
+ self, block_or_loop, tensor_intrin
+ )
+
########## Schedule: Annotation ##########
@type_checked
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 55a1a5a..3d30eef 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -511,7 +511,7 @@ Range IntSet::CoverRange(Range max_range) const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
ICHECK(s_int != nullptr);
if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
- return Range::FromMinExtent(s_int->min_value,
+ return Range::FromMinExtent(analyzer.Simplify(s_int->min_value),
analyzer.Simplify(s_int->max_value + 1 -
s_int->min_value));
}
return max_range;
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 101d80a..1c34e34 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -64,6 +64,51 @@ FuncType PrimFuncNode::func_type_annotation() const {
TVM_REGISTER_NODE_TYPE(PrimFuncNode);
+class TensorIntrinManager {
+ public:
+ Map<String, tir::TensorIntrin> reg;
+
+ static TensorIntrinManager* Global() {
+ static TensorIntrinManager* inst = new TensorIntrinManager();
+ return inst;
+ }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc, PrimFunc impl) {
+ // Check the number of func var is equal
+ CHECK_EQ(desc->params.size(), impl->params.size())
+ << "ValueError: The number of parameters of the description and the
implementation of the "
+ "tensor intrinsic doesn't match.";
+ for (size_t i = 0; i < desc->params.size(); i++) {
+ CHECK(desc->params[i]->dtype.is_handle()) << "ValueError: Parameters of
the description of the "
+ "tensor intrinsic should be
handle only.";
+ CHECK(impl->params[i]->dtype.is_handle()) << "ValueError: Parameters of
the implementation of "
+ "the tensor intrinsic should
be handle only.";
+ }
+ ICHECK_EQ(desc->buffer_map.size(), impl->buffer_map.size());
+
+ ObjectPtr<TensorIntrinNode> n = make_object<TensorIntrinNode>();
+ n->desc = std::move(desc);
+ n->impl = std::move(impl);
+ data_ = std::move(n);
+}
+
+void TensorIntrin::Register(String name, TensorIntrin intrin) {
+ TensorIntrinManager* manager = TensorIntrinManager::Global();
+ CHECK_EQ(manager->reg.count(name), 0)
+ << "ValueError: TensorIntrin '" << name << "' has already been
registered";
+ manager->reg.Set(name, intrin);
+}
+
+TensorIntrin TensorIntrin::Get(String name) {
+ const TensorIntrinManager* manager = TensorIntrinManager::Global();
+ auto it = manager->reg.find(name);
+ CHECK(it != manager->reg.end()) << "ValueError: TensorIntrin '" << name <<
"' is not registered";
+ return manager->reg.at(name);
+}
+
+TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
// TODO(tvm-team) redirect to Text printer once we have a good text
format.
@@ -85,5 +130,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
});
+TVM_REGISTER_GLOBAL("tir.TensorIntrin")
+ .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) {
+ return TensorIntrin(desc_func, intrin_func);
+ });
+
+TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register);
+TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 9f8dc6d..fc63f30 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -606,6 +606,29 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV&
loop_rv, int factor_axis) {
}
/******** Schedule: Blockize & Tensorize ********/
+BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) {
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::Blockize(state_, this->GetSRef(loop_rv));
+ this->state_->DebugVerify();
+ TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
+ return CreateRV<BlockRV>(result);
+}
+
+void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String&
intrin) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::Tensorize(state_, this->GetSRef(loop_rv),
tir::TensorIntrin::Get(intrin));
+ this->state_->DebugVerify();
+ TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
+}
+
+void ConcreteScheduleNode::Tensorize(const BlockRV& block_rv, const String&
intrin) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::Tensorize(state_, this->GetSRef(block_rv),
tir::TensorIntrin::Get(intrin));
+ this->state_->DebugVerify();
+ TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
+}
+
/******** Schedule: Annotation ********/
ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef&
ann_val) {
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 96cb0f7..5f10817 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -123,6 +123,9 @@ class ConcreteScheduleNode : public ScheduleNode {
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String&
storage_scope) override;
/******** Schedule: Blockize & Tensorize ********/
+ BlockRV Blockize(const LoopRV& loop_rv) override;
+ void Tensorize(const BlockRV& loop_rv, const String& intrin) override;
+ void Tensorize(const LoopRV& loop_rv, const String& intrin) override;
/******** Schedule: Annotation ********/
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef&
ann_val) override;
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
diff --git a/src/tir/schedule/ir_comparator.cc
b/src/tir/schedule/ir_comparator.cc
new file mode 100644
index 0000000..3e61e95
--- /dev/null
+++ b/src/tir/schedule/ir_comparator.cc
@@ -0,0 +1,363 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./ir_comparator.h"
+
+namespace tvm {
+
+namespace tir {
+
+/******** Tensorize Comparator ********/
+
+class TensorIntrinMismatchError : public ScheduleError {
+ public:
+ explicit TensorIntrinMismatchError(IRModule lhs_mod, Stmt lhs_stmt, Stmt
rhs_stmt,
+ std::vector<std::string> error_messages)
+ : lhs_mod_(std::move(lhs_mod)),
+ lhs_stmt_(std::move(lhs_stmt)),
+ rhs_stmt_(std::move(rhs_stmt)),
+ error_messages_(std::move(error_messages)) {
+ ICHECK(lhs_stmt_->IsInstance<ForNode>() ||
lhs_stmt_->IsInstance<BlockNode>());
+ }
+
+ String FastErrorString() const final {
+ return "ScheduleError: The stmt doesn't match the tensor intrin.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The stmt {0} doesn't match the tensor intrin\n " << rhs_stmt_;
+ for (const auto& msg : error_messages_) {
+ os << msg << std::endl;
+ }
+ return os.str();
+ }
+
+ IRModule mod() const final { return lhs_mod_; }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {lhs_stmt_}; }
+
+ private:
+ IRModule lhs_mod_;
+ Stmt lhs_stmt_;
+ Stmt rhs_stmt_;
+ std::vector<std::string> error_messages_;
+};
+
+/* Override the dispatcher to make sure RHS is always valid */
+bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
+ bool equal = n.same_as(other) ||
+ ((n->type_index() == other->type_index()) &&
StmtComparator::VisitStmt(n, other));
+ if (!equal && assert_mode_ && (n->IsInstance<ForNode>() ||
n->IsInstance<BlockNode>())) {
+ throw TensorIntrinMismatchError(lhs_mod_, n, other,
std::move(error_messages_));
+ }
+ return equal;
+}
+
+bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
+ bool equal =
+ n.same_as(other) || ((n->type_index() == other->type_index()) &&
n->dtype == other->dtype &&
+ ExprComparator::VisitExpr(n, other));
+ if (!equal && assert_mode_) {
+ std::ostringstream os;
+ os << "Expression mismatch: " << n << " vs " << other;
+ EmitError(os.str());
+ }
+ return equal;
+}
+
+bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
+ const auto* rhs = other.as<ForNode>();
+ if (!DefEqual(op->loop_var, rhs->loop_var)) return false;
+ if (!VisitExpr(op->min, rhs->min)) return false;
+ if (!VisitExpr(op->extent, rhs->extent)) return false;
+ if (op->thread_binding.defined() != rhs->thread_binding.defined()) return
false;
+ if (op->thread_binding.defined() &&
+ !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) {
+ return false;
+ }
+ if (op->kind != rhs->kind) return false;
+ if (!CompareAnnotationMap(op->annotations, rhs->annotations)) return false;
+ return VisitStmt(op->body, rhs->body);
+}
+
+bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other)
{
+ const auto* rhs = other.as<SeqStmtNode>();
+ return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt);
+}
+
+bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BufferStoreNode>();
+ return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value);
+}
+
+bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt&
other) {
+ const auto* rhs = other.as<BlockRealizeNode>();
+ if (!is_scope_block) {
+ if (!CompareArray(op->iter_values, rhs->iter_values,
&TensorizeComparator::VisitExpr)) {
+ return false;
+ }
+ }
+ return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block,
rhs->block);
+}
+
+bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) {
+ const auto* rhs = other.as<BlockNode>();
+ // Check block equality.
+ // All iter vars and buffer regions including the order should match.
+ // When checking iter vars, DefEqual is used to remap variables.
+ if (!is_scope_block) {
+ if (!CompareArray(op->iter_vars, rhs->iter_vars,
&TensorizeComparator::CompareIterVar)) {
+ return false;
+ }
+ if (!CompareAnnotationMap(op->annotations, rhs->annotations)) {
+ return false;
+ }
+ if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers,
&TensorizeComparator::CompareBuffer)) {
+ return false;
+ }
+ }
+ if (!CompareArray(op->writes, rhs->writes,
&TensorizeComparator::CompareBufferRegion)) {
+ return false;
+ }
+ if (!CompareArray(op->reads, rhs->reads,
&TensorizeComparator::CompareBufferRegion)) {
+ return false;
+ }
+ is_scope_block = false;
+ return VisitStmt(op->body, rhs->body);
+}
+
+// Exprs
+#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName)
\
+ bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr&
other) { \
+ const auto* rhs = other.as<OpName>();
\
+ return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b);
\
+ }
+
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode);
+TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode);
+
+bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<IntImmNode>();
+ return op->value == rhs->value;
+}
+
+bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<FloatImmNode>();
+ return op->value == rhs->value;
+}
+
+bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<CastNode>();
+ return VisitExpr(op->value, rhs->value);
+}
+
+bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other)
{
+ const auto* rhs = other.as<VarNode>();
+ auto lhs = GetRef<Var>(op);
+ if (lhs.same_as(other)) return true;
+ if (op->dtype != rhs->dtype) return false;
+ auto it = equal_map_.find(lhs);
+ return it != equal_map_.end() && it->second.same_as(other);
+}
+
+bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<BufferLoadNode>();
+ return CompareBufferAccess(op, rhs);
+}
+
+bool TensorizeComparator::VisitExpr_(const SelectNode* op, const PrimExpr&
other) {
+ const auto* rhs = other.as<SelectNode>();
+ return VisitExpr(op->condition, rhs->condition) && VisitExpr(op->true_value,
rhs->true_value) &&
+ VisitExpr(op->false_value, rhs->false_value);
+}
+
+bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) {
+ if (lhs.same_as(rhs)) return true;
+ auto it = equal_map_.find(lhs);
+ // If there is already a mapping
+ if (it != equal_map_.end()) return it->second.same_as(rhs);
+ // Otherwise remap lhs to rhs
+ equal_map_[lhs] = rhs;
+ analyzer_.Bind(lhs, rhs);
+ return true;
+}
+
+bool TensorizeComparator::CompareAnnotation(const std::pair<String,
ObjectRef>& lhs,
+ const std::pair<String,
ObjectRef>& rhs) {
+ if (lhs.first != rhs.first) return false;
+ if (!lhs.second.same_as(rhs.second)) return false;
+ return VisitExpr(Downcast<PrimExpr>(lhs.second),
Downcast<PrimExpr>(rhs.second));
+}
+
+bool TensorizeComparator::CompareAnnotationMap(const Map<String, ObjectRef>&
lhs,
+ const Map<String, ObjectRef>&
rhs) {
+ if (lhs.same_as(rhs)) return true;
+ if (lhs.size() != rhs.size()) return false;
+
+ auto sort_map =
+ [](const Map<String, ObjectRef>& map) -> std::vector<std::pair<String,
ObjectRef>> {
+ std::vector<std::pair<String, ObjectRef>> ret(map.begin(), map.end());
+ sort(ret.begin(), ret.end());
+ return ret;
+ };
+
+ std::vector<std::pair<String, ObjectRef>> lhs_array = sort_map(lhs);
+ std::vector<std::pair<String, ObjectRef>> rhs_array = sort_map(rhs);
+
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false;
+ }
+ return true;
+}
+
+bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) {
+ if (lhs.same_as(rhs)) return true;
+ auto it = rhs_buffer_map_.find(rhs);
+ bool equal;
+ if (it != rhs_buffer_map_.end()) {
+ equal = (*it).second.same_as(lhs);
+ } else {
+ // Remap both buffer itself and buffer data, skip buffer shape
+ equal =
+ DefEqual(lhs->data, rhs->data) && lhs->dtype == rhs->dtype &&
lhs.scope() == rhs.scope();
+ if (equal) {
+ rhs_buffer_map_[rhs] = lhs;
+ }
+ }
+ return equal;
+}
+
+bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const
BufferRegion& rhs) {
+ if (!CompareBuffer(lhs->buffer, rhs->buffer)) {
+ if (assert_mode_) {
+ std::ostringstream os;
+ os << "Buffer mismatch: " << lhs->buffer << " vs " << rhs->buffer;
+ EmitError(os.str());
+ }
+ return false;
+ }
+ int offset = static_cast<int>(lhs->region.size()) -
static_cast<int>(rhs->region.size());
+ // Number of indices in RHS (desc of the tensor intrinsic) must be smaller
than it in LHS
+ if (offset < 0) return false;
+
+ auto it = buffer_indices_.find(lhs->buffer);
+ if (it == buffer_indices_.end()) {
+ // Update base indices for the buffer, this can only happen if it is
visiting the scope block.
+ ICHECK(is_scope_block);
+ std::vector<PrimExpr> indices_base;
+ indices_base.reserve(lhs->region.size());
+ for (int i = 0; i < offset; i++) {
+ // High-dim region must be element-wise
+ if (!is_one(lhs->region[i]->extent)) return false;
+ indices_base.emplace_back(lhs->region[i]->min);
+ }
+ for (size_t i = 0; i < rhs->region.size(); i++) {
+ // save base index
+ indices_base.emplace_back(lhs->region[i + offset]->min);
+ // check extent match
+ if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent,
rhs->region[i]->extent)) {
+ return false;
+ }
+ }
+ buffer_indices_.emplace(lhs->buffer, std::move(indices_base));
+ } else {
+ // Check the base indices are consistent.
+ const std::vector<PrimExpr>& indices_base = it->second;
+ for (int i = 0; i < offset; i++) {
+ // High-dim region must be element-wise
+ if (!is_one(lhs->region[i]->extent)) return false;
+ if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min))
return false;
+ }
+ for (size_t i = 0; i < rhs->region.size(); i++) {
+ // check extent match
+ if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent,
rhs->region[i]->extent)) {
+ return false;
+ }
+ PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min -
indices_base[i + offset]);
+ if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+// Comparator for BufferStoreNode and BufferLoadNode
+template <typename T>
+bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
+ if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false;
+ int offset = static_cast<int>(lhs->indices.size()) -
static_cast<int>(rhs->indices.size());
+ if (offset < 0) return false;
+ auto it = buffer_indices_.find(lhs->buffer);
+ ICHECK(it != buffer_indices_.end());
+ const std::vector<PrimExpr>& indices_base = (*it).second;
+ ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset);
+ for (size_t i = 0; i < rhs->indices.size(); i++) {
+ PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i
+ offset];
+ if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) {
+ if (assert_mode_) {
+ std::ostringstream os;
+ os << "Buffer indices mismatch: " << lhs->indices[i + offset] << " vs
" << rhs->indices[i];
+ EmitError(os.str());
+ }
+ return false;
+ }
+ }
+ return true;
+}
+
+template <typename T, typename F>
+bool TensorizeComparator::CompareArray(const Array<T>& lhs, const Array<T>&
rhs, F cmp) {
+ if (lhs.same_as(rhs)) return true;
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!(this->*cmp)(lhs[i], rhs[i])) return false;
+ }
+ return true;
+}
+
+bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) {
+ return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent);
+}
+
+bool TensorizeComparator::CompareIterVar(const IterVar& lhs, const IterVar&
rhs) {
+ return DefEqual(lhs->var, rhs->var) && lhs->iter_type == rhs->iter_type;
+}
+
+void TensorizeComparator::EmitError(const std::string& error_message) {
+ error_messages_.push_back(error_message);
+}
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h
new file mode 100644
index 0000000..359677d
--- /dev/null
+++ b/src/tir/schedule/ir_comparator.h
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
+#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
+
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace tir {
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr&
other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent for
tensorization*/
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+ /*!
+ * \brief Constructor of TensorizeComparator
+ * \param assert_mode Whether to raise an error if the two IR ASTs do not
match.
+ * \param lhs_mod The IRModule of the LHS. This is used for error reporting.
+ */
+ explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true)
+ : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {}
+
+ bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
+ bool VisitStmt(const Stmt& n, const Stmt& other) override;
+
+ bool VisitStmt_(const ForNode* op, const Stmt& other) override;
+ bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
+ bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+ bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
+ bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
+
+ bool VisitExpr_(const AddNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const SubNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const MulNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const DivNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const ModNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const EQNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const NENode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const LTNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const LENode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const GTNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const GENode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const AndNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const OrNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const MinNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const CastNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const VarNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
+ bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override;
+
+ /*! \brief Map from RHS buffer to LHS buffer */
+ std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+ /*! \brief Base indices of the LHS buffer. */
+ std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash,
ObjectPtrEqual> buffer_indices_;
+
+ protected:
+ bool DefEqual(const Var& lhs, const Var& rhs);
+ virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs);
+ bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs);
+ bool CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
+ const std::pair<String, ObjectRef>& rhs);
+ bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const
Map<String, ObjectRef>& rhs);
+ template <typename T>
+ bool CompareBufferAccess(const T* lhs, const T* rhs);
+ template <typename T, typename F>
+ bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp);
+ bool CompareRange(const Range& lhs, const Range& rhs);
+ bool CompareIterVar(const IterVar& lhs, const IterVar& rhs);
+ void EmitError(const std::string& error_message);
+
+ /*! \brief IRModule of the LHS stmt. */
+ IRModule lhs_mod_;
+ /*! \brief Whether assertion mode is enabled. */
+ bool assert_mode_;
+ /*! \brief Whether it is visiting the scope block (the outermost block). */
+ bool is_scope_block = true;
+ /*! \brief The arithmetic analyzer. */
+ arith::Analyzer analyzer_;
+ /*! \brief Additional error messages. Only used when assert_mode is true. */
+ std::vector<std::string> error_messages_;
+ // variable remap if any
+ std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual>
equal_map_;
+};
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index f0b38af..2368411 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -378,6 +378,24 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef&
block_sref, int buffer
const String& storage_scope);
/******** Schedule: Blockize & Tensorize ********/
+
+/*!
+ * \brief Convert the subtree rooted at a specific loop into a block.
+ * \param self The state of the schedule
+ * \param loop_sref The root of the subtree
+ * \return The new block
+ */
+TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref);
+
+/*!
+ * \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
+ * \param self The state of the schedule
+ * \param block_or_loop_sref The block or loop to be tensorized.
+ * \param intrin The tensor intrinsic.
+ */
+TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
+ const TensorIntrin& intrin);
+
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a block/loop with a key value pair
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc
b/src/tir/schedule/primitive/blockize_tensorize.cc
new file mode 100644
index 0000000..bbeb9ca
--- /dev/null
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -0,0 +1,698 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <functional>
+
+#include "../ir_comparator.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief ScheduleError that the bindings of the inner block are not divisible
by the subspace
+ * represented by the outer loops.
+ */
+class SubspaceNotDivisibleError : public ScheduleError {
+ public:
+ explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block
inner_block)
+ : mod_(std::move(mod)),
+ scope_loop_(std::move(scope_loop)),
+ inner_block_(std::move(inner_block)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The bindings of the inner block can not be
blockized.";
+ }
+
+ String DetailRenderTemplate() const final {
+ return "ScheduleError: The bindings of the inner block {0} can not be
blockized by the loops "
+ "starting at {1}.";
+ }
+
+ IRModule mod() const final { return mod_; }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {inner_block_,
scope_loop_}; }
+
+ private:
+ IRModule mod_;
+ For scope_loop_;
+ Block inner_block_;
+};
+
+/*!
+ * \brief Detect if bindings are a trivial case of the subspace division where
we can divide the
+ * block iter bindings into two categories:
+ * 1. The binding covers no inner loop vars.
+ * 2. The binding covers only inner loop vars.
+ *
+ * The bindings are not required to be quasi-affine.
+ *
+ * \param iter_vars The input iterators
+ * \param bindings The values of iter_vars
+ * \param outer_loops Iterators outside the subspace.
+ * \param inner_loops Iterators of the subspace
+ * \param predicate The predicate constraint on the input iterators.
+ * \return The result of the subspace division.
+ */
+Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>&
iter_vars,
+ const Array<PrimExpr>&
bindings,
+ const Array<Var>&
outer_iters,
+ const Array<Var>&
inner_iters,
+ const PrimExpr&
predicate) {
+ if (!is_one(predicate)) return {};
+ Array<Array<arith::IterMark>> res;
+ std::unordered_set<const VarNode*> outer_loop_vars;
+ std::unordered_set<const VarNode*> inner_loop_vars;
+
+ auto make_uses_var = [](const Array<Var>& vars) -> std::function<bool(const
PrimExpr& expr)> {
+ std::unordered_set<const VarNode*> var_set;
+ var_set.reserve(vars.size());
+ for (const Var& var : vars) {
+ var_set.insert(var.get());
+ }
+ return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool {
+ return UsesVar(expr, [&var_set](const VarNode* var) {
+ return var_set.count(var); //
+ });
+ };
+ };
+ auto use_outer_loop_vars = make_uses_var(outer_iters);
+ auto use_inner_loop_vars = make_uses_var(inner_iters);
+ arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1);
+
+ for (size_t i = 0; i < bindings.size(); ++i) {
+ bool outer = use_outer_loop_vars(bindings[i]);
+ bool inner = use_inner_loop_vars(bindings[i]);
+ arith::IterMark iter_mark;
+ if (bindings[i]->IsInstance<VarNode>()) {
+ iter_mark = arith::IterMark(
+ arith::IterSplitExpr(arith::IterMark(bindings[i],
iter_vars[i]->dom->extent)),
+ iter_vars[i]->dom->extent);
+ } else {
+ iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]),
iter_vars[i]->dom->extent);
+ }
+ if (outer && !inner) {
+ res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark});
+ } else if (inner && !outer) {
+ res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark});
+ } else if (!outer && !inner) {
+ res.push_back({/*outer_iter=*/unit_iter_mark,
/*inner_iter=*/unit_iter_mark});
+ } else {
+ return {};
+ }
+ }
+ res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)),
+ arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))});
+ return res;
+}
+
+/*!
+ * \brief Generate the blockized init block.
+ * \param block The original block with init.
+ * \param inner_block_realize The block realize of the inner block after
blockize.
+ * \param inner_loops The inner loops after blockize.
+ * \return The subtree of the init block and its outer loops.
+ */
+Stmt GenerateBlockizedInit(const Block& block, const BlockRealize&
inner_block_realize,
+ const std::vector<const ForNode*>& inner_loops) {
+ Array<IterVar> init_block_iters;
+ Array<PrimExpr> init_bindings;
+ const Block& inner_block = inner_block_realize->block;
+
+ // Step 1: Collect data-parallel block iters
+ for (size_t i = 0; i < inner_block->iter_vars.size(); i++) {
+ const IterVar& iter_var = inner_block->iter_vars[i];
+ const PrimExpr& binding = inner_block_realize->iter_values[i];
+ if (iter_var->iter_type == IterVarType::kDataPar &&
+ UsesVar(block->init.value(),
+ [tgt_var = iter_var->var.get()](const VarNode* var) { return
var == tgt_var; })) {
+ init_block_iters.push_back(iter_var);
+ init_bindings.push_back(binding);
+ }
+ }
+
+ // Step 2: Collect loops related to iters of the init block
+ std::vector<const ForNode*> init_loops;
+ for (const ForNode* inner_loop : inner_loops) {
+ for (const PrimExpr& init_binding : init_bindings) {
+ if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const
VarNode* var) {
+ return var == tgt_var;
+ })) {
+ init_loops.push_back(inner_loop);
+ break;
+ }
+ }
+ }
+
+ // Step 3: Create new block iters for the init block
+ Map<Var, PrimExpr> subst_map;
+ for (size_t i = 0; i < init_block_iters.size(); i++) {
+ IterVar new_iter_var = init_block_iters[i];
+ Var old_var = new_iter_var->var;
+ Var new_var = old_var.copy_with_suffix("_init");
+ new_iter_var.CopyOnWrite()->var = new_var;
+ subst_map.Set(old_var, new_var);
+ init_block_iters.Set(i, std::move(new_iter_var));
+ }
+
+ // Step 4: Generate loop nests and the init block
+ Stmt new_init = BlockRealize(
+ /*iter_values=*/init_bindings,
+ /*predicate=*/inner_block_realize->predicate,
+ /*block=*/
+ Block{/*iter_vars=*/init_block_iters,
+ /*reads=*/{},
+ /*writes=*/block->writes,
+ /*name_hint=*/block->name_hint + "_init",
+ /*body=*/block->init.value(),
+ /*init=*/NullOpt});
+
+ // Step 5: Generate the parent loops for the init block
+ for (const ForNode* init_loop : init_loops) {
+ ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop);
+ new_loop->loop_var = init_loop->loop_var.copy_with_suffix("");
+ subst_map.Set(init_loop->loop_var, new_loop->loop_var);
+ new_loop->body = std::move(new_init);
+ new_init = For(new_loop);
+ }
+
+ // Step 6: Substitute with new loop variables and block iters to prevent
duplication of
+ // variables in the outer block.
+ new_init = Substitute(new_init, subst_map);
+
+ return new_init;
+}
+
+/*!
+ * \brief A helper to collect the parent loops of the block. The loops are
divided into two groups,
+ * 'outer_loops', and 'inner_loops', by a specified loop as the separator.
'outer_loops' are the
+ * ancestor loops of the separator loop. 'inner_loops' include the separator
loop itself, and its
+ * successor loops. It is possible that 'outer_loops' is empty.
+ */
+class LoopSubspaceCollector {
+ public:
+ /*!
+ * \brief Collect the parent loops of the block and store the result in the
corresponding fields.
+ * \param block_sref The sref to the target block.
+ * \param loop_sref The sref to the separator loop. The loop itself is
counted as an inner loop.
+ */
+ void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
+ bool inner = true;
+ for (StmtSRefNode* current_sref = block_sref->parent;
+ current_sref && current_sref->stmt->IsInstance<ForNode>();
+ current_sref = current_sref->parent) {
+ const auto* current_loop = current_sref->StmtAs<ForNode>();
+ ICHECK(current_loop);
+ if (inner) {
+ inner_loops.push_back(current_loop);
+ inner_loop_vars.push_back(current_loop->loop_var);
+ } else {
+ outer_loops.push_back(current_loop);
+ outer_loop_vars.push_back(current_loop->loop_var);
+ }
+ loop_var_domain.Set(current_loop->loop_var,
+ Range::FromMinExtent(current_loop->min,
current_loop->extent));
+ if (current_sref == loop_sref.get()) inner = false;
+ }
+ }
+ /*! \brief Outer loops which are ancestors of the separator. */
+ std::vector<const ForNode*> outer_loops;
+ /*! \brief Inner loops which are the separator itself or its successors. */
+ std::vector<const ForNode*> inner_loops;
+ /*! \brief Loop variables of the outer loops. */
+ Array<Var> outer_loop_vars;
+ /*! \brief Loop variables of the inner loops. */
+ Array<Var> inner_loop_vars;
+ /*! \brief Domain of the loop variables. */
+ Map<Var, Range> loop_var_domain;
+};
+
+/*!
+ * \brief Check the bindings of the block iters can be divided by a subspace
collected by the
+ * collector.
+ * \param mod The current IR module.
+ * \param block_realize The block realize to be checked.
+ * \param collector The collector which has collected the loops of the block.
+ * \param analyzer The arithmetic analyzer.
+ * \return The result of the subspace division.
+ * \throws ScheduleError If the bindings are not divisible by the subspace.
+ */
+Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod,
+ const BlockRealize&
block_realize,
+ const
LoopSubspaceCollector& collector,
+ arith::Analyzer*
analyzer) {
+ const Block& block = block_realize->block;
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(mod));
+
+ Array<Array<arith::IterMark>> division =
+ arith::SubspaceDivide(block_realize->iter_values,
collector.loop_var_domain,
+ collector.inner_loop_vars,
block_realize->predicate,
+ /*require_bijective=*/false, analyzer, diag_ctx);
+
+ if (division.empty()) {
+ // If we can't do perfect subspace division, check if it is a trivial case
of subspace division.
+ // In this case, we can still blockize.
+ division = TrivialSubspaceDivision(block->iter_vars,
block_realize->iter_values,
+ collector.outer_loop_vars,
collector.inner_loop_vars,
+ block_realize->predicate);
+ }
+ if (division.empty()) {
+ throw SubspaceNotDivisibleError(mod,
GetRef<For>(collector.inner_loops.back()), block);
+ }
+ return division;
+}
+
+/*!
+ * \brief The binding extractor to compute the bindings of the outer and the
inner blocks after
+ * blockize.
+ */
+class BlockizedBindingExtractor {
+ public:
+ /*!
+ * \brief Extract bindings for blockize.
+ * \param iter_vars The iter vars of the original inner block.
+ * \param division The result of the subspace division.
+ */
+ void ExtractBindings(const Array<IterVar>& iter_vars,
+ const Array<Array<arith::IterMark>>& division,
arith::Analyzer* analyzer) {
+ ICHECK_EQ(iter_vars.size() + 1, division.size());
+ for (size_t i = 0; i < iter_vars.size(); ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ arith::IterMark outer_mark = division[i][0];
+ arith::IterMark inner_mark = division[i][1];
+ const auto* outer_binding =
+ TVM_TYPE_AS(outer_binding, outer_mark->source,
arith::IterMapExprNode);
+ const auto* inner_binding =
+ TVM_TYPE_AS(inner_binding, inner_mark->source,
arith::IterMapExprNode);
+
+ // After computing the subspace division, bindings[i] can be written as
+ // outer_binding * inner_binding->extent + inner_binding
+ // The outer block will have binding: iter_outer -> outer_binding
+ // The inner block will have binding: iter_inner -> inner_binding
+ // The iter in the original block will be substituted with base +
iter_inner where
+ // base == iter_outer * iter_inner_extent
+
+ if (is_one(division[i][1]->extent)) { // IsOuter
+ // extract this iter var to outer block directly
+ outer_bindings.push_back(
+
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
+ outer_iter_vars.push_back(iter_var);
+ } else {
+ // create iter var for the outer block
+ const IterVar outer_var(/*dom=*/Range::FromMinExtent(0,
division[i][0]->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_o"),
+ /*iter_type=*/iter_var->iter_type);
+ outer_bindings.push_back(
+
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
+ outer_iter_vars.push_back(outer_var);
+ PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var *
division[i][1]->extent;
+ // create iter var for the inner block
+ IterVar new_iter = iter_var;
+ auto* new_iter_node = new_iter.CopyOnWrite();
+ new_iter_node->dom = Range::FromMinExtent(0, division[i][1]->extent);
+ inner_iter_dom_map.Set(new_iter->var,
arith::IntSet::FromRange(new_iter->dom));
+ analyzer->Bind(new_iter->var, new_iter->dom);
+ inner_iter_vars.push_back(new_iter);
+ inner_bindings.push_back(
+
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(inner_binding)));
+ inner_iter_subst_map.Set(iter_var->var, base + new_iter->var);
+ }
+ }
+ }
+ Map<Var, PrimExpr> inner_iter_subst_map;
+ /*! \brief Iters of the outer block. */
+ Array<IterVar> outer_iter_vars;
+ /*! \brief Iters of the outer block. */
+ Array<IterVar> inner_iter_vars;
+ /*! \brief Binding values of the outer block. */
+ Array<PrimExpr> outer_bindings;
+ /*! \brief Binding values of the inner block. */
+ Array<PrimExpr> inner_bindings;
+ /*! \brief The domain of the inner block iters. */
+ Map<Var, arith::IntSet> inner_iter_dom_map;
+};
+
+/*!
+ * \brief Replacer for the inner block after blockize. Inner block iters will
be replaced with
+ * base + inner_iter and the expressions after substituion will be simplified
if possible.
+ */
+class InnerIterReplacer : public StmtExprMutator {
+ public:
+ /*!
+ * \brief The constructor
+ * \param subst_map The substitution map of the inner block iters.
+ * \param analyzer The arithmetic analyzer.
+ * \param block_sref_reuse The map to save the block reuse information.
+ */
+ InnerIterReplacer(Map<Var, PrimExpr> subst_map, arith::Analyzer* analyzer,
+ Map<Block, Block>* block_sref_reuse)
+ : subst_map_(std::move(subst_map)),
+ analyzer_(analyzer),
+ block_sref_reuse_(block_sref_reuse) {}
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = subst_map_.find(GetRef<Var>(op));
+ if (it != subst_map_.end()) {
+ return (*it).second;
+ }
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ PrimExpr VisitExpr(const PrimExpr& op) final {
+ PrimExpr result = StmtExprMutator::VisitExpr(op);
+ if (!result.same_as(op)) {
+ return analyzer_->Simplify(result);
+ }
+ return result;
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Stmt result = StmtExprMutator::VisitStmt_(op);
+ if (!result.same_as(GetRef<Stmt>(op))) {
+ block_sref_reuse_->Set(GetRef<Block>(op), Downcast<Block>(result));
+ }
+ return result;
+ }
+
+ private:
+ Map<Var, PrimExpr> subst_map_;
+ arith::Analyzer* analyzer_;
+ Map<Block, Block>* block_sref_reuse_;
+};
+
+/*!
+ * \brief Compute the access region of the outer block by relaxing the inner
loops.
+ * \param buffer_region The original buffer region.
+ * \param The range of the inner loops.
+ * \return The new buffer region.
+ */
+BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region,
+ const Map<Var, arith::IntSet>&
inner_iter_relaxed_range) {
+ Array<Range> new_region;
+ new_region.reserve(buffer_region->region.size());
+ Array<arith::IntSet> relaxed_int_set =
+ arith::EvalSet(buffer_region->region, inner_iter_relaxed_range);
+ ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size());
+ for (size_t i = 0; i < buffer_region->region.size(); i++) {
+ Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]);
+ new_region.push_back(relaxed_int_set[i].CoverRange(max_range));
+ }
+ return BufferRegion(buffer_region->buffer, std::move(new_region));
+}
+
+/*!
+ * \brief Generate the outer block after blockize.
+ * \param extractor The binding extractor which has extracted the blockized
bindings.
+ * \param block The original inner block.
+ * \param inner_block_realize The block realize of the inner block after
blockize.
+ * \param inner_loops The inner loops after blockize.
+ * \param predicate The outer predicate of the subspace division.
+ * \return The block realize of the outer block after blockize.
+ */
+BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor&
extractor,
+ const Block& block, BlockRealize
inner_block_realize,
+ const std::vector<const ForNode*>&
inner_loops,
+ PrimExpr predicate) {
+ // Step 1: Generate the init block if needed
+ Optional<Stmt> new_init = NullOpt;
+ if (block->init.defined()) {
+ new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops);
+ }
+
+ // Step 2: Compute the access regions of the outer block by relaxing the
inner loops
+ Array<BufferRegion> new_reads = block->reads;
+ Array<BufferRegion> new_writes = block->writes;
+
+ auto f_mutate = [&](const BufferRegion& buffer_region) {
+ return RelaxBlockizedInnerIters(buffer_region,
extractor.inner_iter_dom_map);
+ };
+ new_reads.MutateByApply(f_mutate);
+ new_writes.MutateByApply(f_mutate);
+
+ // Step 3: Generate the body of the outer block. The body of the outer block
is the inner block
+ // realize and its surrounding loops.
+ Stmt outer_block_body = inner_block_realize;
+ for (const ForNode* loop : inner_loops) {
+ ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+ new_loop->body = std::move(outer_block_body);
+ outer_block_body = For(new_loop);
+ }
+
+ // Step 4: Generate the outer block and block realize.
+ return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings),
+ /*predicate=*/std::move(predicate),
+ /*block=*/
+
Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), //
+ /*reads=*/std::move(new_reads),
//
+ /*writes=*/std::move(new_writes),
//
+ /*name_hint=*/block->name_hint + "_o",
//
+ /*body=*/std::move(outer_block_body),
//
+ /*init=*/std::move(new_init)));
+}
+
+StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ arith::Analyzer analyzer;
+
+ // Step 1: Check the loop has a single child BlockRealize on the sref tree.
+ BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self,
loop_sref);
+ Block block = block_realize->block;
+ StmtSRef block_sref = self->stmt2ref.at(block.get());
+
+ // Step 2: Collect loops inside and outside loop_sref.
+ LoopSubspaceCollector collector;
+ collector.Collect(block_sref, loop_sref);
+
+ // Step 3: Calculate subspace division for the inner loops.
+ Array<Array<arith::IterMark>> division =
+ CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer);
+
+ // Step 4: Generate bindings for the outer block and the inner block based
on the result of
+ // the subspace division.
+ BlockizedBindingExtractor extractor;
+ extractor.ExtractBindings(block->iter_vars, division, &analyzer);
+ const PrimExpr& outer_pred = division.back()[0]->extent;
+ const PrimExpr& inner_pred = division.back()[1]->extent;
+
+ // Step 5: Substitute the iter vars in the original block with the inner
iters after the subspace
+ // division
+ Map<Block, Block> block_sref_reuse;
+ InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map),
&analyzer,
+ &block_sref_reuse);
+ Block new_block = Downcast<Block>(replacer(block));
+
+ // Step 6: Generate the inner block.
+ BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite();
+ inner_block_realize->iter_values = extractor.inner_bindings;
+ inner_block_realize->predicate = inner_pred;
+ inner_block_realize->block = new_block;
+ BlockNode* inner_block = inner_block_realize->block.CopyOnWrite();
+ inner_block->iter_vars = extractor.inner_iter_vars;
+ inner_block->init = NullOpt;
+ block_sref_reuse.Set(block, inner_block_realize->block);
+
+ // Step 6: Generate the outer block.
+ BlockRealize outer_realize =
+ GenerateBlockizedOuterBlock(extractor, new_block,
GetRef<BlockRealize>(inner_block_realize),
+ collector.inner_loops, outer_pred);
+ // Step 7: Do the actual replacement
+ self->Replace(loop_sref, outer_realize, block_sref_reuse);
+
+ // Step 8: Update the cached flags
+ StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
+ StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref,
/*require_stage_pipeline=*/false,
+
/*require_subtree_compact_dataflow=*/false);
+ bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
+ self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
+ self->block_info[scope_root].affine_binding = scope_block_affine_binding;
+ return outer_block_sref;
+}
+
+/*!
+ * \brief Update the map from the buffers in the desc to the impl of the tensor
+ * intrinsic.
+ * \param intrinsic The tensor intrinsic.
+ * \param buffer_map The map to be updated.
+ */
+void RemapTensorIntrinBuffers(
+ const TensorIntrin& intrinsic,
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>*
buffer_map) {
+ ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size());
+ for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) {
+ const Var& lhs_var = intrinsic->desc->params[i];
+ const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var];
+ const Var& rhs_var = intrinsic->impl->params[i];
+ const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var];
+ (*buffer_map)[rhs_buffer] = lhs_buffer;
+ }
+}
+
+void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
+ const TensorIntrin& intrinsic) {
+ /*!
+ * Check:
+ * - Check buffer binding, including type, alignment, shape and etc.
+ * - Check the sub AST is equal to the desc function.
+ *
+ * Mutate:
+ * - Blockize the sub AST (please refer blockize for details)
+ * - Bind buffers
+ * - Mutate the impl of the tensor intrinsic by replacing its buffers with
new
+ * buffers created via match buffer region.
+ * - Replace the sub tree with the mutated function.
+ */
+ const BlockRealize& desc_block_realize =
Downcast<BlockRealize>(intrinsic->desc->body);
+ const BlockRealize& impl_block_realize =
Downcast<BlockRealize>(intrinsic->impl->body);
+ Block impl_block = impl_block_realize->block;
+
+ // Step 1: Blockize the subtree rooted at the given loop if needed
+ StmtSRef block_sref{nullptr};
+ if (block_or_loop_sref->StmtAs<ForNode>()) {
+ block_sref = Blockize(self, block_or_loop_sref);
+ } else {
+ ICHECK(block_or_loop_sref->StmtAs<BlockNode>());
+ block_sref = block_or_loop_sref;
+ }
+ const BlockRealize& block_realize = GetBlockRealize(self, block_sref);
+
+ // Step 2: Compare the block with the desc of the tensor intrinsic, find the
correspondence
+ // between buffers in the block and the desc.
+ TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
+ comparator.VisitStmt(block_realize, desc_block_realize);
+
+ // Step 3: Find the correspondence between buffers in the current AST and
the impl of
+ // the tensor intrinsic
+ // Step 3.1: Map from intrinsic func buffer to desc func buffer
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
intrin_buffer_map;
+ RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map);
+ // Step 3.2: Map form intrinsic func buffer to current AST buffer
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map;
+ for (const auto& pair : intrin_buffer_map) {
+ auto it = comparator.rhs_buffer_map_.find(pair.second);
+ ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second;
+ buffer_map[pair.first] = it->second;
+ }
+
+ // Step 4: Create MatchBufferRegion for the params of the impl function of
the tensor
+ // intrin to make them subregions of the buffer in the original IR.
+ std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual>
buffer_region_map;
+ for (const BufferRegion& read : impl_block->reads) {
+ buffer_region_map.emplace(read->buffer, read->region);
+ }
+ for (const BufferRegion& write : impl_block->writes) {
+ buffer_region_map.emplace(write->buffer, write->region);
+ }
+ Array<MatchBufferRegion> match_buffer_regions;
+ match_buffer_regions.reserve(intrinsic->impl->params.size());
+ for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) {
+ const auto& param = intrinsic->impl->params[i];
+ const auto& buffer = intrinsic->impl->buffer_map.at(param);
+ const auto& source = buffer_map.at(buffer);
+ // add the detected base indices to each buffer access region of the
tensor intrinsic
+ Region old_region = buffer_region_map.at(buffer);
+ const auto& indices_base = comparator.buffer_indices_.at(source);
+ int offset = static_cast<int>(indices_base.size()) -
static_cast<int>(old_region.size());
+ ICHECK(offset >= 0);
+ Region new_region;
+ new_region.reserve(source->shape.size());
+ for (int i = 0; i < offset; i++) {
+ new_region.push_back(Range::FromMinExtent(indices_base[i], 1));
+ }
+ for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
+ new_region.push_back(Range::FromMinExtent(indices_base[i + offset],
old_region[i]->extent));
+ }
+ match_buffer_regions.push_back(MatchBufferRegion(buffer,
BufferRegion(source, new_region)));
+ }
+
+ // Step 5: Replace the subtree in the original IR with the tensor intrin
impl.
+ ObjectPtr<BlockNode> new_block_ptr =
make_object<BlockNode>(*block_realize->block.get());
+ new_block_ptr->body = impl_block->body;
+ ICHECK(new_block_ptr->match_buffers.empty());
+ new_block_ptr->match_buffers = std::move(match_buffer_regions);
+ Block new_block(new_block_ptr);
+
+ self->Replace(block_sref, new_block, {{block_realize->block, new_block}});
+
+ // Step 6: Update the cached flags.
+ StmtSRef scope_root = tir::GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false,
+
/*require_subtree_compact_dataflow=*/false);
+ self->UpdateScopeBlockInfo(static_cast<const
BlockNode*>(scope_root->stmt)->body);
+}
+
+/******** InstructionKind Registration ********/
+
+struct BlockizeTraits : public UnpackedInstTraits<BlockizeTraits> {
+ static constexpr const char* kName = "Blockize";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 0;
+ static constexpr size_t kNumDecisions = 0;
+
+ static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) {
+ return sch->Blockize(loop_rv);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String loop_rv) {
+ PythonAPICall py("blockize");
+ py.Input("loop", loop_rv);
+ py.SingleOutput(outputs);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+struct TensorizeTraits : public UnpackedInstTraits<TensorizeTraits> {
+ static constexpr const char* kName = "Tensorize";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 1;
+ static constexpr size_t kNumDecisions = 0;
+
+ static void UnpackedApplyToSchedule(Schedule sch, ObjectRef
block_or_loop_rv, String intrin) {
+ if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
+ sch->Tensorize(GetRef<BlockRV>(block), intrin);
+ } else if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
+ sch->Tensorize(GetRef<LoopRV>(loop), intrin);
+ } else {
+ LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: "
+ << block_or_loop_rv->GetTypeKey();
+ }
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String
block_or_loop_rv, String intrin) {
+ PythonAPICall py("tensorize");
+ py.Input("block_or_loop", block_or_loop_rv);
+ py.Input("intrin", intrin);
+ return py.Str();
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
+TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits);
+TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 6e33862..b466843 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -185,6 +185,20 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
.set_body_method<Schedule>(&ScheduleNode::SetScope);
/******** (FFI) Blockize & Tensorize ********/
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
+ .set_body_method<Schedule>(&ScheduleNode::Blockize);
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
+ .set_body_typed([](Schedule self, ObjectRef rv, String intrin) {
+ if (const auto* block_rv = rv.as<BlockRVNode>()) {
+ self->Tensorize(GetRef<BlockRV>(block_rv), intrin);
+ } else if (const auto* loop_rv = rv.as<LoopRVNode>()) {
+ self->Tensorize(GetRef<LoopRV>(loop_rv), intrin);
+ } else {
+ LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type:
" << rv->GetTypeKey()
+ << ". Its value is: " << rv;
+ }
+ });
+
/******** (FFI) Annotation ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate")
.set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key,
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 04b7dd5..3a37f81 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -201,9 +201,7 @@ class BlockInfoCollector : private StmtVisitor {
bool is_root_block = srefs_.empty();
// Calculate `BlockInfo::scope`
Array<StmtSRef> child_block_srefs = std::move(block_frames_.back());
- BlockInfo& info =
- self_->block_info.emplace(scope_root,
BlockInfo(BlockScope(child_block_srefs)))
- .first->second;
+ BlockInfo& info = self_->block_info[scope_root] =
BlockInfo(BlockScope(child_block_srefs));
// Set `affine_binding`
if (is_root_block) {
// If the block doesn't have outer loops and BlockRealize,
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index da7a264..1e2e57e 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -356,6 +356,37 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv,
int buffer_index,
/******** Schedule: Blockize & Tensorize ********/
+BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) {
+ BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv);
+ static const InstructionKind& kind = InstructionKind::Get("Blockize");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{loop_rv},
+ /*attrs=*/{},
+ /*outputs=*/{new_block}));
+ return new_block;
+}
+
+void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String&
intrin) {
+ ConcreteScheduleNode::Tensorize(loop_rv, intrin);
+ static const InstructionKind& kind = InstructionKind::Get("Tensorize");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{loop_rv},
+ /*attrs=*/{intrin},
+ /*outputs=*/{}));
+}
+
+void TracedScheduleNode::Tensorize(const BlockRV& block_rv, const String&
intrin) {
+ ConcreteScheduleNode::Tensorize(block_rv, intrin);
+ static const InstructionKind& kind = InstructionKind::Get("Tensorize");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{intrin},
+ /*outputs=*/{}));
+}
+
/******** Schedule: Annotation ********/
void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index b35f1b6..3a88e86 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -87,6 +87,9 @@ class TracedScheduleNode : public ConcreteScheduleNode {
int offset) final;
void SetScope(const BlockRV& block_rv, int buffer_index, const String&
storage_scope) final;
/******** Schedule: Blockize & Tensorize ********/
+ BlockRV Blockize(const LoopRV& loop_rv) final;
+ void Tensorize(const BlockRV& block_rv, const String& intrin) final;
+ void Tensorize(const LoopRV& loop_rv, const String& intrin) final;
/******** Schedule: Annotation ********/
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef&
ann_val) override;
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py
b/tests/python/unittest/test_tir_schedule_blockize.py
new file mode 100644
index 0000000..b4a16a8
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm.script import tir as T
+from tvm import tir
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
[email protected]_func
+def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,
128), "float32"]):
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]_func
+def single_elementwise_blockized1(
+ A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+ with T.block("blockized_B"):
+ vio = T.axis.spatial(1, 0)
+ vjo = T.axis.spatial(1, 0)
+ T.reads(A[0:128, 0:128])
+ T.writes(B[0:128, 0:128])
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+
+
[email protected]_func
+def single_elementwise_blockized2(
+ A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]
+) -> None:
+ for i in T.serial(128):
+ with T.block("blockized_B"):
+ vi = T.axis.spatial(128, i)
+ vjo = T.axis.spatial(1, 0)
+ T.reads(A[vi, 0:128])
+ T.writes(B[vi, 0:128])
+ for j in T.serial(128):
+ with T.block("B"):
+ vj = T.axis.remap("S", [j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+
+
[email protected]_func
+def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128,
128), "float32"]) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+
+
[email protected]_func
+def two_elementwise_blockized(
+ A: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"]
+) -> None:
+ B = T.alloc_buffer([128, 128], dtype="float32")
+ for i_0, j_0 in T.grid(8, 8):
+ with T.block("blockized_B"):
+ vio, vjo = T.axis.remap("SS", [i_0, j_0])
+ T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+ T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i_1, j_1])
+ T.reads(A[vio * 16 + vi, vjo * 16 + vj])
+ T.writes(B[vio * 16 + vi, vjo * 16 + vj])
+ B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo *
16 + vj] * T.float32(2)
+ with T.block("blockized_C"):
+ vio, vjo = T.axis.remap("SS", [i_0, j_0])
+ T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+ T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+ for ax0, ax1 in T.grid(16, 16):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[vio * 16 + vi, vjo * 16 + vj])
+ T.writes(C[vio * 16 + vi, vjo * 16 + vj])
+ C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo *
16 + vj] + T.float32(1)
+
+
[email protected]_func
+def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"])
-> None:
+ for k, i in T.grid(128, 128):
+ with T.block("B"):
+ vk, vi = T.axis.remap("RS", [k, i])
+ with T.init():
+ B[vi] = 0.0
+ B[vi] = B[vi] + A[vi, vk]
+
+
[email protected]_func
+def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,),
"float32"]) -> None:
+ with T.block("blockized_B"):
+ vko = T.axis.R(1, 0)
+ vio = T.axis.S(1, 0)
+ with T.init():
+ for i1 in T.serial(0, 128):
+ with T.block("B_init"):
+ vi_init = T.axis.S(128, i1)
+ B[vi_init] = T.float32(0)
+ for i0, i1_1 in T.grid(128, 128):
+ with T.block("B"):
+ vk, vi = T.axis.remap("RS", [i0, i1_1])
+ B[vi] = B[vi] + A[vi, vk]
+
+
+# fmt: off
+# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
+def test_blockize_outer():
+ func = single_elementwise
+ # schedule
+ s = tir.Schedule(func, debug_mask="all")
+ B = s.get_block("B")
+ x, y = s.get_loops(B)
+ s.blockize(x)
+ print(s.mod['main'].script())
+ tvm.ir.assert_structural_equal(s.mod["main"],
single_elementwise_blockized1)
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_blockize_inner():
+ func = single_elementwise
+ # schedule
+ s = tir.Schedule(func, debug_mask="all")
+ B = s.get_block("B")
+ x, y = s.get_loops(B)
+ s.blockize(y)
+ tvm.ir.assert_structural_equal(s.mod["main"],
single_elementwise_blockized2)
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_two_elementwise_blockize_reverse_compute_at():
+ func = two_elementwise
+ s = tir.Schedule(func, debug_mask="all")
+ B = s.get_block("B")
+ C = s.get_block("C")
+ x, y = s.get_loops(B)
+ xo, xi = s.split(x, factors=[None, 16])
+ yo, yi = s.split(y, factors=[None, 16])
+ s.reorder(xo, yo, xi, yi)
+ s.blockize(xi)
+ s.reverse_compute_at(C, yo)
+ s.blockize(s.get_loops(C)[-2])
+ tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_two_elementwise_blockize_compute_at():
+ func = two_elementwise
+ s = tir.Schedule(func, debug_mask="all")
+ B = s.get_block("B")
+ C = s.get_block("C")
+ x, y = s.get_loops(C)
+ xo, xi = s.split(x, factors=[None, 16])
+ yo, yi = s.split(y, factors=[None, 16])
+ s.reorder(xo, yo, xi, yi)
+ s.blockize(xi)
+ s.compute_at(B, yo)
+ s.blockize(s.get_loops(B)[-2])
+ tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized)
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_blockize_init_loops():
+ s = tir.Schedule(rowsum, debug_mask="all")
+ k, _ = s.get_loops(s.get_block("B"))
+ s.blockize(k)
+ tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized)
+ verify_trace_roundtrip(sch=s, mod=rowsum)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py
b/tests/python/unittest/test_tir_schedule_tensorize.py
new file mode 100644
index 0000000..401a39f
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -0,0 +1,431 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# fmt: off
+# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
[email protected]_func
+def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
+ T.writes(C[0 : 16, 0 : 16])
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("update"):
+ vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+ C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
+
+
[email protected]_func
+def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
+ T.writes(C[0 : 16, 0 : 16])
+ T.evaluate(
+ T.tvm_mma_sync(
+ C.data,
+ C.elem_offset // 256,
+ A.data,
+ A.elem_offset // 256,
+ B.data,
+ B.elem_offset // 256,
+ C.data,
+ C.elem_offset // 256,
+ dtype="handle",
+ )
+ )
+
+
[email protected]_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,))
+ B = T.match_buffer(b, (4,))
+ C = T.match_buffer(c, ())
+
+ with T.block("root"):
+ T.reads(C[()], A[0 : 4], B[0 : 4])
+ T.writes(C[()])
+ for i in range(0, 4):
+ with T.block("update"):
+ vi = T.axis.remap("R", [i])
+ C[()] = C[()] + A[vi] * B[vi]
+
+
[email protected]_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), offset_factor=1)
+ B = T.match_buffer(b, (4,), offset_factor=1)
+ C = T.match_buffer(c, (), offset_factor=1)
+
+ with T.block("root"):
+ T.reads(C[()], A[0 : 4], B[0 : 4])
+ T.writes(C[()])
+ T.evaluate(
+ T.call_extern(
+ "vec4add",
+ C.data,
+ C.elem_offset,
+ A.data,
+ A.elem_offset,
+ B.data,
+ B.elem_offset,
+ dtype="int32",
+ )
+ )
+
+
[email protected]_func
+def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 1), offset_factor=1)
+ B = T.match_buffer(b, (16, 1), offset_factor=1)
+ C = T.match_buffer(c, (16, 16), offset_factor=1)
+
+ with T.block("root"):
+ T.reads(
+ C[0 : 16, 0 : 16],
+ A[0 : 16, 0 : 1],
+ B[0 : 16, 0 : 1],
+ )
+ T.writes(C[0 : 16, 0 : 16])
+ for i, j in T.grid(16, 16):
+ with T.block("update"):
+ vii, vjj = T.axis.remap("SS", [i, j])
+ C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0]
+
+
[email protected]_func
+def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 1), offset_factor=1)
+ B = T.match_buffer(b, (16, 1), offset_factor=1)
+ C = T.match_buffer(c, (16, 16), offset_factor=1)
+
+ with T.block("root"):
+ T.reads(
+ C[0 : 16, 0 : 16],
+ A[0 : 16, 0 : 1],
+ B[0 : 16, 0 : 1],
+ )
+ T.writes(C[0 : 16, 0 : 16])
+ T.evaluate(
+ T.call_extern(
+ "outer_product",
+ C.data,
+ C.elem_offset,
+ A.data,
+ A.elem_offset,
+ B.data,
+ B.elem_offset,
+ dtype="int32",
+ )
+ )
+
+
[email protected]_func
+def matmul(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+) -> None:
+ for i, j, k in T.grid(128, 128, 128):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]_func
+def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+ C = T.match_buffer(c, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ B = T.match_buffer(b, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+ A = T.match_buffer(a, [128, 128], elem_offset=0, align=128,
offset_factor=1)
+
+ for i_outer, j_outer in T.grid(8, 8):
+ for i_inner_init, j_inner_init in T.grid(16, 16):
+ with T.block("init"):
+ vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init))
+ vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init))
+ C[vi_init, vj_init] = T.float32(0)
+ for k_outer in T.grid(8):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
+ T.reads(
+ [
+ C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+ A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+ B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+ ]
+ )
+ T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
+ A_elem_offset = T.var("int32")
+ B_elem_offset = T.var("int32")
+ C_elem_offset = T.var("int32")
+ A_sub = T.match_buffer(
+ A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
+ [16, 16],
+ elem_offset=A_elem_offset,
+ )
+ B_sub = T.match_buffer(
+ B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
+ [16, 16],
+ elem_offset=B_elem_offset,
+ )
+ C_sub = T.match_buffer(
+ C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
+ [16, 16],
+ elem_offset=C_elem_offset,
+ )
+ T.evaluate(
+ T.tvm_mma_sync(
+ C_sub.data,
+ T.floordiv(C_sub.elem_offset, 256),
+ A_sub.data,
+ T.floordiv(A_sub.elem_offset, 256),
+ B_sub.data,
+ T.floordiv(B_sub.elem_offset, 256),
+ C_sub.data,
+ T.floordiv(C_sub.elem_offset, 256),
+ dtype="handle",
+ )
+ )
+
+
[email protected]_func
+def batch_matmul(
+ A: T.Buffer[(16, 128, 128), "float32"],
+ B: T.Buffer[(16, 128, 128), "float32"],
+ C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+ for n, i, j in T.grid(16, 128, 128):
+ with T.block("init"):
+ vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+ C[vn, vi, vj] = T.float32(0)
+
+ for n, i, j, k in T.grid(16, 128, 128, 128):
+ with T.block("update"):
+ vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
+ C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
+
+
[email protected]_func
+def tensorized_batch_matmul_mma(
+ A: T.Buffer[(16, 128, 128), "float32"],
+ B: T.Buffer[(16, 128, 128), "float32"],
+ C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+ for n, i, j in T.grid(16, 128, 128):
+ with T.block("init"):
+ vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+ T.reads()
+ T.writes(C[vn, vi, vj])
+ C[vn, vi, vj] = T.float32(0)
+ for n in range(0, 16):
+ for i, j, k in T.grid(8, 8, 8):
+ with T.block("update"):
+ vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
+ T.reads(
+ C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 +
16],
+ A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 +
16],
+ B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 +
16],
+ )
+ T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj *
16 + 16])
+ A_elem_offset = T.var("int32")
+ B_elem_offset = T.var("int32")
+ C_elem_offset = T.var("int32")
+ A_sub = T.match_buffer(
+ A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 +
16],
+ (16, 16),
+ elem_offset=A_elem_offset,
+ )
+ B_sub = T.match_buffer(
+ B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 +
16],
+ (16, 16),
+ elem_offset=B_elem_offset,
+ )
+ C_sub = T.match_buffer(
+ C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 +
16],
+ (16, 16),
+ elem_offset=C_elem_offset,
+ )
+ T.evaluate(
+ T.tvm_mma_sync(
+ C_sub.data,
+ T.floordiv(C_sub.elem_offset, 256),
+ A_sub.data,
+ T.floordiv(A_sub.elem_offset, 256),
+ B_sub.data,
+ T.floordiv(B_sub.elem_offset, 256),
+ C_sub.data,
+ T.floordiv(C_sub.elem_offset, 256),
+ dtype="handle",
+ )
+ )
+
+
[email protected]_func
+def tensorized_batch_matmul_dot_product(
+ A: T.Buffer[(16, 128, 128), "float32"],
+ B: T.Buffer[(16, 128, 128), "float32"],
+ C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+ for n, i, j in T.grid(16, 128, 128):
+ with T.block("init"):
+ vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+ T.reads()
+ T.writes(C[vn, vi, vj])
+ C[vn, vi, vj] = T.float32(0)
+ for n, i, j, k_0 in T.grid(16, 128, 128, 32):
+ with T.block("blockized_update"):
+ vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0])
+ T.reads(
+ C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko
* 4 : vko * 4 + 4]
+ )
+ T.writes(C[vn, vi, vj])
+ A_1 = T.match_buffer(
+ A[vn, vi, vko * 4 : vko * 4 + 4], [4], dtype="float32",
offset_factor=1
+ )
+ B_1 = T.match_buffer(
+ B[vn, vj, vko * 4 : vko * 4 + 4], [4], dtype="float32",
offset_factor=1
+ )
+ C_1 = T.match_buffer(C[vn, vi, vj], [], dtype="float32",
offset_factor=1)
+ T.evaluate(
+ T.call_extern(
+ "vec4add",
+ C_1.data,
+ C_1.elem_offset,
+ A_1.data,
+ A_1.elem_offset,
+ B_1.data,
+ B_1.elem_offset,
+ dtype="int32",
+ )
+ )
+
+
[email protected]_func
+def tensorized_batch_matmul_outer_product(
+ A: T.Buffer[(16, 128, 128), "float32"],
+ B: T.Buffer[(16, 128, 128), "float32"],
+ C: T.Buffer[(16, 128, 128), "float32"],
+) -> None:
+ for n, i, j in T.grid(16, 128, 128):
+ with T.block("init"):
+ vn, vi, vj = T.axis.remap("SSS", [n, i, j])
+ T.reads()
+ T.writes(C[vn, vi, vj])
+ C[vn, vi, vj] = T.float32(0)
+ for n, i_0, j_0, k in T.grid(16, 8, 8, 128):
+ with T.block("blockized_update"):
+ vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k])
+ T.reads(
+ C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
+ A[vn, vio * 16 : vio * 16 + 16, vk],
+ B[vn, vjo * 16 : vjo * 16 + 16, vk],
+ )
+ T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
+ A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, vk], [16, 1],
dtype="float32", offset_factor=1)
+ B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, vk], [16, 1],
dtype="float32", offset_factor=1
+ )
+ C_1 = T.match_buffer(
+ C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
[16, 16], dtype="float32", offset_factor=1
+ )
+ T.evaluate(
+ T.call_extern("outer_product", C_1.data, C_1.elem_offset,
A_1.data, A_1.elem_offset,
+ B_1.data, B_1.elem_offset, dtype="int32"
+ )
+ )
+
+
+# fmt: off
+# pylint:
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
+tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc,
dot_product_intrin)
+tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc,
outer_product_intrin)
+
+
+def test_tensorize_matmul():
+ func = matmul
+ # schedule
+ s = tir.Schedule(func, debug_mask="all")
+ update = s.get_block("update")
+ i, j, k = s.get_loops(update)
+ io, ii = s.split(i, factors=[None, 16])
+ jo, ji = s.split(j, factors=[None, 16])
+ ko, ki = s.split(k, factors=[None, 16])
+ s.reorder(io, jo, ko, ii, ji, ki)
+ s.decompose_reduction(update, ko)
+ s.tensorize(ii, "test_mma_intrin")
+ tvm.ir.assert_structural_equal(tensorized_matmul, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_tensorize_batch_matmul():
+ func = batch_matmul
+ s = tir.Schedule(func, debug_mask="all")
+ update = s.get_block("update")
+ _, i, j, k = s.get_loops(update)
+ io, ii = s.split(i, factors=[None, 16])
+ jo, ji = s.split(j, factors=[None, 16])
+ ko, ki = s.split(k, factors=[None, 16])
+ s.reorder(io, jo, ko, ii, ji, ki)
+ s.tensorize(ii, "test_mma_intrin")
+ tvm.ir.assert_structural_equal(tensorized_batch_matmul_mma, s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=batch_matmul)
+
+
+def test_tensorize_dot_product():
+ func = batch_matmul
+ s = tir.Schedule(func, debug_mask="all")
+ C = s.get_block("update")
+ _, _, _, k = s.get_loops(C)
+ _, ki = s.split(k, factors=[None, 4])
+ s.tensorize(ki, "test_dot_product_intrin")
+ tvm.ir.assert_structural_equal(tensorized_batch_matmul_dot_product,
s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+def test_tensorize_outer_product():
+ func = batch_matmul
+ s = tir.Schedule(func, debug_mask="all")
+ C = s.get_block("update")
+ _, i, j, k = s.get_loops(C)
+ io, ii = s.split(i, factors=[None, 16])
+ jo, ji = s.split(j, factors=[None, 16])
+ s.reorder(io, jo, k, ii, ji)
+ s.tensorize(ii, "test_outer_product_intrin")
+ tvm.ir.assert_structural_equal(tensorized_batch_matmul_outer_product,
s.mod["main"])
+ verify_trace_roundtrip(sch=s, mod=func)
+
+
+if __name__ == "__main__":
+ sys.exit(pytest.main([__file__] + sys.argv[1:]))