This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9ff2a5e796 add layerNormal infer layout (#11784)
9ff2a5e796 is described below
commit 9ff2a5e796b9cac0bdf05bfd600c50ade9728f1f
Author: ah cheng <[email protected]>
AuthorDate: Tue Jun 21 09:05:38 2022 +0800
add layerNormal infer layout (#11784)
---
src/relay/op/nn/nn.cc | 36 ++++++++++++++++
tests/python/relay/test_pass_convert_op_layout.py | 50 +++++++++++++++++++++++
2 files changed, 86 insertions(+)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index b8d48d9e9e..bf00ee5117 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -921,6 +921,41 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int
axis, double epsilon, b
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}
+InferCorrectLayoutOutput LayerNormInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>&
new_in_layouts,
+ const Array<Layout>&
old_in_layouts,
+ const
Array<tvm::relay::Type>& old_in_types) {
+ const auto* attrs_ptr = attrs.as<LayerNormAttrs>();
+ ICHECK(attrs_ptr);
+ ObjectPtr<LayerNormAttrs> param = make_object<LayerNormAttrs>(*attrs_ptr);
+
+ Array<Array<IndexExpr>> old_in_shapes;
+ for (auto old_in_t : old_in_types) {
+ ICHECK(old_in_t.as<TensorTypeNode>());
+ old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
+ }
+
+ size_t axis =
+ param->axis < 0 ? param->axis + old_in_shapes[0].size() :
static_cast<size_t>(param->axis);
+
+ Layout ret = Layout::Undef();
+
+ // If new_in_layouts are defined, this code tries to modify the layout.
+ if (new_in_layouts.defined() && old_in_layouts.defined()) {
+ // Get the new C axis. Extract the dim in old layout. Find the index of
that dim in next layout.
+ const auto& ln_dim = old_in_layouts[0][axis];
+ auto new_index = new_in_layouts[0].IndexOf(ln_dim);
+ param->axis = new_index;
+ ret = new_in_layouts[0];
+ } else if (old_in_layouts.defined()) {
+ ret = old_in_layouts[0];
+ }
+
+ // LN has 3 inputs, 1 outputs. The last 2 inputs have "C" layout.
+ Layout c_layout = Layout("C");
+ return InferCorrectLayoutOutput({ret, c_layout, c_layout}, {ret},
Attrs(param));
+}
+
TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm);
RELAY_REGISTER_OP("nn.layer_norm")
@@ -931,6 +966,7 @@ RELAY_REGISTER_OP("nn.layer_norm")
.add_argument("data", "Tensor", "Input to which layer_norm will be
applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
LayerNormInferCorrectLayout)
.set_support_level(1)
.add_type_rel("LayerNorm", LayerNormRel);
diff --git a/tests/python/relay/test_pass_convert_op_layout.py
b/tests/python/relay/test_pass_convert_op_layout.py
index b2eb0bae57..7d093d4854 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -974,6 +974,56 @@ def test_scalar_convert_layout():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+def test_conv_ln_convert_layout():
+ """Check that layout transforms are propagated through ln."""
+
+ def before():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight = relay.var("weight", shape=(3, 3, 64, 64))
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+
+ dtype = "float32"
+ beta = relay.var("beta", relay.TensorType((64,), dtype))
+ gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+
+ y = relay.nn.layer_norm(y, gamma, beta, axis=3)
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ w = relay.var("weight", shape=(3, 3, 64, 64))
+ x = relay.layout_transform(x, "NHWC", "NCHW")
+ w = relay.layout_transform(w, "HWIO", "OIHW")
+ y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1,
1))
+
+ dtype = "float32"
+ beta = relay.var("beta", relay.TensorType((64,), dtype))
+ gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+
+ y = relay.nn.layer_norm(y, gamma, beta, axis=1)
+ y = relay.layout_transform(y, "NCHW", "NHWC")
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW",
"default"]}))
+ print(a)
+ b = run_opt_pass(expected(), transform.InferType())
+ print(" ")
+ print(b)
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
def test_conv_bn_convert_layout():
"""Check that layout transforms are propagated through bn."""