This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5535d48 Add test for MergeComposite on a QNN graph (#7080)
5535d48 is described below
commit 5535d48843499aa51f2c8249f2a31565d7ee4d22
Author: Core.Halt <[email protected]>
AuthorDate: Fri Dec 11 16:32:04 2020 +0900
Add test for MergeComposite on a QNN graph (#7080)
---
tests/python/frontend/pytorch/qnn_test.py | 104 ++++++++++++++++++------------
1 file changed, 63 insertions(+), 41 deletions(-)
diff --git a/tests/python/frontend/pytorch/qnn_test.py
b/tests/python/frontend/pytorch/qnn_test.py
index 4b73959..07e52b7 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -31,10 +31,7 @@ import tvm.testing
from tvm import relay
from tvm.relay.frontend.pytorch_utils import is_version_greater_than
from tvm.contrib.download import download_testdata
-
-from tvm.relay.dataflow_pattern import wildcard, is_op
-from tvm.relay.op.contrib.register import register_pattern_table
-from tvm.relay.op.contrib.register import get_pattern_table
+from tvm.relay.op.contrib.register import register_pattern_table,
get_pattern_table
def torch_version_check():
@@ -43,47 +40,10 @@ def torch_version_check():
return version.parse(torch.__version__) > version.parse("1.4.0")
-def make_qnn_add_pattern():
- lhs = wildcard()
- rhs = wildcard()
- lhs_scale = wildcard()
- lhs_zero_point = wildcard()
- rhs_scale = wildcard()
- rhs_zero_point = wildcard()
- output_scale = wildcard()
- output_zero_point = wildcard()
- qadd = is_op("qnn.add")(
- lhs,
- rhs,
- lhs_scale,
- lhs_zero_point,
- rhs_scale,
- rhs_zero_point,
- output_scale,
- output_zero_point,
- )
- return qadd.optional(is_op("clip"))
-
-
-@register_pattern_table("test_table")
-def pattern_table():
- return [
- ("qnn_add", make_qnn_add_pattern()),
- ]
-
-
def get_tvm_runtime(script_module, input_name, ishape):
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
- pattern_table = get_pattern_table("test_table")
- with tvm.transform.PassContext(opt_level=3):
- pass_list = [
- tvm.relay.transform.SimplifyInference(),
- tvm.relay.transform.MergeComposite(pattern_table),
- ]
- composite_partition = tvm.transform.Sequential(pass_list)
- partitioned = composite_partition(mod)
with tvm.transform.PassContext(opt_level=3):
# test on only cpu for now, torch cannot run quant models on cuda
@@ -587,3 +547,65 @@ def test_quantize_dynamic():
# Outputs from v1.6 seem reliable. TVM's outputs are always the
same
if is_version_greater_than("1.5.1"):
tvm.testing.assert_allclose(tvm_result, pt_result, rtol=1e-4,
atol=1e-4)
+
+
+def make_qnn_add_pattern():
+ from tvm.relay.dataflow_pattern import wildcard, is_op
+
+ lhs = wildcard()
+ rhs = wildcard()
+ lhs_scale = wildcard()
+ lhs_zero_point = wildcard()
+ rhs_scale = wildcard()
+ rhs_zero_point = wildcard()
+ output_scale = wildcard()
+ output_zero_point = wildcard()
+ qadd = is_op("qnn.add")(
+ lhs,
+ rhs,
+ lhs_scale,
+ lhs_zero_point,
+ rhs_scale,
+ rhs_zero_point,
+ output_scale,
+ output_zero_point,
+ )
+ return qadd.optional(is_op("clip"))
+
+
+@register_pattern_table("test_table")
+def pattern_table():
+ return [
+ ("qnn_add", make_qnn_add_pattern()),
+ ]
+
+
+def run_qnn_mergecomposite(script_module, input_name, ishape):
+ input_shapes = [(input_name, ishape)]
+ mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+ pattern_table = get_pattern_table("test_table")
+ with tvm.transform.PassContext(opt_level=3):
+ pass_list = [
+ tvm.relay.transform.SimplifyInference(),
+ tvm.relay.transform.MergeComposite(pattern_table),
+ ]
+ composite_partition = tvm.transform.Sequential(pass_list)
+ partitioned = composite_partition(mod)
+
+
+def test_qnn_mergecomposite():
+ from torchvision.models.quantization import resnet as qresnet
+
+ model = qresnet.resnet18(pretrained=True)
+ model.eval()
+
+ inp = torch.zeros((1, 3, 224, 224))
+ model.fuse_model()
+ model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
+ torch.quantization.prepare(model, inplace=True)
+ model(inp)
+ torch.quantization.convert(model, inplace=True)
+ script_module = torch.jit.trace(model, inp).eval()
+
+ input_name = "image"
+ run_qnn_mergecomposite(script_module, input_name, inp.shape)