This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ce7fcbdae3b28698fc37513cb3e3d65bb3c120b0 Author: Andrew Zhao Luo <[email protected]> AuthorDate: Thu Sep 1 21:46:53 2022 -0700 dnnl pattern matching --- python/tvm/relay/op/contrib/dnnl.py | 64 +++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index f7752e41b0..e27449ac43 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -36,22 +36,18 @@ import logging from functools import reduce import tvm.ir -from tvm.ir import Op from tvm import relay +from tvm.ir import Op +from tvm.relay import expr as _expr from tvm.relay import transform -from tvm.relay.expr import GlobalVar -from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import const - from tvm.relay.analysis import analysis as _analysis -from tvm.relay import expr as _expr +from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const +from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import Call, TupleGetItem from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback +from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard from .register import register_pattern_table - logger = logging.getLogger("DNNL") supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None] @@ -809,7 +805,7 @@ def prune_dnnl_subgraphs(mod): return new_mod -class LayerNormRewrite(DFPatternCallback): +class LayerNormRewritePattern1(DFPatternCallback): """ A callback to rewrite the following operators into a single layer normalization operator. @@ -826,7 +822,42 @@ class LayerNormRewrite(DFPatternCallback): /* ty=Tensor[(1, 3136, 64), float32] */; 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */; + """ + + def __init__(self): + super(LayerNormRewritePattern1, self).__init__() + self.data = wildcard() + self.gamma = wildcard() + self.beta = wildcard() + mu = is_op("mean")(self.data) + diff = is_op("subtract")(self.data, mu) + cdiff = is_op("cast")(diff) + const_two = ( + is_expr(relay.const(2)) + | is_expr(relay.const(2.0)) + | is_expr(relay.const(2.0, dtype="float16")) + ) + p1 = is_op("power")(cdiff, const_two) + mp1 = is_op("mean")(p1) + eps = is_constant() # TODO: check epsilon is something reasonable + added_eps = is_op("add")(mp1, eps) + deno = is_op("sqrt")(added_eps) + div_out = is_op("divide")(diff, deno) + div_out2 = diff * is_op("rsqrt")(added_eps) + weighted = is_op("multiply")(div_out | div_out2, self.gamma) + added_bias = is_op("add")(weighted, self.beta) + self.pattern = added_bias + def callback(self, pre, post, node_map): + data = node_map[self.data][0] + gamma = node_map[self.gamma][0] + beta = node_map[self.beta][0] + return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta) + + +class LayerNormRewritePattern2(DFPatternCallback): + """ + A callback to rewrite the following operators into a single layer normalization operator. Pattern #2: 1 %0 = mean(%input, axis=[-1], keepdims=True); 2 %1 = variance(%input, %0, axis=[-1], keepdims=True); @@ -842,19 +873,16 @@ class LayerNormRewrite(DFPatternCallback): """ def __init__(self): - super(LayerNormRewrite, self).__init__() + super(LayerNormRewritePattern2, self).__init__() self.data = wildcard() self.gamma = wildcard() self.beta = wildcard() mu = is_op("mean")(self.data) - diff = is_op("subtract")(self.data, mu) - cdiff = diff | is_op("cast")(diff) - const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0)) - p1 = is_op("power")(cdiff, const_two) - mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu) + mp1 = is_op("variance")(self.data, mu) eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6)) added_eps = is_op("add")(mp1, eps) deno = is_op("sqrt")(added_eps) + diff = is_op("subtract")(self.data, mu) div_out = is_op("divide")(diff, deno) div_out2 = diff * is_op("rsqrt")(added_eps) weighted = is_op("multiply")(div_out | div_out2, self.gamma) @@ -872,7 +900,9 @@ def rewrite_layer_norm(mod): """Rewrite the input graph to replace multiple operators with a TVM native layer normalization operator so that we can offload them to dnnl layer normalization byoc part. """ - mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) + mod["main"] = rewrite(LayerNormRewritePattern1(), mod["main"]) + mod["main"] = rewrite(LayerNormRewritePattern2(), mod["main"]) + return mod
