comaniac commented on a change in pull request #9261: URL: https://github.com/apache/tvm/pull/9261#discussion_r737937439
########## File path: tests/python/contrib/test_cutlass.py ########## @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import pytest +import tvm +from tvm import relay +import numpy as np +from tvm.contrib.cutlass import profile_and_build + + +def get_ref_rt_mod(mod, params): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target="cuda", params=params) + dev = tvm.device("cuda", 0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + return rt_mod, dev + + +def get_output(rt_mod, x): + rt_mod.set_input("data", x) + rt_mod.run() + return rt_mod.get_output(0).asnumpy() + + +def get_dense(M, N, K, out_dtype="float16"): + data = relay.var("data", shape=(M, K), dtype="float16") + weight = relay.var("weight", shape=(N, K), dtype="float16") + return relay.nn.dense(data, weight, out_dtype=out_dtype) + + +def get_dense_bias(M, N, K, out_dtype="float16"): + dense = get_dense(M, N, K, out_dtype=out_dtype) + bias = relay.var("bias", shape=(N,), dtype=out_dtype) + return relay.nn.bias_add(dense, bias) + + +def get_dense_bias_relu(M, N, K, out_dtype="float16"): + return relay.nn.relu(get_dense_bias(M, N, K, out_dtype="float16")) + + +def get_dense_bias_gelu(M, N, K, out_dtype="float16"): + bias_add = get_dense_bias(M, N, K, out_dtype) + mul = bias_add * relay.const((1.0 / math.sqrt(2.0)), dtype=out_dtype) + if out_dtype == "float16": + erf = relay.cast(relay.op.erf(relay.cast(mul, "float32")), "float16") + else: + erf = relay.op.erf(mul) + mul_half = erf * relay.const(0.5, dtype=out_dtype) + add = mul_half + relay.const(0.5, dtype=out_dtype) + return add * bias_add + + +def verify(func, M, N, K, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False): + if not tvm.get_global_func("relay.ext.cutlass", True): + return + mod = tvm.IRModule.from_expr(func) + typ = relay.transform.InferType()(mod) + out_dtype = typ["main"].body.checked_type.dtype + np_data = np.random.uniform(-1, 1, (M, K)).astype("float16") + np_weight = np.random.uniform(-1, 1, (N, K)).astype("float16") + np_bias = np.random.uniform(-1, 1, (N,)).astype(out_dtype) + + params = {"weight": np_weight, "bias": np_bias} + + rt_mod_ref, dev = get_ref_rt_mod(mod, params) + rt_mod, dev, num_partition = profile_and_build(mod, params, sm, tmp_dir="tmp") Review comment: Just curious, how long would this profile take? ########## File path: python/tvm/contrib/cutlass/gen_gemm.py ########## @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Kernel generator and profiler for CUTLASS.""" +import os +import re +import tempfile +import subprocess +from .gemm_operation import GemmOperation, EmitGemmInstance +from .gemm_profiler import GemmProfilerEmitter +from .library import ( + EpilogueFunctor, + SwizzlingFunctor, + TensorDescription, + DataTypeTag, + LayoutType, + MathInstruction, + DataType, + OpcodeClass, + MathOperation, + TileDescription, +) + + +def create_gemm_operator( + layouts, + tile_descriptions, + data_type, + alignment_constraints, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, +): + """Exhaustively instantiate all kernels from a given configuration.""" + ret = [] + kernel_emitter = EmitGemmInstance() + profiler_emitter = GemmProfilerEmitter() + + element_a, element_b, element_c, element_epilogue = data_type + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + op_entry = {} + op = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + op_bias = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationBias, + swizzling_functor, + ) + op_bias_relu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationRelu, + swizzling_functor, + ) + op_bias_gelu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationGelu, + swizzling_functor, + ) + + kernel_emitter = EmitGemmInstance() + op_entry["op"] = op + op_entry["name"] = op.procedural_name() + op_entry["opdef"] = kernel_emitter.emit(op) + op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, no_beta_scaling=True) + op_entry["opdef_bias_relu"] = kernel_emitter.emit( + op_bias_relu, no_beta_scaling=True + ) + op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu) + op_entry["src"] = profiler_emitter.emit( + op.procedural_name(), + op_entry["opdef"], + DataTypeTag[element_a], + DataTypeTag[element_b], + DataTypeTag[element_c], + op.leading_dim(), + ) + op_entry["runtime"] = 9999999 + ret.append(op_entry) + return ret + + +def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile_descriptions): + """Common kernel generator to be used by archtecture specific generators.""" + ops = [] + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + for math_inst in math_instructions: + tile_descriptions = get_tile_descriptions(math_inst) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints) + + ops.extend(out) + + return ops + + +def generate_sm75_tensor_op_1688(out_dtype): + """Generate GEMM kernels for Turing.""" + assert out_dtype in ["float32", "float16"] + math_instructions = { + "float32": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + "float16": [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ], + }[out_dtype] Review comment: IMHO, we should not select them simply based on the performance. My considerations are: 1. Currently TVM doesn't really support configurable accumulation dtype, so if we want to align the TVM semantic, we should always use the FP16 one. 2. However, TVM doesn't know what accumulation dtype we are using here, so it may not be an issue for us to choose between FP16 and FP32. 3. If both dtypes won't hurt the inference accuracy, always using FP16 seems no doubt, as I believe its performance will never be worse than FP32. ########## File path: python/tvm/contrib/cutlass/build.py ########## @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Driver for partitioning and building a Relay module for CUTLASS offload.""" +import tvm +from tvm import runtime, relay +from tvm.relay.op.contrib.cutlass import partition_for_cutlass +from .gen_gemm import CutlassGemmProfiler + + +class GemmAnnotator(tvm.relay.ExprVisitor): + """Annotates partitioned functions with shape and dtype information.""" + + def __init__(self): + super().__init__() + self.signature = {} + + def visit_call(self, call): + op = call.op + if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: + self.signature["op_type"] = op.attrs["Composite"] + for i, arg in enumerate(op.params): + self.signature["arg%d_shape" % i] = arg.checked_type.shape + self.signature["arg%d_dtype" % i] = arg.checked_type.dtype + self.signature["ret_shape"] = op.ret_type.shape + self.signature["ret_dtype"] = op.ret_type.dtype + + +def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): Review comment: I am concerned about this API as it wraps too many things and hurts the flexibility. Specifically: 1. It runs partitioning, which means we cannot have the flow like `partition_for_cutlass -> partition_rest_for_trt`. 2. It builds the module, which means we cannot directly access/control the build API (e.g., pass context, tuning the rest ops, etc). Since we still require users to explicitly run `partition_for_xxx` APIs for each BYOC backend, I would suggest keeping them separate: ```python mod = partition_for_cutlass(mod) # users may do something else for the unpartitioned parts here. mod = tune_for_cutlass(mod) # we could improve this later with tvm.transform.PassContext(opt_level=3): # Maybe we still need to wrap relay.build to perform exporting (compile) and load for now... lib = relay.build(mod, target="cuda", params=params) ``` ########## File path: python/tvm/contrib/cutlass/gemm_operation.py ########## @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS GEMM kernels.""" +from .library import * + + +class GemmOperation: + """Describes various attributes for instantiating GEMM kernels.""" + + def __init__( + self, + arch, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + ): + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + def accumulator_type(self): + return self.tile_description.math_instruction.element_accumulator + + def short_math_name(self): + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + inst_shape = "" + intermediate_type = "" + + if ( + self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp + or self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp + ): + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if ( + self.tile_description.math_instruction.element_a != self.A.element + and self.tile_description.math_instruction.element_a + != self.tile_description.math_instruction.element_accumulator + ): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % ( + self.short_math_name(), + inst_shape, + intermediate_type, + "gemm", + ) + + def extended_name(self): + """ Append data types if they differ from compute type. """ + if ( + self.C.element != self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${element_c}_${core_name}_${element_a}" + elif ( + self.C.element == self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = substitute_template( + extended_name, + { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + def layout_name(self): + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, + and layout. + """ + threadblock = self.tile_description.procedural_name() + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + return substitute_template( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment, + }, + ) + + def leading_dim(self): + """ lda, ldb, ldc, according to the leading dimension. """ + if self.A.layout == LayoutType.RowMajor: + lda = "K" + elif self.A.layout == LayoutType.ColumnMajor: + lda = "M" + else: + ValueError("The layout of A is not implemented.") + + if self.B.layout == LayoutType.RowMajor: + ldb = "N" + elif self.B.layout == LayoutType.ColumnMajor: + ldb = "K" + else: + ValueError("The layout of B is not implemented.") + + if self.C.layout == LayoutType.RowMajor: + ldc = "N" + elif self.C.layout == LayoutType.ColumnMajor: + ldc = "M" + else: + ValueError("The layout of B is not implemented.") + + return substitute_template( + "int lda = ${lda_val};\n\tint ldb = ${ldb_val};\n\tint ldc = ${ldc_val};\n", + { + "lda_val": lda, + "ldb_val": ldb, + "ldc_val": ldc, + }, + ) + + +class EmitGemmInstance: + """ Responsible for emitting a CUTLASS template definition.""" + + def __init__(self): + self.epilogue_default = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + self.epilogue_no_beta_scaling = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >""" + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::Gemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue}, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + def emit(self, operation, no_beta_scaling=False): + """Instantiate a GEMM kernel from given `operation`.""" + warp_shape = [ + operation.tile_description.threadblock_shape[idx] + // operation.tile_description.warp_count[idx] + for idx in range(3) + ] + epilogue_vector_length = int( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + / DataTypeSize[operation.C.element] + ) Review comment: nit: You might be able to use `//` to get the integer directly. ########## File path: python/tvm/contrib/cutlass/gen_gemm.py ########## @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Kernel generator and profiler for CUTLASS.""" +import os +import re +import tempfile +import subprocess +from .gemm_operation import GemmOperation, EmitGemmInstance +from .gemm_profiler import GemmProfilerEmitter +from .library import ( + EpilogueFunctor, + SwizzlingFunctor, + TensorDescription, + DataTypeTag, + LayoutType, + MathInstruction, + DataType, + OpcodeClass, + MathOperation, + TileDescription, +) + + +def create_gemm_operator( + layouts, + tile_descriptions, + data_type, + alignment_constraints, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, +): + """Exhaustively instantiate all kernels from a given configuration.""" + ret = [] + kernel_emitter = EmitGemmInstance() + profiler_emitter = GemmProfilerEmitter() + + element_a, element_b, element_c, element_epilogue = data_type + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + op_entry = {} + op = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + op_bias = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationBias, + swizzling_functor, + ) + op_bias_relu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationRelu, + swizzling_functor, + ) + op_bias_gelu = GemmOperation( + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + EpilogueFunctor.LinearCombinationGelu, + swizzling_functor, + ) + + kernel_emitter = EmitGemmInstance() + op_entry["op"] = op + op_entry["name"] = op.procedural_name() + op_entry["opdef"] = kernel_emitter.emit(op) + op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, no_beta_scaling=True) + op_entry["opdef_bias_relu"] = kernel_emitter.emit( + op_bias_relu, no_beta_scaling=True + ) + op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu) + op_entry["src"] = profiler_emitter.emit( + op.procedural_name(), + op_entry["opdef"], + DataTypeTag[element_a], + DataTypeTag[element_b], + DataTypeTag[element_c], + op.leading_dim(), + ) + op_entry["runtime"] = 9999999 + ret.append(op_entry) + return ret + + +def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile_descriptions): + """Common kernel generator to be used by archtecture specific generators.""" + ops = [] + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + for math_inst in math_instructions: + tile_descriptions = get_tile_descriptions(math_inst) + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints) Review comment: As the comment indicates, the logic you removed is emitting the kernel which output dtype is same as the first input. I believe this is to generate the kernel that takes two FP16 tensors and outputs one FP16 tensors, but uses FP32 for accumulation. Specifically, with two math instructions: Inst1: (FP16, FP16, FP32) Inst2: (FP16, FP16, FP16) We will only have two kernels without the removed logic, 1. (from Inst1): (FP16, FP16) -> (acc: FP32) -> FP32 2. (from Inst2): (FP16, FP16) -> (acc: FP16) -> FP16 The logic generates the 3rd kernel from the first instruction: 3. (from Inst1): (FP16, FP16) -> (acc: FP32) -> FP16 However, we need the if-guard to prevent the duplicated kernel from the second instruction: 4. (from Inst2): (FP16, FP16) -> (acc: FP16) -> FP16 ... same as 2 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
