This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e54e04d520 [Unity][BYOC] Add cuBLAS backend (#14291)
e54e04d520 is described below

commit e54e04d520f889d9e528337830a56e607b233527
Author: masahi <[email protected]>
AuthorDate: Tue Apr 4 11:00:32 2023 +0900

    [Unity][BYOC] Add cuBLAS backend (#14291)
    
    * stub
    
    * fixed build
    
    * test stub
    
    * basic gemm working
    
    * transposed gemm work
    
    * wip
    
    * bias and epilogue work
    
    * support fp16 and transposed bias
    
    * support batched gemm
    
    * clean up
    
    * access arguments properly
    
    * expose ExtractArgIdx to python and use it in cutlass byoc
    
    * put matmul ir into common testing file
    
    * updated for the latest rev
    
    * pylint
---
 cmake/modules/CUDA.cmake                          |   4 +-
 python/tvm/contrib/cutlass/build.py               |  19 +--
 python/tvm/relax/backend/contrib/cublas.py        | 154 +++++++++++++++++++++
 python/tvm/relax/backend/contrib/cutlass.py       |  15 +--
 python/tvm/relax/testing/__init__.py              |   1 +
 python/tvm/relax/testing/matmul.py                |  66 +++++++++
 src/relax/backend/contrib/cublas/codegen.cc       | 110 +++++++++++++++
 src/relax/backend/contrib/utils.cc                |  68 ++++++++++
 src/relax/backend/contrib/utils.h                 |  13 ++
 src/runtime/contrib/cblas/gemm_common.h           |  16 ++-
 src/runtime/contrib/cublas/cublas.cc              | 118 +++++++++++++++-
 src/runtime/contrib/cublas/cublas_json_runtime.cc | 118 ++++++++++++++++
 src/runtime/contrib/cublas/cublas_utils.h         |   6 +
 tests/python/relax/test_codegen_cublas.py         | 156 ++++++++++++++++++++++
 tests/python/relax/test_codegen_cutlass.py        |  49 +------
 15 files changed, 825 insertions(+), 88 deletions(-)

diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index 96d5922e84..1502c4f8bc 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -50,8 +50,8 @@ if(USE_CUDA)
 
   if(USE_CUBLAS)
     message(STATUS "Build with cuBLAS support")
-    tvm_file_glob(GLOB CUBLAS_RELAY_CONTRIB_SRC 
src/relay/backend/contrib/cublas/*.cc)
-    list(APPEND COMPILER_SRCS ${CUBLAS_RELAY_CONTRIB_SRC})
+    tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC 
src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc)
+    list(APPEND COMPILER_SRCS ${CUBLAS_CONTRIB_SRC})
     tvm_file_glob(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc)
     list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS})
     list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY})
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 93d1331ac4..43494991a0 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -560,22 +560,9 @@ def _extract_relax_function_signature(f):
 
 
 def _extract_arg_idx(pattern_name, f):
-    pattern_entry = relax.backend.get_pattern(pattern_name)
-    if pattern_entry is None:
-        raise ValueError(f"Unsupported op_type {pattern_name}")
-    var2val = relax.analysis.get_var2val(f)
-    matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body, 
var2val)
-
-    func_args = list(f.params)
-
-    arg_idx = {}
-    for name, annotation_pattern in pattern_entry.annotation_patterns.items():
-        arg_expr = matched_expr[annotation_pattern]
-        if arg_expr not in func_args:
-            continue
-        arg_idx[name] = func_args.index(arg_expr)
-
-    return arg_idx
+    extract_func = tvm.get_global_func("relax.contrib.extract_arg_idx")
+    arg_indices = extract_func(pattern_name, f)
+    return {k: int(v) for k, v in arg_indices.items()}
 
 
 def is_shape_valid_for_cutlass_matmul(
diff --git a/python/tvm/relax/backend/contrib/cublas.py 
b/python/tvm/relax/backend/contrib/cublas.py
new file mode 100644
index 0000000000..627c936993
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for cuBLAS backend"""
+import operator
+from functools import reduce
+
+import tvm
+from tvm.relax import transform
+from tvm.relax.transform import PatternCheckContext
+
+from ..pattern_registry import get_patterns_with_prefix, register_patterns
+from ..patterns import make_matmul_pattern
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+    """Check if dtypes in the given workload are supported by cuBLAS BYOC."""
+    return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
+        lhs_dtype == "float32" and rhs_dtype == "float32"
+    )
+
+
+def _check_matmul(context: PatternCheckContext) -> bool:
+    lhs = context.annotated_expr["lhs"]
+    rhs = context.annotated_expr["rhs"]
+
+    lhs_dtype = lhs.struct_info.dtype
+    rhs_dtype = rhs.struct_info.dtype
+    if not _is_supported_dtype(lhs_dtype, rhs_dtype):
+        return False
+
+    lhs_shape = lhs.struct_info.shape.values
+    rhs_shape = rhs.struct_info.shape.values
+
+    if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
+        # Reduction axis must be constant
+        return False
+
+    lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+    rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+
+    # cuBLASLt does not seem to support batched GEMM with one of matrices 
having
+    # one batch (with batch_stride 0). So for batched GEMM, the two batch 
counts
+    # must be equal.
+    return (
+        (lhs_batches == 1 and rhs_batches == 1)
+        or isinstance(lhs_batches, tvm.tir.Var)
+        or isinstance(rhs_batches, tvm.tir.Var)
+        or (int(lhs_batches) == int(rhs_batches))
+    )
+
+
+register_patterns(
+    [
+        (
+            "cublas.matmul",
+            *make_matmul_pattern(
+                with_bias=False,
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_bias",
+            *make_matmul_pattern(
+                with_bias=True,
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_bias_relu",
+            *make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.relu",
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_bias_gelu",
+            *make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.gelu",
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_transposed",
+            *make_matmul_pattern(
+                with_bias=False,
+                transposed_rhs=True,
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_transposed_bias",
+            *make_matmul_pattern(
+                with_bias=True,
+                transposed_rhs=True,
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_transposed_bias_relu",
+            *make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.relu",
+                transposed_rhs=True,
+            ),
+            _check_matmul,
+        ),
+        (
+            "cublas.matmul_transposed_bias_gelu",
+            *make_matmul_pattern(
+                with_bias=True,
+                activation="relax.nn.gelu",
+                transposed_rhs=True,
+            ),
+            _check_matmul,
+        ),
+    ]
+)
+
+
+def partition_for_cublas(mod):
+    """
+    Partition the input module into cuBLAS-supported subgraphs.
+
+    Parameters
+    ----------
+    mod: tvm.IRModule
+        The IRModule to be partitioned.
+
+    Returns
+    -------
+    mod: tvm.IRModule
+        The resulting IRModule, containing partitioned subgraphs to be
+        offloaded to the cuBLAS backend.
+    """
+
+    patterns = get_patterns_with_prefix("cublas")
+    return transform.FuseOpsByPattern(patterns, bind_constants=False, 
annotate_codegen=True)(mod)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index c03c913d63..856cd4d787 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,11 +17,10 @@
 
 """Pattern table for CUTLASS backend"""
 
-from typing import Mapping, Optional, Sequence, Tuple
+from typing import Mapping, Sequence
 
-import tvm
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, ShapeExpr, Var, transform
+from tvm.relax import DataflowVar, Var, transform
 from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
@@ -33,16 +32,6 @@ from ..patterns import (
 )
 
 
-def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
-    result = []
-    for dim in shape.values:
-        if isinstance(dim, tvm.tir.expr.IntImm):
-            result.append(int(dim))
-        else:
-            return None
-    return result
-
-
 def _is_supported_dtype(lhs_dtype, rhs_dtype):
     """Check if dtypes in the given workload are supported by CUTLASS."""
     return (
diff --git a/python/tvm/relax/testing/__init__.py 
b/python/tvm/relax/testing/__init__.py
index a6e3a94251..4256ebc3be 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -20,3 +20,4 @@
 from .nn import *
 from .relay_translator import *
 from .ast_printer import dump_ast
+from .matmul import *
diff --git a/python/tvm/relax/testing/matmul.py 
b/python/tvm/relax/testing/matmul.py
new file mode 100644
index 0000000000..bac6fc6c9a
--- /dev/null
+++ b/python/tvm/relax/testing/matmul.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities to construct matmul workloads."""
+import tvm
+from tvm.script import relax as R
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def get_relax_matmul_module(
+    x_shape,
+    y_shape,
+    dtype,
+    transposed_y=False,
+    with_bias=False,
+    activation=None,
+    residual_bin_op=None,
+    residual_activation=None,
+):
+    """Create a matmul op followd by epilogue operations."""
+    if transposed_y:
+        n = y_shape[-2]
+    else:
+        n = y_shape[-1]
+
+    with IRBuilder() as builder:
+        with relax_builder.function():
+            R.func_name("main")
+            x = R.arg("x", R.Tensor(x_shape, dtype))
+            y = R.arg("y", R.Tensor(y_shape, dtype))
+            if with_bias:
+                bias = R.arg("bias", R.Tensor((n,), dtype))
+
+            with R.dataflow() as frame:
+                if transposed_y:
+                    axes = list(range(len(y_shape) - 2)) + [-1, -2]
+                    y = R.emit(R.permute_dims(y, axes=axes))
+                result = R.emit(R.matmul(x, y, out_dtype=dtype))
+                if with_bias:
+                    result = R.emit(result + bias)
+                if activation is not None:
+                    result = R.emit(activation(result))
+                if residual_bin_op is not None:
+                    result = R.emit(residual_bin_op(result, x))
+                    if residual_activation is not None:
+                        result = R.emit(residual_activation(result))
+                R.output(result)
+
+            R.func_ret_value(frame.output_vars[0])
+
+    func = builder.get()
+    return tvm.IRModule({"main": func})
diff --git a/src/relax/backend/contrib/cublas/codegen.cc 
b/src/relax/backend/contrib/cublas/codegen.cc
new file mode 100644
index 0000000000..e573d9a123
--- /dev/null
+++ b/src/relax/backend/contrib/cublas/codegen.cc
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/contrib/cublas/codegen.cc
+ * \brief Implementation of the CUBLAS JSON serializer.
+ */
+#include <tvm/ir/module.h>
+
+#include <string>
+
+#include "../codegen_json/codegen_json.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+namespace contrib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+using JSONSerializer = backend::contrib::JSONSerializer;
+using backend::contrib::NodeEntries;
+
+class CublasJSONSerializer : public JSONSerializer {
+ public:
+  CublasJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr> 
bindings)
+      : JSONSerializer(constant_names), bindings_(bindings) {}
+
+  using JSONSerializer::VisitExpr_;
+
+  NodeEntries VisitExpr_(const CallNode* call_node) final {
+    const auto* fn_var = call_node->op.as<VarNode>();
+    ICHECK(fn_var);
+    const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
+    ICHECK(fn.defined()) << "Expects the callee to be a function.";
+
+    auto composite_opt = fn->GetAttr<String>(attr::kComposite);
+    ICHECK(composite_opt.defined()) << "Only composite functions are 
supported.";
+
+    std::string composite_name = composite_opt.value();
+
+    NodeEntries inputs_tmp;
+    for (const auto& arg : call_node->args) {
+      auto res = VisitExpr(arg);
+      inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
+    }
+
+    ICHECK(inputs_tmp.size() <= 3);
+    NodeEntries inputs(inputs_tmp.size());
+
+    auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
+    inputs[0] = inputs_tmp[arg_idx["lhs"]->value];
+    inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
+    if (inputs_tmp.size() == 3) {
+      inputs[2] = inputs_tmp[arg_idx["bias"]->value];
+    }
+
+    auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
+                                                "kernel",       /* op_type_ */
+                                                inputs, 1 /* num_outputs_ */);
+
+    const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul");
+    SetCallNodeAttribute(node, root_call);
+    return AddNode(node, GetRef<Expr>(call_node));
+  }
+
+ private:
+  /*! \brief The bindings to look up composite functions. */
+  Map<Var, Expr> bindings_;
+};
+
+Array<runtime::Module> CublasCompiler(Array<Function> functions, Map<String, 
ObjectRef> /*unused*/,
+                                      Map<Constant, String> constant_names) {
+  Array<runtime::Module> compiled_functions;
+
+  for (const auto& func : functions) {
+    CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
+    serializer.serialize(func);
+    auto graph_json = serializer.GetJSON();
+    auto constant_names = serializer.GetConstantNames();
+    const auto* pf = runtime::Registry::Get("runtime.CublasJSONRuntimeCreate");
+    ICHECK(pf != nullptr) << "Cannot find CUBLAS runtime module create 
function.";
+    auto func_name = GetExtSymbol(func);
+    compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
+  }
+
+  return compiled_functions;
+}
+
+TVM_REGISTER_GLOBAL("relax.ext.cublas").set_body_typed(CublasCompiler);
+
+}  // namespace contrib
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/backend/contrib/utils.cc 
b/src/relax/backend/contrib/utils.cc
new file mode 100644
index 0000000000..565b7769f0
--- /dev/null
+++ b/src/relax/backend/contrib/utils.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "utils.h"
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/expr.h>
+
+#include <optional>
+
+#include "../pattern_registry.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+Map<String, IntImm> ExtractArgIdx(String pattern_name, Function f) {
+  Map<String, IntImm> arg_idx;
+  auto pattern = backend::GetPattern(pattern_name);
+  ICHECK(pattern) << "Unsupported op_type " << pattern_name;
+
+  auto bindings = AnalyzeVar2Value(f);
+  auto inner_body = Downcast<SeqExpr>(f->body)->body;
+  auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern, 
inner_body, bindings);
+  ICHECK(matched_expr);
+
+  auto find_index = [](const Array<Var>& params, Var v) -> 
std::optional<size_t> {
+    for (size_t i = 0; i < params.size(); ++i) {
+      if (params[i] == v) {
+        return i;
+      }
+    }
+    return std::nullopt;
+  };
+
+  for (const auto& [name, pat] : pattern.value()->annotation_patterns) {
+    auto exp = matched_expr.value()[pat];
+    if (auto arg_var = exp.as<VarNode>()) {
+      if (auto idx = find_index(f->params, GetRef<Var>(arg_var))) {
+        arg_idx.Set(name, IntImm(DataType::Int(64), *idx));
+      }
+    }
+  }
+
+  return arg_idx;
+}
+
+TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx);
+
+}  // namespace backend
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/backend/contrib/utils.h 
b/src/relax/backend/contrib/utils.h
index 4190ad66b6..ee1240aaed 100644
--- a/src/relax/backend/contrib/utils.h
+++ b/src/relax/backend/contrib/utils.h
@@ -120,6 +120,19 @@ inline const CallNode* GetOpInFunction(Function f, const 
std::string& op_name) {
   return nullptr;
 }
 
