wllvcxz opened a new issue, #13204: URL: https://github.com/apache/tvm/issues/13204
The `DefaultCUDATensorCore` schedule currently support matmul with dtype `fp16fp16fp16` and `s8s8s32`: https://github.com/apache/tvm/blob/main/src/meta_schedule/schedule_rule/schedule_rule.cc#L122-L171 When I use metaschedule to tuning a matmul with DefaultCUDATensorCore, the results seems incorrect, however when I change the input/output dtype of matmul which DefaultCUDATensorCore do not support, the result is correct. ### Expected behavior metaschedule produce correct result when using tensorcore ### Actual behavior the result is wrong ### Environment Operating System: Ubnutu-20.04 TVM version: main branch GPU: nvidia-a100 ### Steps to reproduce ``` import tempfile import numpy as np import tvm from tvm.contrib import nvcc from tvm import meta_schedule as ms from tvm.meta_schedule import tune_tir from tvm.target import Target from tvm.meta_schedule.testing import te_workload def test_tune_tir_matmul_cuda_tensor_core(in_dtype, out_dtype, n, m, k): mod = tvm.te.create_prim_func( te_workload.matmul(n, m, k, in_dtype=in_dtype, out_dtype=out_dtype) ) target = Target("nvidia/nvidia-a100") with tempfile.TemporaryDirectory() as work_dir: database = tune_tir( mod=mod, target=target, work_dir=work_dir, num_trials_per_iter=32, max_trials_global=32, strategy="replay-trace", ) sch = ms.tir_integration.compile_tir(database, mod, target) if sch is None: raise RuntimeError("No valid schedule found!") ctx = tvm.cuda() if nvcc.have_tensorcore(ctx.compute_version): with tvm.transform.PassContext(): func = tvm.build(sch.mod["main"], [], "cuda") # print(func.imported_modules[0].get_source()) # print(sch.mod.script()) print(sch.trace) a_np = np.random.uniform(-10, 10, size=(n, k)).astype(in_dtype) b_np = np.random.uniform(-10, 10, size=(k, m)).astype(in_dtype) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros((n, m), dtype=out_dtype), ctx) func(a, b, c) np.testing.assert_allclose( c.asnumpy(), np.matmul(a_np, b_np, dtype=out_dtype), rtol=1e-6, atol=1e-6, ) print("passed!") if __name__ == "__main__": test_tune_tir_matmul_cuda_tensor_core(in_dtype="float32", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", out_dtype="float16", n=128, m=128, k=128) # tensor core, incorrect test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", out_dtype="int32", n=128, m=128, k=128) # tensor core, incorrect test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", out_dtype="float32", n=128, m=128, k=128) # cuda core, correct ``` cc @vinx13 @junrushao -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
