This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9817338508 [BYOC][DNNL] Enable layer normalization in DNNL byoc. 
(#11508)
9817338508 is described below

commit 9817338508f3f8cd5a444133b4de99ce577c031b
Author: billishyahao <[email protected]>
AuthorDate: Thu Jun 9 03:12:36 2022 +0800

    [BYOC][DNNL] Enable layer normalization in DNNL byoc. (#11508)
    
    * Enable layer normalization in DNNL byoc.
    
    * Added unittest for layer norm and make code compatible after introducing 
TensorRequisite(PR-11345)
    
    * Fix lint issue
    
    * Fix clang format issue
---
 python/tvm/relay/op/contrib/dnnl.py           | 70 ++++++++++++++++++++++++++-
 src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 47 ++++++++++++++++++
 tests/python/contrib/test_dnnl.py             | 21 ++++++++
 3 files changed, 137 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/contrib/dnnl.py 
b/python/tvm/relay/op/contrib/dnnl.py
index 2e975cf49c..c87a7162b0 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -41,7 +41,7 @@ from tvm.relay.expr import GlobalVar
 from tvm.relay.expr_functor import ExprMutator, ExprVisitor
 
 from ... import _ffi_api
-from ...dataflow_pattern import wildcard, is_op
+from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, 
DFPatternCallback
 from .register import register_pattern_table
 
 logger = logging.getLogger("DNNL")
@@ -92,6 +92,7 @@ _register_external_op_helper("sigmoid")
 _register_external_op_helper("nn.softmax")
 _register_external_op_helper("add")
 _register_external_op_helper("multiply")
+_register_external_op_helper("nn.layer_norm")
 
 
 def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
@@ -455,6 +456,7 @@ class IsComputeIntensiveGraph(ExprVisitor):
                 "nn.conv3d",
                 "nn.conv3d_transpose",
                 "nn.dense",
+                "nn.layer_norm",
             ]
         )
         if isinstance(call.op, tvm.tir.op.Op):
@@ -526,3 +528,69 @@ def prune_dnnl_subgraphs(mod):
     new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod, 
new_mod).visit(mod["main"])
     new_mod = transform.RemoveUnusedFunctions()(new_mod)
     return new_mod
+
+
+class LayerNormRewrite(DFPatternCallback):
+    """
+    A callback to rewrite the following operators into a single layer 
normalization operator.
+
+    Pattern #1:
+    1   %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), 
float32] */;
+    2   %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
+    3   %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
+    4   %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64), 
float32] */;
+    5   %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1), 
float32] */;
+    6   %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1), 
float32] */;
+    7   %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
+    8   %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
+    9   %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64), 
float32] */)
+            /* ty=Tensor[(1, 3136, 64), float32] */;
+    10   %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
+            /* ty=Tensor[(1, 3136, 64), float32] */;
+
+    Pattern #2:
+    1   %0 = mean(%input, axis=[-1], keepdims=True);
+    2   %1 = variance(%input, %0, axis=[-1], keepdims=True);
+    3   %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1), 
float32] */;
+    4   %3 = subtract(%input, %0);
+    5   %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */;
+    6   %5 = divide(%3, %4);
+    7   %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32] 
*/)
+            /* ty=Tensor[(1, 49, 64), float32] */;
+    8   %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */)
+            /* ty=Tensor[(1, 49, 64), float32] */
+
+    """
+
+    def __init__(self):
+        super(LayerNormRewrite, self).__init__()
+        self.data = wildcard()
+        self.gamma = wildcard()
+        self.beta = wildcard()
+        mu = is_op("mean")(self.data)
+        diff = is_op("subtract")(self.data, mu)
+        cdiff = diff | is_op("cast")(diff)
+        const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
+        p1 = is_op("power")(cdiff, const_two)
+        mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
+        eps = is_expr(relay.const(1e-5))
+        added_eps = is_op("add")(mp1, eps)
+        deno = is_op("sqrt")(added_eps)
+        div_out = is_op("divide")(diff, deno)
+        weighted = is_op("multiply")(div_out, self.gamma)
+        added_bias = is_op("add")(weighted, self.beta)
+        self.pattern = added_bias
+
+    def callback(self, pre, post, node_map):
+        data = node_map[self.data][0]
+        gamma = node_map[self.gamma][0]
+        beta = node_map[self.beta][0]
+        return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
+
+
+def rewrite_layer_norm(mod):
+    """Rewrite the input graph to replace multiple operators with a TVM native 
layer normalization
+    operator so that we can offload them to dnnl layer normalization byoc part.
+    """
+    mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
+    return mod
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc 
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index a2417f012e..db8f25e2a6 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
           Binary(nid, dnnl::algorithm::binary_add);
         } else if ("multiply" == op_name) {
           Binary(nid, dnnl::algorithm::binary_mul);
+        } else if ("nn.layer_norm" == op_name) {
+          LayerNorm(nid);
         } else {
           LOG(FATAL) << "Unsupported op: " << op_name;
         }
