This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d8678a6a9a [TIR] CSE pass : Restrict the equivalence to be decided by 
a normal form - avoids comparison of terms (#11574)
d8678a6a9a is described below

commit d8678a6a9aa7962b658efb603e27d83ea7737a02
Author: FranckQC <[email protected]>
AuthorDate: Thu Jun 9 11:32:15 2022 -0500

    [TIR] CSE pass : Restrict the equivalence to be decided by a normal form - 
avoids comparison of terms (#11574)
    
    The CSE pass had been designed for potentially allowing comparisons (and 
commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion 
of being equivalent was customizable, and no assumption was made about it**. 
That means that the implementation of the equivalence test function 
`EquivalentTerms()` - which was at the moment just calling the syntactical 
equality test `EqualTerms()` - could be replaced later by a cleverer equality 
test.
    
    However, having such a generic way of comparing elements meant that in the 
function `SyntacticToSemanticComputations()`, where we were going from a 
hashtable of syntactical entities to what I called a vector of "semantical 
entites" (which are just canonical forms/representants of classes of 
equivalence of terms), **the only way was to compare each pair**.
    That resulted in a quadratic behavior of this function, but there was no 
way around it as in order to merge equivalent entities into their class of 
equivalence, we had to compare them.
    
    **This PR essentially does the following:**
    
    - When computing the classes of equivalences of terms (therefore 
transforming a ComputationTable (i.e. a hashtable) into a vector of classes of 
equivalence) : **instead of comparing each pair of terms, relies on a 
normalization procedure to obtain a normal form for each of them**.
    That transforms a small part of the algorithm that was quadratic to n.logn. 
However, it's difficult to see improvements in practice, in particular for 
average sized programs, as that part was a "small" quadratic to a "big" n.logn 
(finding things in a hash-table, copying it to a vector, etc).
    It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a 
complexity of ~O(3n + n.logn), so potential gains would only be expected for 
very large programs.
    
    - Completely gives the user the possibility to turn ON/OFF the semantical 
comparisons of terms. It is turned OFF by default (as it's quite longer to 
compile with it ON, unsurprisingly), which means that by default, the 
equivalence coincides with the (syntactical) equality of terms.
        As the pass was written with the possibility to do these additional 
commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that 
completely, up to the Python user who can now turn that ON if he wants to. But 
again, it is OFF by default, so no real change on that.
    
    To run it ON, simply do:
    `with 
tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):`
    before calling `build()`
    
    - When this boolean is set to ON, it uses a simple implementation of the 
normalization function with equivalences that uses `arith::Analyzer::Simplify` 
as noted by in https://github.com/apache/tvm/pull/10544 . Note that this is not 
a real normalization procedure as it is incomplete (i.e., it is not guarantee 
to converge to the normal form), but it is correct, and it works well with most 
properties : associativity of +, distributivity of * on +, etc.
    
    - Clarifies and enhance the test base for the pass. In particular, it adds 
the tests that were written in https://github.com/apache/tvm/pull/10544 but 
which did not make it through.
    
    - Also add the test ( 
https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1
 ) demonstrating the (older) non-deterministic lowering and put it into a 
proper test, as I found it useful for making sure that this does not happen 
again. It has been copied from https://github.com/apache/tvm/pull/10663 and 
only slightly adapted (in particular for doing the comparison of hashes 
automatically instead of printing them [...]
---
 include/tvm/tir/transform.h                        |   3 +-
 python/tvm/tir/transform/transform.py              |   4 +-
 src/driver/driver_api.cc                           |   6 +-
 src/tir/transforms/common_subexpr_elim.cc          |  96 +++++---
 src/tir/transforms/common_subexpr_elim.h           |   8 +-
 src/tir/transforms/common_subexpr_elim_tools.cc    | 145 +++++++++---
 src/tir/transforms/common_subexpr_elim_tools.h     |  10 +-
 .../test_tir_transform_common_subexpr_elim.py      | 260 ++++++++++++++++-----
 8 files changed, 409 insertions(+), 123 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 24c3cfa78f..4612d5ad3f 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -470,9 +470,10 @@ TVM_DLL Pass LowerVtcmAlloc();
  * \brief Implements a Common Subexpression Elimination (CSE) for TIR
  *        which introduces let-in bindings for duplicated sub-expressions.
  * \param enable_cse_tir Whether common subexpression elimination is enabled.
+ * \param identify_equiv_terms Whether equivalent terms should be identified.
  * \return The pass.
  */
-TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true);
+TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool 
identify_equiv_terms = false);
 
 /*!
  * \brief Unify all the thread bindings for "blockIdx.x/y/z", 
"threadIdx.x/y/z", and
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 802fdc576c..1bed29c560 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -324,7 +324,7 @@ def BF16TypeLowering():
     return _ffi_api.BF16TypeLowering()  # type: ignore
 
 
-def CommonSubexprElimTIR(enable_cse_tir: bool = True):
+def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: 
bool = False):
     """Replace redundant computations by new variables.
 
     Returns
@@ -332,7 +332,7 @@ def CommonSubexprElimTIR(enable_cse_tir: bool = True):
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.CommonSubexprElimTIR(enable_cse_tir)  # type: ignore
+    return _ffi_api.CommonSubexprElimTIR(enable_cse_tir, identify_equiv_terms) 
 # type: ignore
 
 
 def RewriteUnsafeSelect():
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 7df1a844ac..7706f229c9 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -45,6 +45,7 @@ 
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
@@ -198,6 +199,8 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", 
Bool(false)).value();
   bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir", 
Bool(false)).value();
+  bool enable_equiv_terms_in_cse_tir =
+      pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir", 
Bool(false)).value();
 
   // Get any user-added passes
   Array<Array<ObjectRef>> add_lower_pass =
@@ -289,7 +292,8 @@ Array<tvm::transform::Pass> CreatePassList(bool 
disable_loop_partition) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
 
-  pass_list.push_back(tir::transform::CommonSubexprElimTIR(!disable_cse_tir));
+  pass_list.push_back(
+      tir::transform::CommonSubexprElimTIR(!disable_cse_tir, 
enable_equiv_terms_in_cse_tir));
 
   return pass_list;
 }
diff --git a/src/tir/transforms/common_subexpr_elim.cc 
b/src/tir/transforms/common_subexpr_elim.cc
index d43b30d17b..290f920e3f 100644
--- a/src/tir/transforms/common_subexpr_elim.cc
+++ b/src/tir/transforms/common_subexpr_elim.cc
@@ -60,7 +60,7 @@ namespace tir {
           to collect them for the CSE pass, but we also won't even want to 
collect computations
           that contain them.
           The reason is that reusing such computations would change the 
semantics of the program,
-          and therefore before doing any introduction of variable or any reuse 
of already introduced
+          and therefore before doing any introduction of var or any reuse of 
already introduced
           variables, we will make sure that the computation being considered 
is not forbidden, and
           that it does not even contain a forbidden computation.
  * \param expr The expression to check
@@ -120,6 +120,42 @@ bool 
CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp
   return true;
 }
 
+/*!
+ * \brief Implements an order on pairs (expression,frequency). First attempts 
to compare them
+          using the size of the expression. If it is the same, decides 
something else still
+          deterministic.
+ * \param a The first pair
+ * \param b The second pair
+ * \return A boolean telling if the first pair `a` comes before the second 
pair `b`
+ * \note We need this order to be deterministic in order to have a fully 
deterministic pass,
+ *       as we will deal with elements that are coming from a hashtable, but 
the order in which
+ *       they appeared in the hashtable was based on some runtime addresses, 
so it can potentially
+ *       change with every execution.
+ */
+bool 
CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, 
size_t> a,
+                                                            
std::pair<PrimExpr, size_t> b) {
+  size_t a_size = CalculateExprComplexity(a.first);
+  size_t b_size = CalculateExprComplexity(b.first);
+
+  // Criteria 1 - Size of the expression comes first
+  // `a` comes before `b` if the size of `a` is bigger
+  if (a_size > b_size) {
+    return true;
+  }
+  // `a` does NOT come before `b` if the size of `b` is bigger
+  if (b_size > a_size) {
+    return false;
+  }
+
+  // Criteria 2 - If they had the same size, use the lexicographic order as a 
last resort
+  // as we need a deterministic order
+  std::stringstream a_stream;
+  std::stringstream b_stream;
+  a_stream << a.first;
+  b_stream << b.first;
+  return (a_stream.str().compare(b_stream.str()) < 0);
+}
+
 /*!
  * \brief Generates a new fresh variable, whose name will be cse_var_i.
  * \param type_annotation The type of the new variable to generate
@@ -166,10 +202,12 @@ int CommonSubexpressionEliminator::GetNbVarGenerated() { 
return nb_var_; }
                           of the function being analyzed
  * \return A new statement where CSE has been performed
  */
-Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const 
Context& context_init) {
+Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const 
Context& context_init,
+                                               bool identify_equiv_terms) {
   // As this function is being called for each PrimFunc definition, we create 
a new instance
   // for the one we are having now.
-  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, 
context_init);
+  CommonSubexpressionEliminator common_subexpression_eliminator(stmt, 
context_init,
+                                                                
identify_equiv_terms);
   return common_subexpression_eliminator.VisitStmt(stmt);
 }
 
@@ -179,8 +217,9 @@ Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& 
stmt, const Context&
                         formal parameters of the function that will be analyzed
  */
 CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
-                                                             const Context& 
context_init)
-    : initial_body_(stmt), context_(context_init) {}
+                                                             const Context& 
context_init,
+                                                             bool 
identify_equiv_terms)
+    : initial_body_(stmt), context_(context_init), 
identify_equiv_terms_(identify_equiv_terms) {}
 
 /*!
  * \brief The method which overrides the generic dispatcher of StmtExprMutator.
@@ -200,28 +239,28 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const 
PrimExpr& expr) {
   // Transform the hashtable of *syntactic* eligible computations into a 
vector of pairs
   // containing *semantic* entities, i.e. where equivalent computations are 
merged.
   std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
-      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr);
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, 
identify_equiv_terms_);
 
   // Sort the vector of semantic entities by decreasing size
   std::sort(semantic_comp_done_by_expr.begin(), 
semantic_comp_done_by_expr.end(),
-            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
-              return (CalculateExprComplexity(a.first) > 
CalculateExprComplexity(b.first));
-            });
+            OrderOnExprAndFrequency);
 
   // For each computation done (considering them from biggest to smallest)
   for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
     std::pair<PrimExpr, size_t>& computation_and_nb = 
semantic_comp_done_by_expr[i];
 
+    bool ident_equiv_terms = identify_equiv_terms_;  // To avoid the capture 
of "this"
+
     // The predicate later used (when doing replacements) to select 
expressions that are
     // equivalent to the current computation (`computation_and_nb.first`)
     std::function<bool(const PrimExpr&)> predicate_selector =
-        [computation_and_nb](const PrimExpr& current_expr) {
+        [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
           // `current_expr` should be equivalent to 
`computation_and_nb.first`, but we also check
           // that `current_expr` is an eligible computation even if we know 
that
           // `computation_and_nb.first` is eligible by construction, in case 
that one day the
           // equivalence relation would not preserve the eligibility any more 
(even though that
           // would probably be a very weird equivalence).
-          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+          return (EquivalentTerms(current_expr, computation_and_nb.first, 
ident_equiv_terms) &&
                   IsEligibleComputation(current_expr));
         };
 
@@ -229,10 +268,11 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const 
PrimExpr& expr) {
     // equivalent to `computation_and_nb.first`
     auto it_on_var = std::find_if(
         context_.begin(), context_.end(),
-        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+        [computation_and_nb, ident_equiv_terms](const std::pair<Var, 
MaybeValue>& var_and_value) {
           // Note : safe to call value() as we check has_value() just before
           return (var_and_value.second.has_value() &&
-                  EquivalentTerms(var_and_value.second.value(), 
computation_and_nb.first));
+                  EquivalentTerms(var_and_value.second.value(), 
computation_and_nb.first,
+                                  ident_equiv_terms));
         });
 
     // Case where we have a perfectly equivalent computation already available 
in a variable
@@ -298,7 +338,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const 
PrimExpr& expr) {
         // The following insertion will maintain `semantic_comp_done_by_expr` 
sorted (by
         // decreasing size/complexity), and it will only insert at locations > 
i as the
         // direct subexprs are necessarily smaller than the current 
computation.
-        InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, 
direct_subexprs);
+        InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, 
direct_subexprs,
+                                                 identify_equiv_terms_);
       }
     }
     // Note : we do not remove the current element, as we never look back in 
the local vector
@@ -378,28 +419,28 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& 
stmt) {
   // Transform the hashtable of *syntactic* eligible computations into a 
vector of pairs
   // containing *semantic* entities, i.e. where equivalent computations are 
merged.
   std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
-      SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt);
+      SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, 
identify_equiv_terms_);
 
   // Sort the vector of semantic entities by decreasing size
   std::sort(semantic_comp_done_by_stmt.begin(), 
semantic_comp_done_by_stmt.end(),
-            [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
-              return (CalculateExprComplexity(a.first) > 
CalculateExprComplexity(b.first));
-            });
+            OrderOnExprAndFrequency);
 
   // For each computation done (considering them from biggest to smallest)
   for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
     std::pair<PrimExpr, size_t>& computation_and_nb = 
semantic_comp_done_by_stmt[i];
 
+    bool ident_equiv_terms = identify_equiv_terms_;  // To avoid the capture 
of "this"
+
     // The predicate later used (when doing replacements) to select 
expressions that are
     // equivalent to the current computation (`computation_and_nb.first`)
     std::function<bool(const PrimExpr&)> predicate_selector =
-        [computation_and_nb](const PrimExpr& current_expr) {
+        [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
           // `current_expr` should be equivalent to 
`computation_and_nb.first`, but we also check
           // that `current_expr` is an eligible computation even if we know 
that
           // `computation_and_nb.first` is eligible by construction, in case 
that one day the
           // equivalence relation would not preserve the eligibility any more 
(even though that
           // would probably be a very weird equivalence).
-          return (EquivalentTerms(current_expr, computation_and_nb.first) &&
+          return (EquivalentTerms(current_expr, computation_and_nb.first, 
ident_equiv_terms) &&
                   IsEligibleComputation(current_expr));
         };
 
@@ -407,10 +448,11 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& 
stmt) {
     // equivalent to `computation_and_nb.first`
     auto it_on_var = std::find_if(
         context_.begin(), context_.end(),
-        [computation_and_nb](const std::pair<Var, MaybeValue>& var_and_value) {
+        [computation_and_nb, ident_equiv_terms](const std::pair<Var, 
MaybeValue>& var_and_value) {
           // Note : safe to call value() as we check has_value() just before
           return (var_and_value.second.has_value() &&
-                  EquivalentTerms(var_and_value.second.value(), 
computation_and_nb.first));
+                  EquivalentTerms(var_and_value.second.value(), 
computation_and_nb.first,
+                                  ident_equiv_terms));
         });
 
     // Case where we have a perfectly equivalent computation already available 
in a variable
@@ -477,7 +519,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& 
stmt) {
         // The following insertion will maintain `semantic_comp_done_by_stmt` 
sorted (by
         // decreasing size/complexity), and it will only insert at locations > 
i as the
         // direct subexprs are necessarily smaller than the current 
computation.
-        InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, 
direct_subexprs);
+        InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, 
direct_subexprs,
+                                                 identify_equiv_terms_);
       }
     }
     // Note : we do not remove the current element, as we never look back in 
the local vector
@@ -587,8 +630,8 @@ namespace transform {
  * \brief The function which returns the pass for the Common Subexpression 
Elimination.
  * \return The pass for performing CSE.
  */
