This is an automated email from the ASF dual-hosted git repository.

cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 78a1f80bf2 [CODEGEN] Vector-Codegen support for llvm-pure-intrin 
(#16985)
78a1f80bf2 is described below

commit 78a1f80bf24f1a1114f2ed7d17563d267bb38cc9
Author: rutkoor <[email protected]>
AuthorDate: Tue Jun 4 14:24:36 2024 +0530

    [CODEGEN] Vector-Codegen support for llvm-pure-intrin (#16985)
    
    * Vector-Codegen support for llvm-pure-intrin
---
 src/tir/op/builtin.cc                              |  3 +-
 src/tir/transforms/vectorize_loop.cc               | 23 ++++++++-
 .../tir-transform/test_tir_transform_vectorize.py  | 58 ++++++++++++++++++++++
 .../python/tvmscript/test_tvmscript_printer_tir.py | 21 ++++++++
 4 files changed, 103 insertions(+), 2 deletions(-)

diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index cf82eb07ed..67d01aa923 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
 TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
Integer(ScriptDtypePrintLocation::kFirst));
+                                         
Integer(ScriptDtypePrintLocation::kFirst))
+    .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index 63569f342a..b4e3d67e50 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -550,7 +550,28 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       }
     } else {
       int lane = 0;
-      Array<PrimExpr> new_args = MutateArray(op->args, &lane);
+      Array<PrimExpr> new_args;
+      if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
+        // op->args[1], will give us total number of arguments to intrinsic
+        int num_signature = Downcast<IntImm>(op->args[1])->value;
+        Array<PrimExpr> op_expr_args;
+        for (int i = 0; i < num_signature; i++) {
+          // Collect all intrinsic arguments
+          op_expr_args.push_back(op->args[i + 2]);
+        }
+        // Generate RAMP nodes for intrinsic arguments
+        Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
+        // Collect Intrinsic ID and no. of argument
+        for (int i = 0; i < 2; i++) {
+          new_args.push_back(op->args[i]);
+        }
+        // Collect updated intrinsic arguments
+        for (int i = 0; i < num_signature; i++) {
+          new_args.push_back(updated_args[i]);
+        }
+      } else {
+        new_args = MutateArray(op->args, &lane);
+      }
       // normal code path.
       if (op->args.same_as(new_args)) {
         return GetRef<PrimExpr>(op);
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py 
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 7523cab549..9659d896ae 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -790,5 +790,63 @@ def 
test_vectorize_and_predicate_buffer_load_stores_with_sve_attr_scope_target()
     tvm.ir.assert_structural_equal(after, expected)
 
 
[email protected](
+    "extent, vec_str, target",
+    [(4, "float32x4", simple_target)],
+)
+def test_vectorize_llvm_pure_intrin(extent, vec_str, target):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+            for j in T.vectorized(extent):
+                A[j] = T.call_llvm_pure_intrin(
+                    "float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
+                )
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+            A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
+                vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, 
extent)]
+            )
+
+    with tvm.target.Target(target):
+        mod = tvm.tir.transform.VectorizeLoop()(Before)
+        tvm.ir.assert_structural_equal(mod, After)
+        mod = tvm.build(mod, target)
+
+
[email protected](
+    "extent, vec_str, target",
+    [(4, "int32x4", simple_target)],
+)
+def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+            for j in T.vectorized(extent):
+                A[j] = T.call_llvm_pure_intrin(
+                    "int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
+                )
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+            A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
+                vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 
1, extent)]
+            )
+
+    with pytest.raises(Exception) as e_info:
+        with tvm.target.Target(target):
+            mod = tvm.tir.transform.VectorizeLoop()(Before)
+            ex = tvm.build(mod, target)
+    tvm.ir.assert_structural_equal(mod, After)
+    assert "Intrinsic does not support vectors" in e_info.value.args[0]
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 9e77fa0900..8364e65a41 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -1045,5 +1045,26 @@ def main(A: T.Buffer((128,), "float32"), B: 
T.Buffer((128,), "float32")):
     _assert_print(main, expected_output)
 
 
+def test_vectorize_llvm_pure_intrin():
+    from tvm.script import tir as T
+
+    @T.prim_func
+    def main(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (4,), "float32")
+        B = T.match_buffer(b, (4,), "float32")
+        A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
+            "float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
+        )
+
+    expected_output = """
+# from tvm.script import tir as T
+
[email protected]_func
+def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
+    A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
+    """
+    _assert_print(main, expected_output)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to