This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c448c50cd1 [Unity][CUTLASS] Support more residual input shape (#14968)
c448c50cd1 is described below
commit c448c50cd19cc892b77bbcb68dc20e995fd67f40
Author: masahi <[email protected]>
AuthorDate: Sun May 28 11:06:12 2023 +0900
[Unity][CUTLASS] Support more residual input shape (#14968)
* improved cutlass residual fusion
* update cutlass residual test
* update residual check
* fix residual check
* clean
* fix
* minor
---
python/tvm/contrib/cutlass/build.py | 48 ++++++++++-------
python/tvm/contrib/cutlass/conv2d_operation.py | 25 ++++++++-
python/tvm/contrib/cutlass/gen_tensor_op.py | 3 ++
python/tvm/relax/backend/contrib/cutlass.py | 11 +++-
tests/python/relax/test_codegen_cutlass.py | 74 +++++++++++++++++---------
5 files changed, 111 insertions(+), 50 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index 1d105bbe82..07583a4851 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -668,26 +668,34 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
use_multiprocessing=use_multiprocessing,
)
- return f.with_attrs(
- {
- "op_type": op_type,
- "data_arg_idx": arg_idx["lhs"],
- "weight_arg_idx": arg_idx["rhs"],
- "bias_arg_idx": arg_idx.get("bias"),
- "residual_arg_idx": arg_idx.get("residual"),
- "arg0_dtype": data_dtype,
- "arg1_dtype": weight_dtype,
- "ret_dtype": out_dtype,
- "arg0_shape": d_shape,
- "arg1_shape": w_shape,
- "ret_shape": out_shape,
- "strides": strides,
- "padding": padding,
- "dilation": dilation,
- "cutlass_op_name": op_name,
- "cutlass_op_def": op_def,
- }
- )
+ attrs = {
+ "op_type": op_type,
+ "data_arg_idx": arg_idx["lhs"],
+ "weight_arg_idx": arg_idx["rhs"],
+ "bias_arg_idx": arg_idx.get("bias"),
+ "residual_arg_idx": arg_idx.get("residual"),
+ "arg0_dtype": data_dtype,
+ "arg1_dtype": weight_dtype,
+ "ret_dtype": out_dtype,
+ "arg0_shape": d_shape,
+ "arg1_shape": w_shape,
+ "ret_shape": out_shape,
+ "strides": strides,
+ "padding": padding,
+ "dilation": dilation,
+ "cutlass_op_name": op_name,
+ "cutlass_op_def": op_def,
+ }
+
+ residual_arg = arg_idx.get("residual")
+
+ if residual_arg:
+ residual_shape = signature[f"arg{residual_arg}_shape"]
+ attrs["residual_shape"] = residual_shape
+ elif "residual" in op_type:
+ attrs["residual_shape"] = d_shape
+
+ return f.with_attrs(attrs)
def handle_matmul(self, f, op_type):
"""Tune and annotate a dense op."""
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index f2d2f01276..8ae9b1414d 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -394,11 +394,12 @@ def instantiate_conv2d_template(attrs):
auto activation_shape = TensorNHWC::packed(cutlass::make_Coord(N, H, W, C));
auto weight_shape = TensorNHWC::packed(cutlass::make_Coord(K, R, S, C));
auto output_shape = TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K));
+ ${residual_shape_decl}
TensorNHWC layout_A(${A_shape});
TensorNHWC layout_B(${B_shape});
TensorNHWC layout_C(${C_shape});
- TensorNHWC layout_D(${C_shape});
+ TensorNHWC layout_D(${D_shape});
using ElementOutput = ${ElementOutput};
cutlass::TensorRef<ElementOutput, TensorNHWC>
tensor_c{static_cast<ElementOutput*>(${tensor_c}), ${tensor_c_layout}};
@@ -506,18 +507,38 @@ def instantiate_conv2d_template(attrs):
else:
aux_map["additional_args"] = ""
+ aux_map["residual_shape_decl"] = ""
+
if is_wgrad:
aux_map["A_shape"] = "output_shape"
aux_map["B_shape"] = "activation_shape"
aux_map["C_shape"] = "weight_shape"
+ aux_map["D_shape"] = "weight_shape"
elif is_dgrad:
aux_map["A_shape"] = "output_shape"
aux_map["B_shape"] = "weight_shape"
aux_map["C_shape"] = "activation_shape"
+ aux_map["D_shape"] = "activation_shape"
else:
aux_map["A_shape"] = "activation_shape"
aux_map["B_shape"] = "weight_shape"
- aux_map["C_shape"] = "output_shape"
+ aux_map["D_shape"] = "output_shape"
+
+ if has_residual_block:
+ res_shape = list(attrs.pop("residual_shape"))
+ shape_str = f"cutlass::make_Coord({res_shape[0]}, {res_shape[1]},
{res_shape[2]}, K)"
+ aux_map[
+ "residual_shape_decl"
+ ] = f"auto residual_shape = TensorNHWC::packed({shape_str});"
+ aux_map["C_shape"] = "residual_shape"
+
+ if res_shape == [int(attrs[c]) for c in ["N", "H", "W", "K"]]:
+ aux_map["tensor_c_layout"] = "layout_C"
+ else:
+ # bias-like residual input
+ aux_map["tensor_c_layout"] =
"cutlass::layout::TensorNHWC::Stride(0)"
+ else:
+ aux_map["C_shape"] = "output_shape"
if use_split_k:
aux_map["ElementOutput"] = "EpilogueOutputOp::ElementOutput"
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index abc13d3570..0caf62043e 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -689,6 +689,9 @@ def instantiate_template(func_name, annotations, func_args):
attrs["split_k_mode"] = "kSerial"
attrs["split_k_slices"] = 1
+ if "residual_shape" in annotations:
+ attrs["residual_shape"] = annotations["residual_shape"]
+
code = instantiate_conv2d_template(attrs)
return CodegenResult(code, headers)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index dd2739fd98..dffd7c401c 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -47,6 +47,10 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
)
+def _shape_1d(shape):
+ return reduce(operator.mul, shape, 1)
+
+
def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
"""
Check whether intermediate variables in the region to be fused are used
outside
@@ -104,7 +108,10 @@ def _check_residual(root_call: Call, context:
PatternCheckContext) -> bool:
shape1 = [int(s) for s in root_var.struct_info.shape]
shape2 = [int(s) for s in residual.struct_info.shape]
- if shape1 != shape2:
+ out_channel = shape1[-1]
+ is_bias_like = lambda shape: (shape[-1] == out_channel and
_shape_1d(shape) == out_channel)
+
+ if shape1 != shape2 and not is_bias_like(shape2):
return False
return True
@@ -402,7 +409,7 @@ class WorkspaceAnnotator(PyExprMutator):
if "attention" in f.attrs["Composite"]:
# Workspace is needed only for larger head sizes, but for
simplicity we always allocate.
out_dtype = f.ret_struct_info.dtype
- out_size_1d = reduce(operator.mul, f.ret_struct_info.shape, 1)
+ out_size_1d = _shape_1d(f.ret_struct_info.shape)
# This needs to be in sync with the actual value that the kernel
expects.
workspace_size_bytes = out_size_1d * {"float16": 2, "float32":
4}[out_dtype]
return f.with_attr("WorkspaceSize", workspace_size_bytes)
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 339b22ca9b..8a1675ad35 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1011,35 +1011,57 @@ def
test_attention_rewrite_offload(attention_rewrite_size):
tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5,
atol=1e-5)
-def test_invalid_residual():
- @tvm.script.ir_module
- class Module:
- @R.function
- def main(
- x: R.Tensor((2, 64, 64, 8), dtype="float16"),
- w: R.Tensor((8, 3, 3, 8), dtype="float16"),
- bias: R.Tensor((1, 1, 8), dtype="float16"),
- residual: R.Tensor((2, 1, 1, 8), dtype="float16"),
- ) -> R.Tensor((1, 256, 64, 64), dtype="float16"):
- with R.dataflow():
- conv = R.nn.conv2d(
- x,
- w,
- padding=[1, 1, 1, 1],
- out_dtype="float16",
- data_layout="NHWC",
- kernel_layout="OHWI",
+def test_conv2d_residual_broadcast():
+ data_shape = (2, 64, 64, 8)
+ weight_shape = (8, 3, 3, 8)
+ dtype = "float16"
+
+ def get_mod(residual_batch):
+ with IRBuilder() as builder:
+ with relax_builder.function():
+ R.func_name("main")
+ data = R.arg("data", R.Tensor(data_shape, dtype))
+ weight = R.arg("weight", R.Tensor(weight_shape, dtype))
+ bias = R.arg("bias", R.Tensor((1, 1, weight_shape[0]), dtype))
+ residual = R.arg(
+ "residual", R.Tensor((residual_batch, 1, 1,
weight_shape[0]), dtype)
)
- bias_out = R.add(conv, bias)
- out = R.add(bias_out, residual)
- R.output(out)
- return out
- rewritten = partition_for_cutlass(Module)
- func_names = [gv.name_hint for gv in rewritten.functions.keys()]
+ with R.dataflow() as frame:
+ output = R.emit(
+ R.nn.conv2d(
+ data,
+ weight,
+ out_dtype=dtype,
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ )
+ )
+ output = R.emit(output + bias)
+ output = R.emit(R.nn.relu(output))
+ output = R.emit(R.add(output, residual))
+ R.output(output)
+
+ R.func_ret_value(frame.output_vars[0])
+
+ func = builder.get()
+ return tvm.IRModule({"main": func})
- assert "fused_relax_nn_conv2d_relax_add_relax_add_cutlass" not in
func_names
- assert "fused_relax_nn_conv2d_relax_add_cutlass" in func_names
+ low = -1
+ high = 1
+
+ residual_batch = 1
+ mod = get_mod(residual_batch)
+ data = np.random.randint(low, high, size=data_shape).astype(dtype)
+ weight = np.random.randint(low, high, size=weight_shape).astype(dtype)
+ bias = np.random.randint(low, high, size=(1, 1,
weight_shape[0])).astype(dtype)
+ bias2 = np.random.randint(low, high, size=(residual_batch, 1, 1,
weight_shape[0])).astype(dtype)
+
+ args = [data, weight, bias, bias2]
+ out = get_result_with_relax_cutlass_offload(mod, *args)
+ ref = build_and_run(mod, args, "llvm")
+ tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(