+/*!
+ * \brief Extract indices of the argument patterns in the function parameter 
list.
+ * Each composite function pattern can register a mapping between variable 
names and the
+ * corresponding patterns. This function tells at which index a given parameter
+ * in the function pattern, identified by its name, appears in the partitioned 
function parameter
+ * list.
+ * \param pattern_name The name the composite function pattern.
+ * \param f The function partitioned according to the function pattern.
+ * \return A mapping between variable pattern names and their positions in the 
partitioned
+ * function parameter list.
+ */
+Map<String, IntImm> ExtractArgIdx(String pattern_name, Function f);
+
 }  // namespace backend
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/runtime/contrib/cblas/gemm_common.h 
b/src/runtime/contrib/cblas/gemm_common.h
index 4724b14bff..af073da9ba 100644
--- a/src/runtime/contrib/cblas/gemm_common.h
+++ b/src/runtime/contrib/cblas/gemm_common.h
@@ -35,7 +35,7 @@ namespace tvm {
 namespace contrib {
 
 using namespace runtime;
-inline int ColumnStride(DLTensor* tensor) {
+inline int ColumnStride(const DLTensor* tensor) {
   // If the tensor itself is transposed then it will have strides
   // backward from what we expect.  Regardless, the max of the strides
   // (the other stride is 1) is the column stride.
@@ -46,7 +46,7 @@ inline int ColumnStride(DLTensor* tensor) {
   }
 }
 
-inline int ElementStride(DLTensor* tensor) {
+inline int ElementStride(const DLTensor* tensor) {
   if (tensor->strides) {
     return std::min(tensor->strides[0], tensor->strides[1]);
   } else {
@@ -55,13 +55,17 @@ inline int ElementStride(DLTensor* tensor) {
 }
 
 // Reversed strides indicates an in-place transpose operation.
-inline bool IsInPlaceTransposed(DLTensor* tensor) {
+inline bool IsInPlaceTransposed(const DLTensor* tensor) {
   return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
 }
 
-inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans 
? 1 : 0]; }
+inline int RowCount(const DLTensor* tensor, bool trans, int batch_offset = 0) {
+  return tensor->shape[batch_offset + (trans ? 1 : 0)];
+}
 
-inline int ColumnCount(DLTensor* tensor, bool trans) { return 
tensor->shape[trans ? 0 : 1]; }
+inline int ColumnCount(const DLTensor* tensor, bool trans, int batch_offset = 
0) {
+  return tensor->shape[batch_offset + (trans ? 0 : 1)];
+}
 
 // Call a column major blas.  Note that data is stored in tvm as row
 // major, so this we switch the arguments.
@@ -159,7 +163,7 @@ inline int ColumnStride3D(DLTensor* tensor) {
     return tensor->shape[2];
   }
 }
-inline int ElementStride3D(DLTensor* tensor) {
+inline int ElementStride3D(const DLTensor* tensor) {
   if (tensor->strides) {
     return std::min(tensor->strides[1], tensor->strides[2]);
   } else {
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index ee0f50e349..b49f15008c 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -24,6 +24,7 @@
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/registry.h>
 
+#include "../../3rdparty/compiler-rt/builtin_fp16.h"
 #include "../cblas/gemm_common.h"
 #include "cublas_utils.h"
 
@@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype, 
DLDataType out_dtype, bool int_s
 int roundoff(int v, int d) { return (v + d - 1) / d * d; }
 
 #if CUDART_VERSION >= 10010
+
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, 
const DLTensor* bias,
+                  const DLTensor* C, bool transa, bool transb, 
cublasLtEpilogue_t epilogue) {
+  ICHECK(TypeEqual(A->dtype, B->dtype));
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed(A) ? !transa : transa;
+  transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+  auto compute_type = CUBLAS_COMPUTE_32F;
+  auto scale_type = CUDA_R_32F;
+  cudaDataType_t ab_type = CUDA_R_32F;
+  cudaDataType_t c_type = CUDA_R_32F;
+  float one_fp32 = 1.0;
+  float zero_fp32 = 0.0;
+  auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(1.0);
+  auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(0.0);
+  void* alpha = &one_fp32;
+  void* beta = &zero_fp32;
+
+  if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) {
+    ab_type = CUDA_R_16F;
+  }
+
+  if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) {
+    c_type = CUDA_R_16F;
+    compute_type = CUBLAS_COMPUTE_16F;
+    scale_type = CUDA_R_16F;
+    alpha = &one_fp16;
+    beta = &zero_fp16;
+  }
+
+  cublasLtMatmulDesc_t op_desc;
+  cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
+  cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, 
scale_type));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSA,
+                                                    &op_transb, 
sizeof(op_transa)));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSB,
+                                                    &op_transa, 
sizeof(op_transb)));
+
+  if (bias != nullptr) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+                                                      &bias->data, 
sizeof(float*)));
+  }
+
+  if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_EPILOGUE,
+                                                      &epilogue, 
sizeof(epilogue)));
+  }
+
+  int batch_offset_A = A->ndim - 2;
+  int batch_offset_B = B->ndim - 2;
+
+  int M = ColumnCount(B, transb, batch_offset_B);
+  int N = RowCount(A, transa, batch_offset_A);
+  int K = ColumnCount(A, transa, batch_offset_A);
+
+  int lda = transb ? K : M;
+  int ldb = transa ? N : K;
+  int ldc = M;
+
+  cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? 
K : M, lda));
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? 
N : K, ldb));
+  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));
+
+  if (A->ndim != 2 || B->ndim != 2) {
+    auto get_batch_count = [](int64_t* shape, int batch_offset) {
+      int64_t count = 1;
+      for (int i = 0; i < batch_offset; ++i) {
+        count *= shape[i];
+      }
+      return count;
+    };
+    auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, 
int64_t batch_stride) {
+      CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
+          mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, 
sizeof(batch_count)));
+      CHECK_CUBLAS_ERROR(
+          cublasLtMatrixLayoutSetAttribute(mat_desc, 
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+                                           &batch_stride, 
sizeof(batch_stride)));
+    };
+
+    int batch_count_A = get_batch_count(A->shape, batch_offset_A);
+    int batch_count_B = get_batch_count(B->shape, batch_offset_B);
+    int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
+    int64_t batch_stride_A = M * K;
+    int64_t batch_stride_B = K * N;
+    int64_t batch_stride_C = M * N;
+
+    // cuBLASLt does not seem to support batched GEMM with one of matrices 
having
+    // one batch (with batch_stride 0).
+    ICHECK_EQ(batch_count_A, batch_count_B);
+
+    set_batch(A_desc, batch_count_A, batch_stride_A);
+    set_batch(B_desc, batch_count_B, batch_stride_B);
+    set_batch(C_desc, batch_count_C, batch_stride_C);
+  }
+
+  auto A_data = static_cast<char*>(A->data) + A->byte_offset;
+  auto B_data = static_cast<char*>(B->data) + B->byte_offset;
+  auto C_data = static_cast<char*>(C->data) + C->byte_offset;
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, 
A_data, B_desc, beta,
+                                    C_data, C_desc, C_data, C_desc, nullptr, 
nullptr, 0, nullptr));
+
+  cublasLtMatmulDescDestroy(op_desc);
+  cublasLtMatrixLayoutDestroy(A_desc);
+  cublasLtMatrixLayoutDestroy(B_desc);
+  cublasLtMatrixLayoutDestroy(C_desc);
+}
+
 inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) {
   DLTensor* A = args[0];
   DLTensor* B = args[1];
@@ -172,7 +287,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, 
cublasLtHandle_t hdl) {
   auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + 
B->byte_offset);
   auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + 
