This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit c606a069c350e1ca77ce21145b94be4c35591d71 Author: Wuwei Lin <[email protected]> AuthorDate: Sun Dec 10 21:16:44 2023 +0000 kernel --- src/runtime/contrib/cutlass/moe_compute_rows.cu | 45 ++++++++++++ src/runtime/contrib/cutlass/moe_gemm.cc | 98 +++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/src/runtime/contrib/cutlass/moe_compute_rows.cu b/src/runtime/contrib/cutlass/moe_compute_rows.cu new file mode 100644 index 0000000000..07b408baff --- /dev/null +++ b/src/runtime/contrib/cutlass/moe_compute_rows.cu @@ -0,0 +1,45 @@ +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/moe_kernels.cu +__device__ inline int find_total_elts_leq_target(const int* sorted_indices, const int arr_length, const int target) +{ + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] > target) { + high = mid - 1; + } + else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} +__global__ void compute_total_rows_before_expert_kernel(const int* sorted_experts, + const int sorted_experts_len, + const int64_t num_experts, + int64_t* total_rows_before_expert) +{ + + // First, compute the global tid. We only need 1 thread per expert. + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + if (expert >= num_experts) + return; + + // This should construct the last index where each expert occurs. + total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); +} + +void compute_total_rows_before_expert(const int* sorted_indices, + const int total_indices, + const int num_experts, + int64_t* total_rows_before_expert, + cudaStream_t stream) +{ + + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>( + sorted_indices, total_indices, num_experts, total_rows_before_expert); +} diff --git a/src/runtime/contrib/cutlass/moe_gemm.cc b/src/runtime/contrib/cutlass/moe_gemm.cc new file mode 100644 index 0000000000..a1796991c3 --- /dev/null +++ b/src/runtime/contrib/cutlass/moe_gemm.cc @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cuda.h> +#include <cuda_fp16.h> +#include <cuda_runtime.h> +#include <tvm/runtime/ndarray.h> +#include <tvm/runtime/packed_func.h> +#include <tvm/runtime/registry.h> + +#include <optional> +#include <string> + +#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/half.h" +// clang-format off +// theses headers can't be reordered +#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/numeric_types.h" +#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/integer_subbyte.h" +// clang-format on +#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +void compute_total_rows_before_expert(const int* sorted_indices, const int total_indices, + const int num_experts, int64_t* total_rows_before_expert, + cudaStream_t stream); + +namespace fastertransformer { + +template <typename T, typename WeightType> +void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, + T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, std::optional<std::string> activation, + cudaStream_t stream); +} + +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16") + .set_body_typed([](NDArray x, NDArray weight, NDArray total_rows_before_expert, + int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) { + LOG(INFO) << "GEMM MOE F16F16"; + LOG(INFO) << "x: " << x->data << " weight: " << weight->data + << " total_rows_before_expert: " << total_rows_before_expert->data + << " total_rows: " << total_rows << " n: " << n << " k: " << k + << " num_experts: " << num_experts << " out: " << out->data; + // using half = cutlass::half_t; + fastertransformer::moe_gemm_bias_act<half, half>( + reinterpret_cast<half*>(x->data), reinterpret_cast<half*>(weight->data), nullptr, nullptr, + reinterpret_cast<half*>(out->data), + reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts, + std::nullopt, + /*stream=*/nullptr /*FIXME*/); + LOG(INFO) << "MOE OK"; + }); + +TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16") + .set_body_typed([](NDArray x, NDArray weight, NDArray scales, NDArray total_rows_before_expert, + int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) { + fastertransformer::moe_gemm_bias_act<half, cutlass::uint4b_t>( + reinterpret_cast<half*>(x->data), reinterpret_cast<cutlass::uint4b_t*>(weight->data), + reinterpret_cast<half*>(scales->data), nullptr, reinterpret_cast<half*>(out->data), + reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts, + std::nullopt, + /*stream=*/nullptr /*FIXME*/); + }); + +TVM_REGISTER_GLOBAL("moe_compute_rows_before") + .set_body_typed([](NDArray sorted_indices, NDArray total_rows_before_expert) { + CHECK(sorted_indices->dtype.code == kDLInt && sorted_indices->dtype.bits == 32); + CHECK(total_rows_before_expert->dtype.code == kDLInt && + total_rows_before_expert->dtype.bits == 64); + CHECK(sorted_indices->ndim == 1); + CHECK(total_rows_before_expert->ndim == 1); + + int num_experts = total_rows_before_expert->shape[0]; + compute_total_rows_before_expert( + reinterpret_cast<int*>(sorted_indices->data), sorted_indices->shape[0], num_experts, + reinterpret_cast<int64_t*>(total_rows_before_expert->data), nullptr); + }); + +} // namespace runtime +} // namespace tvm
