This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6832309  [AlterLayout] Respect input layout for dense op if explicitly 
specified (#9535)
6832309 is described below

commit 6832309d10d20ef77a4323c1ee677538e858ccfe
Author: masahi <[email protected]>
AuthorDate: Fri Nov 19 13:25:57 2021 +0900

    [AlterLayout] Respect input layout for dense op if explicitly specified 
(#9535)
---
 src/relay/op/nn/nn.cc                           | 12 ++++++++++++
 tests/python/relay/test_pass_alter_op_layout.py | 12 ++++++++----
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index f334361..574ecc0 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -210,6 +210,10 @@ InferCorrectLayoutOutput DenseInferCorrectLayout(const 
Attrs& attrs,
                                                  const Array<Layout>& 
new_in_layouts,
                                                  const Array<Layout>& 
old_in_layouts,
                                                  const 
Array<tvm::relay::Type>& old_in_types) {
+  // Respect input layout, if explicitly specified (for example, "NW").
+  if (new_in_layouts.size() > 0 && new_in_layouts[0].defined()) {
+    return InferCorrectLayoutOutput({new_in_layouts[0], "NC"}, {"NC"}, attrs);
+  }
   return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
 }
 
@@ -279,6 +283,14 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const 
Attrs& attrs,
                                                      const 
Array<tvm::relay::Type>& old_in_types) {
   auto params = attrs.as<DensePackAttrs>();
   ICHECK(params);
+  // Respect input layout, if explicitly specified (for example, "NW").
+  // However, a packed layout such as "NC8c" is not supported by dense_pack 
op. For such cases,
+  // we insert a layout transform "NC8c" -> "NC".
+  // We do not expect to get a packed layout like "NW8w", which is not 
compatitble with "NC",
+  // since packing is always done on the "C" axis.
+  if (new_in_layouts.size() > 0 && new_in_layouts[0].defined() && 
new_in_layouts[0].ndim() == 2) {
+    return InferCorrectLayoutOutput({new_in_layouts[0], 
params->weight_layout}, {"NC"}, attrs);
+  }
   return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, 
attrs);
 }
 
diff --git a/tests/python/relay/test_pass_alter_op_layout.py 
b/tests/python/relay/test_pass_alter_op_layout.py
index ab36f79..7514a93 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -1305,19 +1305,23 @@ def test_alter_op_with_global_var():
 
 def test_alter_op_dense():
     def before():
-        x = relay.var("x", shape=(32, 64))
+        x = relay.var("x", shape=(32, 1, 128))
         weight = relay.var("weight", shape=(48, 64))
-        y = relay.nn.dense(x, weight)
+        avg1d = relay.nn.adaptive_avg_pool1d(x, [64])
+        squeeze = relay.squeeze(avg1d, axis=[1])
+        y = relay.nn.dense(squeeze, weight)
         y = relay.Function(analysis.free_vars(y), y)
         return y
 
     def expected():
-        x = relay.var("x", shape=(32, 64))
+        x = relay.var("x", shape=(32, 1, 128))
         weight = relay.var("weight", shape=(48, 64))
         target_layout = "NC16n"
         weight_transform = relay.layout_transform(weight, "NC", target_layout)
+        avg1d = relay.nn.adaptive_avg_pool1d(x, [64])
+        squeeze = relay.squeeze(avg1d, axis=[1])
         y = relay.nn.contrib_dense_pack(
-            x, weight_transform, target_layout, units=None, out_dtype="float32"
+            squeeze, weight_transform, target_layout, units=None, 
out_dtype="float32"
         )
         y = relay.Function(analysis.free_vars(y), y)
         return y

Reply via email to