-Pass CommonSubexprElimTIR(bool enable_cse_tir) {
-  auto pass_func = [enable_cse_tir](PrimFunc f, IRModule m, PassContext ctx) {
+Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) {
+  auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule 
m, PassContext ctx) {
     if (enable_cse_tir) {
       auto* n = f.CopyOnWrite();
       Context context_init;
@@ -603,7 +646,8 @@ Pass CommonSubexprElimTIR(bool enable_cse_tir) {
 
       // Do the Common Subexpression Elimination on the body of the function, 
with the initial
       // context that we have prepared
-      n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), 
context_init);
+      n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), 
context_init,
+                                                          
identify_equiv_terms);
     }
 
     return f;
diff --git a/src/tir/transforms/common_subexpr_elim.h 
b/src/tir/transforms/common_subexpr_elim.h
index 484d93c769..5c14caf1a6 100644
--- a/src/tir/transforms/common_subexpr_elim.h
+++ b/src/tir/transforms/common_subexpr_elim.h
@@ -55,7 +55,7 @@ using Context = std::vector<std::pair<Var, MaybeValue>>;
 class CommonSubexpressionEliminator : public StmtExprMutator {
  public:
   // Toplevel (static) function
-  static Stmt PerformCSE(const Stmt& stmt, const Context& context_init);
+  static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool 
identify_equiv_terms);
 
   PrimExpr VisitExpr(const PrimExpr& expr) override;
   Stmt VisitStmt(const Stmt& stmt) override;
