Cydia2018 opened a new issue, #15805: URL: https://github.com/apache/tvm/issues/15805
This question is related to https://github.com/apache/tvm/pull/15487 I tried to embed AMD's **attention** operator directly in the llama model. The model could be compiled normally, but I encountered a runtime error: Cannot find PackedFunc attention in either Relax VM kernel library. The operator file is as follows: `import os import tempfile import pytest import tvm from tvm.script import ir as I, tir as T, relax as R from tvm.relax.frontend import nn from tvm.relax.frontend.nn import spec def _gen_extern_module(mod_dir, file): src = """ #include <tvm/runtime/packed_func.h> #include <dlpack/dlpack.h> #include <iostream> #include <ck/ck.hpp> #include <ck/tensor_operation/gpu/device/gemm_specialization.hpp> #include <ck/tensor_operation/gpu/device/tensor_layout.hpp> #include <ck/tensor_operation/gpu/element/element_wise_operation.hpp> #include <ck/tensor_operation/gpu/device/tensor_specialization.hpp> #include <ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp> void fused_relax_nn_attention_composable_kernel1_(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* out0) { using F16 = ck::half_t; using F32 = float; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; using B0DataType = F16; using B1DataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F16; using D0DataType = F16; using Acc0BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>; static constexpr ck::index_t NumDimG = 2; static constexpr ck::index_t NumDimM = 1; static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimO = 1; using AElementOp = PassThrough; using B0ElementOp = PassThrough; using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; using B1ElementOp = PassThrough; using CElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto MaskingSpec = ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc1BiasDataType, AccDataType, CShuffleDataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, 256, 128, // MPerBlock 128, // NPerBlock 32, // KPerBlock 64, // Gemm1NPerBlock 32, // Gemm1KPerBlock 8, // AK1 8, // BK1 2, // B1K1 32, // MPerXDL 32, // NPerXDL 1, // MXdlPerWave 4, // NXdlPerWave 2, // Gemm1NXdlPerWave ck::Sequence<4, 64, 1>, // ABlockTransfer ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, true, ck::Sequence<4, 64, 1>, // BBlockTransfer ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, true, ck::Sequence<16, 16, 1>, // B1BlockTransfer ck::Sequence<0, 2, 1>, ck::Sequence<0, 2, 1>, 1, 4, 2, false, 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle ck::Sequence<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // param ck::index_t M = q->shape[1]; ck::index_t N = k->shape[1]; ck::index_t K = 128; ck::index_t O = 128; ck::index_t G0 = 1; ck::index_t G1 = 40; float scale = 0.088388348; auto a_element_op = AElementOp{}; auto b0_element_op = B0ElementOp{}; auto acc0_element_op = Acc0ElementOp{scale}; auto b1_element_op = B1ElementOp{}; auto c_element_op = CElementOp{}; bool input_permute = true; bool output_permute = true; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_strides = input_permute ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_strides = input_permute ? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] : std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_strides = input_permute ? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] : std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_strides = output_permute ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); auto argument = gemm.MakeArgument( static_cast<ADataType*>(q->data), static_cast<B0DataType*>(k->data), static_cast<B1DataType*>(v->data), static_cast<CDataType*>(out0->data), {}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc1_biases; a_gs_ms_ks_lengths, a_gs_ms_ks_strides, b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, b1_gs_os_ns_lengths, b1_gs_os_ns_strides, c_gs_ms_os_lengths, c_gs_ms_os_strides, {}, // std::array<std::vector<ck::index_t>, 1>(acc0_biases_gs_ms_ns_lengths), {}, // std::array<std::vector<ck::index_t>, 1>(acc0_biases_gs_ms_ns_strides), {}, // std::array<std::vector<ck::index_t>, 1>(acc1_biases_gs_ms_os_lengths), {}, // std::array<std::vector<ck::index_t>, 1>(acc1_biases_gs_ms_os_strides), a_element_op, b0_element_op, acc0_element_op, b1_element_op, c_element_op); if(!gemm.IsSupportedArgument(argument)) { std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; return; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, false}); } int fused_relax_nn_attention_composable_kernel1_wrapper_(DLTensor* arg0, DLTensor* arg1, DLTensor* arg2, DLTensor* out0) { fused_relax_nn_attention_composable_kernel1_(arg0, arg1, arg2, out0); return 0; } #ifdef __cplusplus extern "C" { #endif TVM_DLL int32_t fused_relax_nn_attention_composable_kernel1(DLTensor* arg0, DLTensor* arg1, DLTensor* arg2, DLTensor* out0) { //DLTensor* arg0 = (DLTensor*)(((TVMValue*)args)[0].v_handle); //DLTensor* arg1 = (DLTensor*)(((TVMValue*)args)[1].v_handle); //DLTensor* arg2 = (DLTensor*)(((TVMValue*)args)[2].v_handle); //int arg3 = (int)((TVMValue*)args)[3]; //int arg4 = (int)((TVMValue*)args)[4]; //int arg5 = (int)((TVMValue*)args)[5]; //int arg6 = (int)((TVMValue*)args)[6]; //float arg7 = (float)((TVMValue*)args)[7]; //DLTensor* arg8 = (DLTensor*)(((TVMValue*)args)[8].v_handle); //DLTensor* ret9 = (DLTensor*)(((TVMValue*)args)[9].v_handle); // fused_relax_nn_attention_composable_kernel1_wrapper_(arg0,arg1,arg2,arg3,arg4,arg5,arg6,arg7,arg8,ret9); fused_relax_nn_attention_composable_kernel1_wrapper_(arg0, arg1, arg2, out0); return 0; } #ifdef __cplusplus } #endif TVM_DLL_EXPORT_TYPED_FUNC(attention, fused_relax_nn_attention_composable_kernel1) """ with open(f"{mod_dir}/{file}.cc", "w") as cc_file: cc_file.write(src) cur_dir = os.path.dirname(os.path.abspath(__file__)) inc_dir = "/home/xxx/tvm-unity/3rdparty/composable_kernel" os.system( f"hipcc -c {mod_dir}/{file}.cc " f"-o {mod_dir}/{file}.o " f"-I {inc_dir}/include " f"-I /home/xxx/tvm-unity/include " f"-I /home/xxx/tvm-unity/3rdparty/dlpack/include " f"-I /home/xxx/tvm-unity/3rdparty/dmlc-core/include " f"-std=c++17 " ) return f"{mod_dir}/{file}.o" def get_extern_module(shape_q, shape_k, shape_v, shape_out): dtype = "float16" tmp_dir = tempfile.mkdtemp() obj_file = _gen_extern_module(tmp_dir, "test") func_name = "attention" ext_mod = nn.ExternModule( module_spec=spec.ExternModuleSpec( filename=obj_file, functions=[ spec.ExternFunctionSpec( symbol=func_name, args=[ spec.Tensor(shape_q, dtype), spec.Tensor(shape_k, dtype), spec.Tensor(shape_v, dtype), ], ret=spec.Tensor(shape_out, dtype), ) ], ) ) class AttentionModule(nn.Module): def __init__(self) -> None: self.attention = ext_mod def forward(self, q, k, v): # query = nn.Tensor(_expr = q) # key = nn.Tensor(_expr = k) # value = nn.Tensor(_expr = v) # return self.attention.get_extern_func(func_name)(query, key, value) return self.attention.get_extern_func(func_name)(q, k, v) attention_mod = AttentionModule() return attention_mod` @cyx-6 thanks lot for your work, looking forward to your reply -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