C->byte_offset);
 
-  cublasOperation_t opTranspose = CUBLAS_OP_T;
   cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
   cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
   cublasLtMatmulDesc_t operationDesc = nullptr;
@@ -181,8 +295,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, 
cublasLtHandle_t hdl) {
 #else
   CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
 #endif
-  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, 
CUBLASLT_MATMUL_DESC_TRANSB,
-                                                    &opTranspose, 
sizeof(opTranspose)));
   cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa);
   cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb);
   CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, 
CUBLASLT_MATMUL_DESC_TRANSA,
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
new file mode 100644
index 0000000000..8afccb2730
--- /dev/null
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cublas/cublas_json_runtime.cc
+ * \brief A simple JSON runtime for CUBLAS.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <regex>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "cublas_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::json;
+
+class CublasJSONRuntime : public JSONRuntimeBase {
+ public:
+  CublasJSONRuntime(const std::string& symbol_name, const std::string& 
graph_json,
+                    const Array<String> const_names)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+  void Init(const Array<NDArray>& consts) override {}
+
+  void Run() override {
+    // TODO(masahi): Reuse the same handle across different subgraphs
+    cublasLtHandle_t handle;
+    cublasLtCreate(&handle);
+
+    for (size_t i = 0; i < nodes_.size(); ++i) {
+      const auto& node = nodes_[i];
+      if (node.GetOpType() == "kernel") {
+        auto op_name = node.GetOpName();
+        uint32_t output_eid = EntryID(outputs_[0]);
+        auto out_ptr = data_entry_[output_eid];
+        bool transa = false;
+        bool transb = false;
+        cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+
+        if (op_name.find("transposed") != std::string::npos) {
+          transb = true;
+        }
+
+        if (op_name.find("relu") != std::string::npos) {
+          epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
+        } else if (op_name.find("gelu") != std::string::npos) {
+          epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
+        } else if (op_name.find("bias") != std::string::npos) {
+          epilogue = CUBLASLT_EPILOGUE_BIAS;
+        }
+
+        auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
+          const DLTensor* bias = nullptr;
+          if (has_bias) {
+            bias = GetInput(node, 2);
+          }
+          return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
+        };
+
+        auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != 
CUBLASLT_EPILOGUE_DEFAULT);
+
+        tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, 
transa, transb,
+                                   epilogue);
+      }
+    }
+    cublasLtDestroy(handle);
+  }
+
+ private:
+  const DLTensor* GetInput(const JSONGraphNode& node, const int idx) {
+    ICHECK_LT(idx, node.GetInputs().size());
+    auto eid = EntryID(node.GetInputs()[idx]);
+    ICHECK(eid < data_entry_.size());
+    return data_entry_[eid];
+  }
+};
+
+runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json,
+                                        const Array<String>& const_names) {
+  auto n = make_object<CublasJSONRuntime>(symbol_name, graph_json, 
const_names);
+  return runtime::Module(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate);
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json")
+    .set_body_typed(JSONRuntimeBase::LoadFromBinary<CublasJSONRuntime>);
+
+}  // namespace contrib
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/runtime/contrib/cublas/cublas_utils.h 
b/src/runtime/contrib/cublas/cublas_utils.h
index 62863b8f7b..ac03b12103 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -104,6 +104,12 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
   }
   LOG(FATAL) << "Unsupported cuda type";
 }
