ganler commented on code in PR #14312:
URL: https://github.com/apache/tvm/pull/14312#discussion_r1137813441
##########
tests/python/relax/test_dataflow_pattern.py:
##########
@@ -888,5 +888,116 @@ def simple_chain(x: R.Tensor((32, 32), "float32")) ->
R.Tensor:
assert not ctx1.match_dfb(simple_chain.body.blocks[0])
+def test_rewrite_simple():
+ @R.function
+ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16),
"float32"):
+ with R.dataflow():
+ x2 = R.add(x, x)
+ x4 = R.add(x2, x2)
+ R.output(x4)
+ return x4
+
+ @R.function
+ def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2,
"float32"))
+ x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv,
R.const(2, "float32"))
+ R.output(x4)
+ return x4
+
+ @R.function
+ def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ with R.dataflow():
+ x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4,
"float32"))
+ R.output(x4)
+ return x4
+
+ x = wildcard()
+ pattern = is_op("relax.add")(x, x)
+
+ def callback(_, matchings):
+ return R.multiply(matchings[x], R.const(2, "float32"))
+
+ rewritten = rewrite(pattern, callback, main)
+ tvm.ir.assert_structural_equal(rewritten, expected1)
+
+ add1 = is_op("relax.add")(x, x)
+ pattern = is_op("relax.add")(add1, add1)
+
+ def callback(_, matchings):
+ return R.multiply(matchings[x], R.const(4, "float32"))
+
+ rewritten = rewrite(pattern, callback, main)
+ tvm.ir.assert_structural_equal(rewritten, expected2)
+
+
+def test_rewrite_attention():
+ @R.function
+ def main(
+ Q: R.Tensor((2, 4096, 8, 40), "float32"),
+ K: R.Tensor((2, 4096, 8, 40), "float32"),
+ V: R.Tensor((2, 4096, 8, 40), "float32"),
+ ) -> R.Tensor((2, 4096, 8, 40), "float32"):
+ with R.dataflow():
+ lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
+ lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
+
+ lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
+ lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
+
+ lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
+ lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
+
+ lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
+ lv3_1 = R.matmul(lv59, lv62_transposed)
+ lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
+ lv69 = R.nn.softmax(lv68, axis=-1)
+ lv_3 = R.matmul(lv69, lv65)
+
+ lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
+ lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
+ R.output(lv72)
+
+ return lv72
+
+ @R.function
+ def expected(
+ Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
+ with R.dataflow():
+ lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") =
R.nn.attention(Q, V, K)
+ R.output(lv72)
+ return lv72
+
+ def BSNH_to_BSH(tensor):
+ return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor),
wildcard())
+
+ def BSH_to_BSNH(tensor):
+ return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor,
wildcard()))
+
+ Q = wildcard()
+ K = wildcard()
+ V = wildcard()
+
+ Q_3D = BSNH_to_BSH(Q)
+ V_3D = BSNH_to_BSH(V)
+ K_3D = BSNH_to_BSH(K)
+
+ matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
+ multiply = is_op("relax.multiply")(matmul1, is_const())
+ softmax = is_op("relax.nn.softmax")(multiply)
+ matmul2 = is_op("relax.matmul")(softmax, K_3D)
+
+ pattern = BSH_to_BSNH(matmul2)
+
+ def callback(_, matchings):
+ return R.nn.attention(matchings[Q], matchings[K], matchings[V])
Review Comment:
Can we have an example where the first argument of callback is actually used?
##########
python/tvm/relax/dpl/pattern.py:
##########
@@ -1115,3 +1115,30 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
return is_op(activation)(out)
return out
+
+
+def rewrite(
+ pattern: DFPattern, callback: Callable[[Expr, Dict[DFPattern, Expr]],
Expr], func: Function
+) -> Function:
+ """
+ Rewrite a function with the given pattern and the callback.
+
+ Parameters
+ ----------
+ pattern: DFPattern
+ The pattern to match.
+
+ callback: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+ The function to be called on a successful matching for rewriting.
Given the matching
Review Comment:
nit: "Given the matched call node"
##########
python/tvm/relax/dpl/pattern.py:
##########
@@ -1115,3 +1115,30 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
return is_op(activation)(out)
return out
+
+
+def rewrite(
+ pattern: DFPattern, callback: Callable[[Expr, Dict[DFPattern, Expr]],
Expr], func: Function
Review Comment:
Maybe using `rewriter` over `callback` would make the interface more
intuitive?
##########
python/tvm/relax/dpl/pattern.py:
##########
@@ -1115,3 +1115,30 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
return is_op(activation)(out)
return out
+
+
+def rewrite(
+ pattern: DFPattern, callback: Callable[[Expr, Dict[DFPattern, Expr]],
Expr], func: Function
+) -> Function:
+ """
+ Rewrite a function with the given pattern and the callback.
+
+ Parameters
+ ----------
+ pattern: DFPattern
+ The pattern to match.
+
+ callback: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+ The function to be called on a successful matching for rewriting.
Given the matching
+ call node and the map of patterns and matched expressions, it should
return a new call node
+ or the original matched call node as is.
Review Comment:
Maybe explicitly saying "returns a new call node ... **to replace the
original given call node**" would make things more clear.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]