yzh119 commented on issue #15342: URL: https://github.com/apache/tvm/issues/15342#issuecomment-1641660470
Hi @MrJungle1 , TE is being deprecated, there is a [create_prim_func](https://tvm.apache.org/docs/reference/api/python/te.html?highlight=create_prim_func#tvm.te.create_prim_func) function that turning a TE to TIR. Ansor is no longer the preferred auto-scheduler because it lacks support for Tensor Cores. Metaschedule should work, I tried using metaschedule to tune the operator, the script is attached below: ```python import tvm import tvm.te as te import tvm.dlight as dl from tvm import meta_schedule as ms a, b, c, d = 1024, 8, 17, 17 A = te.placeholder((a, b, c, d), name="A", dtype="float32") k1 = te.reduce_axis((0, a), name="k1") k2 = te.reduce_axis((0, c), name="k2") k3 = te.reduce_axis((0, d), name="k3") C = te.compute( (b,), lambda i: te.sum(A[k1, i, k2, k3], axis=[k1, k2, k3]), name="C", ) f = te.create_prim_func([A, C]) mod = tvm.IRModule.from_expr(f) target = tvm.target.Target("nvidia/geforce-rtx-3090", host="llvm") database = ms.tune_tir( mod=mod, target=target, max_trials_global=64, num_trials_per_iter=64, space="cuda", work_dir="./tune_tmp", ) sch = ms.tir_integration.compile_tir(database, mod, target) f = tvm.build(sch.mod["main"], target=target) print(f.imported_modules[0].get_source()) ``` #15327 indeed improves performance, before this PR, searching for a TVM kernel tuned by metaschedule has latency `38.7868 us` on my machine, and after #15327 merging, the TVM kernel tuned by metaschedule has latency `35.7256 us`, as a reference, the running time of PyTorch kernel on my machine is `23.62 us`, the script to profile PyTorch kernel is attached below for your reference: ```python import torch import tvm import torch from torch.profiler import profile, ProfilerActivity, schedule from typing import List, Callable, Any, Tuple, Union def profile_pytorch_ms(f: Callable[[], None]) -> float: r""" Use Triton's profiler that flushes L2 cache. """ n_wait = 1 n_warmup = 10 n_repeat = 100 """The following code copied from Triton profiler.""" cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") start_event = [ torch.cuda.Event(enable_timing=True) for i in range(n_repeat) ] end_event = [ torch.cuda.Event(enable_timing=True) for i in range(n_repeat) ] # Warm-up for _ in range(n_warmup): f() # Benchmark for i in range(n_repeat): # we clear the L2 cache before each run cache.zero_() # record time of `fn` start_event[i].record() f() end_event[i].record() # Record clocks torch.cuda.synchronize() times = torch.tensor( [s.elapsed_time(e) for s, e in zip(start_event, end_event)]) dur = torch.mean(times).item() return dur x = torch.randn(1024, 8, 17, 17).float().to(0) print("{} ms".format(profile_pytorch_ms(lambda: torch.sum(x, (0, 2, 3))))) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
