This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 02c4c55eaa [SVE] Add codegen support for `vscale_range()` function
attribute (#16962)
02c4c55eaa is described below
commit 02c4c55eaa2fe81e516bc4741345f8fb82fc0945
Author: Andrei Hutu <[email protected]>
AuthorDate: Wed May 8 09:39:25 2024 +0100
[SVE] Add codegen support for `vscale_range()` function attribute (#16962)
This commit adds support for the `vscale_range()` LLVM function attribute
to be generated for SVE and SME targets.
Some LLVM optimisation passes make use of the `vscale_range()` function
attribute when scalable vectors are present (e.g. BasicAA
llvm/llvm-project/pull/80445), so we include it alongside the "target_cpu" and
"target-features" attributes.
---
src/target/llvm/codegen_aarch64.cc | 13 ++++++++
src/target/llvm/codegen_llvm.h | 2 +-
.../python/codegen/test_target_codegen_aarch64.py | 38 ++++++++++++++++++++++
3 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/src/target/llvm/codegen_aarch64.cc
b/src/target/llvm/codegen_aarch64.cc
index 94ad34bbcf..785c45457e 100644
--- a/src/target/llvm/codegen_aarch64.cc
+++ b/src/target/llvm/codegen_aarch64.cc
@@ -27,6 +27,7 @@
#include <llvm/Target/TargetMachine.h>
#include <tvm/runtime/registry.h>
+#include "../../arith/scalable_expression.h"
#include "codegen_cpu.h"
#include "llvm_instance.h"
@@ -40,6 +41,7 @@ class CodeGenAArch64 final : public CodeGenCPU {
void VisitStmt_(const AttrStmtNode* op);
void AddFunction(const GlobalVar& gvar, const PrimFunc& f);
+ void SetTargetAttributes(llvm::Function* func);
bool func_has_pstate_sm = false;
bool func_has_pstate_za = false;
@@ -51,6 +53,17 @@ void CodeGenAArch64::AddFunction(const GlobalVar& gvar,
const PrimFunc& f) {
CodeGenCPU::AddFunction(gvar, f);
}
+void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) {
+#if TVM_LLVM_VERSION >= 130
+ // Add vscale_range() function attribute when appropriate.
+ if (llvm_target_->TargetHasCPUFeature("sve") ||
llvm_target_->TargetHasCPUFeature("sme")) {
+ func->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
+ *llvm_target_->GetContext(), 1,
tvm::arith::kAArch64VScaleValues.size()));
+ }
+#endif
+ CodeGenCPU::SetTargetAttributes(func);
+}
+
/*!
* \brief Visit and handle AArch64 specific pragmas. To be AArch64 specific,
* the expectation is that they are prepended with "pragma_aarch64".
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 0f7aa847ec..d46ab7320b 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -431,7 +431,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const
PrimExpr&)>,
*
* \param func The function to set attributes on.
*/
- void SetTargetAttributes(llvm::Function* func);
+ virtual void SetTargetAttributes(llvm::Function* func);
/*!
* \brief Emit LLVM IR for conversion functions __extendhfsf2 and
__truncsfhf2
* into the current llvm::Module.
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py
b/tests/python/codegen/test_target_codegen_aarch64.py
index 452638beda..9726f79d7a 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -537,6 +537,44 @@ def test_scalable_broadcast():
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
[email protected](
+ llvm_version_major() < 13,
+ reason="Function attribute vscale_range() is not supported in earlier
versions of LLVM",
+)
[email protected](
+ "mattr,expect_attr",
+ [
+ ("+neon", False),
+ ("+sve", True),
+ ("+v9a", True),
+ ("+sme", True),
+ ],
+)
+def test_vscale_range_function_attribute(mattr, expect_attr):
+ target = f"llvm -mtriple=aarch64-linux-gnu -mattr={mattr}"
+
+ m = te.var("m")
+ A = te.placeholder(m, dtype="float32", name="A")
+ C = te.compute((m), lambda i: A[i] + 1, name="C")
+ s = te.create_schedule([C.op])
+
+ with tvm.target.Target(target) as target:
+ f = tvm.build(s, [A, C], target)
+
+ # Check if the vscale_range() attribute exists
+ ll = f.get_source("ll")
+ attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll)
+
+ if expect_attr:
+ assert (
+ len(attr) > 0
+ ), f"Function attribute vscale_range() was not found in generated LLVM
IR"
+ else:
+ assert (
+ len(attr) == 0
+ ), f"Unexpected function attribute vscale_range() was found in
generated LLVM IR"
+
+
@pytest.mark.skipif(
llvm_version_major() < 16, reason="Test requires an LLVM version of at
least 16 to target SME"
)