+
+/*! \brief Execute matrix multiply followed by the specified epilogue, using 
cuBLASLt. */
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, 
const DLTensor* bias,
+                  const DLTensor* C, bool transa, bool transb,
+                  cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT);
+
 }  // namespace contrib
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_codegen_cublas.py 
b/tests/python/relax/test_codegen_cublas.py
new file mode 100644
index 0000000000..023054256e
--- /dev/null
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.topi.testing
+from tvm import relax
+from tvm.relax.backend.contrib.cublas import partition_for_cublas
+from tvm.relax.testing import get_relax_matmul_module
+from tvm.script import relax as R
+
+
[email protected](autouse=True)
+def reset_seed():
+    np.random.seed(0)
+
+
+has_cublas = tvm.get_global_func("relax.ext.cublas", True)
+
+cublas_enabled = pytest.mark.skipif(
+    not has_cublas,
+    reason="CUBLAS not enabled.",
+)
+
+pytestmark = [cublas_enabled]
+
+
+def build_and_run(mod, inputs_np, target, legalize=False):
+    if legalize:
+        mod = relax.transform.LegalizeOps()(mod)
+
+    dev = tvm.device(target, 0)
+    ex = relax.build(mod, target)
+    vm = relax.VirtualMachine(ex, dev)
+    f = vm["main"]
+    inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+    return f(*inputs).numpy()
+
+
+def get_result_with_relax_cublas_offload(mod, *args):
+    mod = partition_for_cublas(mod)
+    mod = relax.transform.RunCodegen()(mod)
+
+    return build_and_run(mod, args, "cuda")
+
+
+def _to_concrete_shape(symbolic_shape, var_table):
+    result = []
+    for dim in symbolic_shape:
+        if not isinstance(dim, tvm.tir.expr.Var):
+            result.append(dim)
+            continue
+
+        if dim not in var_table:
+            var_table[dim] = np.random.randint(10, 50)
+        result.append(var_table[dim])
+
+    return tuple(result)
+
+
+_vars = {
+    "a": tvm.tir.expr.Var("a", "int64"),
+    "b": tvm.tir.expr.Var("b", "int64"),
+}
+
+
+_epilogue_table = {
+    "none": (False, None),
+    "bias": (True, None),
+    "relu": (True, R.nn.relu),
+    "gelu": (True, R.nn.gelu),
+}
+
+
[email protected](
+    "x_shape, y_shape, transpose_y, epilogue",
+    [
+        # Regular
+        ((8, 8), (8, 8), False, "none"),
+        ((_vars["a"], 6), (6, 16), False, "bias"),
+        # Transposed
+        ((4, 16), (16, 128), True, "relu"),
+        ((35, 8), (8, 8), True, "gelu"),
+        # # 3D x 3D
+        ((6, 32, 8), (6, 8, 10), False, "bias"),
+        ((6, 32, 8), (6, 8, 10), True, "none"),
+        ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"),
+        # ND x ND
+        ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"),
+    ],
+)
[email protected](
+    "dtype",
+    [
+        "float16",
+        "float32",
+    ],
+)
+def test_matmul_offload(
+    x_shape,
+    y_shape,
+    transpose_y,
+    epilogue,
+    dtype,
+):
+    with_bias, activation = _epilogue_table[epilogue]
+    var_table = {}
+    concrete_x_shape = _to_concrete_shape(x_shape, var_table)
+    concrete_y_shape = _to_concrete_shape(y_shape, var_table)
+    x = np.random.randn(*concrete_x_shape).astype(dtype)
+    y = np.random.randn(*concrete_y_shape).astype(dtype)
+
+    if transpose_y:
+        y = np.swapaxes(y, -2, -1)
+        y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
+
+    if with_bias:
+        bias = np.random.randn(concrete_y_shape[-1]).astype(dtype)
+        args = (x, y, bias)
+    else:
+        bias = None
+        args = (x, y)
+
+    mod = get_relax_matmul_module(
+        x_shape,
+        y_shape,
+        dtype,
+        with_bias=with_bias,
+        transposed_y=transpose_y,
+        activation=activation,
+    )
+
+    out = get_result_with_relax_cublas_offload(mod, *args)
+    ref = build_and_run(mod, args, "llvm", legalize=True)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index c8ca44311d..b9ba4f4dc9 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -23,8 +23,8 @@ import tvm.topi.testing
 from tvm import relax
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
 from tvm.contrib.pickle_memoize import memoize
