wllvcxz opened a new issue, #13204:
URL: https://github.com/apache/tvm/issues/13204

   The `DefaultCUDATensorCore` schedule currently support matmul with dtype 
`fp16fp16fp16` and `s8s8s32`: 
https://github.com/apache/tvm/blob/main/src/meta_schedule/schedule_rule/schedule_rule.cc#L122-L171
   
   When I use metaschedule to tuning a matmul with DefaultCUDATensorCore, the 
results seems incorrect, however when I change the input/output dtype of matmul 
which DefaultCUDATensorCore do not support, the result is correct.
   
   ### Expected behavior
   
   metaschedule produce correct result when using tensorcore
   
   ### Actual behavior
   
   the result is wrong
   
   ### Environment
   
   Operating System: Ubnutu-20.04
   TVM version: main branch
   GPU: nvidia-a100
   
   ### Steps to reproduce
   
   ```
   import tempfile
   import numpy as np
   import tvm
   from tvm.contrib import nvcc
   from tvm import meta_schedule as ms
   from tvm.meta_schedule import tune_tir
   from tvm.target import Target
   from tvm.meta_schedule.testing import te_workload
   
   def test_tune_tir_matmul_cuda_tensor_core(in_dtype, out_dtype, n, m, k):
       mod = tvm.te.create_prim_func(
           te_workload.matmul(n, m, k, in_dtype=in_dtype, out_dtype=out_dtype)
       )
       target = Target("nvidia/nvidia-a100")
   
       with tempfile.TemporaryDirectory() as work_dir:
           database = tune_tir(
               mod=mod,
               target=target,
               work_dir=work_dir,
               num_trials_per_iter=32,
               max_trials_global=32,
               strategy="replay-trace",
           )
           sch = ms.tir_integration.compile_tir(database, mod, target)
           if sch is None:
               raise RuntimeError("No valid schedule found!")
           ctx = tvm.cuda()
           if nvcc.have_tensorcore(ctx.compute_version):
               with tvm.transform.PassContext():
                   func = tvm.build(sch.mod["main"], [], "cuda")
                   # print(func.imported_modules[0].get_source())
                   # print(sch.mod.script())
                   print(sch.trace)
               a_np = np.random.uniform(-10, 10, size=(n, k)).astype(in_dtype)
               b_np = np.random.uniform(-10, 10, size=(k, m)).astype(in_dtype)
               a = tvm.nd.array(a_np, ctx)
               b = tvm.nd.array(b_np, ctx)
               c = tvm.nd.array(np.zeros((n, m), dtype=out_dtype), ctx)
               func(a, b, c)
               np.testing.assert_allclose(
                   c.asnumpy(),
                   np.matmul(a_np, b_np, dtype=out_dtype),
                   rtol=1e-6,
                   atol=1e-6,
               )
               print("passed!")
   
   if __name__ == "__main__":
       test_tune_tir_matmul_cuda_tensor_core(in_dtype="float32", 
out_dtype="float32", n=128, m=128, k=128) # cuda core, correct
       test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", 
out_dtype="float32", n=128, m=128, k=128) # cuda core, correct
       test_tune_tir_matmul_cuda_tensor_core(in_dtype="float16", 
out_dtype="float16", n=128, m=128, k=128) # tensor core, incorrect
       test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", 
out_dtype="int32", n=128, m=128, k=128)  # tensor core, incorrect
       test_tune_tir_matmul_cuda_tensor_core(in_dtype="int8", 
out_dtype="float32", n=128, m=128, k=128)  # cuda core, correct
   
   ```
   
   cc @vinx13 @junrushao 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to