This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit c606a069c350e1ca77ce21145b94be4c35591d71
Author: Wuwei Lin <[email protected]>
AuthorDate: Sun Dec 10 21:16:44 2023 +0000

    kernel
---
 src/runtime/contrib/cutlass/moe_compute_rows.cu | 45 ++++++++++++
 src/runtime/contrib/cutlass/moe_gemm.cc         | 98 +++++++++++++++++++++++++
 2 files changed, 143 insertions(+)

diff --git a/src/runtime/contrib/cutlass/moe_compute_rows.cu 
b/src/runtime/contrib/cutlass/moe_compute_rows.cu
new file mode 100644
index 0000000000..07b408baff
--- /dev/null
+++ b/src/runtime/contrib/cutlass/moe_compute_rows.cu
@@ -0,0 +1,45 @@
+// 
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/moe_kernels.cu
+__device__ inline int find_total_elts_leq_target(const int* sorted_indices, 
const int arr_length, const int target)
+{
+    int64_t low = 0, high = arr_length - 1, target_location = -1;
+    while (low <= high) {
+        int64_t mid = (low + high) / 2;
+
+        if (sorted_indices[mid] > target) {
+            high = mid - 1;
+        }
+        else {
+            low             = mid + 1;
+            target_location = mid;
+        }
+    }
+    return target_location + 1;
+}
+__global__ void compute_total_rows_before_expert_kernel(const int*    
sorted_experts,
+                                                        const int     
sorted_experts_len,
+                                                        const int64_t 
num_experts,
+                                                        int64_t*      
total_rows_before_expert)
+{
+
+    // First, compute the global tid. We only need 1 thread per expert.
+    const int expert = blockIdx.x * blockDim.x + threadIdx.x;
+    if (expert >= num_experts)
+        return;
+
+    // This should construct the last index where each expert occurs.
+    total_rows_before_expert[expert] = 
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
+}
+
+void compute_total_rows_before_expert(const int*   sorted_indices,
+                                      const int    total_indices,
+                                      const int    num_experts,
+                                      int64_t*     total_rows_before_expert,
+                                      cudaStream_t stream)
+{
+
+    const int threads = std::min(1024, num_experts);
+    const int blocks  = (num_experts + threads - 1) / threads;
+
+    compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
+        sorted_indices, total_indices, num_experts, total_rows_before_expert);
+}
diff --git a/src/runtime/contrib/cutlass/moe_gemm.cc 
b/src/runtime/contrib/cutlass/moe_gemm.cc
new file mode 100644
index 0000000000..a1796991c3
--- /dev/null
+++ b/src/runtime/contrib/cutlass/moe_gemm.cc
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <cuda.h>
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+
+#include <optional>
+#include <string>
+
+#include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/half.h"
+// clang-format off
+// theses headers can't be reordered
+#include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/numeric_types.h"
+#include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass/include/cutlass/integer_subbyte.h"
+// clang-format on
+#include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/moe_gemm/moe_gemm_kernels.h"
+
+void compute_total_rows_before_expert(const int* sorted_indices, const int 
total_indices,
+                                      const int num_experts, int64_t* 
total_rows_before_expert,
+                                      cudaStream_t stream);
+
+namespace fastertransformer {
+
+template <typename T, typename WeightType>
+void moe_gemm_bias_act(const T* A, const WeightType* B, const T* 
weight_scales, const T* biases,
+                       T* C, int64_t* total_rows_before_expert, int64_t 
total_rows, int64_t gemm_n,
+                       int64_t gemm_k, int num_experts, 
std::optional<std::string> activation,
+                       cudaStream_t stream);
+}
+
+namespace tvm {
+namespace runtime {
+
+TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16")
+    .set_body_typed([](NDArray x, NDArray weight, NDArray 
total_rows_before_expert,
+                       int64_t total_rows, int64_t n, int64_t k, int64_t 
num_experts, NDArray out) {
+      LOG(INFO) << "GEMM MOE F16F16";
+      LOG(INFO) << "x: " << x->data << " weight: " << weight->data
+                << " total_rows_before_expert: " << 
total_rows_before_expert->data
+                << " total_rows: " << total_rows << " n: " << n << " k: " << k
+                << " num_experts: " << num_experts << " out: " << out->data;
+      //   using half = cutlass::half_t;
+      fastertransformer::moe_gemm_bias_act<half, half>(
+          reinterpret_cast<half*>(x->data), 
reinterpret_cast<half*>(weight->data), nullptr, nullptr,
+          reinterpret_cast<half*>(out->data),
+          reinterpret_cast<int64_t*>(total_rows_before_expert->data), 
total_rows, n, k, num_experts,
+          std::nullopt,
+          /*stream=*/nullptr /*FIXME*/);
+      LOG(INFO) << "MOE OK";
+    });
+
+TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16")
+    .set_body_typed([](NDArray x, NDArray weight, NDArray scales, NDArray 
total_rows_before_expert,
+                       int64_t total_rows, int64_t n, int64_t k, int64_t 
num_experts, NDArray out) {
+      fastertransformer::moe_gemm_bias_act<half, cutlass::uint4b_t>(
+          reinterpret_cast<half*>(x->data), 
reinterpret_cast<cutlass::uint4b_t*>(weight->data),
+          reinterpret_cast<half*>(scales->data), nullptr, 
reinterpret_cast<half*>(out->data),
+          reinterpret_cast<int64_t*>(total_rows_before_expert->data), 
total_rows, n, k, num_experts,
+          std::nullopt,
+          /*stream=*/nullptr /*FIXME*/);
+    });
+
+TVM_REGISTER_GLOBAL("moe_compute_rows_before")
+    .set_body_typed([](NDArray sorted_indices, NDArray 
total_rows_before_expert) {
+      CHECK(sorted_indices->dtype.code == kDLInt && sorted_indices->dtype.bits 
== 32);
+      CHECK(total_rows_before_expert->dtype.code == kDLInt &&
+            total_rows_before_expert->dtype.bits == 64);
+      CHECK(sorted_indices->ndim == 1);
+      CHECK(total_rows_before_expert->ndim == 1);
+
+      int num_experts = total_rows_before_expert->shape[0];
+      compute_total_rows_before_expert(
+          reinterpret_cast<int*>(sorted_indices->data), 
sorted_indices->shape[0], num_experts,
+          reinterpret_cast<int64_t*>(total_rows_before_expert->data), nullptr);
+    });
+
+}  // namespace runtime
+}  // namespace tvm

Reply via email to