This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4247433e33 [WebGPU] Add `tir.dp4a` (#17124)
4247433e33 is described below
commit 4247433e33dfeff9bc82521ed4c7e85605d94893
Author: Jiawei Shao <[email protected]>
AuthorDate: Mon Jul 1 20:36:14 2024 +0800
[WebGPU] Add `tir.dp4a` (#17124)
* [WebGPU] Add `tir.dp4a`
This patch adds `tir.dp4a` as a new TIR built-in operator as a
preparation of supporting int8 computation with `dot4I8Packed`
in WebGPU backend.
* Fix format issues
* Fix format issue
* Replace `accumulation` with `accumulator`
---
include/tvm/tir/builtin.h | 5 +++++
python/tvm/script/ir_builder/tir/ir.py | 2 ++
python/tvm/tir/__init__.py | 1 +
python/tvm/tir/op.py | 25 +++++++++++++++++++++++++
src/tir/op/builtin.cc | 5 +++++
tests/python/tir-base/test_tir_op_types.py | 8 ++++++++
6 files changed, 46 insertions(+)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 120c1b71be..ea2d07903e 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -816,6 +816,11 @@ TVM_DLL const Op& vectorlow();
*/
TVM_DLL const Op& vectorcombine();
+/*!
+ * \brief Dot product of two int8x4 vectors and add an optional accumulator
+ */
+TVM_DLL const Op& dp4a();
+
/*!
* \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
*/
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index caefc6a6bc..bdbd6e2cda 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1932,6 +1932,7 @@ vectorlow = _dtype_forward(_tir_op.vectorlow)
vectorhigh = _dtype_forward(_tir_op.vectorhigh)
vectorcombine = _dtype_forward(_tir_op.vectorcombine)
get_active_lane_mask = _dtype_forward(_tir_op.get_active_lane_mask)
+dp4a = _dtype_forward(_tir_op.dp4a)
broadcast = Broadcast
@@ -2191,6 +2192,7 @@ __all__ = [
"vectorlow",
"vectorhigh",
"vectorcombine",
+ "dp4a",
"assume",
"undef",
"tvm_call_packed",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 5360ab2b96..bcfbe6575d 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -95,6 +95,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis,
shift_left, shift_r
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale, get_active_lane_mask, get_vscale_expr
+from .op import dp4a
from .generic import add, subtract, multiply
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule,
ScheduleError
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 81d6604259..0bc299e403 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1813,6 +1813,31 @@ def vectorcombine(dtype, vec1, vec2):
return call_intrin(dtype, "tir.vectorcombine", vec1, vec2)
+def dp4a(vec1, vec2, acc=0):
+ """Dot product of two int8x4 vectors and add an optional accumulator
+
+ Parameters
+ ----------
+ vec1 : int8x4
+ The input vector.
+
+ vec2 : int8x4
+ The input vector.
+
+ acc : int32
+ The accumulator.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ vec1 = convert(vec1)
+ vec2 = convert(vec2)
+ acc = convert(acc)
+ return call_intrin("int32", "tir.dp4a", vec1, vec2, acc)
+
+
def ret(val):
"""Create a tir return expression
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0404fd2823..0d4a213a23 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -355,6 +355,11 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));
+TIR_DEFINE_BUILTIN_FUNC(dp4a)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
+ .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+
Integer(ScriptDtypePrintLocation::kFirst));
+
TIR_DEFINE_BUILTIN_FUNC(atomic_add)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
diff --git a/tests/python/tir-base/test_tir_op_types.py
b/tests/python/tir-base/test_tir_op_types.py
index 7398ee781b..aefab62559 100644
--- a/tests/python/tir-base/test_tir_op_types.py
+++ b/tests/python/tir-base/test_tir_op_types.py
@@ -295,6 +295,14 @@ def test_tir_op_vectorhigh():
assert expr.op.name == "tir.vectorhigh"
+def test_tir_op_dp4a():
+ vec1 = tir.Var("vec1", dtype="int8x4")
+ vec2 = tir.Var("vec2", dtype="int8x4")
+ acc = tir.Var("acc", dtype="int32")
+ expr = tir.dp4a(vec1, vec2, acc)
+ assert expr.op.name == "tir.dp4a"
+
+
def test_tir_op_vectorcombine():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")