vinx13 commented on code in PR #14274: URL: https://github.com/apache/tvm/pull/14274#discussion_r1145391451
########## tests/python/relax/test_codegen_tir_cutlass.py: ########## @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +import tempfile + +from tvm import relax, runtime +import tvm +import tvm.testing +from tvm import relax +import scipy +from scipy.special import erf +import numpy as np +from tvm.relax.vm_build import build as relax_build +from tvm.relax.transform import LegalizeOps +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder import ir as I +from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import IRBuilder + +from tvm.relax.backend_tir import get_tir_pattern +from tvm.relax.backend_tir.contrib.cutlass import cutlass_fcodegen + +A_TYPE = "float16" +B_TYPE = "float16" +C_TYPE = "float16" + +target = "cuda" + + +def f_run(rt_mod: runtime.Module, device: runtime.ndarray.Device, *input): + vm = relax.vm.VirtualMachine(rt_mod=rt_mod, device=device) + return vm["main"](*input) + + +def build(mod): + mod = relax.transform.LegalizeOps()(mod) + mod = relax.transform.AnnotateTIROpPattern()(mod) + mod = relax.transform.FuseOps()(mod) + mod = relax.transform.FuseTIR()(mod) + mod = relax.transform.SplitCallTIRByPattern(get_tir_pattern(), cutlass_fcodegen())(mod) + mod = relax.transform.DeadCodeElimination()(mod) + print(mod.script()) + f = tempfile.NamedTemporaryFile(suffix=".so", delete=True) + executable = relax_build(mod, target) + executable.mod.export_library(f.name, cc="nvcc") + rt_mod = runtime.load_module(f.name) + f.close() + return rt_mod + + +def constructGEMM(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_dense(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructGEMM_bias(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + +def constructGEMM_bias2(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_dense_bias(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + [email protected]_cutlass +def test_cutlass_dense_bias2(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias2(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructGEMM_bias_relu(M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.nn.relu(D)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_dense_bias_relu(): + m, n, k = 128, 64, 256 + executable = build(constructGEMM_bias_relu(m, n, k)) + dev = tvm.cuda() + A = np.random.randn(m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), np.maximum(A @ B + bias, 0), rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_batch_dense(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM2(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((batch, K, N), B_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + R.output(C) + (C,) = df.output_vars + R.func_ret_value(C) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_batch_dense2(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM2(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(b, k, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + result = f_run(executable, dev, A_tvm, B_tvm) + np.testing.assert_allclose(result.numpy(), A @ B, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((1, N), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_batch_dense_bias(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(1, n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + R.output(D) + (D,) = df.output_vars + R.func_ret_value(D) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_batch_dense_bias2(): + b, m, n, k = 2, 128, 256, 64 + executable = build(constructBatchGEMM_bias2(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + np.testing.assert_allclose(result.numpy(), A @ B + bias, rtol=5e-2, atol=5e-2) + + +def constructBatchGEMM_bias2_gelu(batch, M, N, K): + with IRBuilder() as ib: # pylint: disable=invalid-name + with I.ir_module() as frame: + with R.function(): + R.func_name("main") + A = R.arg( + "A", relax.TensorStructInfo((batch, M, K), A_TYPE) + ) # pylint: disable=invalid-name + B = R.arg( + "B", relax.TensorStructInfo((K, N), B_TYPE) + ) # pylint: disable=invalid-name + bias = R.arg( + "bias", relax.TensorStructInfo((N,), A_TYPE) + ) # pylint: disable=invalid-name + with R.dataflow() as df: + C = R.emit(R.matmul(A, B, out_dtype=C_TYPE)) + D = R.emit(R.add(C, bias)) + E = R.emit(R.nn.gelu(D)) + R.output(E) + (E,) = df.output_vars + R.func_ret_value(E) + relax_mod = ib.get() + return relax_mod + + [email protected]_cutlass +def test_cutlass_batch_dense_bias2_gelu(): + b, m, n, k = 2, 128, 64, 256 + executable = build(constructBatchGEMM_bias2_gelu(b, m, n, k)) + dev = tvm.cuda() + A = np.random.randn(b, m, k).astype("float16") + B = np.random.randn(k, n).astype("float16") + bias = np.random.randn(n).astype("float16") + A_tvm = tvm.nd.array(A, dev) + B_tvm = tvm.nd.array(B, dev) + bias_tvm = tvm.nd.array(bias, dev) + result = f_run(executable, dev, A_tvm, B_tvm, bias_tvm) + C = A @ B + bias + O = 0.5 * C * (1 + erf(C / np.sqrt(2))) + np.testing.assert_allclose(result.numpy(), O, rtol=5e-2, atol=5e-2) + + +# def constructBatchGEMM_bias2_mul(batch, M, N, K): Review Comment: remove this ########## python/tvm/_ffi/libinfo.py: ########## @@ -205,11 +205,18 @@ def find_include_path(name=None, search_path=None, optional=False): tvm_include_path = [os.path.join(p, "include") for p in header_path] dlpack_include_path = [os.path.join(p, "dlpack/include") for p in header_path] dmlc_include_path = [os.path.join(p, "dmlc-core/include") for p in header_path] - + if use_nvcc: + cutlass_include_path = [os.path.join(p, "cutlass/include") for p in header_path] Review Comment: is it possible to make it configurable and move this to some cutlass specific places? ########## python/tvm/contrib/cc.py: ########## @@ -231,6 +234,21 @@ def _linux_compile(output, objects, options, compile_cmd, compile_shared=False): else: if compile_shared or output.endswith(".so") or output.endswith(".dylib"): cmd += ["--shared"] + + compute_version = "".join( + get_target_compute_version(Target.current(allow_none=True)).split(".") + ) + cmd += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] + cmd += ["-O3"] + cmd += ["-std=c++17"] + cmd += ["-Xcompiler=-fPIC"] + cmd += ["-Xcompiler=-fno-strict-aliasing"] + cuda_ver = get_cuda_version() + if cuda_ver >= (11, 2): + cmd += ["-t " + str(multiprocessing.cpu_count())] + cmd += ["-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"] Review Comment: these should be passed through `options` ########## python/tvm/contrib/cutlass/gemm_operation.py: ########## @@ -369,7 +369,8 @@ def instantiate_gemm_template(attrs): { "bias_decl": "void* ptr_bias = (void*)(${bias_arg}->data);\n", "ptr_c": "ptr_bias", - "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"], + # "c_stride": "${bias_arg}->ndim == 1 ? 0 : " + attrs["ldc"], Review Comment: is it needed? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
