Cydia2018 opened a new issue, #15805:
URL: https://github.com/apache/tvm/issues/15805

   This question is related to https://github.com/apache/tvm/pull/15487
   
   I tried to embed AMD's **attention** operator directly in the llama model. 
The model could be compiled normally, but I encountered a runtime error: Cannot 
find PackedFunc attention in either Relax VM kernel library.
   
   The operator file is as follows:
   
   `import os
   import tempfile
   
   import pytest
   
   import tvm
   from tvm.script import ir as I, tir as T, relax as R
   from tvm.relax.frontend import nn
   from tvm.relax.frontend.nn import spec
   
   
   def _gen_extern_module(mod_dir, file):
       src = """
       #include <tvm/runtime/packed_func.h>
       #include <dlpack/dlpack.h>
       #include <iostream>
       #include <ck/ck.hpp>
       #include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
       #include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
       #include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
       #include <ck/tensor_operation/gpu/device/tensor_specialization.hpp>
       #include 
<ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp>
   
       void fused_relax_nn_attention_composable_kernel1_(DLTensor* q, DLTensor* 
k, DLTensor* v,
                                                         DLTensor* out0) {
           using F16 = ck::half_t;
           using F32 = float;
           
           using PassThrough = ck::tensor_operation::element_wise::PassThrough;
   
           using ADataType        = F16;
           using B0DataType       = F16;
           using B1DataType       = F16;
           using AccDataType      = F32;
           using CShuffleDataType = F32;
           using CDataType        = F16;
           using D0DataType       = F16;
           using Acc0BiasDataType = ck::Tuple<>;
           using Acc1BiasDataType = ck::Tuple<>;
   
           static constexpr ck::index_t NumDimG = 2;
           static constexpr ck::index_t NumDimM = 1;
           static constexpr ck::index_t NumDimN = 1;
           static constexpr ck::index_t NumDimK = 1;
           static constexpr ck::index_t NumDimO = 1;
   
           using AElementOp    = PassThrough;
           using B0ElementOp   = PassThrough;
           using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
           using B1ElementOp   = PassThrough;
           using CElementOp    = PassThrough;
   
           static constexpr auto GemmSpec = 
ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
           static constexpr auto MaskingSpec = 
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
   
           static constexpr auto TensorSpecA  = 
ck::tensor_operation::device::TensorSpecialization::Default;
           static constexpr auto TensorSpecB0 = 
ck::tensor_operation::device::TensorSpecialization::Default;
           static constexpr auto TensorSpecB1 = 
ck::tensor_operation::device::TensorSpecialization::Default;
           static constexpr auto TensorSpecC  = 
ck::tensor_operation::device::TensorSpecialization::Default;
           
           using DeviceGemmInstance =
           
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
               NumDimG,
               NumDimM,
               NumDimN,
               NumDimK,
               NumDimO,
               ADataType,
               B0DataType,
               B1DataType,
               CDataType,
               Acc0BiasDataType,
               Acc1BiasDataType,
               AccDataType,
               CShuffleDataType,
               AElementOp,
               B0ElementOp,
               Acc0ElementOp,
               B1ElementOp,
               CElementOp,
               GemmSpec,
               TensorSpecA,
               TensorSpecB0,
               TensorSpecB1,
               TensorSpecC,
               1,
               256,
               128,         // MPerBlock
               128,         // NPerBlock
               32,          // KPerBlock
               64,          // Gemm1NPerBlock
               32,          // Gemm1KPerBlock
               8,           // AK1
               8,           // BK1
               2,           // B1K1
               32,          // MPerXDL
               32,          // NPerXDL
               1,           // MXdlPerWave
               4,           // NXdlPerWave
               2,           // Gemm1NXdlPerWave
               ck::Sequence<4, 64, 1>, // ABlockTransfer
               ck::Sequence<1, 0, 2>,
               ck::Sequence<1, 0, 2>,
               2,
               8,
               8,
               true,
               ck::Sequence<4, 64, 1>, // BBlockTransfer
               ck::Sequence<1, 0, 2>,
               ck::Sequence<1, 0, 2>,
               2,
               8,
               8,
               true,
               ck::Sequence<16, 16, 1>, // B1BlockTransfer
               ck::Sequence<0, 2, 1>,
               ck::Sequence<0, 2, 1>,
               1,
               4,
               2,
               false,
               1,              // CShuffleMXdlPerWavePerShuffle
               2,              // CShuffleNXdlPerWavePerShuffle
               ck::Sequence<1, 32, 1, 8>, // 
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
               8,              // CShuffleBlockTransferScalarPerVector_NPerBlock
               MaskingSpec>;   // MaskingSpecialization
   
           // param
           ck::index_t M = q->shape[1];
           ck::index_t N = k->shape[1];
           ck::index_t K = 128;
           ck::index_t O = 128;
           ck::index_t G0 = 1;
           ck::index_t G1 = 40;
           float scale = 0.088388348;
   
           auto a_element_op    = AElementOp{};
           auto b0_element_op   = B0ElementOp{};
           auto acc0_element_op = Acc0ElementOp{scale};
           auto b1_element_op   = B1ElementOp{};
           auto c_element_op    = CElementOp{};
   
           bool input_permute  = true;
           bool output_permute = true;
   
           std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
           std::vector<ck::index_t> a_gs_ms_ks_strides =
               input_permute
                   ? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A 
layout [G0, M, G1, K]
                   : std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A 
layout [G0, G1, M, K]
   
           std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
           std::vector<ck::index_t> b0_gs_ns_ks_strides =
               input_permute
                   ? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 
layout [G0, N, G1, K]
                   : std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 
layout [G0, G1, N, K]
   
           std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
           std::vector<ck::index_t> b1_gs_os_ns_strides =
               input_permute
                   ? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 
layout [G0, N, G1, O]
                   : std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 
layout [G0, G1, N, O]
   
           std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
           std::vector<ck::index_t> c_gs_ms_os_strides =
               output_permute
                   ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C 
layout [G0, M, G1, O]
                   : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C 
layout [G0, G1, M, O]
   
           auto gemm     = DeviceGemmInstance{};
           auto invoker  = gemm.MakeInvoker();
           
           auto argument = gemm.MakeArgument(
               static_cast<ADataType*>(q->data),
               static_cast<B0DataType*>(k->data),
               static_cast<B1DataType*>(v->data),
               static_cast<CDataType*>(out0->data),
               {}, // std::array<void*, 1> p_acc0_biases;
               {}, // std::array<void*, 1> p_acc1_biases;
               a_gs_ms_ks_lengths,
               a_gs_ms_ks_strides,
               b0_gs_ns_ks_lengths,
               b0_gs_ns_ks_strides,
               b1_gs_os_ns_lengths,
               b1_gs_os_ns_strides,
               c_gs_ms_os_lengths,
               c_gs_ms_os_strides,
               {}, // std::array<std::vector<ck::index_t>, 
1>(acc0_biases_gs_ms_ns_lengths),
               {}, // std::array<std::vector<ck::index_t>, 
1>(acc0_biases_gs_ms_ns_strides),
               {}, // std::array<std::vector<ck::index_t>, 
1>(acc1_biases_gs_ms_os_lengths),
               {}, // std::array<std::vector<ck::index_t>, 
1>(acc1_biases_gs_ms_os_strides),
               a_element_op,
               b0_element_op,
               acc0_element_op,
               b1_element_op,
               c_element_op);
   
           if(!gemm.IsSupportedArgument(argument))
           {
               std::cout << gemm.GetTypeString() << " does not support this 
problem" << std::endl;
               return;
           }
   
           float ave_time = invoker.Run(argument, StreamConfig{nullptr, false});
       }
   
       int fused_relax_nn_attention_composable_kernel1_wrapper_(DLTensor* arg0,
                                                               DLTensor* arg1,
                                                               DLTensor* arg2,
                                                               DLTensor* out0) {
           fused_relax_nn_attention_composable_kernel1_(arg0,
               arg1,
               arg2,
               out0);
               return 0;
       }
   
       #ifdef __cplusplus
       extern "C" {
       #endif
       TVM_DLL int32_t fused_relax_nn_attention_composable_kernel1(DLTensor* 
arg0,
                                                                   DLTensor* 
arg1,
                                                                   DLTensor* 
arg2,
                                                                   DLTensor* 
out0) {
           //DLTensor* arg0 = (DLTensor*)(((TVMValue*)args)[0].v_handle);
           //DLTensor* arg1 = (DLTensor*)(((TVMValue*)args)[1].v_handle);
           //DLTensor* arg2 = (DLTensor*)(((TVMValue*)args)[2].v_handle);
           //int arg3 = (int)((TVMValue*)args)[3];
           //int arg4 = (int)((TVMValue*)args)[4];
           //int arg5 = (int)((TVMValue*)args)[5];
           //int arg6 = (int)((TVMValue*)args)[6];
           //float arg7 = (float)((TVMValue*)args)[7];
           //DLTensor* arg8 = (DLTensor*)(((TVMValue*)args)[8].v_handle);
           //DLTensor* ret9 = (DLTensor*)(((TVMValue*)args)[9].v_handle);
           // 
fused_relax_nn_attention_composable_kernel1_wrapper_(arg0,arg1,arg2,arg3,arg4,arg5,arg6,arg7,arg8,ret9);
           fused_relax_nn_attention_composable_kernel1_wrapper_(arg0, arg1, 
arg2, out0);
           return 0;
       }
       #ifdef __cplusplus
       }
       #endif
   
       TVM_DLL_EXPORT_TYPED_FUNC(attention, 
fused_relax_nn_attention_composable_kernel1)
       """
       with open(f"{mod_dir}/{file}.cc", "w") as cc_file:
           cc_file.write(src)
       cur_dir = os.path.dirname(os.path.abspath(__file__))
       inc_dir = "/home/xxx/tvm-unity/3rdparty/composable_kernel"
       os.system(
           f"hipcc -c {mod_dir}/{file}.cc "
           f"-o {mod_dir}/{file}.o "
           f"-I {inc_dir}/include "
           f"-I /home/xxx/tvm-unity/include "
           f"-I /home/xxx/tvm-unity/3rdparty/dlpack/include "
           f"-I /home/xxx/tvm-unity/3rdparty/dmlc-core/include "
           f"-std=c++17 "
       )
       return f"{mod_dir}/{file}.o"
   
   
   def get_extern_module(shape_q, shape_k, shape_v, shape_out):
       dtype = "float16"
       tmp_dir = tempfile.mkdtemp()
       obj_file = _gen_extern_module(tmp_dir, "test")
       func_name = "attention"
   
       ext_mod = nn.ExternModule(
           module_spec=spec.ExternModuleSpec(
               filename=obj_file,
               functions=[
                   spec.ExternFunctionSpec(
                       symbol=func_name,
                       args=[
                           spec.Tensor(shape_q, dtype),
                           spec.Tensor(shape_k, dtype),
                           spec.Tensor(shape_v, dtype),
                       ],
                       ret=spec.Tensor(shape_out, dtype),
                   )
               ],
           )
       )
   
       class AttentionModule(nn.Module):
           def __init__(self) -> None:
               self.attention = ext_mod
   
           def forward(self, q, k, v):
               # query = nn.Tensor(_expr = q)
               # key = nn.Tensor(_expr = k)
               # value = nn.Tensor(_expr = v)
               # return self.attention.get_extern_func(func_name)(query, key, 
value)
               return self.attention.get_extern_func(func_name)(q, k, v)
   
       attention_mod = AttentionModule()
       return attention_mod`
   
   @cyx-6 thanks lot for your work, looking forward to your reply


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to