This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6704175fc7 Pass to eliminate redundant branch and overcompute (#17170)
6704175fc7 is described below

commit 6704175fc7d427bded07e7348c230c58bd9ef75f
Author: sdalvi-quic <[email protected]>
AuthorDate: Wed Jul 24 23:26:47 2024 -0500

    Pass to eliminate redundant branch and overcompute (#17170)
    
    * Implementation to eliminate redundant branch introduced due to operator 
padding and overcompute, this creates more opportunities to vectorize the code
    
    * Fixed lint error in transform.py file
    
    * Fixed lint errors in the file using_assume_to_reduce_branches.cc
    
    * Fixed lint error in transform.py related to line too long
    
    * Fixed Lint error related to space and length of the sentence in 
using_assume_to_reduce_branches.cc
    
    * Fixed lint error : trailing whitespaces in 
using_assume_to_reduce_breanches.cc
    
    * Fixed lint error: clang format issue in cpp files
    
    * fixed pylint errors in python files and used clang format to format the 
cpp files
    
    * Ran black format and removed the attr_registry_map.h import as it was 
running into some other issue because of which build was failing
---
 include/tvm/tir/transform.h                        |   8 +
 python/tvm/tir/transform/transform.py              |  13 +
 .../transforms/using_assume_to_reduce_branches.cc  | 394 +++++++++++++
 ...eliminate_pad_branch_using_buffer_assumption.py | 648 +++++++++++++++++++++
 4 files changed, 1063 insertions(+)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 98edbeaceb..a8d93bf898 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -834,6 +834,14 @@ TVM_DLL Pass InstrumentProfileIntrinsics();
  */
 TVM_DLL Pass DefaultGPUSchedule();
 
+/*!
+ * \brief This pass analyzes primfunc & eliminates branch introdued due to 
layout specific padding.
+ *  It leverages from the buffer assumptions and use the information to 
eliminate the branch.
+ * \note This creates more opportunity to vectorize the code.
+ * \return The Pass.
+ */
+TVM_DLL Pass UseAssumeToReduceBranches();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index c2022b9186..d8531401d4 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1199,3 +1199,16 @@ def DefaultGPUSchedule():
     ret: tvm.transform.Pass
     """
     return _ffi_api.DefaultGPUSchedule()  # type: ignore
+
+
+def UseAssumeToReduceBranches():
+    """This pass attempts to eliminates layout specific pad branch by 
overcomputing the values
+    for padded region. Eliminating the branch will help to vectorize code,
+    and improve element wise ops performance.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.UseAssumeToReduceBranches()  # type: ignore
diff --git a/src/tir/transforms/using_assume_to_reduce_branches.cc 
b/src/tir/transforms/using_assume_to_reduce_branches.cc
new file mode 100644
index 0000000000..2e45bb0ff8
--- /dev/null
+++ b/src/tir/transforms/using_assume_to_reduce_branches.cc
@@ -0,0 +1,394 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file using_assume_to_reduce_branches.cc
+ *
+ * \brief Attempt to remove conditional branch statements by introducing
+ * extra computations that do not impact the final results. Mainly
+ * oriented for layout specific padding related branches.
+ *
+ * \note
+ *    1. This pass works if the buffer assumption variable is in the branch 
statement.
+ *       In case, the buffer assumption is not present in the branch statement 
and
+ *       there are intermediate buffers then, inline the code.
+ *    2. The assumptions leveraged here should be of the form 
T.assume(condition_on_indices or
+ *       buffer_equals_to_some_value)
+ *    3. Some part of the code are reused from the control_flow_graph.cc file 
which also
+ *       handles eliminating branches in particular scenarios.
+ *    4. This pass currently works for op_pattern kElemWise and kBroadcast.
+ */
+
+#include <tvm/relax/expr.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <optional>
+
+#include "../../arith/constraint_extract.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/unwrap_vector_expr.h"
+#include "simplify.h"
+#include "tvm/ir/expr.h"
+namespace tvm {
+namespace tir {
+
+using namespace arith;
+
+class AssumeChecker : public StmtExprVisitor {
+  /* This class checks if the primfunc has assume statement.
+  If yes, then only the FuncAnanlyzerMutator class runs. This is to ensure 
speedup in the pass.*/
+ public:
+  bool has_assume = false;
+
+  void VisitStmt(const Stmt& stmt) final {
+    if (has_assume) {
+      return;
+    }
+    StmtVisitor::VisitStmt(stmt);
+  }
+  void VisitExpr_(const CallNode* op) override {
+    if (op->op.same_as(builtin::assume())) {
+      has_assume = true;
+    }
+  }
+};
+
+class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer {
+  /* This class analyzes the complete primfunc.
+  It parses the buffer assumptions and eliminates the redundant branch
+  introduced due to layout specific padding by leveraging from buffer 
assumptions.
+  On eliminating the branch there are more opportunities to vectorize the code
+  and improve performance.
+
+  Example:
+  -------------
+  Prim Func Before :
+  for (...)
+    T.assume( assume_condition or A[i] == 0 )
+  for (...)
+    out = T.if_then_else(if_then_else_condition, 0, function(A))
+    # here function(A) is some function on Var A
+
+  Prim Func After :
+    for (...)
+    T.assume( assume_condition or A[i] == 0 )
+  for (...)
+    out = function(A) # here function(A) is some function on the Var A
+  --------------
+  # High-level implementation details :
+    1. The pass parses the assume statement and stores the relevant 
information.
+    2. The pass tries to evaluate the then_clause and else_clause in 
then_condition_context
+    and else_condition_context.
+    It checks if the context of the assume statement (for condition indices and
+    assume_condition) is same as the context of the if_then_else statement 
(for condition indices
+    and if_then_else condition). If context is same and the expression inside 
if_then_else statement
+    is a function of the buffer assumption (eg A in above example),
+    then the pass substitutes the value from the buffer assumption and 
simplifies the expression.
+    3. The pass then checks if then_clause and else_clause evaluate to same 
value.
+    If yes, then return the else_clause if we are in the 
then_condition_context (since then_clause
+    will be true in this context and if else_clause is also evaluating to true 
then we can directly
+    replace it with else_clause), similarly, we return the then_clause if we 
are in the
+    else_condition_context.
+  This class handles all these scenarios.*/
+
+ public:
+  using Parent = IRMutatorWithAnalyzer;
+  explicit ParseAssumeAndOvercompute(Analyzer* analyzer) : Parent(analyzer) {}
+
+ private:
+  using Parent::VisitExpr_;
+  using Parent::VisitStmt;
+  using Parent::VisitStmt_;
+
+  // This struct stores all the relevant data related to asssume statement
+  struct assume_struct {             // Consider the example : T.assume(i < 14 
or A[i] == 0)
+    PrimExpr buffer_context;         // The context of the assume statement 
(the bound on the axis)
+    PrimExpr buffer_predicate;       // The condition inside assume statement 
(i < 14) excluding
+                                     // bufferload expression (A[i] == 0)
+    tir::BufferLoad buffer_load;     // Storing the buffer load Eg: A[i] in 
A[i] == 0
+    PrimExpr buffer_value;           // Storing the value for the buffer Eg : 
0 in A[i] == 0
+    Array<PrimExpr> buffer_indices;  // Storing the indices of the buffer Eg : 
i
+  };
+  // List of conditions in a scope
+  std::vector<PrimExpr> conditions_;
+
+  // Storing all the buffer assumptions data in map
+  std::map<tir::Buffer, assume_struct> map_buffer_assumption;
+  tir::Buffer current_bufferstorenode_name;
+
+  struct InternalConstraintContext {
+    /* This stuct appends the constraint passed to it in the conditions list.
+    It keeps track of the bounds of the variables along with any conditions on 
the variables */
+    InternalConstraintContext(ParseAssumeAndOvercompute* self, PrimExpr 
constraint)
+        : self(self), analyzer_context(self->analyzer_, constraint) {
+      old_num_constraints = self->conditions_.size();
+
+      auto side_effect = tir::SideEffect(constraint);
+      if (side_effect <= tir::CallEffectKind::kPure) {
+        self->conditions_.push_back(constraint);
+      } else if (side_effect <= tir::CallEffectKind::kReadState) {
+        assume = constraint;
+      }
+
+      new_num_constraints = self->conditions_.size();
+    }
+
+    ~InternalConstraintContext() {
+      ICHECK_EQ(self->conditions_.size(), new_num_constraints)
+          << "Internal error: Each condition should only be popped once.";
+      self->conditions_.erase(self->conditions_.begin() + old_num_constraints,
+                              self->conditions_.end());
+    }
+
+    ParseAssumeAndOvercompute* self{nullptr};
+    With<arith::ConstraintContext> analyzer_context;
+    size_t old_num_constraints{0};
+    size_t new_num_constraints{0};
+    Optional<PrimExpr> assume{NullOpt};
+
+    // Disable default-generated copy/move assignment and constructors
+    InternalConstraintContext(const InternalConstraintContext&) = delete;
+    InternalConstraintContext& operator=(const InternalConstraintContext&) = 
delete;
+    InternalConstraintContext(InternalConstraintContext&&) = delete;
+    InternalConstraintContext& operator=(InternalConstraintContext&&) = delete;
+  };
+
+  PrimExpr CurrentScopePredicate() const {
+    /* This combines all the constraints in a scope */
+    PrimExpr predicate = Bool(true);
+    for (const auto& condition : conditions_) {
+      predicate = predicate && condition;
+    }
+    return predicate;
+  }
+
+  Stmt VisitStmt_(const ForNode* op) final {
+    /* Create and delete the scope with bind.
+    Add the minimum and maximum bound for the variables to the conditions_ 
list using
+    InternalConstraintContext */
+    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
+    InternalConstraintContext ctx1(this, op->loop_var >= op->min);
+    InternalConstraintContext ctx2(this, op->loop_var < op->min + op->extent);
+    return Parent::VisitStmt_(op);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) override {
+    if (map_buffer_assumption.find(op->buffer) != map_buffer_assumption.end()) 
{
+      PrimExpr buf_value;
+      /* If the cuurent context where the buffer load is present is same as
+      the context of the buffer assumption then, return the buffer value 
present in the assumption.
+      This will eventually replace the bufferload value in the complete 
expresison */
+
+      auto buffer_assumption = map_buffer_assumption[op->buffer];
+      PrimExpr current_predicate_and_context = CurrentScopePredicate();
+      PrimExpr buffer_predicate_and_context =
+          buffer_assumption.buffer_context && 
buffer_assumption.buffer_predicate;
+      bool current_context_and_buffer_constraint_is_same = StructuralEqual()(
+          current_predicate_and_context, buffer_predicate_and_context, 
/*map_free_vars=*/true);
+
+      if (current_context_and_buffer_constraint_is_same) {
+        buf_value = buffer_assumption.buffer_value;
+        return buf_value;
+      }
+    }
+    return GetRef<PrimExpr>(op);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
+
+    // Eliminate the builtin if_then_else statement
+    if (auto* call = op->value.as<CallNode>()) {
+      if (call->op.same_as(builtin::if_then_else())) {
+        PrimExpr cond = call->args[0];
+        PrimExpr then_clause = call->args[1];
+        PrimExpr else_clause = call->args[2];
+
+        PrimExpr then_clause_in_then_context;
+        PrimExpr else_clause_in_then_context;
+        PrimExpr then_clause_in_else_context;
+        PrimExpr else_clause_in_else_context;
+        {
+          // Simplifying expressions in " then context "
+          InternalConstraintContext then_ctx(this, cond);
+          // This will call the current class's appropriate VisitStmt function
+          then_clause_in_then_context = (*this)(then_clause);
+          then_clause_in_then_context = 
analyzer_->Simplify(then_clause_in_then_context);
+
+          else_clause_in_then_context = (*this)(else_clause);
+          else_clause_in_then_context = 
analyzer_->Simplify(else_clause_in_then_context);
+        }
+        {
+          // Simplifying expressions in " else context "
+          InternalConstraintContext else_ctx(this, !cond);
+          // This will call the current class's appropriate VisitStmt function
+          then_clause_in_else_context = (*this)(then_clause);
+          then_clause_in_else_context = 
analyzer_->Simplify(then_clause_in_else_context);
+
+          else_clause_in_else_context = (*this)(else_clause);
+          else_clause_in_else_context = 
analyzer_->Simplify(else_clause_in_else_context);
+        }
+
+        auto n = this->CopyOnWrite(op);
+        if (StructuralEqual()(then_clause_in_then_context, 
else_clause_in_then_context)) {
+          n->value = analyzer_->Simplify(else_clause);
+          return Stmt(n);
+        } else if (StructuralEqual()(then_clause_in_else_context, 
else_clause_in_else_context)) {
+          n->value = analyzer_->Simplify(then_clause);
+          return Stmt(n);
+        } else {
+          return Parent::VisitStmt_(op);
+        }
+      }
+    }
+    return Parent::VisitStmt_(op);
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) override {
+    if (op->op.same_as(builtin::assume())) {
+      Assume(op->args[0]);
+    }
+    return Parent::VisitExpr_(op);
+  }
+
+  void Assume(PrimExpr assumption) {
+    for (const auto& expr : arith::ExtractConstraints(assumption, false)) {
+      AssumeConstraintComponent(expr);
+    }
+  }
+
+  void AssumeConstraintComponent(PrimExpr assumption) {
+    PrimExpr additional_predicate = Bool(true);
+    assume_struct buf_data;
+
+    std::vector<PrimExpr> buffer_exprs;
+    for (const auto& expr : arith::ExtractComponents(assumption)) {
+      auto side_effect = tir::SideEffect(expr);
+      if (side_effect <= tir::CallEffectKind::kPure) {
+        // Pulling out portions of the assumption that do not depend
+        // on a buffer value allows the following two forms to be
+        // treated identically.
+        //
+        // Option 1: if i < 3: T.assume(buf[i] == value)
+        // Option 2: T.assume(i>=3 or buf[i] == value)
+        additional_predicate = additional_predicate && logical_not(expr);
+      } else if (side_effect == tir::CallEffectKind::kReadState) {
+        buffer_exprs.push_back(expr);
+      } else {
+        LOG(FATAL) << "Assumption must be pure or read-only, but contained 
expression " << expr
+                   << " with side-effect \'" << side_effect << "\'";
+      }
+    }
+
+    additional_predicate = 
analyzer_->Simplify(std::move(additional_predicate));
+    CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single 
buffer expression";
+
+    auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>();
+    CHECK(as_equal_node) << "T.assume buffer constraint must be of the form 
'buffer[indices] == "
+                            "value', but received "
+                         << assumption;
+    if (!as_equal_node) {
+      // This assumption is an inequality on a data-dependent
+      // conditional.  Not an error for this to occur, but also not
+      // something that is currently supported.
+      return;
+    }
+
+    // Parse the statement and store the desired values
+    // Ex: A[i]==0, load = A[i], value = 0
+    tir::BufferLoad load;
+    PrimExpr value;
+    if (auto opt = as_equal_node->a.as<tir::BufferLoad>()) {
+      load = opt.value();
+      value = as_equal_node->b;
+    } else if (auto opt = as_equal_node->b.as<tir::BufferLoad>()) {
+      load = opt.value();
+      value = as_equal_node->a;
+    } else {
+      LOG(FATAL) << "T.assume buffer constraint must be of the form 
'buffer[indices] == value'";
+    }
+
+    // Populating the assume statement predicate, buffer, value
+    // and the context of the assume statement
+    buf_data.buffer_context = CurrentScopePredicate();
+    buf_data.buffer_predicate = additional_predicate;
+    buf_data.buffer_load = load;
+    buf_data.buffer_value = value;
+    buf_data.buffer_indices = load->indices;
+    for (size_t i = 0; i < load->indices.size(); i++) {
+      buf_data.buffer_indices.push_back(analyzer_->Simplify(load->indices[i]));
+    }
+    map_buffer_assumption[buf_data.buffer_load->buffer] = buf_data;
+
+    auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure;
+    CHECK(!has_side_effect) << "Buffer value in constraint must be pure 
expression, but was "
+                            << value;
+    if (has_side_effect) {
+      return;
+    }
+  }
+};
+
+namespace transform {
+
+Pass UseAssumeToReduceBranches() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    arith::Analyzer analyzer;
+
+    // The pass runs & eliminates pad branch with overcompute only if,
+    // the primfunc has op_pattern defined and is an elementwise op.
+    // AnnotateTIROpPattern pass will set op_pattern in op attributes of the 
primfunc.
+    if (n->attrs.GetAttr<Integer>("op_pattern").defined()) {
+      Optional<Integer> opt_pattern = f->GetAttr<Integer>("op_pattern");
+      if (opt_pattern.defined()) {
+        relay::OpPatternKind pattern;
+        pattern = 
static_cast<relay::OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
+
+        if (pattern == relay::OpPatternKind::kElemWise ||
+            pattern == relay::OpPatternKind::kBroadcast) {
+          // If the primfunc contains assume statement then, run the mutator 
pass.
+          AssumeChecker assume_checker;
+          assume_checker(std::move(n->body));
+
+          if (assume_checker.has_assume) {
+            // Leverage from assume and eliminate the branch
+            ParseAssumeAndOvercompute func_analyzer_mutator(&analyzer);
+            n->body = func_analyzer_mutator(std::move(n->body));
+          }
+        }
+      }
+    }
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.UseAssumeToReduceBranches", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.UseAssumeToReduceBranches")
+    .set_body_typed(UseAssumeToReduceBranches);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git 
a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py 
b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py
new file mode 100644
index 0000000000..b8ff2b6c79
--- /dev/null
+++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py
@@ -0,0 +1,648 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, unused-variable
+
+# The test attempts to eliminate redundant pad branch and overcompute the 
value for elementwise ops.
+# This helps to expose more opportunities to vectorize the code.
+
+import tvm
+import tvm.testing
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
[email protected]_module
+class AddBefore:
+    @T.prim_func(private=True)
+    def add(
+        a: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        b: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        compute: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+    ):
+        T.func_attr(
+            {
+                "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 
0},
+                "op_pattern": 0,
+                "operator_name": "add",
+                "tir.noalias": T.bool(True),
+            }
+        )
+        # with T.block("root"):
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_A_assumptions"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    not (
+                        v_axis1 == T.int64(3)
+                        and T.int64(4) <= v_axis4
+                        or v_axis2 == T.int64(3)
+                        and T.int64(4) <= v_axis5
+                    )
+                    or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_B_assumptions"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    not (
+                        v_axis1 == T.int64(3)
+                        and T.int64(4) <= v_axis4
+                        or v_axis2 == T.int64(3)
+                        and T.int64(4) <= v_axis5
+                    )
+                    or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("compute"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(
+                    a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                    b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                )
+                T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                compute[
+                    v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6
+                ] = T.if_then_else(
+                    v_axis1 == T.int64(3)
+                    and T.int64(4) <= v_axis4
+                    or v_axis2 == T.int64(3)
+                    and T.int64(4) <= v_axis5,
+                    T.uint8(0),
+                    a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    + b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                )
+
+    @R.function
+    def main(
+        a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+        b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+    ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"):
+        out = R.call_tir(
+            AddBefore.add,
+            (a, b),
+            out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"),
+        )
+        return out
+
+
[email protected]_module
+class AddExpected:
+    @T.prim_func(private=True)
+    def add(
+        a: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        b: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        compute: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+    ):
+        T.func_attr(
+            {
+                "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 
0},
+                "op_pattern": 0,
+                "operator_name": "add",
+                "tir.noalias": T.bool(True),
+            }
+        )
+        # with T.block("root"):
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_A_assumptions"):
+                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = 
T.axis.remap(
+                    "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6]
+                )
+                T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    (v_axis1 < T.int64(3) or v_axis4 < T.int64(4))
+                    and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4))
+                    or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_B_assumptions"):
+                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = 
T.axis.remap(
+                    "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6]
+                )
+                T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    (v_axis1 < T.int64(3) or v_axis4 < T.int64(4))
+                    and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4))
+                    or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(2)
+        ):
+            for axis5_1_axis6_fused in T.vectorized(T.int64(128)):
+                with T.block("compute"):
+                    v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                    v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap(
+                        "SSSS", [axis1, axis2, axis3, axis4]
+                    )
+                    v_axis5 = T.axis.spatial(
+                        T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused 
// T.int64(32)
+                    )
+                    v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused 
% T.int64(32))
+                    T.reads(
+                        a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6],
+                        b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6],
+                    )
+                    T.writes(
+                        compute[T.int64(0), v_axis1, v_axis2, v_axis3, 
v_axis4, v_axis5, v_axis6]
+                    )
+                    compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6] = (
+                        a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                        + b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    )
+
+    @R.function
+    def main(
+        a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+        b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+    ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"):
+        out = R.call_tir(
+            AddExpected.add,
+            (a, b),
+            out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"),
+        )
+        return out
+
+
[email protected]_module
+class SubBefore:
+    @T.prim_func(private=True)
+    def sub(
+        a: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        b: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        compute: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+    ):
+        T.func_attr(
+            {
+                "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", 
"rhs_axis": 0},
+                "op_pattern": 0,
+                "operator_name": "sub",
+                "tir.noalias": T.bool(True),
+            }
+        )
+        # with T.block("root"):
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_A_assumptions"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    not (
+                        v_axis1 == T.int64(3)
+                        and T.int64(4) <= v_axis4
+                        or v_axis2 == T.int64(3)
+                        and T.int64(4) <= v_axis5
+                    )
+                    or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_B_assumptions"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    not (
+                        v_axis1 == T.int64(3)
+                        and T.int64(4) <= v_axis4
+                        or v_axis2 == T.int64(3)
+                        and T.int64(4) <= v_axis5
+                    )
+                    or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("compute"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(
+                    a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                    b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                )
+                T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                compute[
+                    v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6
+                ] = T.if_then_else(
+                    v_axis1 == T.int64(3)
+                    and T.int64(4) <= v_axis4
+                    or v_axis2 == T.int64(3)
+                    and T.int64(4) <= v_axis5,
+                    T.uint8(0),
+                    a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    - b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                )
+
+    @R.function
+    def main(
+        a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+        b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+    ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"):
+        out = R.call_tir(
+            SubBefore.sub,
+            (a, b),
+            out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"),
+        )
+        return out
+
+
[email protected]_module
+class SubExpected:
+    @T.prim_func(private=True)
+    def sub(
+        a: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        b: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        compute: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+    ):
+        T.func_attr(
+            {
+                "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", 
"rhs_axis": 0},
+                "op_pattern": 0,
+                "operator_name": "sub",
+                "tir.noalias": T.bool(True),
+            }
+        )
+        # with T.block("root"):
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_A_assumptions"):
+                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = 
T.axis.remap(
+                    "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6]
+                )
+                T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    (v_axis1 < T.int64(3) or v_axis4 < T.int64(4))
+                    and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4))
+                    or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_B_assumptions"):
+                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = 
T.axis.remap(
+                    "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6]
+                )
+                T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    (v_axis1 < T.int64(3) or v_axis4 < T.int64(4))
+                    and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4))
+                    or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(2)
+        ):
+            for axis5_1_axis6_fused in T.vectorized(T.int64(128)):
+                with T.block("compute"):
+                    v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                    v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap(
+                        "SSSS", [axis1, axis2, axis3, axis4]
+                    )
+                    v_axis5 = T.axis.spatial(
+                        T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused 
// T.int64(32)
+                    )
+                    v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused 
% T.int64(32))
+                    T.reads(
+                        a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6],
+                        b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6],
+                    )
+                    T.writes(
+                        compute[T.int64(0), v_axis1, v_axis2, v_axis3, 
v_axis4, v_axis5, v_axis6]
+                    )
+                    compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6] = (
+                        a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                        - b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    )
+
+    @R.function
+    def main(
+        a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+        b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+    ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"):
+        out = R.call_tir(
+            SubExpected.sub,
+            (a, b),
+            out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"),
+        )
+        return out
+
+
[email protected]_module
+class MulBefore:
+    @T.prim_func(private=True)
+    def mul(
+        a: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        b: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        compute: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+    ):
+        T.func_attr(
+            {
+                "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 
0},
+                "op_pattern": 0,
+                "operator_name": "mul",
+                "tir.noalias": T.bool(True),
+            }
+        )
+        # with T.block("root"):
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_A_assumptions"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    not (
+                        v_axis1 == T.int64(3)
+                        and T.int64(4) <= v_axis4
+                        or v_axis2 == T.int64(3)
+                        and T.int64(4) <= v_axis5
+                    )
+                    or a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_B_assumptions"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    not (
+                        v_axis1 == T.int64(3)
+                        and T.int64(4) <= v_axis4
+                        or v_axis2 == T.int64(3)
+                        and T.int64(4) <= v_axis5
+                    )
+                    or b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("compute"):
+                v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 
= T.axis.remap(
+                    "SSSSSSS", [axis0, axis1, axis2, axis3, axis4, axis5, 
axis6]
+                )
+                T.reads(
+                    a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                    b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                )
+                T.writes(compute[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                compute[
+                    v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6
+                ] = T.if_then_else(
+                    v_axis1 == T.int64(3)
+                    and T.int64(4) <= v_axis4
+                    or v_axis2 == T.int64(3)
+                    and T.int64(4) <= v_axis5,
+                    T.uint8(0),
+                    a[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6]
+                    * b[v_axis0, v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, 
v_axis6],
+                )
+
+    @R.function
+    def main(
+        a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+        b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+    ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"):
+        out = R.call_tir(
+            MulBefore.mul,
+            (a, b),
+            out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"),
+        )
+        return out
+
+
[email protected]_module
+class MulExpected:
+    @T.prim_func(private=True)
+    def mul(
+        a: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        b: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+        compute: T.Buffer(
+            (T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)),
+            "uint8",
+        ),
+    ):
+        T.func_attr(
+            {
+                "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 
0},
+                "op_pattern": 0,
+                "operator_name": "mul",
+                "tir.noalias": T.bool(True),
+            }
+        )
+        # with T.block("root"):
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_A_assumptions"):
+                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = 
T.axis.remap(
+                    "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6]
+                )
+                T.reads(a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    (v_axis1 < T.int64(3) or v_axis4 < T.int64(4))
+                    and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4))
+                    or a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5, axis6 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(8), T.int64(32)
+        ):
+            with T.block("buffer_B_assumptions"):
+                v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                v_axis1, v_axis2, v_axis3, v_axis4, v_axis5, v_axis6 = 
T.axis.remap(
+                    "SSSSSS", [axis1, axis2, axis3, axis4, axis5, axis6]
+                )
+                T.reads(b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6])
+                T.writes()
+                T.assume(
+                    (v_axis1 < T.int64(3) or v_axis4 < T.int64(4))
+                    and (v_axis2 < T.int64(3) or v_axis5 < T.int64(4))
+                    or b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    == T.uint8(0)
+                )
+
+        for axis0, axis1, axis2, axis3, axis4, axis5_0 in T.grid(
+            T.int64(1), T.int64(4), T.int64(4), T.int64(16), T.int64(8), 
T.int64(2)
+        ):
+            for axis5_1_axis6_fused in T.vectorized(T.int64(128)):
+                with T.block("compute"):
+                    v_axis0 = T.axis.spatial(T.int64(1), T.int64(0))
+                    v_axis1, v_axis2, v_axis3, v_axis4 = T.axis.remap(
+                        "SSSS", [axis1, axis2, axis3, axis4]
+                    )
+                    v_axis5 = T.axis.spatial(
+                        T.int64(8), axis5_0 * T.int64(4) + axis5_1_axis6_fused 
// T.int64(32)
+                    )
+                    v_axis6 = T.axis.spatial(T.int64(32), axis5_1_axis6_fused 
% T.int64(32))
+                    T.reads(
+                        a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6],
+                        b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6],
+                    )
+                    T.writes(
+                        compute[T.int64(0), v_axis1, v_axis2, v_axis3, 
v_axis4, v_axis5, v_axis6]
+                    )
+                    compute[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6] = (
+                        a[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                        * b[T.int64(0), v_axis1, v_axis2, v_axis3, v_axis4, 
v_axis5, v_axis6]
+                    )
+
+    @R.function
+    def main(
+        a: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+        b: R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"),
+    ) -> R.Tensor((1, 4, 4, 16, 8, 8, 32), "uint8"):
+        out = R.call_tir(
+            MulExpected.mul,
+            (a, b),
+            out_sinfo=R.Tensor((1, 4, 4, 16, 8, 8, 32), dtype="uint8"),
+        )
+        return out
+
+
+def test_add_primfunc_overcompute():
+    add_after = tvm.tir.transform.UseAssumeToReduceBranches()(AddBefore)
+    tvm.ir.structural_equal(add_after["add"], AddExpected["add"], 
map_free_vars=True)
+
+
+def test_sub_primfunc_overcompute():
+    sub_after = tvm.tir.transform.UseAssumeToReduceBranches()(SubBefore)
+    tvm.ir.structural_equal(sub_after["sub"], SubExpected["sub"], 
map_free_vars=True)
+
+
+def test_mul_primfunc_overcompute():
+    mul_after = tvm.tir.transform.UseAssumeToReduceBranches()(MulBefore)
+    tvm.ir.structural_equal(mul_after["mul"], MulExpected["mul"], 
map_free_vars=True)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to