masahi commented on a change in pull request #7074:
URL: https://github.com/apache/tvm/pull/7074#discussion_r539766653
##########
File path: tests/python/frontend/pytorch/qnn_test.py
##########
@@ -32,17 +32,58 @@
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata
+from tvm.relay.dataflow_pattern import wildcard, is_op
+from tvm.relay.op.contrib.register import register_pattern_table
+from tvm.relay.op.contrib.register import get_pattern_table
+
def torch_version_check():
from packaging import version
return version.parse(torch.__version__) > version.parse("1.4.0")
+def make_qnn_add_pattern():
+ lhs = wildcard()
+ rhs = wildcard()
+ lhs_scale = wildcard()
+ lhs_zero_point = wildcard()
+ rhs_scale = wildcard()
+ rhs_zero_point = wildcard()
+ output_scale = wildcard()
+ output_zero_point = wildcard()
+ qadd = is_op("qnn.add")(
+ lhs,
+ rhs,
+ lhs_scale,
+ lhs_zero_point,
+ rhs_scale,
+ rhs_zero_point,
+ output_scale,
+ output_zero_point,
+ )
+ return qadd.optional(is_op("clip"))
+
+
+@register_pattern_table("test_table")
+def pattern_table():
+ return [
+ ("qnn_add", make_qnn_add_pattern()),
+ ]
+
+
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+ pattern_table = get_pattern_table("test_table")
Review comment:
This is added to make sure `MergeComposite` pass at L83 doesn't error,
it comes straight from the patch in https://github.com/apache/tvm/issues/7067.
I'll ask the issue reporter to clean this up and make a standalone test for
this issue.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]