This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d7a6285f47 [Unity][BYOC] Add transposed matmul support to Relax
CUTLASS BYOC (#14128)
d7a6285f47 is described below
commit d7a6285f473dad912dd90183248f05b07a18e7e4
Author: Lite Ye <[email protected]>
AuthorDate: Mon Feb 27 03:33:50 2023 -0500
[Unity][BYOC] Add transposed matmul support to Relax CUTLASS BYOC (#14128)
Add transposed matmul support for Relax CUTLASS
---
python/tvm/contrib/cutlass/build.py | 88 ++++++++++++++++++--------
python/tvm/contrib/cutlass/gemm_operation.py | 11 ++--
python/tvm/contrib/cutlass/gen_tensor_op.py | 93 +++++++++++++++++++++-------
python/tvm/relax/__init__.py | 1 +
python/tvm/relax/backend/contrib/cutlass.py | 30 +++++++++
python/tvm/relax/dpl/pattern.py | 25 ++++++++
src/relax/ir/dataflow_matcher.cc | 2 +
tests/python/relax/test_codegen_cutlass.py | 68 ++++++++++++++++++--
8 files changed, 259 insertions(+), 59 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index c6e5adacec..954aef60c2 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -19,6 +19,7 @@
import logging
import multiprocessing
import os
+from typing import Optional
import tvm
from tvm import relax, relay, runtime
@@ -522,7 +523,19 @@ def tune_cutlass_function(
)
-def _extract_relax_function_info(f):
+def _get_call_node(expr: relax.Expr, op_name: str) -> Optional[relax.Call]:
+ node = None
+
+ def fvisit(e):
+ nonlocal node
+ if isinstance(e, relax.Call) and e.op.name == op_name:
+ node = e
+
+ relax.analysis.post_order_visit(expr, fvisit)
+ return node
+
+
+def _extract_relax_function_signature(f):
signature = {}
for i, arg in enumerate(f.params):
@@ -534,16 +547,26 @@ def _extract_relax_function_info(f):
signature["ret_shape"] = list(ret_sinfo.shape)
signature["ret_dtype"] = ret_sinfo.dtype
- op_attrs = {}
+ return signature
- def fvisit(e):
- nonlocal op_attrs
- if isinstance(e, relax.Call) and e.op.name in ["relax.nn.conv2d"]:
- op_attrs = e.attrs
- relax.analysis.post_order_visit(f.body, fvisit)
+def _extract_arg_idx(pattern_name, f):
+ pattern_entry = relax.backend.get_pattern(pattern_name)
+ if pattern_entry is None:
+ raise ValueError(f"Unsupported op_type {pattern_name}")
+ var2val = relax.analysis.get_var2val(f)
+ matched_expr = pattern_entry.pattern.extract_matched_expr(f.body.body,
var2val)
- return signature, op_attrs
+ func_args = list(f.params)
+
+ arg_idx = {}
+ for arg_name, arg_pattern in pattern_entry.arg_patterns.items():
+ arg_expr = matched_expr[arg_pattern]
+ if arg_expr not in func_args:
+ raise ValueError(f"Cannot find arg {arg_name} in the fused
function parameters")
+ arg_idx[arg_name] = func_args.index(arg_expr)
+
+ return arg_idx
@relax.expr_functor.mutator
@@ -566,7 +589,8 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
def handle_conv2d(self, f, op_type):
"""Tune and annotate a conv2d op."""
- signature, op_attrs = _extract_relax_function_info(f)
+ signature = _extract_relax_function_signature(f)
+ op_attrs = _get_call_node(f.body, "relax.nn.conv2d").attrs
d_shape = signature["arg0_shape"]
w_shape = signature["arg1_shape"]
@@ -622,18 +646,29 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
def handle_matmul(self, f, op_type):
"""Tune and annotate a dense op."""
- signature, _ = _extract_relax_function_info(f)
+ signature = _extract_relax_function_signature(f)
+ arg_idx = _extract_arg_idx(op_type, f)
+
+ lhs_arg = f"arg{arg_idx['lhs']}"
+ rhs_arg = f"arg{arg_idx['rhs']}"
- arg0_shape = signature["arg0_shape"]
- arg1_shape = signature["arg1_shape"]
+ lhs_shape = signature[f"{lhs_arg}_shape"]
+ rhs_shape = signature[f"{rhs_arg}_shape"]
out_shape = signature["ret_shape"]
- arg0_dtype = signature["arg0_dtype"]
- arg1_dtype = signature["arg1_dtype"]
+ lhs_dtype = signature[f"{lhs_arg}_dtype"]
+ rhs_dtype = signature[f"{rhs_arg}_dtype"]
out_dtype = signature["ret_dtype"]
- MM = arg0_shape[0]
- KK = arg0_shape[1]
- NN = arg1_shape[1]
+ MM = lhs_shape[0]
+ KK = lhs_shape[1]
+ if "transposed" in op_type:
+ NN = rhs_shape[0]
+ ldb = "K"
+ layout_b = LayoutType.ColumnMajor
+ else:
+ NN = rhs_shape[1]
+ ldb = "N"
+ layout_b = LayoutType.RowMajor
use_3xtf32 = self.options.get("use_3xtf32", False)
find_first_valid = self.options.get("find_first_valid", True)
@@ -645,26 +680,29 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
NN,
KK,
out_dtype,
- arg0_dtype,
- arg1_dtype,
+ lhs_dtype,
+ rhs_dtype,
use_3xtf32,
batched=False,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
- layout_b=LayoutType.RowMajor,
+ layout_b=layout_b,
)
return f.with_attrs(
{
"op_type": op_type,
- "arg0_dtype": arg0_dtype,
- "arg1_dtype": arg1_dtype,
+ "lhs_arg_idx": arg_idx["lhs"],
+ "rhs_arg_idx": arg_idx["rhs"],
+ "bias_arg_idx": arg_idx.get("bias"),
+ "arg0_dtype": signature["arg0_dtype"],
+ "arg1_dtype": signature["arg1_dtype"],
"ret_dtype": out_dtype,
- "arg0_shape": arg0_shape,
- "arg1_shape": arg1_shape,
+ "arg0_shape": signature["arg0_shape"],
+ "arg1_shape": signature["arg1_shape"],
"ret_shape": out_shape,
"lda": "K",
- "ldb": "N",
+ "ldb": ldb,
"ldc": "N",
"cutlass_op_name": op_name,
"cutlass_op_def": op_def,
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index 58f5de6a9c..3e74cbaec8 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -259,7 +259,7 @@ class EmitGemmInstance:
return substitute_template(gemm_template, values)
-def instantiate_gemm_template(attrs, func_args):
+def instantiate_gemm_template(attrs):
"""Return CUTLASS host code for GEMM based on a template and the provided
attribute map."""
template = """
@@ -277,8 +277,8 @@ def instantiate_gemm_template(attrs, func_args):
cutlass::gemm::GemmCoord problem_size(M, N, K);
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(${beta});
- void* ptr_a = (void*)(${arg0}->data);
- void* ptr_b = (void*)(${arg1}->data);
+ void* ptr_a = (void*)(${lhs_arg}->data);
+ void* ptr_b = (void*)(${rhs_arg}->data);
${bias_decl}
void* ptr_out = (void*)(out0->data);
@@ -310,7 +310,7 @@ def instantiate_gemm_template(attrs, func_args):
if has_bias:
aux_map.update(
{
- "bias_decl": "void* ptr_c_bias = (void*)(${arg2}->data);\n",
+ "bias_decl": "void* ptr_c_bias =
(void*)(${bias_arg}->data);\n",
"ptr_c": "ptr_c_bias",
"c_stride": "0",
}
@@ -342,7 +342,4 @@ def instantiate_gemm_template(attrs, func_args):
template = substitute_template(template, aux_map)
- for i, arg in enumerate(func_args):
- attrs["arg{}".format(i)] = arg
-
return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index d3ab020839..92bf04e863 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -17,27 +17,28 @@
# pylint: disable=invalid-name
"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
import logging
+import multiprocessing
import os
import re
-import tempfile
import subprocess
-import multiprocessing
+import tempfile
+
import tvm._ffi
-from tvm.tir import IntImm
from tvm.runtime import Object
+from tvm.tir import IntImm
+
from . import _ffi_api as ffi
+from .conv2d_operation import instantiate_conv2d_template
+from .gemm_operation import instantiate_gemm_template
from .library import (
- MathInstruction,
DataType,
DataTypeTag,
- OpcodeClass,
+ EpilogueFunctor,
+ MathInstruction,
MathOperation,
+ OpcodeClass,
TileDescription,
- EpilogueFunctor,
)
-from .gemm_operation import instantiate_gemm_template
-from .conv2d_operation import instantiate_conv2d_template
-
logger = logging.getLogger("cutlass")
@@ -371,6 +372,10 @@ EPILOGUE_MAP = {
"cutlass.matmul_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.matmul_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
"cutlass.matmul_bias_gelu": (EpilogueFunctor.LinearCombinationGelu, False),
+ "cutlass.matmul_transposed": (EpilogueFunctor.LinearCombination, False),
+ "cutlass.matmul_transposed_bias": (EpilogueFunctor.LinearCombinationBias,
True),
+ "cutlass.matmul_transposed_bias_relu":
(EpilogueFunctor.LinearCombinationRelu, True),
+ "cutlass.matmul_transposed_bias_gelu":
(EpilogueFunctor.LinearCombinationGelu, False),
"cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_bias_hardswish":
(EpilogueFunctor.LinearCombinationHardSwish, False),
"cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False),
@@ -454,6 +459,13 @@ class CodegenResult(Object):
self.__init_handle_by_constructor__(ffi.CodegenResult, code, headers)
+def _get_optional_int_annotation(annotations, key, default=None):
+ value = annotations.get(key, default)
+ if value is not None:
+ return int(value)
+ return value
+
+
@tvm._ffi.register_func("contrib.cutlass.instantiate_template")
def instantiate_template(func_name, annotations, func_args):
"""Return CUTLASS host code based on a template and the provided
annotations.
@@ -519,32 +531,69 @@ def instantiate_template(func_name, annotations,
func_args):
if "dense" in func_name or "matmul" in func_name:
batched = "batch_matmul" in func_name
batched_offset = 1 if batched else 0
- attrs["K"] = str(int(arg0_shape[batched_offset + 1]))
- attrs["M"] = get_dim(arg0_shape[batched_offset], func_args[0], 0,
batched_offset)
-
- if annotations["ldb"] == "N":
- attrs["N"] = get_dim(arg1_shape[batched_offset + 1], func_args[1],
1, batched_offset)
+ transposed = "transposed" in func_name
+ lhs_arg_idx = _get_optional_int_annotation(annotations, "lhs_arg_idx",
0)
+ rhs_arg_idx = _get_optional_int_annotation(annotations, "rhs_arg_idx",
1)
+ bias_arg_idx = _get_optional_int_annotation(annotations,
"bias_arg_idx", 2)
+ lhs_arg = func_args[lhs_arg_idx]
+ rhs_arg = func_args[rhs_arg_idx]
+ lhs_shape = annotations[f"arg{lhs_arg_idx}_shape"]
+ rhs_shape = annotations[f"arg{rhs_arg_idx}_shape"]
+
+ attrs["lhs_arg"] = lhs_arg
+ attrs["rhs_arg"] = rhs_arg
+ if len(func_args) > 2:
+ attrs["bias_arg"] = func_args[bias_arg_idx]
+ attrs["ElementInputA"] =
DataTypeTag[dtype_map[annotations[f"arg{lhs_arg_idx}_dtype"]]]
+ attrs["ElementInputB"] =
DataTypeTag[dtype_map[annotations[f"arg{rhs_arg_idx}_dtype"]]]
+ attrs["ElementOutput"] =
DataTypeTag[dtype_map[annotations["ret_dtype"]]]
+
+ attrs["K"] = str(int(lhs_shape[batched_offset + 1]))
+ attrs["M"] = get_dim(lhs_shape[batched_offset], lhs_arg, 0,
batched_offset)
+
+ if transposed:
+ attrs["N"] = get_dim(rhs_shape[batched_offset], rhs_arg, 0,
batched_offset)
else:
- attrs["N"] = get_dim(arg1_shape[batched_offset], func_args[1], 0,
batched_offset)
+ attrs["N"] = get_dim(rhs_shape[batched_offset + 1], rhs_arg, 1,
batched_offset)
if batched:
headers.append("cutlass/gemm/device/gemm_batched.h")
- attrs["batch"] = get_dim(arg0_shape[0], func_args[0], 0)
- attrs["batch_stride_A"] =
get_batch_stride(annotations["batch_stride_A"], 0, 0, 1, 2)
- attrs["batch_stride_B"] =
get_batch_stride(annotations["batch_stride_B"], 1, 1, 1, 2)
+ attrs["batch"] = get_dim(lhs_shape[0], lhs_arg, 0)
+ attrs["batch_stride_A"] = get_batch_stride(
+ annotations["batch_stride_A"],
+ lhs_arg_idx,
+ lhs_arg_idx,
+ 1,
+ 2,
+ )
+ attrs["batch_stride_B"] = get_batch_stride(
+ annotations["batch_stride_B"],
+ rhs_arg_idx,
+ rhs_arg_idx,
+ 1,
+ 2,
+ )
- if annotations["ldb"] == "N":
+ if transposed:
attrs["batch_stride_C"] = get_batch_stride(
- annotations["batch_stride_C"], 0, 1, 1, 2
+ annotations["batch_stride_C"],
+ lhs_arg_idx,
+ rhs_arg_idx,
+ 1,
+ 1,
)
else:
attrs["batch_stride_C"] = get_batch_stride(
- annotations["batch_stride_C"], 0, 1, 1, 1
+ annotations["batch_stride_C"],
+ lhs_arg_idx,
+ rhs_arg_idx,
+ 1,
+ 2,
)
else:
headers.append("cutlass/gemm/device/gemm.h")
- code = instantiate_gemm_template(attrs, func_args)
+ code = instantiate_gemm_template(attrs)
return CodegenResult(code, headers)
elif "conv2d" in func_name:
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index d0a1942ebd..e86f8c6074 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -87,6 +87,7 @@ from . import transform
from . import block_builder
from . import op
from . import struct_info
+from . import backend
# VM
from .vm_build import build, Executable
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 20cf57a40a..51684abb06 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -66,6 +66,36 @@ register_patterns(
activation="relax.nn.gelu",
),
),
+ (
+ "cutlass.matmul_transposed",
+ make_matmul_pattern(
+ with_bias=False,
+ transposed_rhs=True,
+ ),
+ ),
+ (
+ "cutlass.matmul_transposed_bias",
+ make_matmul_pattern(
+ with_bias=True,
+ transposed_rhs=True,
+ ),
+ ),
+ (
+ "cutlass.matmul_transposed_bias_relu",
+ make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.relu",
+ transposed_rhs=True,
+ ),
+ ),
+ (
+ "cutlass.matmul_transposed_bias_gelu",
+ make_matmul_pattern(
+ with_bias=True,
+ activation="relax.nn.gelu",
+ transposed_rhs=True,
+ ),
+ ),
]
)
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 9e1963f7ed..300b0af568 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -204,6 +204,31 @@ class DFPattern(Node):
"""
return ffi.match_expr(self, expr, var2val) # type: ignore
+ def extract_matched_expr(
+ self, expr, var2val: Optional[Dict[Var, Expr]] = None
+ ) -> Optional[Dict["DFPattern", Expr]]:
+ """
+ Match a relax.Expr and return a map from matching patterns to matched
expressions.
+
+ Parameters
+ ----------
+ expr : tvm.relax.Expr
+ The expression to match
+ var2val : Optional[Dict[tvm.relax.Var, tvm.relax.Expr]]
+ A mapping from relax.Var to relax.Expr for autojump.
+
+ Returns
+ -------
+ result: Optional[Dict[DFPattern, Expr]]
+ Map from matching patterns to matched expressions.
+ Return None if the pattern does not match expr.
+
+ Note
+ ----
+ Check the note of `match` for the meaning of var2val.
+ """
+ return ffi.extract_matched_expr(self, expr, var2val)
+
def used_by(self, other: Union["DFPattern", "PatternSeq"], index=-1) ->
"PatternSeq":
"""
The current pattern being used by another pattern (sequence)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 92eb452a00..da8c6ce2da 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -515,6 +515,8 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(DFPattern
pattern, Expr expr,
return matching;
}
+TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr);
+
bool MatchExpr(DFPattern pattern, Expr expr, Optional<Map<Var, Expr>>
bindings_opt) {
return static_cast<bool>(ExtractMatchedExpr(pattern, expr, bindings_opt));
}
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 673155342c..af3d40d9c4 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -255,9 +255,12 @@ def test_conv2d_offload():
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
-def get_relax_matmul_module(x, y, with_bias=False, activation=None):
+def get_relax_matmul_module(x, y, transposed_y=False, with_bias=False,
activation=None):
m, k = x.shape
- n = y.shape[-1]
+ if transposed_y:
+ n = y.shape[-2]
+ else:
+ n = y.shape[-1]
dtype = str(x.dtype)
from tvm.script.ir_builder import IRBuilder
@@ -266,13 +269,15 @@ def get_relax_matmul_module(x, y, with_bias=False,
activation=None):
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
- x = R.arg("x", R.Tensor((m, k), dtype))
- y = R.arg("y", R.Tensor((k, n), dtype))
+ x = R.arg("x", R.Tensor(x.shape, dtype))
+ y = R.arg("y", R.Tensor(y.shape, dtype))
if with_bias:
bias = R.arg("bias", R.Tensor((n,), dtype))
with R.dataflow() as frame:
- result = R.emit(R.matmul(x, y))
+ if transposed_y:
+ y = R.emit(R.permute_dims(y))
+ result = R.emit(R.matmul(x, y, out_dtype=dtype))
if with_bias:
result = R.emit(result + bias)
if activation is not None:
@@ -380,5 +385,58 @@ def test_kernel_sharing():
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
+def test_matmul_transposed_offload(matmul_x, matmul_y):
+ x, y = matmul_x, matmul_y
+
+ mod = get_relax_matmul_module(x, y.transpose(), transposed_y=True)
+ out = get_result_with_relax_cutlass_offload(mod, x, y.transpose())
+ ref_relay_expr = get_relay_matmul(x.shape, y.shape[::-1])
+ ref = get_relay_ref(ref_relay_expr, x, y.transpose())
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4)
+
+
+def test_matmul_transposed_bias_offload(matmul_x, matmul_y, matmul_bias):
+ x, y, bias = matmul_x, matmul_y, matmul_bias
+
+ mod = get_relax_matmul_module(
+ x, y.transpose(), transposed_y=True, with_bias=True, activation=None
+ )
+ out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
+
+ ref_relay_expr = get_relay_matmul_bias(x.shape, y.shape[::-1])
+ ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4)
+
+
+def test_matmul_transposed_bias_relu_offload(matmul_x, matmul_y, matmul_bias):
+ x, y, bias = matmul_x, matmul_y, matmul_bias
+
+ mod = get_relax_matmul_module(
+ x, y.transpose(), transposed_y=True, with_bias=True,
activation=R.nn.relu
+ )
+ out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
+
+ ref_relay_expr = get_relay_matmul_bias_relu(x.shape, y.shape[::-1])
+ ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-4)
+
+
+def test_matmul_transposed_bias_gelu_offload(matmul_x, matmul_y, matmul_bias):
+ x, y, bias = matmul_x, matmul_y, matmul_bias
+
+ mod = get_relax_matmul_module(
+ x, y.transpose(), transposed_y=True, with_bias=True,
activation=R.nn.gelu
+ )
+ out = get_result_with_relax_cutlass_offload(mod, x, y.transpose(), bias)
+
+ ref_relay_expr = get_relay_matmul_bias_gelu(x.shape, y.shape[::-1])
+ ref = get_relay_ref(ref_relay_expr, x, y.transpose(), bias)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-3)
+
+
if __name__ == "__main__":
tvm.testing.main()