This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 28e980125b [Quantization]: Update simulated_quantize to infer correct
layout (#14875)
28e980125b is described below
commit 28e980125b1c25861b8462e99f340657085066d6
Author: Krishna Bindumadhavan <[email protected]>
AuthorDate: Thu May 18 16:15:00 2023 +0530
[Quantization]: Update simulated_quantize to infer correct layout (#14875)
---
src/relay/quantize/realize.cc | 21 +++++++++-
tests/python/relay/test_pass_convert_op_layout.py | 49 +++++++++++++++++++++++
2 files changed, 69 insertions(+), 1 deletion(-)
diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index 3c2f6eb96d..514be1f0a7 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -34,6 +34,7 @@
#include "../op/annotation/annotation.h"
#include "../qnn/utils.h"
#include "../transforms/fold_constant.h"
+#include "../transforms/infer_layout_utils.h"
#include "./quantize.h"
namespace tvm {
@@ -155,8 +156,26 @@ Expr QuantizeRealize(const Call& ref_call, const
Array<Expr>& new_args, const Ob
return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32));
}
+InferCorrectLayoutOutput SimQuantizeLayout(const Attrs& attrs, const
Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>&
old_in_types) {
+ Layout ret;
+
+ if (new_in_layouts.defined()) {
+ ICHECK_GE(new_in_layouts.size(), 1);
+ ret = new_in_layouts[0];
+ } else {
+ ICHECK_GE(old_in_layouts.size(), 1);
+ ret = old_in_layouts[0];
+ }
+ Layout channel_layout = Layout("C");
+ Array<Layout> input_layouts = {ret, channel_layout, channel_layout,
channel_layout};
+ return InferCorrectLayoutOutput(input_layouts, {ret}, attrs);
+}
+
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
- .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
+ .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", SimQuantizeLayout);
Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const
ObjectRef& ctx) {
const QConfig& cfg = QConfig::Current();
diff --git a/tests/python/relay/test_pass_convert_op_layout.py
b/tests/python/relay/test_pass_convert_op_layout.py
index 72d0232100..c3d579186d 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -21,6 +21,10 @@ from tvm import relay, te
from tvm.relay import analysis, transform
from tvm.relay.op import op as reg
from tvm.relay.op import register_alter_op_layout
+from tvm.relay.quantize._annotate import (
+ attach_simulated_quantize,
+ QAnnotateKind,
+)
from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput
@@ -2635,6 +2639,51 @@ def test_conv_max_pool_uses_specified_convert_layout():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n
Expected = \n" + str(b)
+def test_simulated_quantize_uses_specified_convert_layout():
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var("weight", shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ )
+ y = attach_simulated_quantize(y, QAnnotateKind.INPUT)
+ y = relay.nn.relu(y)
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight = relay.var("weight", shape=(64, 64, 3, 3))
+ x = relay.layout_transform(x, "NCHW", "NHWC")
+ weight = relay.layout_transform(weight, "OIHW", "OHWI")
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ )
+ y = attach_simulated_quantize(y, QAnnotateKind.INPUT)
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, "NHWC", "NCHW")
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC",
"OHWI"]}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n
Expected = \n" + str(b)
+
+
@pytest.mark.parametrize(
"data_layout, kernel_layout",
[