This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit c342da80525124a18b231ac9fb1c04068e2cbac0
Author: Tristan Konolige <[email protected]>
AuthorDate: Thu May 18 23:47:07 2023 +0000

    Fix RewriteDataflowReshape to no rewrite when buffer shapes are not the same
    
    Before only read and written indices were compared to see if an op could
    be rewritten to a reshape. Now the buffers shape is checked for
    compatibility too.
---
 src/relax/analysis/tir_op_pattern_kind.cc | 15 +++++++++++++--
 1 file changed, 13 insertions(+), 2 deletions(-)

diff --git a/src/relax/analysis/tir_op_pattern_kind.cc 
b/src/relax/analysis/tir_op_pattern_kind.cc
index 950e6a10e0..38f411ae50 100644
--- a/src/relax/analysis/tir_op_pattern_kind.cc
+++ b/src/relax/analysis/tir_op_pattern_kind.cc
@@ -23,7 +23,18 @@
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt_functor.h>
 
+
 namespace tvm {
+namespace {
+// Helper function for flatten and reshape.
+PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) {
+  PrimExpr shape_prod = IntImm(DataType::Int(64), 1);
+  for (PrimExpr value : shape_values) {
+    shape_prod *= value;
+  }
+  return shape_prod;
+}
+}
 namespace relax {
 
 using namespace tir;
@@ -433,8 +444,8 @@ bool HasReshapePattern(const PrimFunc& func) {
       PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, 
buffer_load->indices);
       PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, 
buffer_store->indices);
 
-      // Step 4. Check if we can prove the equality of flattened indices.
-      if (ana_.CanProveEqual(src_idx, dst_idx)) {
+      // Step 4. Check if we can prove the equality of flattened indices and 
that the output buffer shape is equivalent to the input.
+      if (ana_.CanProveEqual(src_idx, dst_idx) && 
ana_.CanProveEqual(ComputeShapeProduct(src_buffer_->shape), 
ComputeShapeProduct(dst_buffer_->shape))) {
         this->is_reshape_ = true;
       }
     }

Reply via email to