This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6832309 [AlterLayout] Respect input layout for dense op if explicitly
specified (#9535)
6832309 is described below
commit 6832309d10d20ef77a4323c1ee677538e858ccfe
Author: masahi <[email protected]>
AuthorDate: Fri Nov 19 13:25:57 2021 +0900
[AlterLayout] Respect input layout for dense op if explicitly specified
(#9535)
---
src/relay/op/nn/nn.cc | 12 ++++++++++++
tests/python/relay/test_pass_alter_op_layout.py | 12 ++++++++----
2 files changed, 20 insertions(+), 4 deletions(-)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index f334361..574ecc0 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -210,6 +210,10 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const
Attrs& attrs,
const Array<Layout>&
new_in_layouts,
const Array<Layout>&
old_in_layouts,
const
Array<tvm::relay::Type>& old_in_types) {
+ // Respect input layout, if explicitly specified (for example, "NW").
+ if (new_in_layouts.size() > 0 && new_in_layouts[0].defined()) {
+ return InferCorrectLayoutOutput({new_in_layouts[0], "NC"}, {"NC"}, attrs);
+ }
return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
}
@@ -279,6 +283,14 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const
Attrs& attrs,
const
Array<tvm::relay::Type>& old_in_types) {
auto params = attrs.as<DensePackAttrs>();
ICHECK(params);
+ // Respect input layout, if explicitly specified (for example, "NW").
+ // However, a packed layout such as "NC8c" is not supported by dense_pack
op. For such cases,
+ // we insert a layout transform "NC8c" -> "NC".
+ // We do not expect to get a packed layout like "NW8w", which is not
compatitble with "NC",
+ // since packing is always done on the "C" axis.
+ if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() &&
new_in_layouts[0].ndim() == 2) {
+ return InferCorrectLayoutOutput({new_in_layouts[0],
params->weight_layout}, {"NC"}, attrs);
+ }
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"},
attrs);
}
diff --git a/tests/python/relay/test_pass_alter_op_layout.py
b/tests/python/relay/test_pass_alter_op_layout.py
index ab36f79..7514a93 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -1305,19 +1305,23 @@ def test_alter_op_with_global_var():
def test_alter_op_dense():
def before():
- x = relay.var("x", shape=(32, 64))
+ x = relay.var("x", shape=(32, 1, 128))
weight = relay.var("weight", shape=(48, 64))
- y = relay.nn.dense(x, weight)
+ avg1d = relay.nn.adaptive_avg_pool1d(x, [64])
+ squeeze = relay.squeeze(avg1d, axis=[1])
+ y = relay.nn.dense(squeeze, weight)
y = relay.Function(analysis.free_vars(y), y)
return y
def expected():
- x = relay.var("x", shape=(32, 64))
+ x = relay.var("x", shape=(32, 1, 128))
weight = relay.var("weight", shape=(48, 64))
target_layout = "NC16n"
weight_transform = relay.layout_transform(weight, "NC", target_layout)
+ avg1d = relay.nn.adaptive_avg_pool1d(x, [64])
+ squeeze = relay.squeeze(avg1d, axis=[1])
y = relay.nn.contrib_dense_pack(
- x, weight_transform, target_layout, units=None, out_dtype="float32"
+ squeeze, weight_transform, target_layout, units=None,
out_dtype="float32"
)
y = relay.Function(analysis.free_vars(y), y)
return y