yzh119 commented on issue #15342:
URL: https://github.com/apache/tvm/issues/15342#issuecomment-1641660470

   Hi @MrJungle1 , TE is being deprecated, there is a 
[create_prim_func](https://tvm.apache.org/docs/reference/api/python/te.html?highlight=create_prim_func#tvm.te.create_prim_func)
 function that turning a TE to TIR.
   
   Ansor is no longer the preferred auto-scheduler because it lacks support for 
Tensor Cores. Metaschedule should work, I tried using metaschedule to tune the 
operator, the script is attached below:
   
   ```python
   import tvm
   import tvm.te as te
   import tvm.dlight as dl
   from tvm import meta_schedule as ms
   
   a, b, c, d = 1024, 8, 17, 17
   
   A = te.placeholder((a, b, c, d), name="A", dtype="float32")
   k1 = te.reduce_axis((0, a), name="k1")
   k2 = te.reduce_axis((0, c), name="k2")
   k3 = te.reduce_axis((0, d), name="k3")
   
   C = te.compute(
       (b,),
       lambda i: te.sum(A[k1, i, k2, k3], axis=[k1, k2, k3]),
       name="C",
   )
   
   f = te.create_prim_func([A, C])
   mod = tvm.IRModule.from_expr(f)
   
   target = tvm.target.Target("nvidia/geforce-rtx-3090", host="llvm")
   
   database = ms.tune_tir(
       mod=mod,
       target=target,
       max_trials_global=64,
       num_trials_per_iter=64,
       space="cuda",
       work_dir="./tune_tmp",
   )
   
   sch = ms.tir_integration.compile_tir(database, mod, target)
   f = tvm.build(sch.mod["main"], target=target)
   print(f.imported_modules[0].get_source())
   ```
   
   #15327  indeed improves performance, before this PR, searching for a TVM 
kernel tuned by metaschedule has latency `38.7868 us` on my machine, and after 
#15327 merging, the TVM kernel tuned by metaschedule has latency `35.7256 us`, 
as a reference, the running time of PyTorch kernel on my machine is `23.62 us`, 
the script to profile PyTorch kernel is attached below for your reference:
   
   ```python
   import torch
   import tvm
   import torch
   from torch.profiler import profile, ProfilerActivity, schedule
   from typing import List, Callable, Any, Tuple, Union
   
   def profile_pytorch_ms(f: Callable[[], None]) -> float:
       r"""
       Use Triton's profiler that flushes L2 cache.
       """
       n_wait = 1
       n_warmup = 10
       n_repeat = 100
       """The following code copied from Triton profiler."""
       cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
       start_event = [
           torch.cuda.Event(enable_timing=True) for i in range(n_repeat)
       ]
       end_event = [
           torch.cuda.Event(enable_timing=True) for i in range(n_repeat)
       ]
       # Warm-up
       for _ in range(n_warmup):
           f()
       # Benchmark
       for i in range(n_repeat):
           # we clear the L2 cache before each run
           cache.zero_()
           # record time of `fn`
           start_event[i].record()
           f()
           end_event[i].record()
       # Record clocks
       torch.cuda.synchronize()
       times = torch.tensor(
           [s.elapsed_time(e) for s, e in zip(start_event, end_event)])
       dur = torch.mean(times).item()
       return dur
   
   x = torch.randn(1024, 8, 17, 17).float().to(0)
   
   print("{} ms".format(profile_pytorch_ms(lambda: torch.sum(x, (0, 2, 3)))))
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to