@@ -64,7 +64,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator {
 
  protected:
   // Constructor
-  CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init);
+  CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init,
+                                bool identify_equiv_terms);
 
   PrimExpr VisitExpr_(const LetNode* op) override;
 
@@ -77,9 +78,12 @@ class CommonSubexpressionEliminator : public StmtExprMutator 
{
   int num_last_try_ = 0;  // Number of the last variable tried
   int nb_var_ = 0;        // Number of variables introduced by the CSE pass
 
+  bool identify_equiv_terms_ = false;
+
   static bool ForbiddenComputation(const PrimExpr& expr);
   static bool IsEligibleComputation(const PrimExpr& expr);
   static bool CanContainEligibleComputations(const PrimExpr& expr);
+  static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, 
std::pair<PrimExpr, size_t> b);
   Var GenerateNewVar(DataType type_annotation);
 };
 
diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc 
b/src/tir/transforms/common_subexpr_elim_tools.cc
index d39d211ba1..b5b1bfccdf 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -25,7 +25,8 @@
 
 #include "common_subexpr_elim_tools.h"
 
-#include <tvm/ir/transform.h>  // For the class Pass and the class PassContext
+#include <tvm/arith/analyzer.h>  // For the arith::Analyzer::Simplify() method 
simplifying terms
+#include <tvm/ir/transform.h>    // For the class Pass and the class 
PassContext
 #include <tvm/runtime/container/string.h>
 #include <tvm/tir/analysis.h>  // For the ExprDeepEqual analysis
 #include <tvm/tir/expr.h>
@@ -720,14 +721,42 @@ bool EqualTerms(const PrimExpr& a, const PrimExpr& b) {
   return deep_equal_(a, b);
 }
 
+/*!
+ * \brief Normalization function of a term, use to decide the equivalence 
relation of interest
+ * \param expr The expression to normalize
+ * \param do_normalization Whether we want the function to actually do 
normalization
+ * \note This function can be customized
+ */
+PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization) {
+  if (do_normalization) {
+    // Customize here!
+    // We could decide to normalize terms in a way that identifies them modulo 
commutativity
+    // (like x+y and y+x), or modulo associativity (like (x+y)+z and x+(y+z)), 
etc.
+    // For that, a normalization procedure (or an incomplete 
"pseudo-normalization" like
+    // arith::Analyzer::Simplify) will be used.
+
+    // One possible customization:
+    // Here is just an attempt to do more commonings by using the 
pseudo-normalization function
+    // offered by arith::Analyzer::Simplify(). "pseudo" because while it is 
correct (i.e.
+    // the simplification is indeed equivalent to the original term), it is 
incomplete (i.e.
+    // the returned term is not guaranteed to be a normal form).
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(expr);
+  } else {
+    // If `do_normalization` is false, the equivalence relation just checks 
the syntactic equality,
+    // so the normalization is just the identity function.
+    return expr;
+  }
+}
+
 /*!
  * \brief Decides if two terms are equivalent semantically
  */
-bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b) {
-  // For now, we just check the syntactic equality, but that could later 
become a semantic test,
-  // for instance identifying computations modulo commutativity (like x+y and 
y+x), or modulo
-  // associativity (like (x+y)+z and x+(y+z)), etc.
-  return EqualTerms(a, b);
+bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool 
identify_equiv_terms) {
+  // We restrict the equivalence to be decidable by a normalization procedure 
that is used to
+  // normalize both sides, and to then compare the normal forms with the 
strict syntactical
+  // equality
+  return EqualTerms(NormalizeTerm(a, identify_equiv_terms), NormalizeTerm(b, 
identify_equiv_terms));
 }
 
 /*!
@@ -739,21 +768,52 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& 
b) {
    \note This function is needed because the advantage of the hashtable was 
the constant lookup.
           But in order to have this constant lookup, we could not collapse 
semantically equivalent
           computations.
+          Attention, the pairs returned are deterministic and will always be 
the same (as the same
+          canonical representant will always be chosen for a given class of 
equivalence), but the
+          order in which these pairs appear in the result is not 
deterministic, as it is based on
+          the order in which we found items in the "normalized hashtable" 
`norm_table`). The caller
+          is expected to sort the result anyway.
  */
 std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
