This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 78a1f80bf2 [CODEGEN] Vector-Codegen support for llvm-pure-intrin
(#16985)
78a1f80bf2 is described below
commit 78a1f80bf24f1a1114f2ed7d17563d267bb38cc9
Author: rutkoor <[email protected]>
AuthorDate: Tue Jun 4 14:24:36 2024 +0530
[CODEGEN] Vector-Codegen support for llvm-pure-intrin (#16985)
* Vector-Codegen support for llvm-pure-intrin
---
src/tir/op/builtin.cc | 3 +-
src/tir/transforms/vectorize_loop.cc | 23 ++++++++-
.../tir-transform/test_tir_transform_vectorize.py | 58 ++++++++++++++++++++++
.../python/tvmscript/test_tvmscript_printer_tir.py | 21 ++++++++
4 files changed, 103 insertions(+), 2 deletions(-)
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index cf82eb07ed..67d01aa923 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-
Integer(ScriptDtypePrintLocation::kFirst));
+
Integer(ScriptDtypePrintLocation::kFirst))
+ .set_attr<TVectorizable>("TVectorizable", true);
TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index 63569f342a..b4e3d67e50 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -550,7 +550,28 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
}
} else {
int lane = 0;
- Array<PrimExpr> new_args = MutateArray(op->args, &lane);
+ Array<PrimExpr> new_args;
+ if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
+ // op->args[1], will give us total number of arguments to intrinsic
+ int num_signature = Downcast<IntImm>(op->args[1])->value;
+ Array<PrimExpr> op_expr_args;
+ for (int i = 0; i < num_signature; i++) {
+ // Collect all intrinsic arguments
+ op_expr_args.push_back(op->args[i + 2]);
+ }
+ // Generate RAMP nodes for intrinsic arguments
+ Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
+ // Collect Intrinsic ID and no. of argument
+ for (int i = 0; i < 2; i++) {
+ new_args.push_back(op->args[i]);
+ }
+ // Collect updated intrinsic arguments
+ for (int i = 0; i < num_signature; i++) {
+ new_args.push_back(updated_args[i]);
+ }
+ } else {
+ new_args = MutateArray(op->args, &lane);
+ }
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 7523cab549..9659d896ae 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -790,5 +790,63 @@ def
test_vectorize_and_predicate_buffer_load_stores_with_sve_attr_scope_target()
tvm.ir.assert_structural_equal(after, expected)
[email protected](
+ "extent, vec_str, target",
+ [(4, "float32x4", simple_target)],
+)
+def test_vectorize_llvm_pure_intrin(extent, vec_str, target):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+ for j in T.vectorized(extent):
+ A[j] = T.call_llvm_pure_intrin(
+ "float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
+ )
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+ A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
+ vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1,
extent)]
+ )
+
+ with tvm.target.Target(target):
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+ mod = tvm.build(mod, target)
+
+
[email protected](
+ "extent, vec_str, target",
+ [(4, "int32x4", simple_target)],
+)
+def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+ for j in T.vectorized(extent):
+ A[j] = T.call_llvm_pure_intrin(
+ "int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
+ )
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+ A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
+ vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0,
1, extent)]
+ )
+
+ with pytest.raises(Exception) as e_info:
+ with tvm.target.Target(target):
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ ex = tvm.build(mod, target)
+ tvm.ir.assert_structural_equal(mod, After)
+ assert "Intrinsic does not support vectors" in e_info.value.args[0]
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 9e77fa0900..8364e65a41 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B:
T.Buffer((128,), "float32")):
_assert_print(main, expected_output)
+def test_vectorize_llvm_pure_intrin():
+ from tvm.script import tir as T
+
+ @T.prim_func
+ def main(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (4,), "float32")
+ B = T.match_buffer(b, (4,), "float32")
+ A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
+ "float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
+ )
+
+ expected_output = """
+# from tvm.script import tir as T
+
[email protected]_func
+def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
+ A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
+ """
+ _assert_print(main, expected_output)
+
+
if __name__ == "__main__":
tvm.testing.main()