This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ff2a5e796 add layerNormal infer layout (#11784)
9ff2a5e796 is described below

commit 9ff2a5e796b9cac0bdf05bfd600c50ade9728f1f
Author: ah cheng <[email protected]>
AuthorDate: Tue Jun 21 09:05:38 2022 +0800

    add layerNormal infer layout (#11784)
---
 src/relay/op/nn/nn.cc                             | 36 ++++++++++++++++
 tests/python/relay/test_pass_convert_op_layout.py | 50 +++++++++++++++++++++++
 2 files changed, 86 insertions(+)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index b8d48d9e9e..bf00ee5117 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -921,6 +921,41 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int 
axis, double epsilon, b
   return Call(op, {data, gamma, beta}, Attrs(attrs), {});
 }
 
+InferCorrectLayoutOutput LayerNormInferCorrectLayout(const Attrs& attrs,
+                                                     const Array<Layout>& 
new_in_layouts,
+                                                     const Array<Layout>& 
old_in_layouts,
+                                                     const 
Array<tvm::relay::Type>& old_in_types) {
+  const auto* attrs_ptr = attrs.as<LayerNormAttrs>();
+  ICHECK(attrs_ptr);
+  ObjectPtr<LayerNormAttrs> param = make_object<LayerNormAttrs>(*attrs_ptr);
+
+  Array<Array<IndexExpr>> old_in_shapes;
+  for (auto old_in_t : old_in_types) {
+    ICHECK(old_in_t.as<TensorTypeNode>());
+    old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
+  }
+
+  size_t axis =
+      param->axis < 0 ? param->axis + old_in_shapes[0].size() : 
static_cast<size_t>(param->axis);
+
+  Layout ret = Layout::Undef();
+
+  // If new_in_layouts are defined, this code tries to modify the layout.
+  if (new_in_layouts.defined() && old_in_layouts.defined()) {
+    // Get the new C axis. Extract the dim in old layout. Find the index of 
that dim in next layout.
+    const auto& ln_dim = old_in_layouts[0][axis];
+    auto new_index = new_in_layouts[0].IndexOf(ln_dim);
+    param->axis = new_index;
+    ret = new_in_layouts[0];
+  } else if (old_in_layouts.defined()) {
+    ret = old_in_layouts[0];
+  }
+
+  // LN has 3 inputs, 1 outputs. The last 2 inputs have "C" layout.
+  Layout c_layout = Layout("C");
+  return InferCorrectLayoutOutput({ret, c_layout, c_layout}, {ret}, 
Attrs(param));
+}
+
 
TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm);
 
 RELAY_REGISTER_OP("nn.layer_norm")
@@ -931,6 +966,7 @@ RELAY_REGISTER_OP("nn.layer_norm")
     .add_argument("data", "Tensor", "Input to which layer_norm will be 
applied.")
     .add_argument("gamma", "Tensor", "The gamma scale factor.")
     .add_argument("beta", "Tensor", "The beta offset factor.")
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
LayerNormInferCorrectLayout)
     .set_support_level(1)
     .add_type_rel("LayerNorm", LayerNormRel);
 
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index b2eb0bae57..7d093d4854 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -974,6 +974,56 @@ def test_scalar_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_ln_convert_layout():
+    """Check that layout transforms are propagated through ln."""
+
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+
+        dtype = "float32"
+        beta = relay.var("beta", relay.TensorType((64,), dtype))
+        gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+
+        y = relay.nn.layer_norm(y, gamma, beta, axis=3)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        w = relay.var("weight", shape=(3, 3, 64, 64))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        w = relay.layout_transform(w, "HWIO", "OIHW")
+        y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 
1))
+
+        dtype = "float32"
+        beta = relay.var("beta", relay.TensorType((64,), dtype))
+        gamma = relay.var("gamma", relay.TensorType((64,), dtype))
+
+        y = relay.nn.layer_norm(y, gamma, beta, axis=1)
+        y = relay.layout_transform(y, "NCHW", "NHWC")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", 
"default"]}))
+    print(a)
+    b = run_opt_pass(expected(), transform.InferType())
+    print(" ")
+    print(b)
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
 def test_conv_bn_convert_layout():
     """Check that layout transforms are propagated through bn."""
 

Reply via email to