-    const ComputationTable& table) {
+    const ComputationTable& table, bool identify_equiv_terms) {
   std::vector<std::pair<PrimExpr, size_t>> result;
 
-  // table.size() is an upper-bound of the number of elements in the resulting 
vector,
-  // as we might merge semantically equivalent computations.
-  // We do this reservation even if it might reserve slightly more space than 
is needed in the end
-  result.reserve(table.size());
+  // If we do NOT identify equivalent terms, then we simply need to transform 
the input hashtable
+  // into a vector, without doing anything else.
+  if (!identify_equiv_terms) {
+    // The result will contain exactly as many elements as the input `table` 
has
+    result.reserve(table.size());
+    for (const auto& elem : table) {
+      result.push_back(elem);
+    }
 
-  // Traverse through map in a sorted order on keys to maintain deterministic 
behavior
-  // We do this by comparing the string repr of each PrimExpr to get a 
determinstic ordering
-  std::vector<std::pair<PrimExpr, size_t>> sorted_map_items(table.begin(), 
table.end());
+    return result;
+  }
 
-  sort(sorted_map_items.begin(), sorted_map_items.end(),
+  // Otherwise, in order to identify equivalent terms, we will go through a 
table `norm_table`
+  // where normal forms are the keys., and use it to efficiently merge 
equivalent terms.
+
+  // In order to produce the result (a vector of semantical entities), the 
input table will be
+  // normalized. This normalized table will keep the count for each set of 
equivalent terms
+  // (i.e. each equivalence class), together with a term that did appear in 
this equivalence class
+  // (in practice, the first term of the equivalence class that was 
encoutered).
+  std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash, 
ExprDeepEqual>
+      norm_table;
+
+  // In order to avoid frequent rehashing if the norm_table becomes big, we 
immediately ask for
+  // enough space to store the amount of elements that the input table has, as 
it's clearly an
+  // upper bound (in the worst case, each element is its own representant, and 
there is as many
+  // equivalence classes as there are elements)
+  norm_table.reserve(table.size());
+
+  // Transform the input hashtable to a vector and sort it according to some 
order, as we will be
+  // iterating through its items soon, and the order of appearance will be 
used to determine the
+  // individual representant for each class of equivalence, which we want to 
be deterministic
+  // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run 
by y+x).
+  std::vector<std::pair<PrimExpr, size_t>> 
sorted_items_of_table(table.begin(), table.end());
+
+  // We do the ordering by comparing the string repr of each expr to get a 
determinstic ordering
+  sort(sorted_items_of_table.begin(), sorted_items_of_table.end(),
        [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
          std::stringstream a_stream;
          std::stringstream b_stream;
@@ -762,21 +822,40 @@ std::vector<std::pair<PrimExpr, size_t>> 
SyntacticToSemanticComputations(
          return a_stream.str().compare(b_stream.str()) < 0;
        });
 
-  // For each element in the hashtable
-  for (auto elem : sorted_map_items) {
-    // We try to see if a semantically equivalent term is already in the 
resulting vector
-    auto it_found = std::find_if(result.begin(), result.end(),
-                                 [elem](std::pair<PrimExpr, size_t> 
already_seen) {
-                                   return EquivalentTerms(already_seen.first, 
elem.first);
-                                 });
-    // And if so, we increase (by `elem.second`) its count
-    if (it_found != result.end()) {
-      it_found->second += elem.second;
+  for (const auto& elem : sorted_items_of_table) {
+    PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms);
+    // If the normalized term is not already a key in the normalized table
+    auto it_found = norm_table.find(norm_elem);
+    if (it_found == norm_table.end()) {
+      // Then we add the mapping `norm_elem` -> (`elem`.first, `elem`.second) 
to the norm table
+      // (i.e. `norm_elem` has been seen `elem`.second many times so far, and 
the chosen element
+      // to represent the equivalence class will be `elem`.first as it's the 
first element of the
+      // class that we see)
+      norm_table[norm_elem] = elem;
     } else {
-      // If we could not find a semantically equivalent term in the resulting 
vector, we add it
-      result.push_back(elem);
+      // Otherwise, it's not the first time we see a term in this equivalence 
class, so we just
+      // increase the count of this equivalence class as we now have 
`elem`.second additional items
+      // coming to the equivalence class.
+      it_found->second.second += elem.second;
     }
   }
+
+  // norm_table.size() is the number of equivalence class that we have built, 
so it's exactly the
+  // number of items that we will return in the vector of semantical entities
+  result.reserve(norm_table.size());
+
+  // Transform the intermediate hashtable `norm_table` into a vector, 
forgetting the keys,
+  // (which are the normal forms), as they won't be used as the canonical 
representants (which are
+  // instead the first element of each class that is effectively seen)
+  // Careful : the pairs will never change (the canonical represantants chosen 
will always be the
+  // same), but the order in which the pairs are produced can vary as we are 
iterating through the
+  // hashtable `norm_table`. It is not an issue as the called will be sorting 
the result anyway.
+  std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash,
+                     ExprDeepEqual>::const_iterator it_norm_table;
+  for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end(); 
++it_norm_table) {
+    result.push_back(it_norm_table->second);
+  }
+
   return result;
 }
 
@@ -822,17 +901,19 @@ void 
InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size
           decreasing size of the expression) and maintain the vector sorted 
