hmz0412 opened a new issue, #17572:
URL: https://github.com/apache/tvm/issues/17572
I want to create my own prefill tir function, but when i built this module
to test, the error messages above were appeared.
I don't know the reason why the bug will be created.
Actual situation
`Traceback (most recent call last):
File
"/home/octal/mlc-llm/3rdparty/tvm/python/tvm/relax/frontend/nn/llm/test_equal.py",
line 554, in <module>
lib_cpu = tvm.build(IR_cpu, target="llvm")
File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/driver/build_module.py",
line 297, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File
"/home/octal/mlc-llm/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line
245, in __call__
raise_last_ffi_error()
File "/home/octal/mlc-llm/3rdparty/tvm/python/tvm/_ffi/base.py", line 481,
in raise_last_ffi_error
raise py_err
File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line
531, in operator()
return TIRToRuntime(inputs_arg, host_target);
File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line
492, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void,
void> const&, tvm::Target const&)
auto pair = SplitMixedModule(ir_module, target, target_host);
File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line
418, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target
const&)
mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed,
target));
File "/home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc", line
291, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
mod = seq(std::move(mod));
File
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc", line
458, in operator()
func = MakePackedAPI(std::move(func));
File
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc", line
420, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
File
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc",
line 186, in tvm::tir::UndefinedVars(tvm::tir::Stmt const&,
tvm::runtime::Array<tvm::tir::Var, void> const&)
m(stmt);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h", line
35, in VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
fvisit(arr[i]);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
119, in operator()
VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File "/home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc", line
58, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
this->VisitStmt(op->body);
File
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc",
line 61, in tvm::tir::VarUseDefAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
this->HandleDef(op->loop_var);
File
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc",
line 136, in tvm::tir::VarUseDefAnalyzer::HandleDef(tvm::tir::Var const&)
ICHECK(!def_count_.count(v)) << "variable " << v->name_hint
tvm.error.InternalError: Traceback (most recent call last):
45: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:531
44: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void,
void> const&, tvm::Target const&)
at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:492
43: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target
const&)
at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:418
42: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
at /home/octal/mlc-llm/3rdparty/tvm/src/driver/driver_api.cc:291
41: operator()
at
/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc:458
40: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
at
/home/octal/mlc-llm/3rdparty/tvm/src/tir/transforms/make_packed_api.cc:420
39: tvm::tir::UndefinedVars(tvm::tir::Stmt const&,
tvm::runtime::Array<tvm::tir::Var, void> const&)
at
/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:186
38: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
37: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
36: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
35: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
34: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
33: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
32: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
31: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
30: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
29: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
28: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
27: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
26: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
25: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
24: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
23: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
22: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
21: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
20: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
19: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
18: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
17: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
16: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
15: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
14: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
13: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
12: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
11: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::SeqStmtNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
10: VisitArray<tvm::tir::Stmt, tvm::tir::StmtVisitor::VisitStmt_(const
tvm::tir::SeqStmtNode*)::<lambda(const tvm::tir::Stmt&)> >
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/functor_common.h:35
9: operator()
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:119
8: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
7: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
6: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
5: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
4: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
3: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
2: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::AllocateNode const*)
at /home/octal/mlc-llm/3rdparty/tvm/src/tir/ir/stmt_functor.cc:58
1: tvm::tir::VarUseDefAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
at
/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:61
0: tvm::tir::VarUseDefAnalyzer::HandleDef(tvm::tir::Var const&)
at
/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc:136
File
"/home/octal/mlc-llm/3rdparty/tvm/src/tir/analysis/var_use_def_analysis.cc",
line 138
InternalError: Check failed: (!use_count_.count(v)) is false: variable b has
been used before definition!`
Reproduce
`
import math
from typing import Any, Dict, Tuple
import tvm
from tvm import relax as rx
from tvm import tir
from tvm.relax.frontend.nn import Object, Tensor
from tvm.runtime import DataType
from tvm.script import tir as T
from tvm.script import ir as I
from tvm.target import Target
def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype):
group_size = h_q // h_kv
sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
@I.ir_module
class cpu_module:
@T.prim_func
def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
var_q: T.handle, # [total_len, h_q, d]
var_q_indptr: T.handle, # [batch_size + 1]
var_k: T.handle, # [total_len, h_kv, d]
var_v: T.handle, # [total_len, h_kv, d]
var_kv_indptr: T.handle, # [batch_size + 1]
var_q_rope_position: T.handle, # [total_q_len]
var_k_rope_pos_offset: T.handle, # [b]
var_output: T.handle, # [total_len, h_q, d]
var_lse: T.handle, # [total_len, h_q]
causal: T.int32,
rotary_mode: T.int32,
rope_scale: T.float32,
rope_theta: T.float32,
attn_score_scaling_factor: T.float32
):
batch_size = T.int32(is_size_var=True)
qo_len = T.int32(is_size_var=True)
kv_len = T.int32(is_size_var=True)
q_indptr_elem_offset = T.int32(is_size_var=True)
kv_indptr_elem_offset = T.int32(is_size_var=True)
q_rope_position_elem_offset = T.int32(is_size_var=True)
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,),
"int32", elem_offset=q_indptr_elem_offset)
k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,),
"int32", elem_offset=kv_indptr_elem_offset)
q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,),
"int32", elem_offset=q_rope_position_elem_offset)
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset,
(batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") #
pylint: disable=unused-variable
for b in T.serial(batch_size):
with T.block("attn"):
# q_token_start = T.alloc_buffer([1,], "uint32")
# q_num = T.alloc_buffer([1,], "uint32")
# k_token_start = T.alloc_buffer([1,], "int32")
# k_num = T.alloc_buffer([1,], "int32")
softmax_sum = T.alloc_buffer([h_q], "float32")
m_prev = T.alloc_buffer([h_q], "float32")
m_new = T.alloc_buffer([h_q], "float32")
d_prev = T.alloc_buffer([h_q], "float32")
d_new = T.alloc_buffer([h_q], "float32")
sum = T.alloc_buffer([d], "float32")
max_score = T.alloc_buffer([h_q], "float32")
attention_scores = T.alloc_buffer([kv_indptr[b + 1] -
kv_indptr[b], h_q], "float32")
exp_scores = T.alloc_buffer([kv_indptr[b + 1] -
kv_indptr[b], h_q], "float32")
attention_score = T.alloc_buffer([1,], "float32")
for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]):
for i in T.serial(h_q):
max_score[i] = -5e4
m_prev[i] = -5e4
d_prev[i] = 1.0
for k_idx in T.serial(kv_indptr[b + 1] -
kv_indptr[b]):
for h in T.serial(h_q):
h_kv_idx = h // group_size
if _causal_mask(causal,
row=q_idx,
col=k_idx,
kv_len=kv_indptr[b + 1] -
kv_indptr[b],
qo_len=q_indptr[b + 1] -
q_indptr[b]):
result = 0.0
for d_idx in T.serial(d):
result += q[q_indptr[b] + q_idx, h,
d_idx] * k[kv_indptr[b] + k_idx, h_kv_idx, d_idx]
attention_score[0] = result * sm_scale
else:
attention_score[0] = -5e4 * sm_scale
attention_scores[k_idx, h] =
attention_score[0]
max_score[h] = T.max(max_score[h],
attention_score[0])
m_new[h] = T.max(m_prev[h], max_score[h])
for h in T.serial(h_q):
d_new[h] = d_prev[h] * T.exp2(m_prev[h] -
m_new[h])
for h in T.serial(h_q):
softmax_sum[h] = 0.0
for k_idx in T.serial(kv_indptr[b + 1] -
kv_indptr[b]):
exp_scores[k_idx, h] =
T.exp(attention_scores[k_idx, h] - m_new[h])
softmax_sum[h] += exp_scores[k_idx, h]
d_new[h]+=softmax_sum[h]
d_prev = d_new
for h in T.serial(h_q):
h_kv_idx = h // group_size
for i in T.serial(d):
sum[i] = 0.0
for v_idx in T.serial(kv_indptr[b + 1] -
kv_indptr[b]):
weight = exp_scores[v_idx, h] /
softmax_sum[h]
for i in T.serial(d):
sum[i] += v[kv_indptr[b] + v_idx,
h_kv_idx, i] * weight
for i in T.serial(d):
output[q_indptr[b] + q_idx, h, i] = sum[i]
return cpu_module
IR_cpu = _attention_prefill_ragged_cpu(2, 16, 256, "float32")
lib_cpu = tvm.build(IR_cpu, target="llvm")
`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]