This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b01e3fc [TIR] CSE-TIR Pass - More deterministic behavior (#10663)
b01e3fc is described below
commit b01e3fc4d21bba898a5ea17d526013c52ea720eb
Author: AndrewZhaoLuo <[email protected]>
AuthorDate: Thu Mar 17 19:17:01 2022 -0700
[TIR] CSE-TIR Pass - More deterministic behavior (#10663)
* iterate through sorted keys
* masa comments -- simplify iteration
* test
* tests
* simplify vector construciton
* jostle ci
---
src/tir/transforms/common_subexpr_elim_tools.cc | 17 +++++++-
.../test_tir_transform_common_subexpr_elim.py | 48 ++++++++++++++++++++++
2 files changed, 63 insertions(+), 2 deletions(-)
diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc
b/src/tir/transforms/common_subexpr_elim_tools.cc
index 218667c..d39d211 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -743,13 +743,27 @@ bool EquivalentTerms(const PrimExpr& a, const PrimExpr&
b) {
std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
const ComputationTable& table) {
std::vector<std::pair<PrimExpr, size_t>> result;
+
// table.size() is an upper-bound of the number of elements in the resulting
vector,
// as we might merge semantically equivalent computations.
// We do this reservation even if it might reserve slightly more space than
is needed in the end
result.reserve(table.size());
+ // Traverse through map in a sorted order on keys to maintain deterministic
behavior
+ // We do this by comparing the string repr of each PrimExpr to get a
determinstic ordering
+ std::vector<std::pair<PrimExpr, size_t>> sorted_map_items(table.begin(),
table.end());
+
+ sort(sorted_map_items.begin(), sorted_map_items.end(),
+ [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
+ std::stringstream a_stream;
+ std::stringstream b_stream;
+ a_stream << a.first;
+ b_stream << b.first;
+ return a_stream.str().compare(b_stream.str()) < 0;
+ });
+
// For each element in the hashtable
- for (auto elem : table) {
+ for (auto elem : sorted_map_items) {
// We try to see if a semantically equivalent term is already in the
resulting vector
auto it_found = std::find_if(result.begin(), result.end(),
[elem](std::pair<PrimExpr, size_t>
already_seen) {
@@ -763,7 +777,6 @@ std::vector<std::pair<PrimExpr, size_t>>
SyntacticToSemanticComputations(
result.push_back(elem);
}
}
-
return result;
}
diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
index 17c0cbd..c12e27a 100644
--- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py
@@ -14,8 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import hashlib
+
import tvm
from tvm import te
+from tvm.ir.base import save_json
+from tvm.ir.module import IRModule
+
# A test program which gives the opportunity for the CSE pass to introduce two
new variables, at two different levels
def test_cse():
@@ -133,6 +138,49 @@ def test_cse():
assert isinstance(body.body, tvm.tir.BufferStore)
+def test_deterministic_cse():
+ import random
+
+ """Test deterministic allocation of CSE vars
+
+ We expect something like
+
+ result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3)
+ -->
+ cse_var_3 = (x + 1)
+ cse_var_2 = (x + 2)
+ cse_var_1 = (x + 3)
+ result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 +
cse_var_1
+ """
+ NUM_TERMS = 10
+ REPEATS = 10
+
+ x = te.var("x")
+ result = te.var("result")
+
+ offsets = sorted([i + 1 for i in range(NUM_TERMS)])
+ inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)]
+ inc2 = [(x + offsets[i]) for i in range(NUM_TERMS)]
+
+ expression = x
+ for add in inc1 + inc2:
+ expression = expression + add
+ let_stmt = tvm.tir.LetStmt(result, expression, tvm.tir.Evaluate(result))
+ mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], let_stmt))
+
+ initial_hash = None
+ for _ in range(REPEATS):
+ body = tvm.tir.transform.CommonSubexprElimTIR()(mod)["main"]
+
+ # Hash and ensure serialize json is the same every time
+ json_val = save_json(body)
+ json_hash = hashlib.sha256(json_val.encode()).hexdigest()
+
+ if initial_hash is None:
+ initial_hash = json_hash
+ assert json_hash == initial_hash
+
+
# First specific test for if nodes : Some duplicated computations appear only
in one branch (here the Then branch), not in both branches.
# In this case, the CSE pass should introduce the redundant computation at the
top if the Then branch, not before the whole If
# (otherwise that would lead to some computations being computed for nothing
when it is the Else branch that is executed).