This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 60f5568415 [CODEGEN][REFACTOR] tir.call_llvm_intrin to remove nargs 
(#18206)
60f5568415 is described below

commit 60f5568415f176b9695230af3664b62835481b7b
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Aug 13 13:42:55 2025 -0400

    [CODEGEN][REFACTOR] tir.call_llvm_intrin to remove nargs (#18206)
    
    This PR refactors the tir.call_llvm_intrin to omit the first nargs argument 
in the beginning.
    Previously the nargs was introduced when prefetch have different number of 
signature.
    The previous reason no longer stands as of now, and it is less intuitive to 
attach nargs
    for the call_llvm_intrin, where nargs directly appears in number of 
argument.
    
    After the update, tir.call_llvm_intrin can directly pass in the arguments 
as it is.
---
 include/tvm/tir/builtin.h                          |  2 +-
 include/tvm/tir/stmt.h                             |  5 ----
 python/tvm/tir/tensor_intrin/arm_cpu.py            | 18 +--------------
 python/tvm/tir/tensor_intrin/hexagon.py            |  3 ---
 python/tvm/tir/tensor_intrin/rocm.py               |  3 ---
 python/tvm/tir/tensor_intrin/x86.py                |  3 ---
 src/target/llvm/codegen_arm.cc                     |  7 +-----
 src/target/llvm/codegen_llvm.cc                    | 27 ++++------------------
 src/target/llvm/intrin_rule_llvm.cc                |  1 -
 src/target/llvm/intrin_rule_llvm.h                 | 11 +++++++--
 src/target/llvm/llvm_instance.h                    | 15 ++++++++++++
 src/target/llvm/llvm_module.cc                     | 15 +-----------
 src/tir/transforms/vectorize_loop.cc               | 12 ++++------
 tests/python/codegen/test_target_codegen_llvm.py   |  8 +++----
 .../test_hexagon/test_async_dma_pipeline.py        |  4 ----
 .../contrib/test_hexagon/test_meta_schedule.py     |  2 +-
 .../contrib/test_hexagon/test_parallel_hvx.py      |  3 ---
 .../test_hexagon/test_parallel_hvx_load_vtcm.py    |  4 ----
 ...est_meta_schedule_postproc_rewrite_tensorize.py |  1 -
 .../test_meta_schedule_trace_apply.py              |  2 +-
 tests/python/tir-base/test_tir_ops.py              |  4 ++--
 ...t_tir_transform_lower_cross_thread_reduction.py |  2 +-
 .../tir-transform/test_tir_transform_vectorize.py  | 12 ++++------
 .../python/tvmscript/test_tvmscript_printer_tir.py | 14 +++++------
 tests/python/tvmscript/test_tvmscript_roundtrip.py |  5 ----
 25 files changed, 55 insertions(+), 128 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index b4ed44fbff..d3573c925d 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -225,7 +225,7 @@ TVM_DLL const Op& call_spirv_pure_glsl450();
 // TODO(tvm-team) revisit the builtins below
 // some of them can simply become ops with special codegen attr.
 /*!
- * \brief Prefetch a cacheline
+ * \brief same signature as llvm.prefetch
  */
 TVM_DLL const Op& prefetch();
 
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 37410b1271..bbdb7c272e 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1106,11 +1106,6 @@ constexpr const char* pragma_import_c = 
"pragma_import_c";
 constexpr const char* pragma_import_llvm = "pragma_import_llvm";
 /*! \brief Try to modify the AST to support Tensor Core */
 constexpr const char* pragma_tensor_core = "pragma_tensor_core";
-/*!
- * \brief Mark of prefetch scope, value=offset,
- *  run prefetch of Tensor on the current loop scope
- */
-constexpr const char* prefetch_scope = "prefetch_scope";
 /*!
  * \brief Marks the layout transforms to be used for a tensor.
  *
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py 
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index a6f3538846..0a5c0ea3a5 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -74,7 +74,6 @@ def neon_4x4_i8i8i32_impl(
 
         multiply_low = T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
-            T.uint32(2),
             vec_a,
             vec_b_low,
             dtype="int16x8",
@@ -82,7 +81,6 @@ def neon_4x4_i8i8i32_impl(
 
         pairwise_reduction_low = T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
-            T.uint32(1),
             multiply_low,
             dtype="int32x4",
         )
@@ -91,7 +89,6 @@ def neon_4x4_i8i8i32_impl(
 
         multiply_high = T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
-            T.uint32(2),
             vec_a,
             vec_b_high,
             dtype="int16x8",
@@ -99,14 +96,12 @@ def neon_4x4_i8i8i32_impl(
 
         pairwise_reduction_high = T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
-            T.uint32(1),
             multiply_high,
             dtype="int32x4",
         )
 
         C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
-            T.uint32(2),
             pairwise_reduction_low,
             pairwise_reduction_high,
             dtype="int32x4",
@@ -159,7 +154,6 @@ def get_dotprod_intrin(in_dtype, out_dtype):
 
             C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
                 T.llvm_lookup_intrinsic_id(f"llvm.aarch64.neon.{instr}"),
-                T.uint32(3),
                 vec_c,
                 vec_a,
                 vec_b,
@@ -311,7 +305,6 @@ def 
get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows):
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.ld1w.horiz",
-                                    T.uint32(4),
                                     predicate,
                                     input_ptr,
                                     sub_tile,
@@ -335,7 +328,6 @@ def 
get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows):
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.st1w.vert",
-                                    T.uint32(4),
                                     predicate,
                                     output_ptr,
                                     sub_tile,
@@ -438,7 +430,6 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.ld1h.horiz",
-                                    T.uint32(4),
                                     ptrue_fp16,
                                     input_ptr,
                                     sub_tile_idx,
@@ -450,7 +441,6 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.ld1h.horiz",
-                                    T.uint32(4),
                                     ptrue_fp16,
                                     input_ptr,
                                     sub_tile_idx,
@@ -467,7 +457,6 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.st1w.vert",
-                                    T.uint32(4),
                                     ptrue_fp32,
                                     output_ptr,
                                     sub_tile_idx,
@@ -479,7 +468,6 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin():
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.st1w.vert",
-                                    T.uint32(4),
                                     ptrue_fp32,
                                     output_ptr,
                                     sub_tile_idx + 2,
@@ -692,7 +680,6 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, 
in_dtype):
                                 T.call_llvm_intrin(
                                     "void",
                                     fmopa_intrin,
-                                    T.uint32(5),
                                     sub_tile,
                                     input_1[1],
                                     input_2[1],
@@ -713,7 +700,6 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, 
in_dtype):
                                 T.call_llvm_intrin(
                                     "void",
                                     "llvm.aarch64.sme.st1w.horiz",
-                                    T.uint32(4),
                                     _create_active_lane_mask(
                                         C, (vert_offset + slice_idx, 
horiz_offset), M
                                     ),
@@ -752,9 +738,7 @@ def get_sme_init_intrin():
             T.reads()
             T.writes(C[0:SVF2, 0:SVF2])
             clear_all_tiles = T.int32(255)
-            T.evaluate(
-                T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", 
T.uint32(1), clear_all_tiles)
-            )
+            T.evaluate(T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", 
clear_all_tiles))
 
     return desc, impl
 
diff --git a/python/tvm/tir/tensor_intrin/hexagon.py 
b/python/tvm/tir/tensor_intrin/hexagon.py
index 22dd9a977c..631d6b3532 100644
--- a/python/tvm/tir/tensor_intrin/hexagon.py
+++ b/python/tvm/tir/tensor_intrin/hexagon.py
@@ -107,7 +107,6 @@ def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
 
             C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
                 T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
-                T.uint32(3),
                 C[T.ramp(T.int32(0), 1, 32)],
                 B_i32x32,
                 A_i32,
@@ -149,7 +148,6 @@ def generate_dot_product_32x4_u8i8i32(mem_scope="global"):
 
             C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
                 
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
-                T.uint32(3),
                 C[T.ramp(T.int32(0), 1, 32)],
                 T.broadcast(A_i32, 32),
                 B_i32x32,
@@ -191,7 +189,6 @@ def generate_dot_product_32x2_i16i16i32(mem_scope="global"):
 
             C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
                 
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vdmpyhvsat.acc.128B"),
-                T.uint32(3),
                 C[T.ramp(T.int32(0), 1, 32)],
                 T.Broadcast(A_i32, 32),
                 B_i32x32,
diff --git a/python/tvm/tir/tensor_intrin/rocm.py 
b/python/tvm/tir/tensor_intrin/rocm.py
index 12dabfb2cd..bfac2ca1d2 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -39,7 +39,6 @@ def sdot4(
 
         C[0] += T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
-            T.uint32(4),
             T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
             T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
             T.int32(0),
@@ -337,7 +336,6 @@ def get_mfma_intrin(k_dim, in_dtype="float32", 
out_dtype="float32", b_transposed
             T.launch_thread(tx, WARP_SIZE)
             C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
                 T.llvm_lookup_intrinsic_id(mfma_intrin),
-                T.uint32(6),
                 A[tx, 0:local_size],
                 B[tx, 0:local_size],
                 C[tx, 0:local_size_out],
@@ -365,7 +363,6 @@ def get_mfma_intrin(k_dim, in_dtype="float32", 
out_dtype="float32", b_transposed
 
             C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
                 T.llvm_lookup_intrinsic_id(mfma_intrin),
-                T.uint32(6),
                 T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
                 T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
                 C[tx, 0:local_size_out],
diff --git a/python/tvm/tir/tensor_intrin/x86.py 
b/python/tvm/tir/tensor_intrin/x86.py
index b4b6f07cd9..8f9518ce45 100644
--- a/python/tvm/tir/tensor_intrin/x86.py
+++ b/python/tvm/tir/tensor_intrin/x86.py
@@ -59,7 +59,6 @@ def dot_product_16x4_u8i8i32_vnni(
 
         C[T.ramp(T.int32(0), 1, 16)] = T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
-            T.uint32(3),
             C_i32x16,
             T.broadcast(A_i32, 16),
             B_i32x16,
@@ -86,7 +85,6 @@ def dot_product_16x4_u8i8i32_avx512(
 
         Red = T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
-            T.uint32(2),
             A_u8x64,
             B_i8x64,
             dtype="int16x32",
@@ -94,7 +92,6 @@ def dot_product_16x4_u8i8i32_avx512(
 
         C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
             T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
-            T.uint32(2),
             Red,
             T.int16x32(1),
             dtype="int32x16",
diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc
index 4abe667107..3adcfc82bb 100644
--- a/src/target/llvm/codegen_arm.cc
+++ b/src/target/llvm/codegen_arm.cc
@@ -67,7 +67,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {
 
 PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   using namespace tir;
-  const PrimExpr& e = call->args[2];
+  const PrimExpr& e = call->args[1];
   llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop;
   llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu;
 
@@ -77,7 +77,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
       (total_size != 128 && total_size != 64)) {
     Array<PrimExpr> vcnt_args;
     vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
-    vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
     vcnt_args.push_back(e);
     return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args);
   }
@@ -101,14 +100,12 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   ICHECK(c0 != nullptr);
   Array<PrimExpr> vcnt8_args;
   vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
-  vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt8_args.push_back(input8);
   PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, 
vcnt8_args);
 
   // Accumulation 8->16bit
   Array<PrimExpr> vcnt16_args;
   vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
-  vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt16_args.push_back(vcnt8);
   PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, 
vcnt16_args);
   if (call->dtype.bits() == 16) {
@@ -118,7 +115,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   // Accumulation 16->32bit
   Array<PrimExpr> vcnt32_args;
   vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
-  vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt32_args.push_back(vcnt16);
   PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, 
vcnt32_args);
   if (call->dtype.bits() == 32) {
@@ -128,7 +124,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
   // Accumulation 32->64bit
   Array<PrimExpr> vcnt64_args;
   vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
-  vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
   vcnt64_args.push_back(vcnt32);
   return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args);
 }
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 4fe120e945..5b2cb5cc95 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1359,34 +1359,18 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool 
use_float16_abi) {
 
 llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   if (op->op.same_as(builtin_call_llvm_intrin_) || 
op->op.same_as(builtin_call_llvm_pure_intrin_)) {
-    ICHECK_GE(op->args.size(), 2U);
+    ICHECK_GE(op->args.size(), 1U);
     llvm::Intrinsic::ID id = 
static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
-    int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
     std::vector<llvm::Value*> arg_value;
     std::vector<llvm::Type*> arg_type;
-    for (size_t i = 2; i < op->args.size(); ++i) {
+    for (size_t i = 1; i < op->args.size(); ++i) {
       arg_value.push_back(MakeValue(op->args[i]));
-      if (i - 2 < static_cast<size_t>(num_signature)) {
-        arg_type.push_back(arg_value.back()->getType());
-      }
+      arg_type.push_back(arg_value.back()->getType());
     }
-    // LLVM's prefetch intrinsic returns "void", while TVM's prefetch
-    // returns int32. This causes problems because prefetch is one of
-    // those intrinsics that is generated automatically via the
-    // tvm.intrin.rule mechanism. Any other intrinsic with a type
-    // mismatch will have to be treated specially here.
-    // TODO(kparzysz-quic): fix this once TVM prefetch uses the same
-    // type as LLVM.
-    llvm::Type* return_type =
-        (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op)) 
: t_void_;
+    llvm::Type* return_type = GetLLVMType(GetRef<PrimExpr>(op));
     llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
     ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
-#if TVM_LLVM_VERSION >= 130
-              << llvm::Intrinsic::getBaseName(id).str();
-#else
-              << llvm::Intrinsic::getName(id, {});
-#endif
-
+              << llvmGetIntrinName(id);
     // In earlier versions of LLVM's, the prefetch intrinsic is not
     // overloaded, and always takes the first argument as i8*.  If
     // this is the case, this argument should insert a cast to i8*.
@@ -1399,7 +1383,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* 
op) {
             builder_->CreatePointerCast(arg_value[0], 
llvmGetPointerTo(t_char_, addrspace));
       }
     }
-
     return builder_->CreateCall(f, arg_value);
   } else if (op->op.same_as(builtin::bitwise_and())) {
     return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
diff --git a/src/target/llvm/intrin_rule_llvm.cc 
b/src/target/llvm/intrin_rule_llvm.cc
index 15cc445090..17de699e00 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -266,7 +266,6 @@ 
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimEx
   ICHECK_EQ(call->args.size(), 1);
   Array<PrimExpr> cargs;
   cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
-  cargs.push_back(IntImm(DataType::UInt(32), 2));
   cargs.push_back(call->args[0]);
   cargs.push_back(IntImm(DataType::Int(1), 1));  // is_zero_undef
   // LLVM requires that the return type must match the first argument type
diff --git a/src/target/llvm/intrin_rule_llvm.h 
b/src/target/llvm/intrin_rule_llvm.h
index 4b64e92127..aa4f68d0b0 100644
--- a/src/target/llvm/intrin_rule_llvm.h
+++ b/src/target/llvm/intrin_rule_llvm.h
@@ -26,11 +26,14 @@
 
 #ifdef TVM_LLVM_VERSION
 
+#include <llvm/IR/Intrinsics.h>
 #include <tvm/ffi/function.h>
 #include <tvm/target/codegen.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 
+#include "llvm_instance.h"
+
 namespace tvm {
 namespace codegen {
 // num_signature means number of arguments used to query signature
@@ -41,7 +44,9 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) {
   Array<PrimExpr> cargs;
   // intrin id.
   cargs.push_back(IntImm(DataType::UInt(32), id));
-  cargs.push_back(IntImm(DataType::UInt(32), num_signature));
+  ICHECK_EQ(call->args.size(), num_signature)
+      << "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << 
num_signature
+      << " arguments, but got " << call->args.size();
 
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
@@ -56,7 +61,9 @@ inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) {
   Array<PrimExpr> cargs;
   // intrin id.
   cargs.push_back(IntImm(DataType::UInt(32), id));
-  cargs.push_back(IntImm(DataType::UInt(32), num_signature));
+  ICHECK_EQ(call->args.size(), num_signature)
+      << "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << 
num_signature
+      << " arguments, but got " << call->args.size();
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h
index f2468a8ef9..a68637cc84 100644
--- a/src/target/llvm/llvm_instance.h
+++ b/src/target/llvm/llvm_instance.h
@@ -51,6 +51,21 @@
 #define llvmGetPointerTo(arg, offset) (arg->getPointerTo(offset))
 #endif
 
+#if TVM_LLVM_VERSION >= 130
+#define llvmGetIntrinName(id) \
+  
std::string(llvm::Intrinsic::getBaseName(static_cast<llvm::Intrinsic::ID>(id)))
+#elif TVM_LLVM_VERSION >= 40
+// This is the version of Intrinsic::getName that works for overloaded
+// intrinsics. Helpfully, if we provide no types to this function, it
+// will give us the overloaded name without the types appended. This
+// should be enough information for most uses.
+#define llvmGetIntrinName(id) \
+  std::string(llvm::Intrinsic::getName(static_cast<llvm::Intrinsic::ID>(id), 
{}))
+#else
+// Nothing to do, just return the intrinsic id number
+#define llvmGetIntrinName(id) std::to_string(id)
+#endif
+
 namespace llvm {
 class LLVMContext;
 class MemoryBuffer;
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index 2daf941edf..dd9622999b 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -670,20 +670,7 @@ static void LLVMReflectionRegister() {
 #endif
            })
       .def("target.llvm_get_intrinsic_name",
-           [](int64_t id) -> String {
-#if TVM_LLVM_VERSION >= 130
-             return 
std::string(llvm::Intrinsic::getBaseName(static_cast<llvm::Intrinsic::ID>(id)));
-#elif TVM_LLVM_VERSION >= 40
-  // This is the version of Intrinsic::getName that works for overloaded
-  // intrinsics. Helpfully, if we provide no types to this function, it
-  // will give us the overloaded name without the types appended. This
-  // should be enough information for most uses.
-  return 
std::string(llvm::Intrinsic::getName(static_cast<llvm::Intrinsic::ID>(id), {}));
-#else
-  // Nothing to do, just return the intrinsic id number
-  return std::to_string(id);
-#endif
-           })
+           [](int64_t id) -> String { return llvmGetIntrinName(id); })
       .def("target.llvm_get_system_x86_vendor",
            []() -> String {
 #if TVM_LLVM_VERSION >= 120
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index 705739d98b..8e35092450 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -558,20 +558,16 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       Array<PrimExpr> new_args;
       if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
         // op->args[1], will give us total number of arguments to intrinsic
-        int num_signature = Downcast<IntImm>(op->args[1])->value;
         Array<PrimExpr> op_expr_args;
-        for (int i = 0; i < num_signature; i++) {
+        for (size_t i = 1; i < op->args.size(); ++i) {
           // Collect all intrinsic arguments
-          op_expr_args.push_back(op->args[i + 2]);
+          op_expr_args.push_back(op->args[i]);
         }
         // Generate RAMP nodes for intrinsic arguments
         Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
-        // Collect Intrinsic ID and no. of argument
-        for (int i = 0; i < 2; i++) {
-          new_args.push_back(op->args[i]);
-        }
+        new_args.push_back(op->args[0]);
         // Collect updated intrinsic arguments
-        for (int i = 0; i < num_signature; i++) {
+        for (size_t i = 0; i < updated_args.size(); ++i) {
           new_args.push_back(updated_args[i]);
         }
       } else {
diff --git a/tests/python/codegen/test_target_codegen_llvm.py 
b/tests/python/codegen/test_target_codegen_llvm.py
index 2105a2a2c3..15c030aeac 100644
--- a/tests/python/codegen/test_target_codegen_llvm.py
+++ b/tests/python/codegen/test_target_codegen_llvm.py
@@ -35,7 +35,7 @@ def test_llvm_intrin():
     n = tvm.runtime.convert(4)
     A = ib.pointer("float32", name="A")
     args = [tvm.tir.call_intrin("handle", "tir.address_of", A[0]), 0, 3, 1]
-    ib.emit(tvm.tir.Evaluate(tvm.tir.Call("int32", "tir.prefetch", args)))
+    ib.emit(tvm.tir.Evaluate(tvm.tir.Call("void", "tir.prefetch", args)))
     body = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], 
body).with_attr("global_symbol", "prefetch"))
@@ -47,7 +47,7 @@ def test_llvm_void_intrin():
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("uint8", name="A")
     # Create an intrinsic that returns void.
-    x = tvm.tir.call_llvm_intrin("", "llvm.va_start", tvm.tir.const(1, 
"uint32"), A.asobject().data)
+    x = tvm.tir.call_llvm_intrin("", "llvm.assume", tvm.tir.const(1, "int1"))
     ib.emit(x)
     body = ib.get()
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], 
body).with_attr("global_symbol", "main"))
@@ -72,9 +72,7 @@ def test_llvm_overloaded_intrin():
     def use_llvm_intrinsic(A, C):
         ib = tvm.tir.ir_builder.create()
         L = A.vload((0, 0))
-        I = tvm.tir.call_llvm_pure_intrin(
-            "int32", "llvm.ctlz", tvm.tir.const(2, "uint32"), L, 
tvm.tir.const(0, "int1")
-        )
+        I = tvm.tir.call_llvm_pure_intrin("int32", "llvm.ctlz", L, 
tvm.tir.const(0, "int1"))
         S = C.vstore((0, 0), I)
         ib.emit(S)
         return ib.get()
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py 
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index 9461da2277..965795d29e 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -206,7 +206,6 @@ def conv2d_async_non_contig(
                         B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, 
dtype="int32x32")
                         C[0:32] = T.call_llvm_pure_intrin(
                             
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
-                            T.uint32(3),
                             C[0:32],
                             B_i32x32,
                             A_i32,
@@ -240,7 +239,6 @@ def conv_approximation(size_a, size_w):
                         c_buffer[vn_index, x] = 0
                 c_buffer[vn_index, T.ramp(0, 1, 32)] = T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
-                    T.uint32(3),
                     c_buffer[vn_index, T.ramp(0, 1, 32)],
                     T.reinterpret(a_buffer[vn_index, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     T.reinterpret(w_buffer[vi_index, T.ramp(0, 1, 128)], 
dtype="int32x32"),
@@ -660,7 +658,6 @@ class ModulePipelined:
                         b_i32x32: T.int32x32 = T.reinterpret(b_i8x128, 
dtype="int32x32")
                         c_buffer[0:32] = T.call_llvm_pure_intrin(
                             
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
-                            T.uint32(3),
                             c_buffer[0:32],
                             T.broadcast(a_i32, 32),
                             b_i32x32,
@@ -815,7 +812,6 @@ class ModuleBase:
                             b_i32x32: T.int32x32 = T.reinterpret(b_i8x128, 
dtype="int32x32")
                             c_buffer[0:32] = T.call_llvm_pure_intrin(
                                 
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"),
-                                T.uint32(3),
                                 c_buffer[0:32],
                                 T.broadcast(a_i32, 32),
                                 b_i32x32,
diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py 
b/tests/python/contrib/test_hexagon/test_meta_schedule.py
index c2a5109aff..c7f9d2a00f 100644
--- a/tests/python/contrib/test_hexagon/test_meta_schedule.py
+++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py
@@ -294,7 +294,7 @@ class ModuleVRMPYAutoTensorize:
                         b_buffer[0, 0:128], dtype="int32x32"
                     )  # type: ignore
                     c_buffer[0:32] = T.call_llvm_pure_intrin(  # type: ignore
-                        4390, T.uint32(3), c_buffer[0:32], b_i32x32, a_i32, 
dtype="int32x32"
+                        4390, c_buffer[0:32], b_i32x32, a_i32, dtype="int32x32"
                     )
 
 
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
index 6822352568..6e1b7db4d5 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
@@ -85,7 +85,6 @@ def get_vmpy_operator(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vmpybusv.128B"),
-                    T.uint32(2),
                     T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     dtype="int16x128",
@@ -108,7 +107,6 @@ def get_vadd_operator(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin(
                     T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vaddubh.128B"),
-                    T.uint32(2),
                     T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     dtype="int16x128",
@@ -131,7 +129,6 @@ def get_vrmpy_operator(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"),
-                    T.uint32(2),
                     T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     dtype="int32x32",
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index 17e31af0a7..a0b94d89cf 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -87,7 +87,6 @@ def vrmpy(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"),
-                    T.uint32(2),
                     T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], 
dtype="int32x32"),
                     dtype="int32x32",
@@ -124,7 +123,6 @@ def preloaded_vrmpy(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_buffer[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = 
T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"),
-                    T.uint32(2),
                     T.reinterpret(
                         a_buffer[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 
128)], dtype="int32x32"
                     ),
@@ -168,7 +166,6 @@ def preallocated_vrmpy(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = 
T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"),
-                    T.uint32(2),
                     T.reinterpret(
                         a_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 
128)],
                         dtype="int32x32",
@@ -267,7 +264,6 @@ def preallocated_single_dma_vrmpy(operations):
                 vn_ind = T.axis.remap("S", [n])
                 c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = 
T.call_llvm_intrin(
                     
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"),
-                    T.uint32(2),
                     T.reinterpret(
                         a_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 
128)],
                         dtype="int32x32",
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py 
b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py
index 1272b35451..313657108c 100644
--- 
a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py
+++ 
b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py
@@ -236,7 +236,6 @@ class Conv2dNCHWcVNNIModuleTensorized:
                     C_i32x16 = C.vload([0], dtype="int32x16")
                     C[T.ramp(0, 1, 16)] = T.call_llvm_pure_intrin(
                         
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
-                        T.uint32(3),
                         C_i32x16,
                         T.broadcast(A_i32, 16),
                         B_i32x16,
diff --git a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py 
b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py
index c3a76e101f..637f3093d8 100644
--- a/tests/python/meta_schedule/test_meta_schedule_trace_apply.py
+++ b/tests/python/meta_schedule/test_meta_schedule_trace_apply.py
@@ -1159,7 +1159,7 @@ def get_conv2d_vnni_mod(intrin_id):
                             B_i8x64: T.int8x64 = B[0, 0:64]
                             B_i32x16: T.int32x16 = T.reinterpret(B_i8x64, 
dtype="int32x16")
                             C_i32x16: T.int32x16 = C[0:16]
-                            C[0:16] = 
T.call_llvm_pure_intrin(T.uint32(intrin_id), T.uint32(3), C_i32x16, 
T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16")
+                            C[0:16] = 
T.call_llvm_pure_intrin(T.uint32(intrin_id), C_i32x16, T.broadcast(A_i32, 16), 
B_i32x16, dtype="int32x16")
                     for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 7):
                         for ax4_fused in T.vectorized(16):
                             with T.block("T_cast_8"):
diff --git a/tests/python/tir-base/test_tir_ops.py 
b/tests/python/tir-base/test_tir_ops.py
index f2a18aeae5..dfa5cbab80 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -236,9 +236,9 @@ def test_comm_reducer(num_args):
 
 def test_llvm_intrin():
     with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function 
llvm.dummy"):
-        a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy", 0)
+        a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy")
     with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function 
llvm.dummy"):
-        a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy", 0)
+        a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy")
 
 
 if __name__ == "__main__":
diff --git 
a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py 
b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py
index 63700853b3..18e16513f4 100644
--- 
a/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py
+++ 
b/tests/python/tir-transform/test_tir_transform_lower_cross_thread_reduction.py
@@ -832,7 +832,7 @@ def single_reduction_loop_with_tensorize(
                 B_i8x128 = B[0, 0:128]
                 B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, 
dtype="int32x32")
                 C[0:32] = T.call_llvm_pure_intrin(
-                    4217, T.uint32(3), C[0:32], T.broadcast(A_i32, 32), 
B_i32x32, dtype="int32x32"
+                    4217, C[0:32], T.broadcast(A_i32, 32), B_i32x32, 
dtype="int32x32"
                 )
 
 
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py 
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 13bb1c60cb..5a4d4ea17d 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -781,16 +781,14 @@ def test_vectorize_llvm_pure_intrin(extent, vec_str, 
target):
         @T.prim_func
         def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
             for j in T.vectorized(extent):
-                A[j] = T.call_llvm_pure_intrin(
-                    "float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
-                )
+                A[j] = T.call_llvm_pure_intrin("float32", "llvm.sqrt", B[j])
 
     @I.ir_module
     class After:
         @T.prim_func
         def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
             A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
-                vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, 
extent)]
+                vec_str, "llvm.sqrt", B[T.Ramp(0, 1, extent)]
             )
 
     with tvm.target.Target(target):
@@ -809,16 +807,14 @@ def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, 
target):
         @T.prim_func
         def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
             for j in T.vectorized(extent):
-                A[j] = T.call_llvm_pure_intrin(
-                    "int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
-                )
+                A[j] = T.call_llvm_pure_intrin("int32", "llvm.lround", B[j])
 
     @I.ir_module
     class After:
         @T.prim_func
         def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
             A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
-                vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 
1, extent)]
+                vec_str, "llvm.lround", B[T.Ramp(0, 1, extent)]
             )
 
     with pytest.raises(Exception) as e_info:
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 267fae20ca..be8b03357d 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -493,10 +493,10 @@ T.Cast("float64", a)
 
 
 def test_llvm_intrin_imm():
-    a = tir.call_llvm_intrin("int32x4", "llvm.donothing", T.uint32(0))
-    _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing", 
T.uint32(0))')
-    a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing", T.uint32(0))
-    _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing", 
T.uint32(0))')
+    a = tir.call_llvm_intrin("int32x4", "llvm.donothing")
+    _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing")')
+    a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing")
+    _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing")')
 
 
 def test_binary_arith():
@@ -1034,16 +1034,14 @@ def test_vectorize_llvm_pure_intrin():
     def main(a: T.handle, b: T.handle):
         A = T.match_buffer(a, (4,), "float32")
         B = T.match_buffer(b, (4,), "float32")
-        A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin(
-            "float32x4", "llvm.sqrt", 1, B[T.Ramp(0, 1, 4)]
-        )
+        A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 
B[T.Ramp(0, 1, 4)])
 
     expected_output = """
 # from tvm.script import tir as T
 
 @T.prim_func
 def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
-    A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", 1, B[0:4])
+    A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", B[0:4])
     """
     _assert_print(main, expected_output)
 
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 0e1b328844..2be2e2e98d 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -375,7 +375,6 @@ def opt_gemm_mod_host():
                                 for x_c in T.serial(0, 32):
                                     C_global[T.ramp((x_c * 32), 1, 32)] = 
T.call_llvm_pure_intrin(
                                         T.uint32(97),
-                                        T.uint32(3),
                                         T.broadcast(
                                             A[
                                                 (
@@ -393,7 +392,6 @@ def opt_gemm_mod_host():
                                     )
                                     C_global[T.ramp((x_c * 32), 1, 32)] = 
T.call_llvm_pure_intrin(
                                         T.uint32(97),
-                                        T.uint32(3),
                                         T.broadcast(
                                             A[
                                                 (
@@ -416,7 +414,6 @@ def opt_gemm_mod_host():
                                     )
                                     C_global[T.ramp((x_c * 32), 1, 32)] = 
T.call_llvm_pure_intrin(
                                         T.uint32(97),
-                                        T.uint32(3),
                                         T.broadcast(
                                             A[
                                                 (
@@ -439,7 +436,6 @@ def opt_gemm_mod_host():
                                     )
                                     C_global[T.ramp((x_c * 32), 1, 32)] = 
T.call_llvm_pure_intrin(
                                         T.uint32(97),
-                                        T.uint32(3),
                                         T.broadcast(
                                             A[
                                                 (
@@ -3216,7 +3212,6 @@ def llvm_intrin_call():
                 )
                 B[vi] = T.call_llvm_pure_intrin(
                     T.llvm_lookup_intrinsic_id("llvm.ctpop.i8"),
-                    T.uint32(1),
                     A[vi],
                     dtype="uint8",
                 )


Reply via email to