This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ed5c46aebc [ROCm][WebGPU] Intrin Dispatch: `tanh`, `erf`, `log`
(#16441)
ed5c46aebc is described below
commit ed5c46aebcfc89a6d13c33272572f2be5d9575d7
Author: Junru Shao <[email protected]>
AuthorDate: Sun Jan 21 06:43:27 2024 -0800
[ROCm][WebGPU] Intrin Dispatch: `tanh`, `erf`, `log` (#16441)
This commit fixes a few minor intrinsic dispatch issues in the ROCm
and WebGPU backend that affects LLM compilation, including Mixtral,
RedPajama (GPT-NeoX) and GPT-BigCode.
---
python/tvm/autotvm/tuner/droplet_tuner.py | 6 +++++-
python/tvm/topi/math.py | 11 +++++-----
src/target/llvm/intrin_rule_rocm.cc | 34 +++++++++++++++++--------------
3 files changed, 29 insertions(+), 22 deletions(-)
diff --git a/python/tvm/autotvm/tuner/droplet_tuner.py
b/python/tvm/autotvm/tuner/droplet_tuner.py
index d58bfa4989..d115353d77 100644
--- a/python/tvm/autotvm/tuner/droplet_tuner.py
+++ b/python/tvm/autotvm/tuner/droplet_tuner.py
@@ -18,8 +18,9 @@
import logging
import os
+
import numpy as np
-from scipy import stats
+
from .tuner import Tuner
LOGGER = logging.getLogger("autotvm")
@@ -85,6 +86,9 @@ class DropletTuner(Tuner):
def p_value(self, elem_1, elem_2):
if len(elem_1) <= 1 or len(elem_2) <= 1:
return True
+
+ from scipy import stats # pylint: disable=import-outside-toplevel
+
return stats.ttest_ind(np.array(elem_1), np.array(elem_2)).pvalue <=
self.pvalue
def next_batch(self, batch_size):
diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py
index 8b66ca2cc9..63a1e48c2b 100644
--- a/python/tvm/topi/math.py
+++ b/python/tvm/topi/math.py
@@ -20,8 +20,7 @@ import tvm
from tvm import te
from tvm.tir import PrimExpr
-from . import tag
-from . import cpp
+from . import cpp, tag
from .utils import get_const_tuple
@@ -855,17 +854,17 @@ def ceil_log2(x):
if "float" in x.dtype:
return tvm.tir.ceil(tvm.tir.log2(x))
- if "vulkan" in tvm.target.Target.current().kind.name:
+ target = tvm.target.Target.current()
+
+ if "vulkan" in target.kind.name:
clz = tvm.tir.clz(x)
bits = int(x.dtype[-2:])
res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits -
clz)
-
if res.dtype != x.dtype:
return cast(res, x.dtype)
-
return res
- if "adreno" in tvm.target.Target.current().device_name:
+ if "adreno" in target.device_name or target.kind.name in ["metal", "rocm",
"webgpu"]:
return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float32"))), x.dtype)
return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype)
diff --git a/src/target/llvm/intrin_rule_rocm.cc
b/src/target/llvm/intrin_rule_rocm.cc
index 0fbfade335..c80d8388da 100644
--- a/src/target/llvm/intrin_rule_rocm.cc
+++ b/src/target/llvm/intrin_rule_rocm.cc
@@ -31,12 +31,15 @@
#include <sstream>
+#include "../intrin_rule.h"
#include "intrin_rule_llvm.h"
namespace tvm {
namespace codegen {
inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) {
+ // NOTE: OCML dispatch fails to work properly with vectorization, and thus
should be used with
+ // extreme caution.
using namespace tir;
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
@@ -150,13 +153,6 @@ TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2,
1>);
-// TVM_REGISTER_OP("tir.exp10")
-// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
-//
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);
-
-// TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
-//
DispatchPureExternOCML);
-
TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd,
3>);
@@ -178,27 +174,35 @@ TVM_REGISTER_OP("tir.sqrt")
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>(
"rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
-// TVM_REGISTER_OP("tir.tanh")
-// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
+TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
+ "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
+
+TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
+ "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
+
+TVM_REGISTER_OP("tir.tanh")
+ .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
+
::tvm::codegen::intrin::DispatchNumericalStableTanh);
+
+TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
+
::tvm::codegen::intrin::DispatchFastErf);
// TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
//
DispatchPureExternOCML);
-TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>(
- "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
-
// TVM_REGISTER_OP("tir.cosh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
-TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>(
- "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
-
// TVM_REGISTER_OP("tir.sinh")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
// TVM_REGISTER_OP("tir.atan")
// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
DispatchPureExternOCML);
+// TVM_REGISTER_OP("tir.exp10")
+// .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
+//
DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>);
+
} // namespace llvm
} // namespace codegen
} // namespace tvm