This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9817338508 [BYOC][DNNL] Enable layer normalization in DNNL byoc.
(#11508)
9817338508 is described below
commit 9817338508f3f8cd5a444133b4de99ce577c031b
Author: billishyahao <[email protected]>
AuthorDate: Thu Jun 9 03:12:36 2022 +0800
[BYOC][DNNL] Enable layer normalization in DNNL byoc. (#11508)
* Enable layer normalization in DNNL byoc.
* Added unittest for layer norm and make code compatible after introducing
TensorRequisite(PR-11345)
* Fix lint issue
* Fix clang format issue
---
python/tvm/relay/op/contrib/dnnl.py | 70 ++++++++++++++++++++++++++-
src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 47 ++++++++++++++++++
tests/python/contrib/test_dnnl.py | 21 ++++++++
3 files changed, 137 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/op/contrib/dnnl.py
b/python/tvm/relay/op/contrib/dnnl.py
index 2e975cf49c..c87a7162b0 100644
--- a/python/tvm/relay/op/contrib/dnnl.py
+++ b/python/tvm/relay/op/contrib/dnnl.py
@@ -41,7 +41,7 @@ from tvm.relay.expr import GlobalVar
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
from ... import _ffi_api
-from ...dataflow_pattern import wildcard, is_op
+from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite,
DFPatternCallback
from .register import register_pattern_table
logger = logging.getLogger("DNNL")
@@ -92,6 +92,7 @@ _register_external_op_helper("sigmoid")
_register_external_op_helper("nn.softmax")
_register_external_op_helper("add")
_register_external_op_helper("multiply")
+_register_external_op_helper("nn.layer_norm")
def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
@@ -455,6 +456,7 @@ class IsComputeIntensiveGraph(ExprVisitor):
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
+ "nn.layer_norm",
]
)
if isinstance(call.op, tvm.tir.op.Op):
@@ -526,3 +528,69 @@ def prune_dnnl_subgraphs(mod):
new_mod["main"] = SubgraphRemover(subgraphs_to_remove, mod,
new_mod).visit(mod["main"])
new_mod = transform.RemoveUnusedFunctions()(new_mod)
return new_mod
+
+
+class LayerNormRewrite(DFPatternCallback):
+ """
+ A callback to rewrite the following operators into a single layer
normalization operator.
+
+ Pattern #1:
+ 1 %4 = mean(%3, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1),
float32] */;
+ 2 %5 = subtract(%3, %4) /* ty=Tensor[(1, 3136, 64), float32] */;
+ 3 %6 = cast(%5, dtype="float32") /* ty=Tensor[(1, 3136, 64), float32] */;
+ 4 %7 = power(%6, 2f /* ty=float32 */) /* ty=Tensor[(1, 3136, 64),
float32] */;
+ 5 %8 = mean(%7, axis=[-1], keepdims=True) /* ty=Tensor[(1, 3136, 1),
float32] */;
+ 6 %9 = add(%8, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 3136, 1),
float32] */;
+ 7 %10 = sqrt(%9) /* ty=Tensor[(1, 3136, 1), float32] */;
+ 8 %11 = divide(%5, %10) /* ty=Tensor[(1, 3136, 64), float32] */;
+ 9 %12 = multiply(%11, meta[relay.Constant][2] /* ty=Tensor[(64),
float32] */)
+ /* ty=Tensor[(1, 3136, 64), float32] */;
+ 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */)
+ /* ty=Tensor[(1, 3136, 64), float32] */;
+
+ Pattern #2:
+ 1 %0 = mean(%input, axis=[-1], keepdims=True);
+ 2 %1 = variance(%input, %0, axis=[-1], keepdims=True);
+ 3 %2 = add(%1, 1e-05f /* ty=float32 */) /* ty=Tensor[(1, 49, 1),
float32] */;
+ 4 %3 = subtract(%input, %0);
+ 5 %4 = sqrt(%2) /* ty=Tensor[(1, 49, 1), float32] */;
+ 6 %5 = divide(%3, %4);
+ 7 %6 = multiply(%5, meta[relay.Constant][0] /* ty=Tensor[(64), float32]
*/)
+ /* ty=Tensor[(1, 49, 64), float32] */;
+ 8 %7 = add(%6, meta[relay.Constant][1] /* ty=Tensor[(64), float32] */)
+ /* ty=Tensor[(1, 49, 64), float32] */
+
+ """
+
+ def __init__(self):
+ super(LayerNormRewrite, self).__init__()
+ self.data = wildcard()
+ self.gamma = wildcard()
+ self.beta = wildcard()
+ mu = is_op("mean")(self.data)
+ diff = is_op("subtract")(self.data, mu)
+ cdiff = diff | is_op("cast")(diff)
+ const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0))
+ p1 = is_op("power")(cdiff, const_two)
+ mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu)
+ eps = is_expr(relay.const(1e-5))
+ added_eps = is_op("add")(mp1, eps)
+ deno = is_op("sqrt")(added_eps)
+ div_out = is_op("divide")(diff, deno)
+ weighted = is_op("multiply")(div_out, self.gamma)
+ added_bias = is_op("add")(weighted, self.beta)
+ self.pattern = added_bias
+
+ def callback(self, pre, post, node_map):
+ data = node_map[self.data][0]
+ gamma = node_map[self.gamma][0]
+ beta = node_map[self.beta][0]
+ return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta)
+
+
+def rewrite_layer_norm(mod):
+ """Rewrite the input graph to replace multiple operators with a TVM native
layer normalization
+ operator so that we can offload them to dnnl layer normalization byoc part.
+ """
+ mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
+ return mod
diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
index a2417f012e..db8f25e2a6 100644
--- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
+++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc
@@ -203,6 +203,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
Binary(nid, dnnl::algorithm::binary_add);
} else if ("multiply" == op_name) {
Binary(nid, dnnl::algorithm::binary_mul);
+ } else if ("nn.layer_norm" == op_name) {
+ LayerNorm(nid);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
@@ -449,6 +451,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
{DNNL_ARG_VARIANCE, var_tr}});
}
+ void LayerNorm(const size_t& nid) {
+ auto node = nodes_[nid];
+
+ auto src_tr = GetInput(nid, 0);
+ auto gamma_tr = GetInput(nid, 1);
+ auto beta_tr = GetInput(nid, 2);
+ auto dst_tr = GetOutput(nid, 0);
+
+ auto axis = GetNodeAttr<int>(node, "axis");
+ auto epsilon = GetNodeAttr<float>(node, "epsilon");
+ auto center = GetNodeAttr<bool>(node, "center");
+ auto scale = GetNodeAttr<bool>(node, "scale");
+
+ ICHECK(axis == -1 && center && scale) << "Unimplemented LayerNorm case";
+
+ // LN description.
+ auto lnorm_desc = dnnl::layer_normalization_forward::desc(
+ dnnl::prop_kind::forward_inference, src_tr.desc(), epsilon,
+ dnnl::normalization_flags::use_scale_shift);
+
+ auto lnorm_prim_desc =
dnnl::layer_normalization_forward::primitive_desc(lnorm_desc, engine_);
+
+ // Concatenate scale and shift tensors
+ auto scale_shift_tr =
TensorRequisite::AsIs(lnorm_prim_desc.weights_desc(), GenUniqueEid());
+ auto sc_sh_dims = scale_shift_tr.dims();
+
+ ICHECK(sc_sh_dims.size() == 2);
+ ICHECK(sc_sh_dims[0] == 2);
+ sc_sh_dims[0] /= 2;
+ auto scale_tr = scale_shift_tr.Crop(sc_sh_dims, {0, 0}).Squeeze();
+ auto shift_tr = scale_shift_tr.Crop(sc_sh_dims, {1, 0}).Squeeze();
+
+ auto register_copy = [this](const TensorRequisite& src, const
TensorRequisite& dst) {
+ dnnl::reorder::primitive_desc copy_pd(engine_, src.desc(), engine_,
dst.desc());
+ Submit(dnnl::reorder(copy_pd), {{DNNL_ARG_SRC, src}, {DNNL_ARG_DST,
dst}});
+ };
+
+ register_copy(gamma_tr, scale_tr);
+ register_copy(beta_tr, shift_tr);
+
+ Submit(
+ dnnl::layer_normalization_forward(lnorm_prim_desc),
+ {{DNNL_ARG_SRC, src_tr}, {DNNL_ARG_DST, dst_tr},
{DNNL_ARG_SCALE_SHIFT, scale_shift_tr}});
+ }
+
void Pooling(const size_t& nid, dnnl::algorithm algo) {
auto node = nodes_[nid];
diff --git a/tests/python/contrib/test_dnnl.py
b/tests/python/contrib/test_dnnl.py
index babfad4a0c..3e4e831aa5 100755
--- a/tests/python/contrib/test_dnnl.py
+++ b/tests/python/contrib/test_dnnl.py
@@ -111,6 +111,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
with tvm.transform.PassContext(opt_level=3):
mod = alter_layout_seq(mod)
+ mod = dnnl.rewrite_layer_norm(mod)
+
byoc_seq = tvm.transform.Sequential(
[
transform.MergeComposite(dnnl.pattern_table()),
@@ -454,6 +456,16 @@ def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8),
k_shape=(16, 32, 3, 3), dtype
return relay.nn.relu(conv2d_bias_bn), dic, param_lst
+def get_layer_norm(x_shape=(1, 49, 64), dtype="float32"):
+ dic = {"input": x_shape}
+ param_lst = []
+ input = relay.var("input", shape=x_shape)
+ beta = relay.const(np.zeros(x_shape[2]).astype(dtype))
+ gamma = relay.const(np.ones(x_shape[2]).astype(dtype))
+ out = relay.nn.layer_norm(input, gamma=gamma, beta=beta)
+ return out, dic, param_lst
+
+
def get_conv2d_bias_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3),
dtype="float32"):
conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape,
dtype=dtype)
sum_data = relay.const(np.random.randint(x_shape).astype(dtype))
@@ -1032,5 +1044,14 @@ def test_prune_dnnl_subgraph(run_module):
run_and_verify_func(get_graph(), subgraph_num=1, run_module=run_module,
test_bf16=False)
+def test_layer_norm(run_module, dtype="float32"):
+ x_shape = (1, 49, 64)
+
+ ln, dic, param_lst = get_layer_norm(x_shape, dtype=dtype)
+ ln = tvm.IRModule.from_expr(ln)
+ config = ln, dic, param_lst
+ run_and_verify_func(config, run_module=run_module, dtype=dtype)
+
+
if __name__ == "__main__":
tvm.testing.main()