hmz0412 opened a new issue, #17572:
URL: https://github.com/apache/tvm/issues/17572

   I want to create my own prefill tir function, but when i built this module 
to test, the error messages above were appeared.
   I don't know the reason why the bug will be created.
   
   Actual situation
   `Traceback (most recent call last):
     File 
"/home/octal/mlc-llm/3rdparty/tvm/python/tvm/relax/frontend/nn/llm/test_equal.py",
 line 554, in <module>
       lib_cpu = tvm.build(IR_cpu, target="llvm")
     File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/driver/build_module.py", 
line 297, in build
       rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
     File 
"/home/octal/mlc-llm/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 
245, in __call__
       raise_last_ffi_error()
     File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/_ffi/base.py", line 481, 
in raise_last_ffi_error
       raise py_err
     File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 
531, in operator()
       return TIRToRuntime(inputs_arg, host_target);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 
492, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, 
void> const&, tvm::Target const&)
       auto pair = SplitMixedModule(ir_module, target, target_host);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 
418, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target 
const&)
       mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, 
target));
     File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line 
291, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
       mod = seq(std::move(mod));
     File 
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc", line 
458, in operator()
       func = MakePackedAPI(std::move(func));
     File 
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc", line 
420, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
       Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
     File 
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", 
line 186, in tvm::tir::UndefinedVars(tvm::tir::Stmt const&, 
tvm::runtime::Array<tvm::tir::Var, void> const&)
       m(stmt);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line 
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
       fvisit(arr[i]);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
119, in operator()
       VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line 
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
       this->VisitStmt(op->body);
     File 
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", 
line 61, in tvm::tir::VarUseDefAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
       this->HandleDef(op->loop_var);
     File 
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", 
line 136, in tvm::tir::VarUseDefAnalyzer::HandleDef(tvm::tir::Var const&)
       ICHECK(!def_count_.count(v)) << "variable " << v->name_hint
   tvm.error.InternalError: Traceback (most recent call last):
     45: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:531
     44: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, 
void> const&, tvm::Target const&)
           at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:492
     43: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target 
const&)
           at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:418
     42: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
           at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:291
     41: operator()
           at 
/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc:458
     40: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
           at 
/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc:420
     39: tvm::tir::UndefinedVars(tvm::tir::Stmt const&, 
tvm::runtime::Array<tvm::tir::Var, void> const&)
           at 
/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:186
     38: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     37: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     36: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     35: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     34: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     33: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     32: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     31: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     30: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     29: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     28: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     27: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     26: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     25: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     24: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     23: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     22: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     21: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     20: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     19: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     18: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     17: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     16: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     15: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     14: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     13: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     12: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     11: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     10: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const 
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
     9: operator()
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
     8: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     7: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     6: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     5: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     4: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     3: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     2: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
           at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
     1: tvm::tir::VarUseDefAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
           at 
/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:61
     0: tvm::tir::VarUseDefAnalyzer::HandleDef(tvm::tir::Var const&)
           at 
/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:136
     File 
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc", 
line 138
   InternalError: Check failed: (!use_count_.count(v)) is false: variable b has 
been used before definition!`
   
   
   
   Reproduce
   `
   import math
   from typing import Any, Dict, Tuple
   
   import tvm
   from tvm import relax as rx
   from tvm import tir
   from tvm.relax.frontend.nn import Object, Tensor
   from tvm.runtime import DataType
   from tvm.script import tir as T
   from tvm.script import ir as I
   from tvm.target import Target
   
   def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype):
   
       group_size = h_q // h_kv
       sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
   
       @I.ir_module
       class cpu_module:
   
           @T.prim_func
           def batch_prefill_ragged_kv(  # pylint: disable=too-many-branches
               var_q: T.handle, # [total_len, h_q, d]
               var_q_indptr: T.handle, # [batch_size + 1]
               var_k: T.handle, # [total_len, h_kv, d]
               var_v: T.handle, # [total_len, h_kv, d]
               var_kv_indptr: T.handle, # [batch_size + 1]
               var_q_rope_position: T.handle, # [total_q_len]
               var_k_rope_pos_offset: T.handle, # [b]
               var_output: T.handle, # [total_len, h_q, d]
               var_lse: T.handle, # [total_len, h_q]
               causal: T.int32,
               rotary_mode: T.int32,
               rope_scale: T.float32,
               rope_theta: T.float32,
               attn_score_scaling_factor: T.float32
           ):
               batch_size = T.int32(is_size_var=True)
               qo_len = T.int32(is_size_var=True)
               kv_len = T.int32(is_size_var=True)
               q_indptr_elem_offset = T.int32(is_size_var=True)
               kv_indptr_elem_offset = T.int32(is_size_var=True)
               q_rope_position_elem_offset = T.int32(is_size_var=True)
               k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
   
               q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
               q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), 
"int32", elem_offset=q_indptr_elem_offset)
               k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
               v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
               kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), 
"int32", elem_offset=kv_indptr_elem_offset)
               q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
               k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, 
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
               output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
               lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # 
pylint: disable=unused-variable
   
   
               for b in T.serial(batch_size):
                   with T.block("attn"):
                       
                       # q_token_start = T.alloc_buffer([1,], "uint32")
                       # q_num = T.alloc_buffer([1,], "uint32")
                       # k_token_start = T.alloc_buffer([1,], "int32")
                       # k_num = T.alloc_buffer([1,], "int32")
   
                       softmax_sum = T.alloc_buffer([h_q], "float32")
                       m_prev = T.alloc_buffer([h_q], "float32")
                       m_new = T.alloc_buffer([h_q], "float32")
                       d_prev = T.alloc_buffer([h_q], "float32")
                       d_new = T.alloc_buffer([h_q], "float32")
                       sum = T.alloc_buffer([d], "float32")
   
                       max_score = T.alloc_buffer([h_q], "float32")
                       attention_scores = T.alloc_buffer([kv_indptr[b + 1] - 
kv_indptr[b], h_q], "float32")
                       exp_scores = T.alloc_buffer([kv_indptr[b + 1] - 
kv_indptr[b], h_q], "float32")
                       attention_score = T.alloc_buffer([1,], "float32")
                       
                       for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
                       
                           for i in T.serial(h_q):
                               max_score[i] = -5e4
                               m_prev[i] = -5e4
                               d_prev[i] = 1.0
   
                           for k_idx in T.serial(kv_indptr[b + 1] - 
kv_indptr[b]):
                               for h in T.serial(h_q):
                                   h_kv_idx = h // group_size
   
                                   if _causal_mask(causal,
                                                   row=q_idx,
                                                   col=k_idx,
                                                   kv_len=kv_indptr[b + 1] - 
kv_indptr[b],
                                                   qo_len=q_indptr[b + 1] - 
q_indptr[b]):
                                       result = 0.0
                                       for d_idx in T.serial(d):
                                           result += q[q_indptr[b] + q_idx, h, 
d_idx] * k[kv_indptr[b] + k_idx, h_kv_idx, d_idx]
                                       attention_score[0] = result * sm_scale
                                               
                                   else:
                                       attention_score[0] = -5e4 * sm_scale
                                   attention_scores[k_idx, h] = 
attention_score[0]
                                   max_score[h] = T.max(max_score[h], 
attention_score[0])
                                   m_new[h] = T.max(m_prev[h], max_score[h])
                                   
   
                           for h in T.serial(h_q):
                               d_new[h] = d_prev[h] * T.exp2(m_prev[h] - 
m_new[h])                
   
                           for h in T.serial(h_q):
                               softmax_sum[h] = 0.0
                               for k_idx in T.serial(kv_indptr[b + 1] - 
kv_indptr[b]):
                                   exp_scores[k_idx, h] = 
T.exp(attention_scores[k_idx, h] - m_new[h])
                                   softmax_sum[h] += exp_scores[k_idx, h]
                               d_new[h]+=softmax_sum[h]
                                   
                           d_prev = d_new
   
                           for h in T.serial(h_q):
                               h_kv_idx = h // group_size
                               
                               for i in T.serial(d):
                                   sum[i] = 0.0
                               for v_idx in T.serial(kv_indptr[b + 1] - 
kv_indptr[b]):
                                   weight = exp_scores[v_idx, h] / 
softmax_sum[h]
                                   for i in T.serial(d):
                                       sum[i] += v[kv_indptr[b] + v_idx, 
h_kv_idx, i] * weight
                               for i in T.serial(d):
                                   output[q_indptr[b] + q_idx, h, i] = sum[i]
       return cpu_module
   
   IR_cpu = _attention_prefill_ragged_cpu(2, 16, 256, "float32")
   lib_cpu = tvm.build(IR_cpu, target="llvm")
   
   
   `
   
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to