This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new df429c58d8 [TIR] Allow TransformLayout with non-inversible index map
(#14095)
df429c58d8 is described below
commit df429c58d833bbe02ceafb69af1c29c7896218b7
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Mar 3 15:00:15 2023 -0800
[TIR] Allow TransformLayout with non-inversible index map (#14095)
* [TIR] Allow TransformLayout with non-inversible index map
TransformLayout requires the index map to have inverse map that can be
calculated by the analyzer in order to check whether padding is added.
However, such check doesn't always work for all cases because of
limitation of the affine analysis that can only handle a set of
supported patterns. In some cases, even if the index map doesn't
introduce padding, the schedule primitive throws
`TransformationIntroducesPaddingError` because it
fails to calculate the inverse index map.
It is safe to allow buffer being padded without providing pad_value
because the original loop extent is not changed and the padded region is
not accessed.
This PR changes the behavior of `TransformLayout` to allow
non-inversible index map.
Previous discussion:
https://discuss.tvm.apache.org/t/conflict-free-shared-memory-permutation-in-tensorir/13959/9
* add assume_injective_transform option
* Apply suggestions from code review
Co-authored-by: Siyuan Feng <[email protected]>
---------
Co-authored-by: Siyuan Feng <[email protected]>
---
include/tvm/tir/schedule/schedule.h | 9 ++++-
python/tvm/tir/schedule/schedule.py | 18 ++++++++-
src/tir/schedule/concrete_schedule.cc | 5 ++-
src/tir/schedule/concrete_schedule.h | 3 +-
src/tir/schedule/primitive.h | 7 +++-
.../schedule/primitive/layout_transformation.cc | 44 +++++++++++++---------
src/tir/schedule/schedule.cc | 4 +-
src/tir/schedule/traced_schedule.cc | 9 +++--
src/tir/schedule/traced_schedule.h | 3 +-
.../unittest/test_tir_schedule_transform_layout.py | 34 ++++++++++++++++-
10 files changed, 104 insertions(+), 32 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 288601d1cc..7f2bdf6b4e 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -642,10 +642,17 @@ class ScheduleNode : public runtime::Object {
* Algebraic symplifications, branch elimination, and other
* optimizations may assume that this precondition is met, and
* may result in incorrect results being returned.
+ *
+ * \param assume_injective_transform If set to true, the schedule primitive
will assume the
+ * index_map is injective and skip checking overlapping of the mapped
indices. This can be useful
+ * for complicated index_map that the analysis does not cover. It is the
callers' responsibility
+ * to ensure the index map is injective, otherwise, the correctness of the
schedule is not
+ * guaranteed.
*/
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const
IndexMap& index_map,
- const Optional<IndexMap>& pad_value = NullOpt)
= 0;
+ const Optional<IndexMap>& pad_value = NullOpt,
+ bool assume_injective_transform = false) = 0;
/*!
* \brief Apply a transformation represented by IndexMap to block
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 04355eb16e..b63353bcb3 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2575,7 +2575,6 @@ class Schedule(Object):
buffer: Union[Tuple[str, int], int, str, Buffer],
required_buffer_type=None,
) -> Tuple[str, int, Buffer]:
-
block_obj: Block = self.get(block)
block_name = block_obj.name_hint
@@ -2645,6 +2644,8 @@ class Schedule(Object):
buffer: Union[Tuple[str, int], str, Buffer],
index_map: Union[IndexMap, Callable],
pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] =
None,
+ *,
+ assume_injective_transform: bool = False,
) -> None:
"""Apply a transformation represented by IndexMap to buffer
@@ -2711,6 +2712,13 @@ class Schedule(Object):
value to be present in the padding in terms of the
transformed index.
+ assume_injective_transform : bool
+
+ If set to true, the schedule primitive will assume the index_map
is injective and skip
+ checking overlapping of the mapped indices. This can be useful for
complicated index_map
+ that the analysis does not cover. It is the callers'
responsibility to ensure the
+ index map is injective, otherwise, the correctness of the schedule
is not guaranteed.
+
Examples
--------
Before transform_layout, in TensorIR, the IR is:
@@ -2787,7 +2795,13 @@ class Schedule(Object):
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint:
disable=no-member
- self, block, buffer_index, buffer_index_type_enum, index_map,
pad_value
+ self,
+ block,
+ buffer_index,
+ buffer_index_type_enum,
+ index_map,
+ pad_value,
+ assume_injective_transform,
)
if axis_separators:
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint:
disable=no-member
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index b6af22263e..8af39b24fd 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -800,14 +800,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV&
block_rv, const String& ann
void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int
buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map,
- const Optional<IndexMap>&
pad_value) {
+ const Optional<IndexMap>& pad_value,
+ bool assume_injective_transform) {
TVM_TIR_SCHEDULE_BEGIN();
auto f_subst = [&](const Var& var) -> Optional<PrimExpr> {
return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var));
};
auto new_index_map = Substitute(index_map, f_subst);
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index,
buffer_index_type,
- new_index_map, pad_value);
+ new_index_map, pad_value, assume_injective_transform);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 44d9e9b69c..41168fb016 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -148,7 +148,8 @@ class ConcreteScheduleNode : public ScheduleNode {
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
- const IndexMap& index_map, const Optional<IndexMap>&
pad_value) override;
+ const IndexMap& index_map, const Optional<IndexMap>&
pad_value,
+ bool assume_injective_transform = false) override;
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap&
index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index dbc4e23596..0b7a4f6280 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -501,10 +501,15 @@ TVM_DLL void Unannotate(ScheduleState self, const
StmtSRef& sref, const String&
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
* \param pad_value The value to write into padding introduced by the
transformation.
+ * \param assume_injective_transform If set to true, the schedule primitive
will assume the
+ * index_map is injective and skip checking overlapping of the mapped indices.
This can be useful
+ * for complicated index_map that the analysis does not cover. It is the
callers' responsibility
+ * to ensure the index map is injective, otherwise, the correctness of the
schedule is not
+ * guaranteed.
*/
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref,
int buffer_index,
BufferIndexType buffer_index_type, const
IndexMap& index_map,
- const Optional<IndexMap>& pad_value);
+ const Optional<IndexMap>& pad_value, bool
assume_injective_transform);
/*!
* \brief Apply a transformation represented by IndexMap to block
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index 0e993d06dc..7eaca74100 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -753,10 +753,12 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
*/
static std::pair<Stmt, Map<Block, Block>> Rewrite(
const Block& scope_stmt, const Buffer& old_buffer, const Buffer&
new_buffer,
- const IndexMap& index_map, const IndexMap& inverse, const PrimExpr&
padding_predicate,
- const Optional<IndexMap>& pad_value) {
- auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer,
new_buffer, index_map, inverse,
- padding_predicate, pad_value);
+ const IndexMap& index_map, const Optional<IndexMap>& opt_inverse,
+ const PrimExpr& padding_predicate, const Optional<IndexMap>& pad_value) {
+ auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(
+ scope_stmt, old_buffer, new_buffer,
index_map,
+ opt_inverse.value(),
padding_predicate, pad_value)
+ :
TransformLayoutPlanner::NoPaddingRequired();
arith::Analyzer analyzer;
TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan,
&analyzer);
@@ -1119,7 +1121,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map,
const Array<PrimExpr>&
void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
BufferIndexType buffer_index_type, const IndexMap&
index_map_orig,
- const Optional<IndexMap>& pad_value) {
+ const Optional<IndexMap>& pad_value, bool
assume_injective_transform) {
// Step 1: Input handling and error checking
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
Buffer old_buffer =
@@ -1147,13 +1149,17 @@ void TransformLayout(ScheduleState self, const
StmtSRef& block_sref, int buffer_
: GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
- auto [inverse, padding_predicate] = [&]() {
- Array<Range> region;
- for (const auto& dim : old_buffer->shape) {
- region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
- }
- return index_map.NonSurjectiveInverse(region);
- }();
+ Optional<IndexMap> opt_inverse = NullOpt;
+ PrimExpr padding_predicate = Bool(false);
+ if (!assume_injective_transform) {
+ std::tie(opt_inverse, padding_predicate) = [&]() {
+ Array<Range> region;
+ for (const auto& dim : old_buffer->shape) {
+ region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
+ }
+ return index_map.NonSurjectiveInverse(region);
+ }();
+ }
bool has_padding = !is_zero(padding_predicate);
if (has_padding && !pad_value.defined()) {
@@ -1168,7 +1174,7 @@ void TransformLayout(ScheduleState self, const StmtSRef&
block_sref, int buffer_
// alloc_buffers.
auto [new_stmt, block_sref_reuse] =
TransformLayoutRewriter::Rewrite(GetRef<Block>(scope_block), old_buffer,
new_buffer,
- index_map, inverse, padding_predicate,
pad_value);
+ index_map, opt_inverse,
padding_predicate, pad_value);
Block new_scope_block = Downcast<Block>(new_stmt);
// Step 4: Rewrite buffer_map of the PrimFunc if necessary.
@@ -1511,20 +1517,21 @@ struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits>
private:
static constexpr size_t kNumInputs = 2;
- static constexpr size_t kNumAttrs = 3;
+ static constexpr size_t kNumAttrs = 4;
static constexpr size_t kNumDecisions = 0;
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap
index_map,
Integer buffer_index, Integer
buffer_index_type,
- Optional<IndexMap> pad_value) {
+ Optional<IndexMap> pad_value,
+ Bool assume_injective_transform) {
return sch->TransformLayout(block_rv, buffer_index.IntValue(),
static_cast<BufferIndexType>(buffer_index_type->value), index_map,
- pad_value);
+ pad_value, assume_injective_transform.operator
bool());
}
static String UnpackedAsPython(Array<String> outputs, String block_rv,
IndexMap index_map,
Integer buffer_index, Integer
buffer_index_type,
- Optional<IndexMap> pad_value) {
+ Optional<IndexMap> pad_value, Bool
assume_injective_transform) {
PythonAPICall py("transform_layout");
py.Input("block", block_rv);
@@ -1534,6 +1541,7 @@ struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits>
py.Input("buffer", os.str());
py.Input("index_map", index_map->ToPythonString());
py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() :
"None");
+ py.Input("assume_injective_transform", assume_injective_transform.operator
bool());
return py.Str();
}
@@ -1549,6 +1557,7 @@ struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits>
} else {
attrs_record.push_back(attrs[2]);
}
+ attrs_record.push_back(attrs[3]);
return std::move(attrs_record);
}
@@ -1562,6 +1571,7 @@ struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits>
} else {
attrs.push_back(attrs_record[2]);
}
+ attrs.push_back(attrs_record[3]);
return attrs;
}
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index d008f3639c..4177d91648 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -253,10 +253,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate")
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
.set_body_typed([](Schedule self, const BlockRV& block_rv, int
buffer_index,
int buffer_index_type, const IndexMap& index_map,
- const Optional<IndexMap>& pad_value) {
+ const Optional<IndexMap>& pad_value, bool
assume_injective_transform) {
return self->TransformLayout(block_rv, buffer_index,
static_cast<BufferIndexType>(buffer_index_type), index_map,
- pad_value);
+ pad_value, assume_injective_transform);
});
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
.set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index 8852fccf88..dba34c2ca3 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -523,15 +523,18 @@ void TracedScheduleNode::Unannotate(const BlockRV&
block_rv, const String& ann_k
void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int
buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map,
- const Optional<IndexMap>& pad_value) {
+ const Optional<IndexMap>& pad_value,
+ bool assume_injective_transform) {
ConcreteScheduleNode::TransformLayout(block_rv, buffer_index,
buffer_index_type, index_map,
- pad_value);
+ pad_value, assume_injective_transform);
static const InstructionKind& kind = InstructionKind::Get("TransformLayout");
trace_->Append(
/*inst=*/Instruction(
/*kind=*/kind,
/*inputs=*/{block_rv, index_map},
- /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type),
pad_value},
+ /*attrs=*/
+ {Integer(buffer_index), Integer(buffer_index_type), pad_value,
+ Bool(assume_injective_transform)},
/*outputs=*/{}));
}
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index ee65c721ad..7bd8385555 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -107,7 +107,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
- const IndexMap& index_map, const Optional<IndexMap>&
pad_value) override;
+ const IndexMap& index_map, const Optional<IndexMap>&
pad_value,
+ bool assume_injective_transform) override;
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap&
index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index d866de33f1..c9a8f70ef7 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -477,11 +477,19 @@ class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
index_map = tvm.testing.parameter(lambda i: [i // 4, i % 4])
+ assume_injective_transform = tvm.testing.parameter(False)
+
@pytest.fixture
- def transform(self, pad_value, transformed_buffer, index_map):
+ def transform(self, pad_value, transformed_buffer, index_map,
assume_injective_transform):
def transform(mod):
sch = tir.Schedule(mod)
- sch.transform_layout("block", transformed_buffer, index_map,
pad_value=pad_value)
+ sch.transform_layout(
+ "block",
+ transformed_buffer,
+ index_map,
+ pad_value=pad_value,
+ assume_injective_transform=assume_injective_transform,
+ )
return sch.mod
return transform
@@ -578,6 +586,28 @@ class TestErrorIfPaddingForbidden(BasePaddingCompare):
expected = tvm.tir.schedule.schedule.ScheduleError
+class TestImplicitPaddingAssumeInjective(BasePaddingCompare):
+ """When pad_value is None and assume_injective_transform is set, the
buffer can be implicitly
+ padded. The padded region is not accessed because the original loop extent
is not changed.
+ """
+
+ assume_injective_transform = tvm.testing.parameter(True)
+
+ def before():
+ A = T.alloc_buffer(14, "int32")
+ for i in T.serial(14):
+ with T.block("block"):
+ vi = T.axis.remap("S", [i])
+ A[vi] = 0
+
+ def expected():
+ A = T.alloc_buffer([4, 4], "int32")
+ for i in T.serial(14):
+ with T.block("block"):
+ vi = T.axis.remap("S", [i])
+ A[vi // 4, vi % 4] = 0
+
+
class TestErrorOnWrongPaddingType(BasePaddingCompare):
"""The padding must have the same dtype as the buffer"""