apeskov commented on code in PR #11513:
URL: https://github.com/apache/tvm/pull/11513#discussion_r893215513
##########
python/tvm/relay/op/contrib/dnnl.py:
##########
@@ -594,3 +611,104 @@ def rewrite_layer_norm(mod):
"""
mod["main"] = rewrite(LayerNormRewrite(), mod["main"])
return mod
+
+
+class DenseReshapeBiasGeluRewrite(DFPatternCallback):
+ """
+ A callback to reorder reshape operators when the patterns are as below:
+
+ Pattern #1:
+ 1 %62 = nn.dense(%61, meta[relay.Constant][13] /* ty=Tensor[(64, 64),
float32] */,
+ units=None, out_dtype="float32") /* ty=Tensor[(3136, 64),
float32] */;
+ 2 %63 = reshape(%62, newshape=[1, 3136, 64]) /* ty=Tensor[(1, 3136, 64),
float32] */;
+ 3 %64 = add(meta[relay.Constant][4] /* ty=Tensor[(64), float32] */, %63)
+ /* ty=Tensor[(1, 3136, 64), float32] */;
+
+ Pattern #2:
+ 1 %76 = nn.dense(%75, meta[relay.Constant][18] /* ty=Tensor[(512, 64),
float32] */,
+ units=None, out_dtype="float32") /* ty=Tensor[(3136, 512),
float32] */;
+ 2 %77 = reshape(%76, newshape=[1, 3136, 512]) /* ty=Tensor[(1, 3136,
512), float32] */;
+ 3 %78 = add(meta[relay.Constant][15] /* ty=Tensor[(512), float32] */,
%77)
+ /* ty=Tensor[(1, 3136, 512), float32] */;
+ 4 %79 = divide(%78, 1.41421f /* ty=float32 */) /* ty=Tensor[(1, 3136,
512), float32] */;
+ 5 %80 = erf(%79) /* ty=Tensor[(1, 3136, 512), float32] */;
+ 6 %81 = add(%80, 1f /* ty=float32 */) /* ty=Tensor[(1, 3136, 512),
float32] */;
+ 7 %82 = multiply(%78, %81) /* ty=Tensor[(1, 3136, 512), float32] */;
+ 8 %83 = multiply(%82, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3136,
512), float32] */;
+ """
+
+ def __init__(self, has_gelu=True):
+ super(DenseReshapeBiasGeluRewrite, self).__init__()
+ self.data = wildcard()
+ self.weight = wildcard()
+ self.bias = wildcard()
+ self.const1 = wildcard()
+ self.const2 = wildcard()
+ self.const3 = wildcard()
+
+ self.attr_map = {}
+ self.has_gelu = has_gelu
+
+ den = is_op("nn.dense")(self.data, self.weight)
+ re_den = is_op("reshape")(den)
+ added = is_op("add")(self.bias, re_den)
+ if self.has_gelu:
+ divisor = is_op("divide")(added, self.const1)
+ val_erf = is_op("erf")(divisor)
+ added_erf = is_op("add")(val_erf, self.const2)
+ mul1 = is_op("multiply")(added, added_erf)
+ mul2 = is_op("multiply")(mul1, self.const3)
+ self.pattern = mul2
+ else:
+ self.pattern = added
+
+ def get_attr(self, pre):
+ """Recursively retrieve attributes from reshape operator."""
+
+ def visit_func(expr):
+ if isinstance(expr, _expr.Call) and expr.op ==
relay.op.get("reshape"):
+ new_attrs = {}
+ for k in expr.attrs.keys():
+ new_attrs[k] = expr.attrs[k]
+ self.attr_map["reshape"] = new_attrs
+
+ _analysis.post_order_visit(pre, visit_func)
+
+ def callback(self, pre, post, node_map):
+ self.get_attr(pre)
+
+ data = node_map[self.data][0]
+ weight = node_map[self.weight][0]
+ bias = node_map[self.bias][0]
+
+ den = relay.op.nn.dense(data, weight)
+ added = relay.op.add(bias, den)
+ if not self.has_gelu:
+ return relay.op.reshape(added,
self.attr_map["reshape"]["newshape"])
+
+ const1 = node_map[self.const1][0]
+ const2 = node_map[self.const2][0]
+ const3 = node_map[self.const3][0]
+
+ divisor = relay.op.divide(added, const1)
+ val_erf = relay.op.erf(divisor)
+ added_erf = relay.op.add(val_erf, const2)
+ mul1 = relay.op.multiply(added, added_erf)
+ mul2 = relay.op.multiply(mul1, const3)
+ return relay.op.reshape(mul2, self.attr_map["reshape"]["newshape"])
+
+
+def rewrite_dense_bias_gelu_reshape_last(mod):
+ """Rewrite the input graph to reorder reshape operators so that
+ we can perform dense_bias_gelu fusion and then offload them to byoc part.
+ """
+ mod["main"] = rewrite(DenseReshapeBiasGeluRewrite(), mod["main"])
+ return mod
+
+
+def rewrite_dense_bias_reshape_last(mod):
Review Comment:
Looks like duplicated identical to previous function like
"rewrite_reshape_for_dense".
You may combine it into single function with using list of DFPatternCallback:
`rewrite([DenseReshapeBiasGeluRewrite(),
DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"])`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]