This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b01e3fc  [TIR] CSE-TIR Pass - More deterministic behavior (#10663)
b01e3fc is described below

commit b01e3fc4d21bba898a5ea17d526013c52ea720eb
Author: AndrewZhaoLuo <[email protected]>
AuthorDate: Thu Mar 17 19:17:01 2022 -0700

    [TIR] CSE-TIR Pass - More deterministic behavior (#10663)
    
    * iterate through sorted keys
    
    * masa comments -- simplify iteration
    
    * test
    
    * tests
    
    * simplify vector construciton
    
    * jostle ci
---
 src/tir/transforms/common_subexpr_elim_tools.cc    | 17 +++++++-
 .../test_tir_transform_common_subexpr_elim.py      | 48 ++++++++++++++++++++++
 2 files changed, 63 insertions(+), 2 deletions(-)

diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc 
b/src/tir/transforms/common_subexpr_elim_tools.cc
index 218667c..d39d211 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -743,13 +743,27 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr& 
b) {
 std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
     const ComputationTable& table) {
   std::vector<std::pair<PrimExpr, size_t>> result;
+
   // table.size() is an upper-bound of the number of elements in the resulting 
vector,
   // as we might merge semantically equivalent computations.
   // We do this reservation even if it might reserve slightly more space than 
is needed in the end
   result.reserve(table.size());
 
+  // Traverse through map in a sorted order on keys to maintain deterministic 
behavior
+  // We do this by comparing the string repr of each PrimExpr to get a 
determinstic ordering
+  std::vector<std::pair<PrimExpr, size_t>> sorted_map_items(table.begin(), 
table.end());
+
+  sort(sorted_map_items.begin(), sorted_map_items.end(),
+       [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+         std::stringstream a_stream;
+         std::stringstream b_stream;
+         a_stream << a.first;
+         b_stream << b.first;
+         return a_stream.str().compare(b_stream.str()) < 0;
+       });
+
   // For each element in the hashtable
-  for (auto elem : table) {
+  for (auto elem : sorted_map_items) {
     // We try to see if a semantically equivalent term is already in the 
resulting vector
     auto it_found = std::find_if(result.begin(), result.end(),
                                  [elem](std::pair<PrimExpr, size_t> 
already_seen) {
@@ -763,7 +777,6 @@ std::vector<std::pair<PrimExpr, size_t>> 
SyntacticToSemanticComputations(
       result.push_back(elem);
     }
   }
-
   return result;
 }
 
diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py 
b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
index 17c0cbd..c12e27a 100644
--- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
@@ -14,8 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import hashlib
+
 import tvm
 from tvm import te
+from tvm.ir.base import save_json
+from tvm.ir.module import IRModule
+
 
 # A test program which gives the opportunity for the CSE pass to introduce two 
new variables, at two different levels
 def test_cse():
@@ -133,6 +138,49 @@ def test_cse():
     assert isinstance(body.body, tvm.tir.BufferStore)
 
 
+def test_deterministic_cse():
+    import random
+
+    """Test deterministic allocation of CSE vars
+
+    We expect something like
+
+        result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3)
+            -->
+        cse_var_3 = (x + 1)
+        cse_var_2 = (x + 2)
+        cse_var_1 = (x + 3)
+        result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + 
cse_var_1
+    """
+    NUM_TERMS = 10
+    REPEATS = 10
+
+    x = te.var("x")
+    result = te.var("result")
+
+    offsets = sorted([i + 1 for i in range(NUM_TERMS)])
+    inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
+    inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)]
+
+    expression = x
+    for add in inc1 + inc2:
+        expression = expression + add
+    let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result))
+    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt))
+
+    initial_hash = None
+    for _ in range(REPEATS):
+        body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"]
+
+        # Hash and ensure serialize json is the same every time
+        json_val = save_json(body)
+        json_hash = hashlib.sha256(json_val.encode()).hexdigest()
+
+        if initial_hash is None:
+            initial_hash = json_hash
+        assert json_hash == initial_hash
+
+
 # First specific test for if nodes : Some duplicated computations appear only 
in one branch (here the Then branch), not in both branches.
 # In this case, the CSE pass should introduce the redundant computation at the 
top if the Then branch, not before the whole If
 # (otherwise that would lead to some computations being computed for nothing 
when it is the Else branch that is executed).

Reply via email to