while doing so.
  */
 void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, 
size_t>>* sorted_vec,
-                                              const std::vector<PrimExpr>& 
vec_to_add) {
+                                              const std::vector<PrimExpr>& 
vec_to_add,
+                                              bool identify_equiv_terms) {
   if (sorted_vec == nullptr) {
     return;
   }
   for (auto elem_to_add : vec_to_add) {
     // See if the current element to add (or an equivalent one) is already 
present
     // in the sorted vector
-    auto it_found = std::find_if(sorted_vec->begin(), sorted_vec->end(),
-                                 [elem_to_add](std::pair<PrimExpr, size_t> 
elem) {
-                                   return EquivalentTerms(elem.first, 
elem_to_add);
-                                 });
+    auto it_found =
+        std::find_if(sorted_vec->begin(), sorted_vec->end(),
+                     [elem_to_add, identify_equiv_terms](std::pair<PrimExpr, 
size_t> elem) {
+                       return EquivalentTerms(elem.first, elem_to_add, 
identify_equiv_terms);
+                     });
 
     // If we found `elem_to_add` (or an equivalent expression) already in 
sorted_vec
     if (it_found != sorted_vec->end()) {
diff --git a/src/tir/transforms/common_subexpr_elim_tools.h 
b/src/tir/transforms/common_subexpr_elim_tools.h
index a590cde69f..fcd29fddc0 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.h
+++ b/src/tir/transforms/common_subexpr_elim_tools.h
@@ -180,9 +180,12 @@ void PrintComputationTable(const ComputationTable& table);
 using MaybeValue = dmlc::optional<PrimExpr>;
 
 bool EqualTerms(const PrimExpr& a, const PrimExpr& b);
-bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b);
+// Used for deciding the (decidable) equivalence relation
+PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization);
+// The equivalence relation, which is the syntactical equality when 
`identify_equiv_terms` is false
+bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool 
identify_equiv_terms);
 std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
-    const ComputationTable& table);
+    const ComputationTable& table, bool identify_equiv_terms);
 bool PredicateIntroVarForComputation(const PrimExpr& computation, size_t 
nb_times_seen);
 
 // Polymorphic (functional) map on a vector, which builds a news vector with 
the same number of
@@ -209,7 +212,8 @@ template std::vector<Var> VectorMap(const 
std::vector<std::pair<Var, MaybeValue>
 void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, 
size_t>>* sorted_vec,
                                             const std::pair<PrimExpr, size_t>& 
pair);
 void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, 
size_t>>* sorted_vec,
-                                              const std::vector<PrimExpr>& 
vec_to_add);
+                                              const std::vector<PrimExpr>& 
vec_to_add,
+                                              bool identify_equiv_terms);
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py 
b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
index c12e27a46e..a546c16a64 100644
--- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
@@ -17,12 +17,16 @@
 import hashlib
 
 import tvm
-from tvm import te
+from tvm import auto_scheduler, te, topi
 from tvm.ir.base import save_json
 from tvm.ir.module import IRModule
+from tvm.script import tir as T
 
-
-# A test program which gives the opportunity for the CSE pass to introduce two 
new variables, at two different levels
+# -----------------------------------------------------
+# Basic test for the expected Behavior of the CSE pass
+# -----------------------------------------------------
+# A test program which gives the opportunity for the CSE pass to introduce two 
new variables,
+# at two different levels
 def test_cse():
     z1 = te.var("z1")
     z2 = te.var("z2")
@@ -70,9 +74,9 @@ def test_cse():
             ),
         ),
     )
-    # This test program gives the opportunity to introduce two new variables, 
at two different levels
-    # and to perform replacements in the value of "a" and "b", using these new 
variables
-    # We will check all of that underneath and more, making also sure that 
nothing else has been changed
+    # This test program gives the opportunity to introduce two new variables, 
at two different
+    # levels and to perform replacements in the value of "a" and "b", using 
these new variables.
+    # We will check all of that underneath and more, making also sure that 
nothing else has changed
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([i1, i2, z3], body))
     body = tvm.tir.transform.CommonSubexprElimTIR()(mod)
@@ -138,52 +142,14 @@ def test_cse():
     assert isinstance(body.body, tvm.tir.BufferStore)
 
 
