This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bd7f1f8de0 [TIR] Validate tir::Buffer axis_separators on construction
(#17219)
bd7f1f8de0 is described below
commit bd7f1f8de046d598bcf15ea6d7dffc596d5119a4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 5 01:27:37 2024 -0500
[TIR] Validate tir::Buffer axis_separators on construction (#17219)
* [TIR] Validate tir::Buffer axis_separators on construction
Prior to this commit, the `axis_separators` field of a TIR buffer
wasn't validated until the `tir.FlattenBuffer` legalization pass.
Delaying the error until this point makes it difficult to determine
where it invalid `axis_separators` were initially defined.
This commit updates the `tir::Buffer` constructor to validate the
`axis_separators` field immediately, allowing these invalid values to
be caught on construction.
Closes https://github.com/apache/tvm/issues/17215
* Update metaschedule primitive to only set axis_separators of alloc
* Allow axis separators to be increasing, rather than strictly increasing
---
src/tir/ir/buffer.cc | 45 ++++++++++++++--------
.../schedule/primitive/layout_transformation.cc | 15 +++++---
tests/python/tir-base/test_tir_buffer.py | 12 ++++--
.../test_tir_schedule_set_axis_separator.py | 4 +-
4 files changed, 51 insertions(+), 25 deletions(-)
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 0256053331..b7c4eb1d42 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -334,24 +334,37 @@ inline Array<PrimExpr> BufferOffset(const BufferNode* n,
Array<PrimExpr> index,
return offsets;
}
-Buffer Buffer::GetFlattenedBuffer() const {
- auto self = operator->();
-
+static void ValidateAxisSeparators(const Array<IntImm>& axis_separators,
size_t buffer_dim) {
// These checks ensure that all output axes contain at least one
// input axis.
- for (size_t i = 0; (i + 1) < self->axis_separators.size(); i++) {
- auto sep = self->axis_separators[i]->value;
- auto next_sep = self->axis_separators[i + 1]->value;
- ICHECK_LT(sep, next_sep) << "Axis separators must be in strictly
increasing order.";
- }
- if (self->axis_separators.size()) {
- auto first_sep = self->axis_separators[0]->value;
- ICHECK_GT(first_sep, 0) << "First axis separator must be strictly greater
than 0, "
- << "so that first output axis contains at least
one input axis";
- auto last_sep = self->axis_separators[self->axis_separators.size() -
1]->value;
- ICHECK_LT(last_sep, self->shape.size())
- << "Last output axis must contain at least one input axis.";
+ for (size_t i = 0; (i + 1) < axis_separators.size(); i++) {
+ auto sep = axis_separators[i]->value;
+ auto next_sep = axis_separators[i + 1]->value;
+ CHECK_LE(sep, next_sep) << "ValueError: "
+ << "Axis separators must be in increasing order, "
+ << "but axis_separators[" << i << "] = " << sep
+ << " is greater than or equal to axis_separators["
<< (i + 1)
+ << "] = " << next_sep << ".";
+ }
+ if (axis_separators.size()) {
+ auto first_sep = axis_separators[0]->value;
+ CHECK_GE(first_sep, 0) << "ValueError: "
+ << "All axis separators must be non-negative. "
+ << "However, the axis_separators[0] = " <<
first_sep;
+ auto last_sep = axis_separators[axis_separators.size() - 1]->value;
+ CHECK_LE(last_sep, buffer_dim)
+ << "ValueError: "
+ << "All axis separators must be within the range "
+ << "0 <= sep <= buffer_dim. "
+ << "However, the last axis_separators[" << (axis_separators.size() - 1)
+ << "] = " << last_sep << " is greater than the buffer's dimensionality
of " << buffer_dim;
}
+}
+
+Buffer Buffer::GetFlattenedBuffer() const {
+ auto self = operator->();
+
+ ValidateAxisSeparators(self->axis_separators, self->shape.size());
Array<PrimExpr> output_shape;
if (self->strides.size()) {
@@ -565,6 +578,8 @@ Buffer::Buffer(Var data, DataType dtype, Array<PrimExpr>
shape, Array<PrimExpr>
ICHECK(data->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>())
<< "Variable " << data->name_hint << " does not point to a primitive.";
+ ValidateAxisSeparators(axis_separators, shape.size());
+
auto n = make_object<BufferNode>();
n->data = std::move(data);
n->dtype = dtype;
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index f1e9106a63..8b95e0dc62 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -1485,11 +1485,16 @@ class BufferAxisSeparatorMutator : private
ReplaceBufferMutator {
if (it != buffer_var_map_.end()) {
const Buffer& new_source_buffer = it->second;
Buffer new_target_buffer = match_buffer->buffer;
- new_target_buffer.CopyOnWrite()->axis_separators =
new_source_buffer->axis_separators;
- if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
- LOG(WARNING)
- << "Target buffer in match_buffer doesn't have the same
dimensionality as its source "
- "buffer. `axis_separators` for the target buffer might be
incorrect.";
+
+ if (new_target_buffer->shape.size() == new_source_buffer->shape.size()) {
+ new_target_buffer.CopyOnWrite()->axis_separators =
new_source_buffer->axis_separators;
+ } else {
+ new_target_buffer.CopyOnWrite()->axis_separators =
+ Array<IntImm>(new_source_buffer->axis_separators.size(),
IntImm(DataType::Int(32), 0));
+ LOG(WARNING) << "Buffer view " << new_target_buffer
+ << " has different dimensionality than backing buffer "
<< new_source_buffer
+ << ". The `axis_separators` for " << new_target_buffer
<< "."
+ << "`axis_separators` for the view might be incorrect.";
}
buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
return MatchBufferRegion(new_target_buffer,
diff --git a/tests/python/tir-base/test_tir_buffer.py
b/tests/python/tir-base/test_tir_buffer.py
index 1ab7662b0b..b4b773197b 100644
--- a/tests/python/tir-base/test_tir_buffer.py
+++ b/tests/python/tir-base/test_tir_buffer.py
@@ -109,9 +109,10 @@ def test_buffer_index_merge_mult_mod():
A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1))
def assert_simplified_equal(index_simplified, index_direct):
- tvm.ir.assert_structural_equal(
- index_simplified, index_direct
- ), "index_simplified=%s, index_direct=%s" % (index_simplified,
index_direct)
+ (
+ tvm.ir.assert_structural_equal(index_simplified, index_direct),
+ "index_simplified=%s, index_direct=%s" % (index_simplified,
index_direct),
+ )
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
@@ -276,5 +277,10 @@ def test_buffer_flatten_uses_axis_separators():
tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32])
+def test_invalid_axis_separators_raises_exception():
+ with pytest.raises(ValueError):
+ tvm.tir.decl_buffer([1], axis_separators=[1, 2])
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
index 76a6ade42f..788e17e771 100644
--- a/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/tir-schedule/test_tir_schedule_set_axis_separator.py
@@ -94,12 +94,12 @@ def element_wise_subregion_match_set_axis_separator(A:
T.Buffer((128, 128), "flo
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32",
offset_factor=1, axis_separators=[1])
+ B_subregion0 = T.match_buffer(B[vi, vj], [], dtype="float32",
offset_factor=1, axis_separators=[0])
B_subregion0[()] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
- B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32",
offset_factor=1, axis_separators=[1])
+ B_subregion1 = T.match_buffer(B[vi, vj], [], dtype="float32",
offset_factor=1, axis_separators=[0])
C[vi, vj] = B_subregion1[()] + T.float32(1)