This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 95cb0de27a [VULKAN] Fix CLZ support for Vulkan (#16858)
95cb0de27a is described below
commit 95cb0de27a8bcfe0586f38d8b0d2da955cf01432
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Apr 10 20:21:20 2024 +0800
[VULKAN] Fix CLZ support for Vulkan (#16858)
CLZ (counting leading zeros) is used for improving ceil_log2 performance
on vulkan. however, the current implantation is incorrect during dtype
converting. This PR contains:
1. Simplify clz for index calculation (happens in vulkan sort)
2. Fix clz for data type conversion
---
python/tvm/target/detect_target.py | 3 ++-
src/arith/rewrite_simplify.cc | 11 +++++++++++
src/tir/ir/data_type_rewriter.cc | 11 +++++++++++
tests/python/arith/test_arith_rewrite_simplify.py | 20 ++++++++++++++++++--
.../test_tir_transform_force_narrow_index_to_i32.py | 19 +++++++++++++++++++
5 files changed, 61 insertions(+), 3 deletions(-)
diff --git a/python/tvm/target/detect_target.py
b/python/tvm/target/detect_target.py
index aada611642..a2fe5e1f8b 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -67,8 +67,9 @@ def _detect_vulkan(dev: Device) -> Target:
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"thread_warp_size": dev.warp_size,
"supports_float16": f_get_target_property(dev, "supports_float16"),
- "supports_int16": f_get_target_property(dev, "supports_int16"),
"supports_int8": f_get_target_property(dev, "supports_int8"),
+ "supports_int16": f_get_target_property(dev, "supports_int16"),
+ "supports_int64": f_get_target_property(dev, "supports_int64"),
"supports_16bit_buffer": f_get_target_property(dev,
"supports_16bit_buffer"),
}
)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e7e58a80fc..a4602bb8b9 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2250,6 +2250,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
CallNode* op) {
}
}
}
+ } else if (op->op.same_as(Op::Get("tir.clz"))) {
+ if (const auto* arg_int = op->args[0].as<IntImmNode>()) {
+ int bits = arg_int->dtype.bits();
+ if (arg_int->value == 0) return make_const(op->dtype, bits);
+ for (int i = bits - 1; i >= 0; --i) {
+ if ((int64_t(1) << i) & arg_int->value) {
+ return IntImm(op->dtype, bits - i - 1);
+ }
+ }
+ LOG(FATAL) << "Should not reach here";
+ }
}
if (op->op.same_as(tir::builtin::likely())) {
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 3461597b8e..a613b8d4bb 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -215,6 +215,7 @@ TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode,
operator>=);
#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) {
+ Call before = GetRef<Call>(op);
PrimExpr e = StmtExprMutator::VisitExpr_(op);
op = e.as<CallNode>();
static const Op& builtin_pow_ = Op::Get("tir.pow");
@@ -234,6 +235,16 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op)
{
return pow(op->args[0], op->args[1]);
} else if (op->op.same_as(builtin::if_then_else())) {
return if_then_else(op->args[0], op->args[1], op->args[2]);
+ } else if (op->op.same_as(Op::Get("tir.clz"))) {
+ DataType before_dtype = before->args[0]->dtype;
+ DataType after_dtype = op->args[0]->dtype;
+ CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 ||
before_dtype.bits() == 64))
+ << "clz only supports 32 or 64 bit integer types, but get type before
legalizing: "
+ << before_dtype;
+ CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 ||
after_dtype.bits() == 64))
+ << "clz only supports 32 or 64 bit integer types, but get type after
legalizing: "
+ << after_dtype;
+ return e - after_dtype.bits() + before_dtype.bits();
}
return e;
}
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index 9cc44aa6a2..6180167555 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -20,9 +20,12 @@ import inspect
import pytest
import tvm
+import tvm.testing
from tvm import te, tir
-
-from tvm.tir import truncdiv as tdiv, truncmod as tmod, floordiv as fld,
floormod as flm
+from tvm.tir import floordiv as fld
+from tvm.tir import floormod as flm
+from tvm.tir import truncdiv as tdiv
+from tvm.tir import truncmod as tmod
class TestCase:
@@ -1150,5 +1153,18 @@ class TestIfThenElse(BaseCompare):
)
+class TestCLZ(BaseCompare):
+ test_case = tvm.testing.parameter(
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz",
tvm.tir.IntImm("int64", 0)), 64),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz",
tvm.tir.IntImm("int64", 1)), 63),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz",
tvm.tir.IntImm("int64", 2)), 62),
+ TestCase(tvm.tir.call_intrin("int32", "tir.clz",
tvm.tir.IntImm("int64", 128)), 56),
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
index c1b81853de..0be0e5fbb5 100644
--- a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
+++ b/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py
@@ -259,5 +259,24 @@ def test_pod_params_and_select():
tvm.ir.assert_structural_equal(Expected, after)
+def test_clz():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(B: T.Buffer((T.int64(4),), "int32")):
+ for i in T.serial(T.int64(4)):
+ B[i] = T.clz(i)
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def main(B: T.Buffer((4,), "int32")):
+ for i in range(4):
+ B[i] = T.clz(i) - 32 + 64
+
+ after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
+ tvm.ir.assert_structural_equal(Expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()