This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8910f72  [ConvertLayout] Support transpose (#7214)
8910f72 is described below

commit 8910f72a38288f09d9e12163095f1675a9ccee83
Author: Cody Yu <[email protected]>
AuthorDate: Thu Jan 7 15:02:19 2021 -0800

    [ConvertLayout] Support transpose (#7214)
    
    * [ConvertLayout] Support transpose
    
    * format
    
    * fix ci
    
    * fix axes missing
    
    * fix
    
    * fix NCHW[x]c
    
    * Update src/relay/op/tensor/transform.cc
    
    * fix negative
    
    * fix
---
 src/relay/op/tensor/transform.cc                  | 75 ++++++++++++++++++++
 tests/python/relay/test_pass_convert_op_layout.py | 85 +++++++++++++++++++++++
 2 files changed, 160 insertions(+)

diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 1ff428c..ecfde35 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -419,6 +419,80 @@ bool TransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   return true;
 }
 
+Array<Array<Layout>> TransposeInferCorrectLayout(const Attrs& attrs,
+                                                 const Array<Layout>& 
new_in_layouts,
+                                                 const Array<Layout>& 
old_in_layouts,
+                                                 const 
Array<tvm::relay::Type>& old_in_types) {
+  // Discard "const" qualifier.
+  auto* params = const_cast<TransposeAttrs*>(attrs.as<TransposeAttrs>());
+  ICHECK(params != nullptr);
+
+  std::string in_layout_str = "";
+  std::string out_layout_str = "";
+
+  // Infer the input layout string and update the axes.
+  if (old_in_layouts.defined() && old_in_layouts[0].defined()) {
+    ICHECK_EQ(old_in_layouts.size(), 1);
+    auto old_layout = old_in_layouts[0];
+    Array<Integer> old_axes = params->axes;
+
+    // Deal with default axes and negative axes.
+    if (!old_axes.defined() || old_axes.size() == 0) {
+      for (int i = old_layout.ndim() - 1; i >= 0; --i) {
+        old_axes.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < old_axes.size(); ++i) {
+      int axis = static_cast<int>(old_axes[i]->value);
+      if (axis < 0) {
+        int pos_axis = static_cast<int>(old_layout.ndim()) + axis;
+        old_axes.Set(i, pos_axis);
+      }
+    }
+
+    if (new_in_layouts.defined() && new_in_layouts[0].defined()) {
+      ICHECK_EQ(new_in_layouts.size(), 1);
+      auto new_layout = new_in_layouts[0];
+
+      // Update the axes based on the new layout.
+      Array<Integer> new_axes = Array<Integer>();
+      for (auto axis : old_axes) {
+        auto new_axis = new_layout.IndexOf(old_layout[axis->value]);
+        if (new_axis == -1) {  // Cannot find the target axis in the new 
layout.
+          new_axes.clear();
+          break;
+        }
+        new_axes.push_back(new_axis);
+      }
+      if (new_axes.defined() && new_axes.size() == new_layout.ndim()) {
+        params->axes = std::move(new_axes);
+        in_layout_str = new_layout.name();
+      }
+    }
+
+    // If the input layout string cannot be determined, propagate the old 
layout.
+    if (in_layout_str == "") {
+      params->axes = std::move(old_axes);
+      in_layout_str = old_layout.name();
+    }
+  }
+
+  // Infer the output layout string based on the input layout and the axes.
+  if (in_layout_str != "") {
+    for (auto axis : params->axes) {
+      ICHECK_LT(axis->value, in_layout_str.length());
+      out_layout_str += in_layout_str[axis->value];
+    }
+    try {
+      return Array<Array<Layout>>({{Layout(in_layout_str)}, 
{Layout(out_layout_str)}});
+    } catch (const dmlc::Error& e) {
+      // If the layout string is invalid for any reason, give up.
+      return Array<Array<Layout>>({{Layout::Undef()}, {Layout::Undef()}});
+    }
+  }
+  return Array<Array<Layout>>({{Layout::Undef()}, {Layout::Undef()}});
+}
+
 Array<te::Tensor> TransposeCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
                                    const Type& out_type) {
   const auto* param = attrs.as<TransposeAttrs>();
@@ -449,6 +523,7 @@ RELAY_REGISTER_OP("transpose")
     .set_support_level(3)
     .add_type_rel("Transpose", TransposeRel)
     .set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
TransposeInferCorrectLayout)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
 /* relay.reshape */
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index 4c4bb9d..ca2469e 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -568,6 +568,90 @@ def test_slice_like_convert_layout():
     verify_slice_like(after, [1, 2])
 
 
+def test_transpose_convert_layout():
+    def verify_transpose(after, expected_axes, expected_transform_cnt):
+        # Verify if the transpose after the convert layout has the expected 
axes.
+        has_expected = list()
+        checker = lambda x: has_expected.append(
+            isinstance(x, tvm.relay.expr.Call)
+            and x.op.name == "transpose"
+            and str(x.attrs.axes) == str(expected_axes)
+        )
+        relay.analysis.post_order_visit(after, checker)
+        assert any(has_expected), after
+
+        is_transform = list()
+        checker = lambda x: is_transform.append(
+            1 if isinstance(x, tvm.relay.expr.Call) and x.op.name == 
"layout_transform" else 0
+        )
+        relay.analysis.post_order_visit(after, checker)
+        assert (
+            sum(is_transform) == expected_transform_cnt
+        ), "Expected %s layout_transform, but get\n%s" % 
(expected_transform_cnt, after)
+
+    def nhwc_to_nchw():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        z = relay.var("z", shape=(56, 56, 32))
+        out = relay.add(y, z)
+        out = relay.transpose(out, axes=[0, 3, 1, 2])
+        out = relay.nn.batch_flatten(out)
+        func = relay.Function(analysis.free_vars(out), out)
+        return run_opt_pass(func, transform.ConvertLayout({"nn.conv2d": 
["NCHW", "default"]}))
+
+    verify_transpose(nhwc_to_nchw(), [0, 1, 2, 3], 3)
+
+    def nchw_to_nhwc():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(32, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        z = relay.var("z", shape=(32, 56, 56))
+        out = relay.add(y, z)
+        out = relay.transpose(out, axes=[0, 2, -1, 1])  # Also test a negative 
axis.
+        out = relay.nn.batch_flatten(out)
+        func = relay.Function(analysis.free_vars(out), out)
+        return run_opt_pass(func, transform.ConvertLayout({"nn.conv2d": 
["NHWC", "default"]}))
+
+    verify_transpose(nchw_to_nhwc(), [0, 1, 2, 3], 3)
+
+    def default_axes():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(32, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        z = relay.var("z", shape=(32, 56, 56))
+        out = relay.add(y, z)
+        out = relay.transpose(out)  # No axes provided, will use the reversed 
axes.
+        func = relay.Function(analysis.free_vars(out), out)
+        return run_opt_pass(func, transform.ConvertLayout({"nn.conv2d": 
["NHWC", "default"]}))
+
+    verify_transpose(default_axes(), [2, 1, 3, 0], 3)
+
+
 def test_resnet_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
@@ -1482,6 +1566,7 @@ if __name__ == "__main__":
     test_dual_path_convert_layout()
     test_bn_convert_layout()
     test_slice_like_convert_layout()
+    test_transpose_convert_layout()
     test_resnet_convert_layout()
     test_scalar_convert_layout()
     test_conv_bn_convert_layout()

Reply via email to