This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 092b54830b add topK FInferCorrectLayout attr (#11849)
092b54830b is described below
commit 092b54830bd2eb77a100d0d7fed0039288d12a57
Author: ah cheng <[email protected]>
AuthorDate: Fri Jun 24 05:23:14 2022 +0800
add topK FInferCorrectLayout attr (#11849)
---
src/relay/op/algorithm/topk.cc | 37 +++++++++++++++++++++
tests/python/relay/test_pass_convert_op_layout.py | 40 +++++++++++++++++++++++
2 files changed, 77 insertions(+)
diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc
index c1d3e54727..c9f0a4396b 100644
--- a/src/relay/op/algorithm/topk.cc
+++ b/src/relay/op/algorithm/topk.cc
@@ -23,13 +23,49 @@
*/
#include <tvm/relay/attrs/algorithm.h>
#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
#include <tvm/tir/op.h>
+#include "../../transforms/infer_layout_utils.h"
+
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(TopKAttrs);
+InferCorrectLayoutOutput TopKInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>&
new_in_layouts,
+ const Array<Layout>&
old_in_layouts,
+ const Array<tvm::relay::Type>&
old_in_types) {
+ const auto* attrs_ptr = attrs.as<TopKAttrs>();
+ ICHECK(attrs_ptr);
+ ObjectPtr<TopKAttrs> param = make_object<TopKAttrs>(*attrs_ptr);
+
+ Array<Array<IndexExpr>> old_in_shapes;
+ for (auto old_in_t : old_in_types) {
+ ICHECK(old_in_t.as<TensorTypeNode>());
+ old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
+ }
+
+ size_t axis =
+ param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);
+
+ Layout ret = Layout::Undef();
+
+ // If new_in_layouts are defined, this code tries to modify the layout.
+ if (new_in_layouts.defined() && old_in_layouts.defined()) {
+ const auto& sp_dim = old_in_layouts[0][axis];
+ auto new_index = new_in_layouts[0].IndexOf(sp_dim);
+ param->axis = new_index;
+ ret = new_in_layouts[0];
+ } else if (old_in_layouts.defined()) {
+ ret = old_in_layouts[0];
+ }
+
+ // TopK has 2 outputs, Values and Indices
+ return InferCorrectLayoutOutput({ret}, {ret, ret}, Attrs(param));
+}
+
bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
@@ -89,6 +125,7 @@ RELAY_REGISTER_OP("topk")
.set_num_inputs(1)
.set_attrs_type<TopKAttrs>()
.add_argument("data", "Tensor", "Input data.")
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
TopKInferCorrectLayout)
.set_support_level(6)
.add_type_rel("TopK", TopKRel);
diff --git a/tests/python/relay/test_pass_convert_op_layout.py
b/tests/python/relay/test_pass_convert_op_layout.py
index d35259fb82..894d19a9fc 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1761,6 +1761,46 @@ def test_conv_strided_slice_axes_convert_layout():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+def test_conv_topk_convert_layout():
+ def before():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight = relay.var("weight", shape=(3, 3, 64, 64))
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ y = relay.topk(y, k=2, axis=2)
+ if isinstance(y, relay.expr.TupleWrapper):
+ y = y.astuple()
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight = relay.var("weight", shape=(3, 3, 64, 64))
+ weight = relay.layout_transform(weight, "HWIO", "OIHW")
+ x = relay.layout_transform(x, "NHWC", "NCHW")
+ y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3),
padding=(1, 1))
+ y = relay.topk(y, k=2, axis=3).astuple()
+ a = relay.TupleGetItem(y, 0)
+ b = relay.TupleGetItem(y, 1)
+ a = relay.layout_transform(a, "NCHW", "NHWC")
+ b = relay.layout_transform(b, "NCHW", "NHWC")
+ out = relay.Tuple([a, b])
+ return relay.Function(analysis.free_vars(out), out)
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW",
"default"]}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
def test_conv_roi_pool_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))