mbrookhart commented on a change in pull request #5930:
URL: https://github.com/apache/incubator-tvm/pull/5930#discussion_r445915192
##########
File path: tests/python/relay/test_dataflow_pattern.py
##########
@@ -1133,6 +1133,37 @@ def test_partition_double_batchnorm():
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)
+def test_overlappting_partitions():
+ x = wildcard()
+ gamma = wildcard()
+ beta = wildcard()
+ moving_mean = wildcard()
+ moving_var = wildcard()
+ bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var)
+ tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
+
+ x = relay.var('x')
+ var = relay.var('var')
+ mean = relay.var('mean')
+ beta = relay.var('beta')
+ gamma = relay.var('gamma')
+ BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
+ T1 = BN[0]
+ T2 = BN[0]
+ add = T1 + T2
+
+ assert tuple_get_item_node.partition(add) == add
Review comment:
Partitioning either path is invalid, since either side would need
intermediate nodes from the other, so we expect the original expression to come
back unchanged, thus the `==`. We treat it as a match, but we don't treat it as
something we can independently rewrite.
##########
File path: tests/python/relay/test_dataflow_pattern.py
##########
@@ -1133,6 +1133,37 @@ def test_partition_double_batchnorm():
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)
+def test_overlappting_partitions():
+ x = wildcard()
+ gamma = wildcard()
+ beta = wildcard()
+ moving_mean = wildcard()
+ moving_var = wildcard()
+ bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var)
+ tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
+
+ x = relay.var('x')
+ var = relay.var('var')
+ mean = relay.var('mean')
+ beta = relay.var('beta')
+ gamma = relay.var('gamma')
+ BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
+ T1 = BN[0]
+ T2 = BN[0]
+ add = T1 + T2
+
+ assert tuple_get_item_node.partition(add) == add
+
+def test_partition_overused():
+ pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
+
+ x = relay.var('input')
+ w = relay.var('weight')
+ conv2d = relay.op.nn.conv2d(x, w)
+ relu = relay.op.nn.relu(conv2d)
+ out = relu + conv2d
+
+ assert pattern.partition(out) == out
Review comment:
Again, fusing the conv and relu would make the rest of the expr invalid,
so we expect the expr to come back unchanged.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]