This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 636463d16c Consider pad value and input zero point in 
FoldExplicitPading (#11127)
636463d16c is described below

commit 636463d16c8f1713a3d93793b60d21dde9b6a6f7
Author: ibsidorenko <[email protected]>
AuthorDate: Fri May 13 22:58:14 2022 +0300

    Consider pad value and input zero point in FoldExplicitPading (#11127)
    
    This commit adds the following:
    Do not fold `nn.pad` and `qnn.conv2d` if padding value is not
    equal to input zero point of qnn operation. Added unit test
    to check such behaviour.
---
 src/relay/transforms/fold_explicit_padding.cc      | 15 +++++++----
 .../relay/test_pass_fold_explicit_padding.py       | 29 ++++++++++++++++++++++
 2 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/src/relay/transforms/fold_explicit_padding.cc 
b/src/relay/transforms/fold_explicit_padding.cc
index 6aac995e35..60b52c170a 100644
--- a/src/relay/transforms/fold_explicit_padding.cc
+++ b/src/relay/transforms/fold_explicit_padding.cc
@@ -129,22 +129,27 @@ class SimplifyConvPad {
     ICHECK(pad_node);
     const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
     ICHECK(param);
-    Array<Expr> args = pad_node->args;
 
     auto x = node_map[x_][0];
     auto w = node_map[w_][0];
 
     // Possibly perform more optimizations if the pad_value is 0
-    const ConstantNode* pad_value = args[1].as<ConstantNode>();
+    const Expr& pv = pad_node->args[1];
+    const ConstantNode* pad_value = pv.as<ConstantNode>();
     if (node_map.find(qconv2d_) != node_map.end()) {
       Attrs attrs = GetAttrs(param, call_node->attrs.as<Conv2DAttrs>());
       auto input_zero_point = node_map[input_zero_point_][0];
       auto kernel_zero_point = node_map[kernel_zero_point_][0];
       auto input_scale = node_map[input_scale_][0];
       auto kernel_scale = node_map[kernel_scale_][0];
-      return Call(call_node->op,
-                  {x, w, input_zero_point, kernel_zero_point, input_scale, 
kernel_scale}, attrs,
-                  call_node->type_args, call_node->span);
+      // Fold Padding and QNN Convolution only if pad value == input zero 
point.
+      if (IsEqualScalar(input_zero_point, pv)) {
+        return Call(call_node->op,
+                    {x, w, input_zero_point, kernel_zero_point, input_scale, 
kernel_scale}, attrs,
+                    call_node->type_args, call_node->span);
+      } else {
+        return post;
+      }
     } else if (param->pad_mode == "constant" && pad_value && 
ToScalar(pad_value->data) == 0.0) {
       Attrs attrs;
       if (node_map.count(conv1d_)) {
diff --git a/tests/python/relay/test_pass_fold_explicit_padding.py 
b/tests/python/relay/test_pass_fold_explicit_padding.py
index 48b5e510d0..2887c0774b 100644
--- a/tests/python/relay/test_pass_fold_explicit_padding.py
+++ b/tests/python/relay/test_pass_fold_explicit_padding.py
@@ -144,6 +144,35 @@ def fold_pad_qconv2d():
     assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + 
str(a)
 
 
+def test_pad_qconv2d_no_fold():
+    def get_expr():
+        x = relay.var("x", shape=(1, 1, 2, 2), dtype="int8")
+        weight = relay.var("weight", shape=(1, 1, 2, 2), dtype="int8")
+        # Pad value and input zp are not equal
+        pad_value = 1
+        input_zero_point = 0
+        pad = relay.nn.pad(x, [[0, 0], [0, 0], [1, 1], [1, 1]], 
pad_value=pad_value)
+        return relay.qnn.op.conv2d(
+            pad,
+            weight,
+            relay.const(input_zero_point, "int32"),
+            relay.const(0, "int32"),
+            relay.const(1, "float32"),
+            relay.const(1, "float32"),
+            channels=1,
+            kernel_size=(2, 2),
+            padding=(0, 0),
+        )
+
+    a = run_opt_pass(get_expr(), relay.transform.FoldExplicitPadding())
+    b = run_opt_pass(get_expr(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b, map_free_vars=True), (
+        "\nActual = \n" + str(a) + "\nExpected = \n" + str(b)
+    )
+
+
 if __name__ == "__main__":
     test_simplify_conv_pad()
     fold_pad_qconv2d()
+    test_pad_qconv2d_no_fold()

Reply via email to