This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 28e980125b [Quantization]: Update simulated_quantize to infer correct 
layout (#14875)
28e980125b is described below

commit 28e980125b1c25861b8462e99f340657085066d6
Author: Krishna Bindumadhavan <[email protected]>
AuthorDate: Thu May 18 16:15:00 2023 +0530

    [Quantization]: Update simulated_quantize to infer correct layout (#14875)
---
 src/relay/quantize/realize.cc                     | 21 +++++++++-
 tests/python/relay/test_pass_convert_op_layout.py | 49 +++++++++++++++++++++++
 2 files changed, 69 insertions(+), 1 deletion(-)

diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index 3c2f6eb96d..514be1f0a7 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -34,6 +34,7 @@
 #include "../op/annotation/annotation.h"
 #include "../qnn/utils.h"
 #include "../transforms/fold_constant.h"
+#include "../transforms/infer_layout_utils.h"
 #include "./quantize.h"
 
 namespace tvm {
@@ -155,8 +156,26 @@ Expr QuantizeRealize(const Call& ref_call, const 
Array<Expr>& new_args, const Ob
   return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32));
 }
 
+InferCorrectLayoutOutput SimQuantizeLayout(const Attrs& attrs, const 
Array<Layout>& new_in_layouts,
+                                           const Array<Layout>& old_in_layouts,
+                                           const Array<tvm::relay::Type>& 
old_in_types) {
+  Layout ret;
+
+  if (new_in_layouts.defined()) {
+    ICHECK_GE(new_in_layouts.size(), 1);
+    ret = new_in_layouts[0];
+  } else {
+    ICHECK_GE(old_in_layouts.size(), 1);
+    ret = old_in_layouts[0];
+  }
+  Layout channel_layout = Layout("C");
+  Array<Layout> input_layouts = {ret, channel_layout, channel_layout, 
channel_layout};
+  return InferCorrectLayoutOutput(input_layouts, {ret}, attrs);
+}
+
 RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
-    .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
+    .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", SimQuantizeLayout);
 
 Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const 
ObjectRef& ctx) {
   const QConfig& cfg = QConfig::Current();
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index 72d0232100..c3d579186d 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -21,6 +21,10 @@ from tvm import relay, te
 from tvm.relay import analysis, transform
 from tvm.relay.op import op as reg
 from tvm.relay.op import register_alter_op_layout
+from tvm.relay.quantize._annotate import (
+    attach_simulated_quantize,
+    QAnnotateKind,
+)
 from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput
 
 
@@ -2635,6 +2639,51 @@ def test_conv_max_pool_uses_specified_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
 
 
+def test_simulated_quantize_uses_specified_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        y = attach_simulated_quantize(y, QAnnotateKind.INPUT)
+        y = relay.nn.relu(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        x = relay.layout_transform(x, "NCHW", "NHWC")
+        weight = relay.layout_transform(weight, "OIHW", "OHWI")
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="OHWI",
+        )
+        y = attach_simulated_quantize(y, QAnnotateKind.INPUT)
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NHWC", "NCHW")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", 
"OHWI"]}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
+
+
 @pytest.mark.parametrize(
     "data_layout, kernel_layout",
     [

Reply via email to