This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 092b54830b add topK FInferCorrectLayout attr (#11849)
092b54830b is described below

commit 092b54830bd2eb77a100d0d7fed0039288d12a57
Author: ah cheng <[email protected]>
AuthorDate: Fri Jun 24 05:23:14 2022 +0800

    add topK FInferCorrectLayout attr (#11849)
---
 src/relay/op/algorithm/topk.cc                    | 37 +++++++++++++++++++++
 tests/python/relay/test_pass_convert_op_layout.py | 40 +++++++++++++++++++++++
 2 files changed, 77 insertions(+)

diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc
index c1d3e54727..c9f0a4396b 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/algorithm/topk.cc
@@ -23,13 +23,49 @@
  */
 #include <tvm/relay/attrs/algorithm.h>
 #include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
 #include <tvm/tir/op.h>
 
+#include "../../transforms/infer_layout_utils.h"
+
 namespace tvm {
 namespace relay {
 
 TVM_REGISTER_NODE_TYPE(TopKAttrs);
 
+InferCorrectLayoutOutput TopKInferCorrectLayout(const Attrs& attrs,
+                                                const Array<Layout>& 
new_in_layouts,
+                                                const Array<Layout>& 
old_in_layouts,
+                                                const Array<tvm::relay::Type>& 
old_in_types) {
+  const auto* attrs_ptr = attrs.as<TopKAttrs>();
+  ICHECK(attrs_ptr);
+  ObjectPtr<TopKAttrs> param = make_object<TopKAttrs>(*attrs_ptr);
+
+  Array<Array<IndexExpr>> old_in_shapes;
+  for (auto old_in_t : old_in_types) {
+    ICHECK(old_in_t.as<TensorTypeNode>());
+    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
+  }
+
+  size_t axis =
+      param->axis < 0 ? param->axis + old_in_shapes[0].size() : 
static_cast<size_t>(param->axis);
+
+  Layout ret = Layout::Undef();
+
+  // If new_in_layouts are defined, this code tries to modify the layout.
+  if (new_in_layouts.defined() && old_in_layouts.defined()) {
+    const auto& sp_dim = old_in_layouts[0][axis];
+    auto new_index = new_in_layouts[0].IndexOf(sp_dim);
+    param->axis = new_index;
+    ret = new_in_layouts[0];
+  } else if (old_in_layouts.defined()) {
+    ret = old_in_layouts[0];
+  }
+
+  // TopK has 2 outputs, Values and Indices
+  return InferCorrectLayoutOutput({ret}, {ret, ret}, Attrs(param));
+}
+
 bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
              const TypeReporter& reporter) {
   // `types` contains: [data, result]
@@ -89,6 +125,7 @@ RELAY_REGISTER_OP("topk")
     .set_num_inputs(1)
     .set_attrs_type<TopKAttrs>()
     .add_argument("data", "Tensor", "Input data.")
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
TopKInferCorrectLayout)
     .set_support_level(6)
     .add_type_rel("TopK", TopKRel);
 
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index d35259fb82..894d19a9fc 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1761,6 +1761,46 @@ def test_conv_strided_slice_axes_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_topk_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.topk(y, k=2, axis=2)
+        if isinstance(y, relay.expr.TupleWrapper):
+            y = y.astuple()
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), 
padding=(1, 1))
+        y = relay.topk(y, k=2, axis=3).astuple()
+        a = relay.TupleGetItem(y, 0)
+        b = relay.TupleGetItem(y, 1)
+        a = relay.layout_transform(a, "NCHW", "NHWC")
+        b = relay.layout_transform(b, "NCHW", "NHWC")
+        out = relay.Tuple([a, b])
+        return relay.Function(analysis.free_vars(out), out)
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", 
"default"]}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
 def test_conv_roi_pool_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 64, 56, 56))

Reply via email to