This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e54e04d520 [Unity][BYOC] Add cuBLAS backend (#14291)
e54e04d520 is described below
commit e54e04d520f889d9e528337830a56e607b233527
Author: masahi <[email protected]>
AuthorDate: Tue Apr 4 11:00:32 2023 +0900
[Unity][BYOC] Add cuBLAS backend (#14291)
* stub
* fixed build
* test stub
* basic gemm working
* transposed gemm work
* wip
* bias and epilogue work
* support fp16 and transposed bias
* support batched gemm
* clean up
* access arguments properly
* expose ExtractArgIdx to python and use it in cutlass byoc
* put matmul ir into common testing file
* updated for the latest rev
* pylint
---
cmake/modules/CUDA.cmake | 4 +-
python/tvm/contrib/cutlass/build.py | 19 +--
python/tvm/relax/backend/contrib/cublas.py | 154 +++++++++++++++++++++
python/tvm/relax/backend/contrib/cutlass.py | 15 +--
python/tvm/relax/testing/__init__.py | 1 +
python/tvm/relax/testing/matmul.py | 66 +++++++++
src/relax/backend/contrib/cublas/codegen.cc | 110 +++++++++++++++
src/relax/backend/contrib/utils.cc | 68 ++++++++++
src/relax/backend/contrib/utils.h | 13 ++
src/runtime/contrib/cblas/gemm_common.h | 16 ++-
src/runtime/contrib/cublas/cublas.cc | 118 +++++++++++++++-
src/runtime/contrib/cublas/cublas_json_runtime.cc | 118 ++++++++++++++++
src/runtime/contrib/cublas/cublas_utils.h | 6 +
tests/python/relax/test_codegen_cublas.py | 156 ++++++++++++++++++++++
tests/python/relax/test_codegen_cutlass.py | 49 +------
15 files changed, 825 insertions(+), 88 deletions(-)
diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index 96d5922e84..1502c4f8bc 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -50,8 +50,8 @@ if(USE_CUDA)
if(USE_CUBLAS)
message(STATUS "Build with cuBLAS support")
- tvm_file_glob(GLOB CUBLAS_RELAY_CONTRIB_SRC
src/relay/backend/contrib/cublas/*.cc)
- list(APPEND COMPILER_SRCS ${CUBLAS_RELAY_CONTRIB_SRC})
+ tvm_file_glob(GLOB CUBLAS_CONTRIB_SRC
src/relay/backend/contrib/cublas/*.cc src/relax/backend/contrib/cublas/*.cc)
+ list(APPEND COMPILER_SRCS ${CUBLAS_CONTRIB_SRC})
tvm_file_glob(GLOB CONTRIB_CUBLAS_SRCS src/runtime/contrib/cublas/*.cc)
list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY})
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 93d1331ac4..43494991a0 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -560,22 +560,9 @@ def _extract_relax_function_signature(f):
def _extract_arg_idx(pattern_name, f):
- pattern_entry = relax.backend.get_pattern(pattern_name)
- if pattern_entry is None:
- raise ValueError(f"Unsupported op_type {pattern_name}")
- var2val = relax.analysis.get_var2val(f)
- matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body,
var2val)
-
- func_args = list(f.params)
-
- arg_idx = {}
- for name, annotation_pattern in pattern_entry.annotation_patterns.items():
- arg_expr = matched_expr[annotation_pattern]
- if arg_expr not in func_args:
- continue
- arg_idx[name] = func_args.index(arg_expr)
-
- return arg_idx
+ extract_func = tvm.get_global_func("relax.contrib.extract_arg_idx")
+ arg_indices = extract_func(pattern_name, f)
+ return {k: int(v) for k, v in arg_indices.items()}
def is_shape_valid_for_cutlass_matmul(
diff --git a/python/tvm/relax/backend/contrib/cublas.py
b/python/tvm/relax/backend/contrib/cublas.py
new file mode 100644
index 0000000000..627c936993
--- /dev/null
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for cuBLAS backend"""
+import operator
+from functools import reduce
+
+import tvm
+from tvm.relax import transform
+from tvm.relax.transform import PatternCheckContext
+
+from ..pattern_registry import get_patterns_with_prefix, register_patterns
+from ..patterns import make_matmul_pattern
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+ """Check if dtypes in the given workload are supported by cuBLAS BYOC."""
+ return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
+ lhs_dtype == "float32" and rhs_dtype == "float32"
+ )
+
+
+def _check_matmul(context: PatternCheckContext) -> bool:
+ lhs = context.annotated_expr["lhs"]
+ rhs = context.annotated_expr["rhs"]
+
+ lhs_dtype = lhs.struct_info.dtype
+ rhs_dtype = rhs.struct_info.dtype
+ if not _is_supported_dtype(lhs_dtype, rhs_dtype):
+ return False
+
+ lhs_shape = lhs.struct_info.shape.values
+ rhs_shape = rhs.struct_info.shape.values
+
+ if not isinstance(lhs_shape[-1], (tvm.tir.expr.IntImm, int)):
+ # Reduction axis must be constant
+ return False
+
+ lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
+ rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
+
+ # cuBLASLt does not seem to support batched GEMM with one of matrices
having
+ # one batch (with batch_stride 0). So for batched GEMM, the two batch
counts
+ # must be equal.
+ return (
+ (lhs_batches == 1 and rhs_batches == 1)
+ or isinstance(lhs_batches, tvm.tir.Var)
+ or isinstance(rhs_batches, tvm.tir.Var)
+ or (int(lhs_batches) == int(rhs_batches))
+ )
+
+
+register_patterns(
+ [
+ (
+ "cublas.matmul",
+ *make_matmul_pattern(
+ with_bias=False,
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_bias",
+ *make_matmul_pattern(
+ with_bias=True,
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_bias_relu",
+ *make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.relu",
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_bias_gelu",
+ *make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.gelu",
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_transposed",
+ *make_matmul_pattern(
+ with_bias=False,
+ transposed_rhs=True,
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_transposed_bias",
+ *make_matmul_pattern(
+ with_bias=True,
+ transposed_rhs=True,
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_transposed_bias_relu",
+ *make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.relu",
+ transposed_rhs=True,
+ ),
+ _check_matmul,
+ ),
+ (
+ "cublas.matmul_transposed_bias_gelu",
+ *make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.gelu",
+ transposed_rhs=True,
+ ),
+ _check_matmul,
+ ),
+ ]
+)
+
+
+def partition_for_cublas(mod):
+ """
+ Partition the input module into cuBLAS-supported subgraphs.
+
+ Parameters
+ ----------
+ mod: tvm.IRModule
+ The IRModule to be partitioned.
+
+ Returns
+ -------
+ mod: tvm.IRModule
+ The resulting IRModule, containing partitioned subgraphs to be
+ offloaded to the cuBLAS backend.
+ """
+
+ patterns = get_patterns_with_prefix("cublas")
+ return transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True)(mod)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index c03c913d63..856cd4d787 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,11 +17,10 @@
"""Pattern table for CUTLASS backend"""
-from typing import Mapping, Optional, Sequence, Tuple
+from typing import Mapping, Sequence
-import tvm
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, ShapeExpr, Var, transform
+from tvm.relax import DataflowVar, Var, transform
from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
@@ -33,16 +32,6 @@ from ..patterns import (
)
-def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
- result = []
- for dim in shape.values:
- if isinstance(dim, tvm.tir.expr.IntImm):
- result.append(int(dim))
- else:
- return None
- return result
-
-
def _is_supported_dtype(lhs_dtype, rhs_dtype):
"""Check if dtypes in the given workload are supported by CUTLASS."""
return (
diff --git a/python/tvm/relax/testing/__init__.py
b/python/tvm/relax/testing/__init__.py
index a6e3a94251..4256ebc3be 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -20,3 +20,4 @@
from .nn import *
from .relay_translator import *
from .ast_printer import dump_ast
+from .matmul import *
diff --git a/python/tvm/relax/testing/matmul.py
b/python/tvm/relax/testing/matmul.py
new file mode 100644
index 0000000000..bac6fc6c9a
--- /dev/null
+++ b/python/tvm/relax/testing/matmul.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities to construct matmul workloads."""
+import tvm
+from tvm.script import relax as R
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def get_relax_matmul_module(
+ x_shape,
+ y_shape,
+ dtype,
+ transposed_y=False,
+ with_bias=False,
+ activation=None,
+ residual_bin_op=None,
+ residual_activation=None,
+):
+ """Create a matmul op followd by epilogue operations."""
+ if transposed_y:
+ n = y_shape[-2]
+ else:
+ n = y_shape[-1]
+
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ x = R.arg("x", R.Tensor(x_shape, dtype))
+ y = R.arg("y", R.Tensor(y_shape, dtype))
+ if with_bias:
+ bias = R.arg("bias", R.Tensor((n,), dtype))
+
+ with R.dataflow() as frame:
+ if transposed_y:
+ axes = list(range(len(y_shape) - 2)) + [-1, -2]
+ y = R.emit(R.permute_dims(y, axes=axes))
+ result = R.emit(R.matmul(x, y, out_dtype=dtype))
+ if with_bias:
+ result = R.emit(result + bias)
+ if activation is not None:
+ result = R.emit(activation(result))
+ if residual_bin_op is not None:
+ result = R.emit(residual_bin_op(result, x))
+ if residual_activation is not None:
+ result = R.emit(residual_activation(result))
+ R.output(result)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
diff --git a/src/relax/backend/contrib/cublas/codegen.cc
b/src/relax/backend/contrib/cublas/codegen.cc
new file mode 100644
index 0000000000..e573d9a123
--- /dev/null
+++ b/src/relax/backend/contrib/cublas/codegen.cc
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/contrib/cublas/codegen.cc
+ * \brief Implementation of the CUBLAS JSON serializer.
+ */
+#include <tvm/ir/module.h>
+
+#include <string>
+
+#include "../codegen_json/codegen_json.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+namespace contrib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+using JSONSerializer = backend::contrib::JSONSerializer;
+using backend::contrib::NodeEntries;
+
+class CublasJSONSerializer : public JSONSerializer {
+ public:
+ CublasJSONSerializer(Map<Constant, String> constant_names, Map<Var, Expr>
bindings)
+ : JSONSerializer(constant_names), bindings_(bindings) {}
+
+ using JSONSerializer::VisitExpr_;
+
+ NodeEntries VisitExpr_(const CallNode* call_node) final {
+ const auto* fn_var = call_node->op.as<VarNode>();
+ ICHECK(fn_var);
+ const auto fn = Downcast<Function>(bindings_[GetRef<Var>(fn_var)]);
+ ICHECK(fn.defined()) << "Expects the callee to be a function.";
+
+ auto composite_opt = fn->GetAttr<String>(attr::kComposite);
+ ICHECK(composite_opt.defined()) << "Only composite functions are
supported.";
+
+ std::string composite_name = composite_opt.value();
+
+ NodeEntries inputs_tmp;
+ for (const auto& arg : call_node->args) {
+ auto res = VisitExpr(arg);
+ inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
+ }
+
+ ICHECK(inputs_tmp.size() <= 3);
+ NodeEntries inputs(inputs_tmp.size());
+
+ auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
+ inputs[0] = inputs_tmp[arg_idx["lhs"]->value];
+ inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
+ if (inputs_tmp.size() == 3) {
+ inputs[2] = inputs_tmp[arg_idx["bias"]->value];
+ }
+
+ auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
+ "kernel", /* op_type_ */
+ inputs, 1 /* num_outputs_ */);
+
+ const CallNode* root_call = backend::GetOpInFunction(fn, "relax.matmul");
+ SetCallNodeAttribute(node, root_call);
+ return AddNode(node, GetRef<Expr>(call_node));
+ }
+
+ private:
+ /*! \brief The bindings to look up composite functions. */
+ Map<Var, Expr> bindings_;
+};
+
+Array<runtime::Module> CublasCompiler(Array<Function> functions, Map<String,
ObjectRef> /*unused*/,
+ Map<Constant, String> constant_names) {
+ Array<runtime::Module> compiled_functions;
+
+ for (const auto& func : functions) {
+ CublasJSONSerializer serializer(constant_names, AnalyzeVar2Value(func));
+ serializer.serialize(func);
+ auto graph_json = serializer.GetJSON();
+ auto constant_names = serializer.GetConstantNames();
+ const auto* pf = runtime::Registry::Get("runtime.CublasJSONRuntimeCreate");
+ ICHECK(pf != nullptr) << "Cannot find CUBLAS runtime module create
function.";
+ auto func_name = GetExtSymbol(func);
+ compiled_functions.push_back((*pf)(func_name, graph_json, constant_names));
+ }
+
+ return compiled_functions;
+}
+
+TVM_REGISTER_GLOBAL("relax.ext.cublas").set_body_typed(CublasCompiler);
+
+} // namespace contrib
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/backend/contrib/utils.cc
b/src/relax/backend/contrib/utils.cc
new file mode 100644
index 0000000000..565b7769f0
--- /dev/null
+++ b/src/relax/backend/contrib/utils.cc
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "utils.h"
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/expr.h>
+
+#include <optional>
+
+#include "../pattern_registry.h"
+
+namespace tvm {
+namespace relax {
+namespace backend {
+
+Map<String, IntImm> ExtractArgIdx(String pattern_name, Function f) {
+ Map<String, IntImm> arg_idx;
+ auto pattern = backend::GetPattern(pattern_name);
+ ICHECK(pattern) << "Unsupported op_type " << pattern_name;
+
+ auto bindings = AnalyzeVar2Value(f);
+ auto inner_body = Downcast<SeqExpr>(f->body)->body;
+ auto matched_expr = relax::ExtractMatchedExpr(pattern.value()->pattern,
inner_body, bindings);
+ ICHECK(matched_expr);
+
+ auto find_index = [](const Array<Var>& params, Var v) ->
std::optional<size_t> {
+ for (size_t i = 0; i < params.size(); ++i) {
+ if (params[i] == v) {
+ return i;
+ }
+ }
+ return std::nullopt;
+ };
+
+ for (const auto& [name, pat] : pattern.value()->annotation_patterns) {
+ auto exp = matched_expr.value()[pat];
+ if (auto arg_var = exp.as<VarNode>()) {
+ if (auto idx = find_index(f->params, GetRef<Var>(arg_var))) {
+ arg_idx.Set(name, IntImm(DataType::Int(64), *idx));
+ }
+ }
+ }
+
+ return arg_idx;
+}
+
+TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx);
+
+} // namespace backend
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/backend/contrib/utils.h
b/src/relax/backend/contrib/utils.h
index 4190ad66b6..ee1240aaed 100644
--- a/src/relax/backend/contrib/utils.h
+++ b/src/relax/backend/contrib/utils.h
@@ -120,6 +120,19 @@ inline const CallNode* GetOpInFunction(Function f, const
std::string& op_name) {
return nullptr;
}
+/*!
+ * \brief Extract indices of the argument patterns in the function parameter
list.
+ * Each composite function pattern can register a mapping between variable
names and the
+ * corresponding patterns. This function tells at which index a given parameter
+ * in the function pattern, identified by its name, appears in the partitioned
function parameter
+ * list.
+ * \param pattern_name The name the composite function pattern.
+ * \param f The function partitioned according to the function pattern.
+ * \return A mapping between variable pattern names and their positions in the
partitioned
+ * function parameter list.
+ */
+Map<String, IntImm> ExtractArgIdx(String pattern_name, Function f);
+
} // namespace backend
} // namespace relax
} // namespace tvm
diff --git a/src/runtime/contrib/cblas/gemm_common.h
b/src/runtime/contrib/cblas/gemm_common.h
index 4724b14bff..af073da9ba 100644
--- a/src/runtime/contrib/cblas/gemm_common.h
+++ b/src/runtime/contrib/cblas/gemm_common.h
@@ -35,7 +35,7 @@ namespace tvm {
namespace contrib {
using namespace runtime;
-inline int ColumnStride(DLTensor* tensor) {
+inline int ColumnStride(const DLTensor* tensor) {
// If the tensor itself is transposed then it will have strides
// backward from what we expect. Regardless, the max of the strides
// (the other stride is 1) is the column stride.
@@ -46,7 +46,7 @@ inline int ColumnStride(DLTensor* tensor) {
}
}
-inline int ElementStride(DLTensor* tensor) {
+inline int ElementStride(const DLTensor* tensor) {
if (tensor->strides) {
return std::min(tensor->strides[0], tensor->strides[1]);
} else {
@@ -55,13 +55,17 @@ inline int ElementStride(DLTensor* tensor) {
}
// Reversed strides indicates an in-place transpose operation.
-inline bool IsInPlaceTransposed(DLTensor* tensor) {
+inline bool IsInPlaceTransposed(const DLTensor* tensor) {
return tensor->strides && (tensor->strides[1] > tensor->strides[0]);
}
-inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans
? 1 : 0]; }
+inline int RowCount(const DLTensor* tensor, bool trans, int batch_offset = 0) {
+ return tensor->shape[batch_offset + (trans ? 1 : 0)];
+}
-inline int ColumnCount(DLTensor* tensor, bool trans) { return
tensor->shape[trans ? 0 : 1]; }
+inline int ColumnCount(const DLTensor* tensor, bool trans, int batch_offset =
0) {
+ return tensor->shape[batch_offset + (trans ? 0 : 1)];
+}
// Call a column major blas. Note that data is stored in tvm as row
// major, so this we switch the arguments.
@@ -159,7 +163,7 @@ inline int ColumnStride3D(DLTensor* tensor) {
return tensor->shape[2];
}
}
-inline int ElementStride3D(DLTensor* tensor) {
+inline int ElementStride3D(const DLTensor* tensor) {
if (tensor->strides) {
return std::min(tensor->strides[1], tensor->strides[2]);
} else {
diff --git a/src/runtime/contrib/cublas/cublas.cc
b/src/runtime/contrib/cublas/cublas.cc
index ee0f50e349..b49f15008c 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -24,6 +24,7 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
+#include "../../3rdparty/compiler-rt/builtin_fp16.h"
#include "../cblas/gemm_common.h"
#include "cublas_utils.h"
@@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype,
DLDataType out_dtype, bool int_s
int roundoff(int v, int d) { return (v + d - 1) / d * d; }
#if CUDART_VERSION >= 10010
+
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B,
const DLTensor* bias,
+ const DLTensor* C, bool transa, bool transb,
cublasLtEpilogue_t epilogue) {
+ ICHECK(TypeEqual(A->dtype, B->dtype));
+ // Reversed strides indicates an in-place transpose operation.
+ transa = IsInPlaceTransposed(A) ? !transa : transa;
+ transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+ auto compute_type = CUBLAS_COMPUTE_32F;
+ auto scale_type = CUDA_R_32F;
+ cudaDataType_t ab_type = CUDA_R_32F;
+ cudaDataType_t c_type = CUDA_R_32F;
+ float one_fp32 = 1.0;
+ float zero_fp32 = 0.0;
+ auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t,
10>(1.0);
+ auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t,
10>(0.0);
+ void* alpha = &one_fp32;
+ void* beta = &zero_fp32;
+
+ if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) {
+ ab_type = CUDA_R_16F;
+ }
+
+ if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) {
+ c_type = CUDA_R_16F;
+ compute_type = CUBLAS_COMPUTE_16F;
+ scale_type = CUDA_R_16F;
+ alpha = &one_fp16;
+ beta = &zero_fp16;
+ }
+
+ cublasLtMatmulDesc_t op_desc;
+ cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
+ cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);
+
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type,
scale_type));
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
+ &op_transb,
sizeof(op_transa)));
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
+ &op_transa,
sizeof(op_transb)));
+
+ if (bias != nullptr) {
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+ &bias->data,
sizeof(float*)));
+ }
+
+ if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
+ &epilogue,
sizeof(epilogue)));
+ }
+
+ int batch_offset_A = A->ndim - 2;
+ int batch_offset_B = B->ndim - 2;
+
+ int M = ColumnCount(B, transb, batch_offset_B);
+ int N = RowCount(A, transa, batch_offset_A);
+ int K = ColumnCount(A, transa, batch_offset_A);
+
+ int lda = transb ? K : M;
+ int ldb = transa ? N : K;
+ int ldc = M;
+
+ cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
+ CHECK_CUBLAS_ERROR(
+ cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ?
K : M, lda));
+ CHECK_CUBLAS_ERROR(
+ cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ?
N : K, ldb));
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));
+
+ if (A->ndim != 2 || B->ndim != 2) {
+ auto get_batch_count = [](int64_t* shape, int batch_offset) {
+ int64_t count = 1;
+ for (int i = 0; i < batch_offset; ++i) {
+ count *= shape[i];
+ }
+ return count;
+ };
+ auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count,
int64_t batch_stride) {
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
+ mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count,
sizeof(batch_count)));
+ CHECK_CUBLAS_ERROR(
+ cublasLtMatrixLayoutSetAttribute(mat_desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+ &batch_stride,
sizeof(batch_stride)));
+ };
+
+ int batch_count_A = get_batch_count(A->shape, batch_offset_A);
+ int batch_count_B = get_batch_count(B->shape, batch_offset_B);
+ int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
+ int64_t batch_stride_A = M * K;
+ int64_t batch_stride_B = K * N;
+ int64_t batch_stride_C = M * N;
+
+ // cuBLASLt does not seem to support batched GEMM with one of matrices
having
+ // one batch (with batch_stride 0).
+ ICHECK_EQ(batch_count_A, batch_count_B);
+
+ set_batch(A_desc, batch_count_A, batch_stride_A);
+ set_batch(B_desc, batch_count_B, batch_stride_B);
+ set_batch(C_desc, batch_count_C, batch_stride_C);
+ }
+
+ auto A_data = static_cast<char*>(A->data) + A->byte_offset;
+ auto B_data = static_cast<char*>(B->data) + B->byte_offset;
+ auto C_data = static_cast<char*>(C->data) + C->byte_offset;
+
+ CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc,
A_data, B_desc, beta,
+ C_data, C_desc, C_data, C_desc, nullptr,
nullptr, 0, nullptr));
+
+ cublasLtMatmulDescDestroy(op_desc);
+ cublasLtMatrixLayoutDestroy(A_desc);
+ cublasLtMatrixLayoutDestroy(B_desc);
+ cublasLtMatrixLayoutDestroy(C_desc);
+}
+
inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) {
DLTensor* A = args[0];
DLTensor* B = args[1];
@@ -172,7 +287,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret,
cublasLtHandle_t hdl) {
auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) +
B->byte_offset);
auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) +
C->byte_offset);
- cublasOperation_t opTranspose = CUBLAS_OP_T;
cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtMatmulDesc_t operationDesc = nullptr;
@@ -181,8 +295,6 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret,
cublasLtHandle_t hdl) {
#else
CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I));
#endif
- CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_TRANSB,
- &opTranspose,
sizeof(opTranspose)));
cublasOperation_t opTransA = CUBLASBooleanToTranspose(transa);
cublasOperation_t opTransB = CUBLASBooleanToTranspose(transb);
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_TRANSA,
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
new file mode 100644
index 0000000000..8afccb2730
--- /dev/null
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cublas/cublas_json_runtime.cc
+ * \brief A simple JSON runtime for CUBLAS.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <regex>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "cublas_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::json;
+
+class CublasJSONRuntime : public JSONRuntimeBase {
+ public:
+ CublasJSONRuntime(const std::string& symbol_name, const std::string&
graph_json,
+ const Array<String> const_names)
+ : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+ void Init(const Array<NDArray>& consts) override {}
+
+ void Run() override {
+ // TODO(masahi): Reuse the same handle across different subgraphs
+ cublasLtHandle_t handle;
+ cublasLtCreate(&handle);
+
+ for (size_t i = 0; i < nodes_.size(); ++i) {
+ const auto& node = nodes_[i];
+ if (node.GetOpType() == "kernel") {
+ auto op_name = node.GetOpName();
+ uint32_t output_eid = EntryID(outputs_[0]);
+ auto out_ptr = data_entry_[output_eid];
+ bool transa = false;
+ bool transb = false;
+ cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+
+ if (op_name.find("transposed") != std::string::npos) {
+ transb = true;
+ }
+
+ if (op_name.find("relu") != std::string::npos) {
+ epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
+ } else if (op_name.find("gelu") != std::string::npos) {
+ epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
+ } else if (op_name.find("bias") != std::string::npos) {
+ epilogue = CUBLASLT_EPILOGUE_BIAS;
+ }
+
+ auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
+ const DLTensor* bias = nullptr;
+ if (has_bias) {
+ bias = GetInput(node, 2);
+ }
+ return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
+ };
+
+ auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue !=
CUBLASLT_EPILOGUE_DEFAULT);
+
+ tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr,
transa, transb,
+ epilogue);
+ }
+ }
+ cublasLtDestroy(handle);
+ }
+
+ private:
+ const DLTensor* GetInput(const JSONGraphNode& node, const int idx) {
+ ICHECK_LT(idx, node.GetInputs().size());
+ auto eid = EntryID(node.GetInputs()[idx]);
+ ICHECK(eid < data_entry_.size());
+ return data_entry_[eid];
+ }
+};
+
+runtime::Module CublasJSONRuntimeCreate(String symbol_name, String graph_json,
+ const Array<String>& const_names) {
+ auto n = make_object<CublasJSONRuntime>(symbol_name, graph_json,
const_names);
+ return runtime::Module(n);
+}
+
+TVM_REGISTER_GLOBAL("runtime.CublasJSONRuntimeCreate").set_body_typed(CublasJSONRuntimeCreate);
+
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cublas_json")
+ .set_body_typed(JSONRuntimeBase::LoadFromBinary<CublasJSONRuntime>);
+
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/contrib/cublas/cublas_utils.h
b/src/runtime/contrib/cublas/cublas_utils.h
index 62863b8f7b..ac03b12103 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -104,6 +104,12 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
}
LOG(FATAL) << "Unsupported cuda type";
}
+
+/*! \brief Execute matrix multiply followed by the specified epilogue, using
cuBLASLt. */
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B,
const DLTensor* bias,
+ const DLTensor* C, bool transa, bool transb,
+ cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT);
+
} // namespace contrib
} // namespace tvm
diff --git a/tests/python/relax/test_codegen_cublas.py
b/tests/python/relax/test_codegen_cublas.py
new file mode 100644
index 0000000000..023054256e
--- /dev/null
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.topi.testing
+from tvm import relax
+from tvm.relax.backend.contrib.cublas import partition_for_cublas
+from tvm.relax.testing import get_relax_matmul_module
+from tvm.script import relax as R
+
+
[email protected](autouse=True)
+def reset_seed():
+ np.random.seed(0)
+
+
+has_cublas = tvm.get_global_func("relax.ext.cublas", True)
+
+cublas_enabled = pytest.mark.skipif(
+ not has_cublas,
+ reason="CUBLAS not enabled.",
+)
+
+pytestmark = [cublas_enabled]
+
+
+def build_and_run(mod, inputs_np, target, legalize=False):
+ if legalize:
+ mod = relax.transform.LegalizeOps()(mod)
+
+ dev = tvm.device(target, 0)
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, dev)
+ f = vm["main"]
+ inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+ return f(*inputs).numpy()
+
+
+def get_result_with_relax_cublas_offload(mod, *args):
+ mod = partition_for_cublas(mod)
+ mod = relax.transform.RunCodegen()(mod)
+
+ return build_and_run(mod, args, "cuda")
+
+
+def _to_concrete_shape(symbolic_shape, var_table):
+ result = []
+ for dim in symbolic_shape:
+ if not isinstance(dim, tvm.tir.expr.Var):
+ result.append(dim)
+ continue
+
+ if dim not in var_table:
+ var_table[dim] = np.random.randint(10, 50)
+ result.append(var_table[dim])
+
+ return tuple(result)
+
+
+_vars = {
+ "a": tvm.tir.expr.Var("a", "int64"),
+ "b": tvm.tir.expr.Var("b", "int64"),
+}
+
+
+_epilogue_table = {
+ "none": (False, None),
+ "bias": (True, None),
+ "relu": (True, R.nn.relu),
+ "gelu": (True, R.nn.gelu),
+}
+
+
[email protected](
+ "x_shape, y_shape, transpose_y, epilogue",
+ [
+ # Regular
+ ((8, 8), (8, 8), False, "none"),
+ ((_vars["a"], 6), (6, 16), False, "bias"),
+ # Transposed
+ ((4, 16), (16, 128), True, "relu"),
+ ((35, 8), (8, 8), True, "gelu"),
+ # # 3D x 3D
+ ((6, 32, 8), (6, 8, 10), False, "bias"),
+ ((6, 32, 8), (6, 8, 10), True, "none"),
+ ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu"),
+ # ND x ND
+ ((5, 3, 32, 8), (5, 3, 8, 10), True, "relu"),
+ ],
+)
[email protected](
+ "dtype",
+ [
+ "float16",
+ "float32",
+ ],
+)
+def test_matmul_offload(
+ x_shape,
+ y_shape,
+ transpose_y,
+ epilogue,
+ dtype,
+):
+ with_bias, activation = _epilogue_table[epilogue]
+ var_table = {}
+ concrete_x_shape = _to_concrete_shape(x_shape, var_table)
+ concrete_y_shape = _to_concrete_shape(y_shape, var_table)
+ x = np.random.randn(*concrete_x_shape).astype(dtype)
+ y = np.random.randn(*concrete_y_shape).astype(dtype)
+
+ if transpose_y:
+ y = np.swapaxes(y, -2, -1)
+ y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
+
+ if with_bias:
+ bias = np.random.randn(concrete_y_shape[-1]).astype(dtype)
+ args = (x, y, bias)
+ else:
+ bias = None
+ args = (x, y)
+
+ mod = get_relax_matmul_module(
+ x_shape,
+ y_shape,
+ dtype,
+ with_bias=with_bias,
+ transposed_y=transpose_y,
+ activation=activation,
+ )
+
+ out = get_result_with_relax_cublas_offload(mod, *args)
+ ref = build_and_run(mod, args, "llvm", legalize=True)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index c8ca44311d..b9ba4f4dc9 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -23,8 +23,8 @@ import tvm.topi.testing
from tvm import relax
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
from tvm.contrib.pickle_memoize import memoize
-from tvm.relax.backend import get_patterns_with_prefix
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
+from tvm.relax.testing import get_relax_matmul_module
from tvm.script import relax as R
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
@@ -96,9 +96,6 @@ def build_and_run(mod, inputs_np, target, legalize=False):
def get_result_with_relax_cutlass_offload(mod, *args,
assert_all_bindings_fused=True):
- patterns = [(entry.name, entry.pattern) for entry in
get_patterns_with_prefix("cutlass")]
- assert len(patterns) != 0, "Cannot find cutlass patterns"
-
mod = partition_for_cutlass(mod)
if assert_all_bindings_fused:
@@ -168,50 +165,6 @@ def get_relax_conv2d_module(
return tvm.IRModule({"main": func})
-def get_relax_matmul_module(
- x_shape,
- y_shape,
- dtype,
- transposed_y=False,
- with_bias=False,
- activation=None,
- residual_bin_op=None,
- residual_activation=None,
-):
- if transposed_y:
- n = y_shape[-2]
- else:
- n = y_shape[-1]
-
- with IRBuilder() as builder:
- with relax_builder.function():
- R.func_name("main")
- x = R.arg("x", R.Tensor(x_shape, dtype))
- y = R.arg("y", R.Tensor(y_shape, dtype))
- if with_bias:
- bias = R.arg("bias", R.Tensor((n,), dtype))
-
- with R.dataflow() as frame:
- if transposed_y:
- axes = list(range(len(y_shape) - 2)) + [-1, -2]
- y = R.emit(R.permute_dims(y, axes=axes))
- result = R.emit(R.matmul(x, y, out_dtype=dtype))
- if with_bias:
- result = R.emit(result + bias)
- if activation is not None:
- result = R.emit(activation(result))
- if residual_bin_op is not None:
- result = R.emit(residual_bin_op(result, x))
- if residual_activation is not None:
- result = R.emit(residual_activation(result))
- R.output(result)
-
- R.func_ret_value(frame.output_vars[0])
-
- func = builder.get()
- return tvm.IRModule({"main": func})
-
-
def _to_concrete_shape(symbolic_shape, var_table):
result = []
for dim in symbolic_shape: