This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 35fdf8b16c [relay][qnn]: Fix qnn.avg_pool2d layout inference (#17339)
35fdf8b16c is described below
commit 35fdf8b16c3cad396dc2d21efe2bc0fc871a2285
Author: Krishna Bindumadhavan <[email protected]>
AuthorDate: Mon Sep 9 00:33:12 2024 +0530
[relay][qnn]: Fix qnn.avg_pool2d layout inference (#17339)
---
src/relay/qnn/op/avg_pool2d.cc | 8 ++-
tests/python/relay/test_pass_convert_op_layout.py | 79 +++++++++++++++++++++++
2 files changed, 84 insertions(+), 3 deletions(-)
diff --git a/src/relay/qnn/op/avg_pool2d.cc b/src/relay/qnn/op/avg_pool2d.cc
index b2dc08b856..e1a28169cc 100644
--- a/src/relay/qnn/op/avg_pool2d.cc
+++ b/src/relay/qnn/op/avg_pool2d.cc
@@ -132,9 +132,11 @@ InferCorrectLayoutOutput
QnnAvgPoolInferCorrectLayout(const Attrs& attrs,
auto avgpool_new_layouts =
PoolInferCorrectLayout<AvgPool2DAttrs>(attrs, new_in_layouts,
old_in_layouts, old_in_types);
- // Scales and zero points are scalars, use the "undef" layout for them.
- Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0],
Layout::Undef(),
- Layout::Undef(), Layout::Undef(),
Layout::Undef()};
+ // Scales and zero points are scalars, the layouts of these tensors can be
treated as channel
+ // layout.
+ Layout channel_layout = Layout("C");
+ Array<Layout> input_layouts = {avgpool_new_layouts->input_layouts[0],
channel_layout,
+ channel_layout, channel_layout,
channel_layout};
Array<Layout> output_layouts = avgpool_new_layouts->output_layouts;
return InferCorrectLayoutOutput(input_layouts, output_layouts, attrs);
}
diff --git a/tests/python/relay/test_pass_convert_op_layout.py
b/tests/python/relay/test_pass_convert_op_layout.py
index 49afe492a1..5450f1aa69 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1542,6 +1542,85 @@ def test_conv_convert_kernel_layout():
tvm.ir.assert_structural_equal(a, b)
+def test_qnn_conv_avgpool_2d_convert_layout():
+ def before():
+ x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+ weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
+ y = relay.qnn.op.conv2d(
+ x,
+ weight,
+ relay.const(1, "int32"),
+ relay.const(1, "int32"),
+ relay.const(1, "float32"),
+ relay.const(1, "float32"),
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ y = relay.cast(y, "int8")
+ y = relay.qnn.op.avg_pool2d(
+ y,
+ relay.const(1, "float32"),
+ relay.const(1, "int32"),
+ relay.const(1, "float32"),
+ relay.const(1, "int32"),
+ layout="NHWC",
+ out_layout="NHWC",
+ pool_size=(3, 3),
+ padding=(0, 0),
+ strides=(1, 1),
+ dilation=(1, 1),
+ )
+ y = relay.Function([x, weight], y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
+ weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
+ x = relay.layout_transform(x, "NHWC", "NCHW")
+ weight = relay.layout_transform(weight, "HWIO", "OIHW")
+ y = relay.qnn.op.conv2d(
+ x,
+ weight,
+ relay.const(1, "int32"),
+ relay.const(1, "int32"),
+ relay.const(1, "float32"),
+ relay.const(1, "float32"),
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ )
+ y = relay.cast(y, "int8")
+ y = relay.qnn.op.avg_pool2d(
+ y,
+ relay.const(1, "float32"),
+ relay.const(1, "int32"),
+ relay.const(1, "float32"),
+ relay.const(1, "int32"),
+ layout="NCHW",
+ out_layout="NCHW",
+ pool_size=(3, 3),
+ padding=(0, 0),
+ strides=(1, 1),
+ dilation=(1, 1),
+ )
+ y = relay.layout_transform(y, "NCHW", "NHWC")
+ y = relay.Function(relay.analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(
+ a, transform.ConvertLayout({"qnn.conv2d": ["NCHW", "default"],
"qnn.avg_pool2d": ["NCHW"]})
+ )
+ b = run_opt_pass(expected(), transform.InferType())
+
+ tvm.ir.assert_structural_equal(a, b)
+
+
def test_conv_roi_align_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))