-from tvm.relax.backend import get_patterns_with_prefix
 from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
+from tvm.relax.testing import get_relax_matmul_module
 from tvm.script import relax as R
 from tvm.script.ir_builder import IRBuilder
 from tvm.script.ir_builder import relax as relax_builder
@@ -96,9 +96,6 @@ def build_and_run(mod, inputs_np, target, legalize=False):
 
 
 def get_result_with_relax_cutlass_offload(mod, *args, 
assert_all_bindings_fused=True):
-    patterns = [(entry.name, entry.pattern) for entry in 
get_patterns_with_prefix("cutlass")]
-    assert len(patterns) != 0, "Cannot find cutlass patterns"
-
     mod = partition_for_cutlass(mod)
 
     if assert_all_bindings_fused:
@@ -168,50 +165,6 @@ def get_relax_conv2d_module(
     return tvm.IRModule({"main": func})
 
 
-def get_relax_matmul_module(
-    x_shape,
-    y_shape,
-    dtype,
-    transposed_y=False,
-    with_bias=False,
-    activation=None,
-    residual_bin_op=None,
-    residual_activation=None,
-):
-    if transposed_y:
-        n = y_shape[-2]
-    else:
-        n = y_shape[-1]
-
-    with IRBuilder() as builder:
-        with relax_builder.function():
-            R.func_name("main")
-            x = R.arg("x", R.Tensor(x_shape, dtype))
-            y = R.arg("y", R.Tensor(y_shape, dtype))
-            if with_bias:
-                bias = R.arg("bias", R.Tensor((n,), dtype))
-
-            with R.dataflow() as frame:
-                if transposed_y:
-                    axes = list(range(len(y_shape) - 2)) + [-1, -2]
-                    y = R.emit(R.permute_dims(y, axes=axes))
-                result = R.emit(R.matmul(x, y, out_dtype=dtype))
-                if with_bias:
-                    result = R.emit(result + bias)
-                if activation is not None:
-                    result = R.emit(activation(result))
-                if residual_bin_op is not None:
-                    result = R.emit(residual_bin_op(result, x))
-                    if residual_activation is not None:
-                        result = R.emit(residual_activation(result))
-                R.output(result)
-
-            R.func_ret_value(frame.output_vars[0])
-
-    func = builder.get()
-    return tvm.IRModule({"main": func})
-
-
 def _to_concrete_shape(symbolic_shape, var_table):
     result = []
     for dim in symbolic_shape:


Reply via email to