This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c20cdafcbc [SME] Target parser support for SME (#16794)
c20cdafcbc is described below
commit c20cdafcbc17d9a6b72fe324da7c2b295074a081
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Apr 2 13:54:15 2024 +0100
[SME] Target parser support for SME (#16794)
This commit adds support for recognising when the SME architecture
feature is available based on the target string. A python user can
use `target.features.has_sme` to check availability.
---
src/target/parsers/aprofile.cc | 3 ++-
tests/cpp/target/parsers/aprofile_test.cc | 17 +++++++++++++++++
2 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc
index 907e0cae72..f84c7485a0 100644
--- a/src/target/parsers/aprofile.cc
+++ b/src/target/parsers/aprofile.cc
@@ -111,7 +111,8 @@ static TargetFeatures GetFeatures(TargetJSON target) {
{"has_sve", Bool(has_feature("sve"))},
{"has_dotprod", Bool(has_feature("dotprod"))},
{"has_matmul_i8", Bool(has_feature("i8mm"))},
- {"has_fp16_simd", Bool(has_feature("fullfp16"))}};
+ {"has_fp16_simd", Bool(has_feature("fullfp16"))},
+ {"has_sme", Bool(has_feature("sme"))}};
#endif
LOG(WARNING) << "Cannot parse Arm(R)-based target features without LLVM
support.";
diff --git a/tests/cpp/target/parsers/aprofile_test.cc
b/tests/cpp/target/parsers/aprofile_test.cc
index a134e162fc..d329a9b958 100644
--- a/tests/cpp/target/parsers/aprofile_test.cc
+++ b/tests/cpp/target/parsers/aprofile_test.cc
@@ -38,6 +38,7 @@ static float defaultI8MM = 8.6;
static float optionalI8MM[] = {8.2, 8.3, 8.4, 8.5};
static float defaultDotProd = 8.4;
static float optionalDotProd[] = {8.2, 8.3};
+static float optionalSME[] = {9.2, 9.3};
static bool CheckArchitectureAvailability() {
#if TVM_LLVM_VERSION > 120
@@ -405,6 +406,21 @@ TEST(AProfileParserInvalid, LLVMUnsupportedArchitecture) {
}
}
+using AProfileOptionalSME = AProfileParserTestWithParam;
+TEST_P(AProfileOptionalSME, OptionalSMESupport) {
+ const std::string arch_attr = "+v9a";
+
+ TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi",
{arch_attr});
+ TargetFeatures features = Downcast<TargetFeatures>(target.at("features"));
+ ASSERT_TRUE(IsArch(target));
+ ASSERT_FALSE(Downcast<Bool>(features.at("has_sme")));
+
+ target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr,
"+sme"});
+ features = Downcast<TargetFeatures>(target.at("features"));
+ ASSERT_TRUE(IsArch(target));
+ ASSERT_TRUE(Downcast<Bool>(features.at("has_sme")));
+}
+
INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalI8MM,
::testing::ValuesIn(optionalI8MM));
INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalDotProd,
::testing::ValuesIn(optionalDotProd));
@@ -412,6 +428,7 @@ INSTANTIATE_TEST_SUITE_P(AProfileParser,
AProfileOptionalSVE,
::testing::Values(8.0, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6,
8.7, 8.8, 8.9));
INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalFP16,
::testing::Values(8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8,
8.9));
+INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalSME,
::testing::ValuesIn(optionalSME));
} // namespace aprofile
} // namespace parsers