@@ -449,6 +451,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
                                                              
{DNNL_ARG_VARIANCE, var_tr}});
   }
 
+  void LayerNorm(const size_t& nid) {
+    auto node = nodes_[nid];
+
+    auto src_tr = GetInput(nid, 0);
+    auto gamma_tr = GetInput(nid, 1);
+    auto beta_tr = GetInput(nid, 2);
+    auto dst_tr = GetOutput(nid, 0);
+
+    auto axis = GetNodeAttr<int>(node, "axis");
+    auto epsilon = GetNodeAttr<float>(node, "epsilon");
+    auto center = GetNodeAttr<bool>(node, "center");
+    auto scale = GetNodeAttr<bool>(node, "scale");
+
+    ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case";
+
+    // LN description.
+    auto lnorm_desc = dnnl::layer_normalization_forward::desc(
+        dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
+        dnnl::normalization_flags::use_scale_shift);
+
+    auto lnorm_prim_desc = 
dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);
+
+    // Concatenate scale and shift tensors
+    auto scale_shift_tr = 
TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid());
+    auto sc_sh_dims = scale_shift_tr.dims();
+
+    ICHECK(sc_sh_dims.size() == 2);
+    ICHECK(sc_sh_dims[0] == 2);
+    sc_sh_dims[0] /= 2;
+    auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
+    auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();
+
+    auto register_copy = [this](const TensorRequisite& src, const 
TensorRequisite& dst) {
+      dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_, 
dst.desc());
+      Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST, 
dst}});
+    };
+
+    register_copy(gamma_tr, scale_tr);
+    register_copy(beta_tr, shift_tr);
+
+    Submit(
+        dnnl::layer_normalization_forward(lnorm_prim_desc),
+        {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr}, 
{DNNL_ARG_SCALE_SHIFT, scale_shift_tr}});
+  }
+
   void Pooling(const size_t& nid, dnnl::algorithm algo) {
     auto node = nodes_[nid];
 
diff --git a/tests/python/contrib/test_dnnl.py 
b/tests/python/contrib/test_dnnl.py
index babfad4a0c..3e4e831aa5 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
                             with tvm.transform.PassContext(opt_level=3):
                                 mod = alter_layout_seq(mod)
 
+    mod = dnnl.rewrite_layer_norm(mod)
+
     byoc_seq = tvm.transform.Sequential(
         [
             transform.MergeComposite(dnnl.pattern_table()),
@@ -454,6 +456,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), 
k_shape=(16, 32, 3, 3), dtype
     return relay.nn.relu(conv2d_bias_bn), dic, param_lst
 
 
+def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
+    dic = {"input": x_shape}
+    param_lst = []
+    input = relay.var("input", shape=x_shape)
+    beta = relay.const(np.zeros(x_shape[2]).astype(dtype))
+    gamma = relay.const(np.ones(x_shape[2]).astype(dtype))
+    out = relay.nn.layer_norm(input, gamma=gamma, beta=beta)
+    return out, dic, param_lst
+
+
 def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), 
dtype="float32"):
     conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, 
dtype=dtype)
     sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
@@ -1032,5 +1044,14 @@ def test_prune_dnnl_subgraph(run_module):
     run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module, 
test_bf16=False)
 
 
+def test_layer_norm(run_module, dtype="float32"):
+    x_shape = (1, 49, 64)
+
+    ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype)
+    ln = tvm.IRModule.from_expr(ln)
+    config = ln, dic, param_lst
+    run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to