masahi commented on code in PR #14312:
URL: https://github.com/apache/tvm/pull/14312#discussion_r1137838295
##########
tests/python/relax/test_dataflow_pattern.py:
##########
@@ -888,5 +888,116 @@ def simple_chain(x: R.Tensor((32, 32), "float32")) ->
R.Tensor:
assert not ctx1.match_dfb(simple_chain.body.blocks[0])
+def test_rewrite_simple():
+ @R.function
+ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16),
"float32"):
+ with R.dataflow():
+ x2 = R.add(x, x)
+ x4 = R.add(x2, x2)
+ R.output(x4)
+ return x4
+
+ @R.function
+ def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2,
"float32"))
+ x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv,
R.const(2, "float32"))
+ R.output(x4)
+ return x4
+
+ @R.function
+ def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ with R.dataflow():
+ x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4,
"float32"))
+ R.output(x4)
+ return x4
+
+ x = wildcard()
+ pattern = is_op("relax.add")(x, x)
+
+ def callback(_, matchings):
+ return R.multiply(matchings[x], R.const(2, "float32"))
+
+ rewritten = rewrite(pattern, callback, main)
+ tvm.ir.assert_structural_equal(rewritten, expected1)
+
+ add1 = is_op("relax.add")(x, x)
+ pattern = is_op("relax.add")(add1, add1)
+
+ def callback(_, matchings):
+ return R.multiply(matchings[x], R.const(4, "float32"))
+
+ rewritten = rewrite(pattern, callback, main)
+ tvm.ir.assert_structural_equal(rewritten, expected2)
+
+
+def test_rewrite_attention():
+ @R.function
+ def main(
+ Q: R.Tensor((2, 4096, 8, 40), "float32"),
+ K: R.Tensor((2, 4096, 8, 40), "float32"),
+ V: R.Tensor((2, 4096, 8, 40), "float32"),
+ ) -> R.Tensor((2, 4096, 8, 40), "float32"):
+ with R.dataflow():
+ lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
+ lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
+
+ lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
+ lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
+
+ lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
+ lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
+
+ lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
+ lv3_1 = R.matmul(lv59, lv62_transposed)
+ lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
+ lv69 = R.nn.softmax(lv68, axis=-1)
+ lv_3 = R.matmul(lv69, lv65)
+
+ lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
+ lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
+ R.output(lv72)
+
+ return lv72
+
+ @R.function
+ def expected(
+ Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
+ with R.dataflow():
+ lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") =
R.nn.attention(Q, V, K)
+ R.output(lv72)
+ return lv72
+
+ def BSNH_to_BSH(tensor):
+ return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor),
wildcard())
+
+ def BSH_to_BSNH(tensor):
+ return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor,
wildcard()))
+
+ Q = wildcard()
+ K = wildcard()
+ V = wildcard()
+
+ Q_3D = BSNH_to_BSH(Q)
+ V_3D = BSNH_to_BSH(V)
+ K_3D = BSNH_to_BSH(K)
+
+ matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
+ multiply = is_op("relax.multiply")(matmul1, is_const())
+ softmax = is_op("relax.nn.softmax")(multiply)
+ matmul2 = is_op("relax.matmul")(softmax, K_3D)
+
+ pattern = BSH_to_BSNH(matmul2)
+
+ def callback(_, matchings):
+ return R.nn.attention(matchings[Q], matchings[K], matchings[V])
Review Comment:
done in
https://github.com/apache/tvm/pull/14312/commits/5bb858aec4cda8dde9e612e09ddc08f239d90564
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]