This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new fe74b37 Conditions updated to cover better user scenarios (#4951)
fe74b37 is described below
commit fe74b37ab578e6d3c540b0f6ac187a220ccc028a
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Mar 4 18:35:38 2020 -0600
Conditions updated to cover better user scenarios (#4951)
* Conditions updated to cover better user scenarios
* [1] New test case added
* [2] New test case added
* [3] Proper variable name used
* [4] Review Comments handled
* [5] Review comments handled
* [6] Review comments handled
---
src/relay/ir/alpha_equal.cc | 10 ++---
tests/cpp/relay_pass_alpha_equal.cc | 67 +++++++++++++++++++++++++++++
tests/python/relay/test_pass_alpha_equal.py | 32 ++++++++++++++
3 files changed, 104 insertions(+), 5 deletions(-)
diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 78688d7..c622599 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -50,14 +50,14 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
- if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
- if (lhs->IsInstance<TypeNode>()) {
- if (!rhs->IsInstance<TypeNode>()) return false;
+ if (lhs.same_as(rhs)) return true;
+ if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
+ if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return
false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
- if (lhs->IsInstance<ExprNode>()) {
- if (!rhs->IsInstance<ExprNode>()) return false;
+ if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
+ if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return
false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
diff --git a/tests/cpp/relay_pass_alpha_equal.cc
b/tests/cpp/relay_pass_alpha_equal.cc
new file mode 100644
index 0000000..0207fca
--- /dev/null
+++ b/tests/cpp/relay_pass_alpha_equal.cc
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/te/operation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/transform.h>
+
+using namespace tvm;
+
+class TestAlphaEquals {
+ runtime::PackedFunc *_packed_func;
+ public:
+ TestAlphaEquals(const char* func_name) {
+ _packed_func = new runtime::PackedFunc();
+ TVMFuncGetGlobal(func_name,
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
+ }
+
+ void UpdatePackedFunc(const char* func_name) {
+ TVMFuncGetGlobal(func_name,
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
+ }
+
+ bool operator()(ObjectRef input_1, ObjectRef input_2) {
+ TVMRetValue rv;
+ std::vector<TVMValue> values(2);
+ std::vector<int> codes(2);
+ runtime::TVMArgsSetter setter(values.data(), codes.data());
+ setter(0, input_1);
+ setter(1, input_2);
+ _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
+ return bool(rv);
+ };
+
+};
+
+TEST(Relay, AlphaTestEmptyTypeNodes) {
+ auto x = TypeVar("x", kTypeData);
+ auto y = TypeVar();
+ EXPECT_FALSE(relay::AlphaEqual(x, y));
+
+ TestAlphaEquals test_equals("relay._make._alpha_equal");
+ EXPECT_FALSE(test_equals(x, y));
+}
+
+int main(int argc, char ** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ testing::FLAGS_gtest_death_test_style = "threadsafe";
+ return RUN_ALL_TESTS();
+}
diff --git a/tests/python/relay/test_pass_alpha_equal.py
b/tests/python/relay/test_pass_alpha_equal.py
index 7e34f48..ec026be 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -28,6 +28,15 @@ def alpha_equal(x, y):
"""
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) ==
analysis.structural_hash(y)
+def alpha_equal_commutative(x, y):
+ """
+ Check for commutative property of equality
+ """
+ xy = analysis.alpha_equal(x, y)
+ yx = analysis.alpha_equal(y, x)
+ assert xy == yx
+ return xy
+
def test_tensor_type_alpha_equal():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
@@ -219,6 +228,26 @@ def test_constant_alpha_equal():
assert not alpha_equal(x, y)
assert alpha_equal(x, relay.const(1))
+def test_type_node_alpha_equal():
+ v1 = relay.TypeVar('v1', 6)
+ v2 = relay.TypeVar('v2', 6)
+ assert not alpha_equal(v1, v2)
+
+ v1 = relay.TypeVar('v1', 0)
+ v2 = relay.TypeVar('v2', 6)
+ assert not alpha_equal(v1, v2)
+
+ assert alpha_equal_commutative(v1, v1)
+
+def test_type_node_incompatible_alpha_equal():
+ v1 = relay.TypeVar('v1', 6)
+ v2 = relay.Var("v2")
+ assert not alpha_equal_commutative(v1, v2)
+
+def test_expr_node_incompatible_alpha_equal():
+ v1 = relay.Var("v1")
+ v2 = relay.PatternVar(relay.Var("v2"))
+ assert not alpha_equal_commutative(v1, v2)
def test_var_alpha_equal():
v1 = relay.Var("v1")
@@ -676,6 +705,9 @@ if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
+ test_type_node_alpha_equal()
+ test_type_node_incompatible_alpha_equal()
+ test_expr_node_incompatible_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()