-def test_deterministic_cse():
-    import random
-
-    """Test deterministic allocation of CSE vars
-
-    We expect something like
-
-        result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3)
-            -->
-        cse_var_3 = (x + 1)
-        cse_var_2 = (x + 2)
-        cse_var_1 = (x + 3)
-        result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + 
cse_var_1
-    """
-    NUM_TERMS = 10
-    REPEATS = 10
-
-    x = te.var("x")
-    result = te.var("result")
-
-    offsets = sorted([i + 1 for i in range(NUM_TERMS)])
-    inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
-    inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)]
-
-    expression = x
-    for add in inc1 + inc2:
-        expression = expression + add
-    let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result))
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt))
-
-    initial_hash = None
-    for _ in range(REPEATS):
-        body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"]
-
-        # Hash and ensure serialize json is the same every time
-        json_val = save_json(body)
-        json_hash = hashlib.sha256(json_val.encode()).hexdigest()
-
-        if initial_hash is None:
-            initial_hash = json_hash
-        assert json_hash == initial_hash
-
-
-# First specific test for if nodes : Some duplicated computations appear only 
in one branch (here the Then branch), not in both branches.
-# In this case, the CSE pass should introduce the redundant computation at the 
top if the Then branch, not before the whole If
-# (otherwise that would lead to some computations being computed for nothing 
when it is the Else branch that is executed).
+# -----------------------------------------------------
+# Tests related to If nodes
+# -----------------------------------------------------
+# First specific test for if nodes : Some duplicated computations appear only 
in one branch (here
+# the Then branch), not in both branches.
+# In this case, the CSE pass should introduce the redundant computation at the 
top of the Then
+# branch, not before the whole If (otherwise that would lead to some 
computations being computed
+# for nothing when it is the Else branch that is executed).
 def test_cse_ifNode_1():
     b = te.var("b")
     i1 = te.var("i1")
@@ -237,9 +203,9 @@ def test_cse_ifNode_1():
     assert tvm.ir.structural_equal(body.value, y + z)
 
 
-# Second test for if nodes : Some duplicated computations appear in both the 
Then and the Else branch.
-# In this case, the CSE pass should introduce the redundant computation before 
the whole If node, because
-# regardless of the execution path, it is going to be computed.
+# Second test for if nodes : Some duplicated computations appear in both the 
Then and Else branch.
+# In this case, the CSE pass should introduce the redundant computation before 
the whole If node,
+# because regardless of the execution path, it is going to be computed.
 def test_cse_ifNode_2():
     b = te.var("b")
     i1 = te.var("i1")
@@ -265,7 +231,7 @@ def test_cse_ifNode_2():
             b,
             tvm.tir.SeqStmt(
                 [
-                    tvm.tir.BufferStore(buffer, y + z, [i1]),  # (y+z) is 
present in the Then branch
+                    tvm.tir.BufferStore(buffer, y + z, [i1]),  # (y+z) is 
present in Then branch
                     tvm.tir.BufferStore(buffer, y, [i2]),
                 ]
             ),
@@ -288,9 +254,11 @@ def test_cse_ifNode_2():
     assert tvm.ir.structural_equal(body.value, y + z)
 
 
+# 
-------------------------------------------------------------------------------------------------
 # Test commoning in cascade : after having introduced a big exp ((x+y)+z) into 
a new variable,
 # it will become possible to do another commoning for (x+y) which appears both 
in the new variable
 # and in the rest of the program.
+# 
-------------------------------------------------------------------------------------------------
 def test_cse_cascade():
     i1 = te.var("i1")
     i2 = te.var("i2")
@@ -353,8 +321,188 @@ def test_cse_cascade():
     assert tvm.ir.structural_equal(store3.value, cse_var_2)
 
 
+# 
-----------------------------------------------------------------------------------------
+# A test which ensures that we don't perform normalizations outside of 
introduced variables
+# 
-----------------------------------------------------------------------------------------
+def test_no_normalization_without_commoning():
+    x = te.var("x")
+    y = te.var("y")
+    z = te.var("z")
+    a = te.var("a")
+    # Test prog :
+    # let a = x + (y + z) in a
+    body = tvm.tir.LetStmt(a, x + (y + z), tvm.tir.Evaluate(a))
+
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x, y, z], body))
+    body = 
tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod)
+
+    tvm.transform.PrintIR()(body)
+
+    body = body["main"].body  # Gets the body of the main, i.e. the full 
statement
+
+    assert body.var.name == "a"
+    assert tvm.ir.structural_equal(body.value, x + (y + z))
+
+
+# -------------------------------------------------
+# Part for testing the commoning with equivalences
+# -------------------------------------------------
[email protected]_func
+def func_distributivity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: 
T.int32) -> None:
+    B = T.buffer_decl((50,), "int32")
+    B[i1] = x * (y + z)
+    B[i2] = x * y + x * z
+
+
[email protected]_func
+def func_distributivity_expected(
+    i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
+) -> None:
+    B = T.buffer_decl((50,), "int32")
+    cse_var_1 = T.var("int32")
+    with T.let(cse_var_1, x * y + x * z):
+        B[i1] = cse_var_1
+        B[i2] = cse_var_1
+
+
[email protected]_func
+def func_associativity(i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: 
T.int32) -> None:
+    B = T.buffer_decl((50,), "int32")
+    B[i1] = (x + y) + z
+    B[i2] = x + (y + z)
+
+
[email protected]_func
+def func_associativity_expected(
+    i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
+) -> None:
+    B = T.buffer_decl((50,), "int32")
+    cse_var_1 = T.var("int32")
+    with T.let(cse_var_1, (x + y) + z):
+        B[i1] = cse_var_1
+        B[i2] = cse_var_1
+
+
+def _check(original, transformed):
+    func = original
+    mod = tvm.IRModule.from_expr(func)
+    body = 
tvm.tir.transform.CommonSubexprElimTIR(identify_equiv_terms=True)(mod)
+    tvm.transform.PrintIR()(body)
+    tvm.ir.assert_structural_equal(body["main"], transformed)
+
+
+def test_semantic_equiv_distributivity():
+    _check(func_distributivity, func_distributivity_expected)
+
+
+def test_semantic_equiv_associativity():
+    _check(func_associativity, func_associativity_expected)
+
+
+# -----------------------------------------------------
+# Tests that verify the determinism of the pass
+# -----------------------------------------------------
+def test_deterministic_cse():
+    import random
+
+    """Test deterministic allocation of CSE vars
+
+    We expect something like
+
+        result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3)
+            -->
+        cse_var_3 = (x + 1)
+        cse_var_2 = (x + 2)
+        cse_var_1 = (x + 3)
+        result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + 
cse_var_1
+    """
+    NUM_TERMS = 10
+    REPEATS = 10
+
+    x = te.var("x")
+    result = te.var("result")
+
+    offsets = sorted([i + 1 for i in range(NUM_TERMS)])
+    inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
+    inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)]
+
+    expression = x
+    for add in inc1 + inc2:
+        expression = expression + add
+    let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt))
+
+    initial_hash = None
+    for _ in range(REPEATS):
+        body = tvm.tir.transform.CommonSubexprElimTIR()(mod)
+
+        body = body["main"]
+
+        # Hash and ensure serialize json is the same every time
+        json_val = save_json(body)
+        json_hash = hashlib.sha256(json_val.encode()).hexdigest()
+
+        if initial_hash is None:
+            initial_hash = json_hash
+        assert json_hash == initial_hash
+
+
+# Needed for the second test on determinism
+LOG_LINE = '{"i": [["[\\"conv2d_layer\\", 1, 7, 7, 512, 512, 3, 3, [1, 1], [1, 
1]]", \
+            "llvm -keys=cpu -link-params=0 -mcpu=broadwell -num-cores=2", \
+            [8, 64, 64, 0, 0, 0, 0, 0], "", 1, []], [[], [["CI", 5], \
+            ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 512, [1, 32, 16], 1], \
+            ["SP", 3, 8, 7, [7, 1, 1], 1], ["SP", 3, 12, 7, [1, 1, 1], 1], \
+            ["SP", 3, 16, 512, [1], 1], ["SP", 3, 18, 3, [1], 1], ["SP", 3, 
20, 3, [3], 1], \
+            ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 
19, 21, 3, 7, \
+            11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP", 6, 6, 
3, 2], \
+            ["FSP", 6, 9, 4, 2], ["RE", 6, [0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 
11]], \
+            ["CA", 3, 6, 7], ["CA", 1, 6, 5], ["FU", 6, [0, 1, 2, 3, 4, 5]], 
["AN", 6, 0, 3], \
+            ["PR", 3, 0, "auto_unroll_max_step$512"], ["AN", 1, 3, 2], ["AN", 
3, 21, 2], \
+            ["AN", 6, 6, 2]]]], "r": [[0.0331129], 0, 0.900362, 1647464342], 
"v": "v0.6"}\n'
+
+# The workload associated with the log
+@auto_scheduler.register_workload
+def conv2d_layer(N, H, W, CO, CI, KH, KW, stride, padding):
+    data = te.placeholder((N, CI, H, W), name="data")
+    kernel = te.placeholder((CO, CI, KH, KW), name="kernel")
+    bias = te.placeholder((1, CO, 1, 1), name="bias")
+    conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, 
out_dtype="float32")
+    out = topi.nn.relu(conv + bias)
+    return [data, kernel, bias, out]
+
+
+def test_deterministic_cse_2():
+    inp, inr = auto_scheduler.measure_record.load_record_from_string(LOG_LINE)
+    inp = auto_scheduler.measure.recover_measure_input(inp, rebuild_state=True)
+
+    initial_hash = None
+
+    for _ in range(10):
+        sch, args = inp.task.compute_dag.apply_steps_from_state(inp.state)
+        ir_module = tvm.lower(sch, args)
+        primfunc = ir_module["main"]
+        json_str = save_json(primfunc)
+        new_hash = hashlib.sha256(json_str.encode("utf-8")).hexdigest()
+        # Make sure that all the hashes are going to be the same
+        if initial_hash is None:
+            initial_hash = new_hash
+        assert new_hash == initial_hash
+
+
 if __name__ == "__main__":
+    # Basic test:
     test_cse()
+    # Tests related to If nodes:
     test_cse_ifNode_1()
     test_cse_ifNode_2()
+    # Test performing a commoning on a commoning:
     test_cse_cascade()
+    # Test that verifies that the input program itself is not being normalized 
by the pass:
+    test_no_normalization_without_commoning()
+    # Tests that turn on the equivalence of terms and verify the commoning 
with equivalences:
+    test_semantic_equiv_distributivity()
+    test_semantic_equiv_associativity()
+    # Tests that verify the determinism of the pass:
+    test_deterministic_cse()
+    test_deterministic_cse_2()

Reply via email to