This is an automated email from the ASF dual-hosted git repository. tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new in repository https://gitbox.apache.org/repos/asf/tvm.git
commit c342da80525124a18b231ac9fb1c04068e2cbac0 Author: Tristan Konolige <[email protected]> AuthorDate: Thu May 18 23:47:07 2023 +0000 Fix RewriteDataflowReshape to no rewrite when buffer shapes are not the same Before only read and written indices were compared to see if an op could be rewritten to a reshape. Now the buffers shape is checked for compatibility too. --- src/relax/analysis/tir_op_pattern_kind.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 950e6a10e0..38f411ae50 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -23,7 +23,18 @@ #include <tvm/tir/function.h> #include <tvm/tir/stmt_functor.h> + namespace tvm { +namespace { +// Helper function for flatten and reshape. +PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) { + PrimExpr shape_prod = IntImm(DataType::Int(64), 1); + for (PrimExpr value : shape_values) { + shape_prod *= value; + } + return shape_prod; +} +} namespace relax { using namespace tir; @@ -433,8 +444,8 @@ bool HasReshapePattern(const PrimFunc& func) { PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices); PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices); - // Step 4. Check if we can prove the equality of flattened indices. - if (ana_.CanProveEqual(src_idx, dst_idx)) { + // Step 4. Check if we can prove the equality of flattened indices and that the output buffer shape is equivalent to the input. + if (ana_.CanProveEqual(src_idx, dst_idx) && ana_.CanProveEqual(ComputeShapeProduct(src_buffer_->shape), ComputeShapeProduct(dst_buffer_->shape))) { this->is_reshape_ = true; } }
