This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed5c46aebc [ROCm][WebGPU] Intrin Dispatch: `tanh`, `erf`, `log` 
(#16441)
ed5c46aebc is described below

commit ed5c46aebcfc89a6d13c33272572f2be5d9575d7
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jan 21 06:43:27 2024 -0800

    [ROCm][WebGPU] Intrin Dispatch: `tanh`, `erf`, `log` (#16441)
    
    This commit fixes a few minor intrinsic dispatch issues in the ROCm
    and WebGPU backend that affects LLM compilation, including Mixtral,
    RedPajama (GPT-NeoX) and GPT-BigCode.
---
 python/tvm/autotvm/tuner/droplet_tuner.py |  6 +++++-
 python/tvm/topi/math.py                   | 11 +++++-----
 src/target/llvm/intrin_rule_rocm.cc       | 34 +++++++++++++++++--------------
 3 files changed, 29 insertions(+), 22 deletions(-)

diff --git a/python/tvm/autotvm/tuner/droplet_tuner.py 
b/python/tvm/autotvm/tuner/droplet_tuner.py
index d58bfa4989..d115353d77 100644
--- a/python/tvm/autotvm/tuner/droplet_tuner.py
+++ b/python/tvm/autotvm/tuner/droplet_tuner.py
@@ -18,8 +18,9 @@
 
 import logging
 import os
+
 import numpy as np
-from scipy import stats
+
 from .tuner import Tuner
 
 LOGGER = logging.getLogger("autotvm")
@@ -85,6 +86,9 @@ class DropletTuner(Tuner):
     def p_value(self, elem_1, elem_2):
         if len(elem_1) <= 1 or len(elem_2) <= 1:
             return True
+
+        from scipy import stats  # pylint: disable=import-outside-toplevel
+
         return stats.ttest_ind(np.array(elem_1), np.array(elem_2)).pvalue <= 
self.pvalue
 
     def next_batch(self, batch_size):
diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index 8b66ca2cc9..63a1e48c2b 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -20,8 +20,7 @@ import tvm
 from tvm import te
 from tvm.tir import PrimExpr
 
-from . import tag
-from . import cpp
+from . import cpp, tag
 from .utils import get_const_tuple
 
 
@@ -855,17 +854,17 @@ def ceil_log2(x):
     if "float" in x.dtype:
         return tvm.tir.ceil(tvm.tir.log2(x))
 
-    if "vulkan" in tvm.target.Target.current().kind.name:
+    target = tvm.target.Target.current()
+
+    if "vulkan" in target.kind.name:
         clz = tvm.tir.clz(x)
         bits = int(x.dtype[-2:])
         res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - 
clz)
-
         if res.dtype != x.dtype:
             return cast(res, x.dtype)
-
         return res
 
-    if "adreno" in tvm.target.Target.current().device_name:
+    if "adreno" in target.device_name or target.kind.name in ["metal", "rocm", 
"webgpu"]:
         return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float32"))), x.dtype)
 
     return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)
diff --git a/src/target/llvm/intrin_rule_rocm.cc 
b/src/target/llvm/intrin_rule_rocm.cc
index 0fbfade335..c80d8388da 100644
--- a/src/target/llvm/intrin_rule_rocm.cc
+++ b/src/target/llvm/intrin_rule_rocm.cc
@@ -31,12 +31,15 @@
 
 #include <sstream>
 
+#include "../intrin_rule.h"
 #include "intrin_rule_llvm.h"
 
 namespace tvm {
 namespace codegen {
 
 inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) {
+  // NOTE: OCML dispatch fails to work properly with vectorization, and thus 
should be used with
+  // extreme caution.
   using namespace tir;
   const CallNode* call = e.as<CallNode>();
   ICHECK(call != nullptr);
@@ -150,13 +153,6 @@ TVM_REGISTER_OP("tir.exp2")
     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
                                DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 
1>);
 
-// TVM_REGISTER_OP("tir.exp10")
-//     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
-//                                
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);
-
-// TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
-//                                                      
DispatchPureExternOCML);
-
 TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
     "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 
3>);
 
@@ -178,27 +174,35 @@ TVM_REGISTER_OP("tir.sqrt")
 TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
     "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
 
-// TVM_REGISTER_OP("tir.tanh")
-//     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", 
DispatchPureExternOCML);
+TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
+    "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+
+TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
+    "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
+
+TVM_REGISTER_OP("tir.tanh")
+    .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
+                               
::tvm::codegen::intrin::DispatchNumericalStableTanh);
+
+TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
+                                                     
::tvm::codegen::intrin::DispatchFastErf);
 
 // TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
 //                                                      
DispatchPureExternOCML);
 
-TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
-    "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
-
 // TVM_REGISTER_OP("tir.cosh")
 //     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", 
DispatchPureExternOCML);
 
-TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
-    "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
-
 // TVM_REGISTER_OP("tir.sinh")
 //     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", 
DispatchPureExternOCML);
 
 // TVM_REGISTER_OP("tir.atan")
 //     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", 
DispatchPureExternOCML);
 
+// TVM_REGISTER_OP("tir.exp10")
+//     .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
+//                                
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);
+
 }  // namespace llvm
 }  // namespace codegen
 }  // namespace tvm

Reply via email to