https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/87247
Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible). >From 2fec99813013adf1ab6b262132ddebe4356ce643 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Mon, 1 Apr 2024 10:07:45 -0500 Subject: [PATCH] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible). --- llvm/include/llvm/Frontend/OpenMP/OMP.h | 7 + llvm/lib/Frontend/OpenMP/OMP.cpp | 64 +++++- llvm/unittests/Frontend/CMakeLists.txt | 1 + llvm/unittests/Frontend/OpenMPComposeTest.cpp | 40 ++++ llvm/utils/TableGen/DirectiveEmitter.cpp | 194 +++++++++++------- 5 files changed, 235 insertions(+), 71 deletions(-) create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h index a85cd9d344c6d7..4ed47f15dfe59e 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h @@ -15,4 +15,11 @@ #include "llvm/Frontend/OpenMP/OMP.h.inc" +#include "llvm/ADT/ArrayRef.h" + +namespace llvm::omp { +ArrayRef<Directive> getLeafConstructs(Directive D); +Directive getCompoundConstruct(ArrayRef<Directive> Parts); +} // namespace llvm::omp + #endif // LLVM_FRONTEND_OPENMP_OMP_H diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index 4f2f95392648b3..dd99d3d074fd1e 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -8,12 +8,74 @@ #include "llvm/Frontend/OpenMP/OMP.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/ErrorHandling.h" +#include <algorithm> +#include <iterator> +#include <type_traits> + using namespace llvm; -using namespace omp; +using namespace llvm::omp; #define GEN_DIRECTIVES_IMPL #include "llvm/Frontend/OpenMP/OMP.inc" + +namespace llvm::omp { +ArrayRef<Directive> getLeafConstructs(Directive D) { + auto Idx = static_cast<int>(D); + if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize)) + return {}; + const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]]; + return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1])); +} + +Directive getCompoundConstruct(ArrayRef<Directive> Parts) { + if (Parts.empty()) + return OMPD_unknown; + + // Parts don't have to be leafs, so expand them into leafs first. + // Store the expanded leafs in the same format as rows in the leaf + // table (generated by tablegen). + SmallVector<Directive> RawLeafs(2); + for (Directive P : Parts) { + ArrayRef<Directive> Ls = getLeafConstructs(P); + if (!Ls.empty()) + RawLeafs.append(Ls.begin(), Ls.end()); + else + RawLeafs.push_back(P); + } + + auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)}; + if (GivenLeafs.size() == 1) + return GivenLeafs.front(); + RawLeafs[1] = static_cast<Directive>(GivenLeafs.size()); + + auto Iter = llvm::lower_bound( + LeafConstructTable, + static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()), + [](const auto *RowA, const auto *RowB) { + const auto *BeginA = &RowA[2]; + const auto *EndA = BeginA + static_cast<int>(RowA[1]); + const auto *BeginB = &RowB[2]; + const auto *EndB = BeginB + static_cast<int>(RowB[1]); + if (BeginA == EndA && BeginB == EndB) + return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]); + return std::lexicographical_compare(BeginA, EndA, BeginB, EndB); + }); + + if (Iter == std::end(LeafConstructTable)) + return OMPD_unknown; + + // Verify that we got a match. + Directive Found = (*Iter)[0]; + ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found); + if (FoundLeafs == GivenLeafs) + return Found; + return OMPD_unknown; +} +} // namespace llvm::omp diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt index c6f60142d6276a..ddb6a16cbb984e 100644 --- a/llvm/unittests/Frontend/CMakeLists.txt +++ b/llvm/unittests/Frontend/CMakeLists.txt @@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests OpenMPContextTest.cpp OpenMPIRBuilderTest.cpp OpenMPParsingTest.cpp + OpenMPComposeTest.cpp DEPENDS acc_gen diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp new file mode 100644 index 00000000000000..2dc35aca8842e9 --- /dev/null +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -0,0 +1,40 @@ +//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Frontend/OpenMP/OMP.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::omp; + +TEST(Composition, GetLeafConstructs) { + ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop); + ASSERT_EQ(L1, (ArrayRef<Directive>{})); + ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for); + ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for})); + ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd); + ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd})); +} + +TEST(Composition, GetCompoundConstruct) { + Directive C1 = getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute}); + ASSERT_EQ(C1, OMPD_target_teams_distribute); + Directive C2 = getCompoundConstruct({OMPD_target}); + ASSERT_EQ(C2, OMPD_target); + Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked}); + ASSERT_EQ(C3, OMPD_unknown); + Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); + ASSERT_EQ(C4, OMPD_target_teams_distribute); + Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); + ASSERT_EQ(C5, OMPD_target_teams_distribute); + Directive C6 = getCompoundConstruct({}); + ASSERT_EQ(C6, OMPD_unknown); + Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); + ASSERT_EQ(C7, OMPD_parallel_for_simd); +} diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index e0edf1720f8ac5..34b517e816a243 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -20,6 +20,9 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include <numeric> +#include <vector> + using namespace llvm; namespace { @@ -39,7 +42,8 @@ class IfDefScope { }; } // namespace -// Generate enum class +// Generate enum class. Entries are emitted in the order in which they appear +// in the `Records` vector. static void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS, StringRef Enum, StringRef Prefix, const DirectiveLanguage &DirLang, @@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const { return HasDuplicateClausesInDirectives(getDirectives()); } +// Count the maximum number of leaf constituents per construct. +static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) { + size_t MaxCount = 0; + for (Record *R : DirLang.getDirectives()) { + size_t Count = Directive{R}.getLeafConstructs().size(); + MaxCount = std::max(MaxCount, Count); + } + return MaxCount; +} + // Generate the declaration section for the enumeration in the directive // language static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { @@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { if (DirLang.hasEnableBitmaskEnumInNamespace()) OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n"; + OS << "#include <cstddef>\n"; // for size_t OS << "\n"; OS << "namespace llvm {\n"; OS << "class StringRef;\n"; @@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; - OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n"; + OS << "constexpr std::size_t getMaxLeafCount() { return " + << GetMaxLeafCount(DirLang) << "; }\n"; OS << "Association getDirectiveAssociation(Directive D);\n"; if (EnumHelperFuncs.length() > 0) { OS << EnumHelperFuncs; @@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses, } } +static std::string GetDirectiveName(const DirectiveLanguage &DirLang, + const Record *Rec) { + Directive Dir{Rec}; + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" + + DirLang.getDirectivePrefix() + Dir.getFormattedName()) + .str(); +} + +static std::string GetDirectiveType(const DirectiveLanguage &DirLang) { + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive") + .str(); +} + // Generate the isAllowedClauseForDirective function implementation. static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, raw_ostream &OS) { @@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, OS << "}\n"; // End of function isAllowedClauseForDirective } -// Generate the getLeafConstructs function implementation. -static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang, - raw_ostream &OS) { - auto getQualifiedName = [&](StringRef Formatted) -> std::string { - return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + - "::Directive::" + DirLang.getDirectivePrefix() + Formatted) - .str(); - }; - - // For each list of leaves, generate a static local object, then - // return a reference to that object for a given directive, e.g. +static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, + StringRef TableName) { + // The leaf constructs are emitted in a form of a 2D table, where each + // row corresponds to a directive (and there is a row for each directive). // - // static ListTy leafConstructs_A_B = { A, B }; - // static ListTy leafConstructs_C_D_E = { C, D, E }; - // switch (Dir) { - // case A_B: - // return leafConstructs_A_B; - // case C_D_E: - // return leafConstructs_C_D_E; - // } - - // Map from a record that defines a directive to the name of the - // local object with the list of its leaves. - DenseMap<Record *, std::string> ListNames; - - std::string DirectiveTypeName = - std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive"; - - OS << '\n'; - - // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir) - OS << "llvm::ArrayRef<" << DirectiveTypeName - << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs(" - << DirectiveTypeName << " Dir) "; - OS << "{\n"; - - // Generate the locals. - for (Record *R : DirLang.getDirectives()) { - Directive Dir{R}; + // Each row consists of + // - the id of the directive itself, + // - number of leaf constructs that will follow (0 for leafs), + // - ids of the leaf constructs (none if the directive is itself a leaf). + // The total number of these entries is at most MaxLeafCount+2. If this + // number is less than that, it is padded to occupy exactly MaxLeafCount+2 + // entries in memory. + // + // The rows are stored in the table in the lexicographical order. This + // is intended to enable binary search when mapping a sequence of leafs + // back to the compound directive. + // The consequence of that is that in order to find a row corresponding + // to the given directive, we'd need to scan the first element of each + // row. To avoid this, an auxiliary ordering table is created, such that + // row for Dir_A = table[auxiliary[Dir_A]]. + + std::vector<Record *> Directives = DirLang.getDirectives(); + DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive + + for (auto [Idx, Rec] : llvm::enumerate(Directives)) + DirId.insert(std::make_pair(Rec, Idx)); + + using LeafList = std::vector<int>; + int MaxLeafCount = GetMaxLeafCount(DirLang); + + // The initial leaf table, rows order is same as directive order. + std::vector<LeafList> LeafTable(Directives.size()); + for (auto [Idx, Rec] : llvm::enumerate(Directives)) { + Directive Dir{Rec}; + std::vector<Record *> Leaves = Dir.getLeafConstructs(); + + auto &List = LeafTable[Idx]; + List.resize(MaxLeafCount + 2); + List[0] = Idx; // The id of the directive itself. + List[1] = Leaves.size(); // The number of leaves to follow. + + for (int I = 0; I != MaxLeafCount; ++I) + List[I + 2] = + static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1; + } - std::vector<Record *> LeafConstructs = Dir.getLeafConstructs(); - if (LeafConstructs.empty()) - continue; + // Avoid sorting the vector<vector> array, instead sort an index array. + // It will also be useful later to create the auxiliary indexing array. + std::vector<int> Ordering(Directives.size()); + std::iota(Ordering.begin(), Ordering.end(), 0); + + llvm::sort(Ordering, [&](int A, int B) { + auto &LeavesA = LeafTable[A]; + auto &LeavesB = LeafTable[B]; + if (LeavesA[1] == 0 && LeavesB[1] == 0) + return LeavesA[0] < LeavesB[0]; + return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1], + &LeavesB[2], &LeavesB[2] + LeavesB[1]); + }); - std::string ListName = "leafConstructs_" + Dir.getFormattedName(); - OS << " static const " << DirectiveTypeName << ' ' << ListName - << "[] = {\n"; - for (Record *L : LeafConstructs) { - Directive LeafDir{L}; - OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n"; + // Emit the table + + // The directives are emitted into a scoped enum, for which the underlying + // type is `int` (by default). The code above uses `int` to store directive + // ids, so make sure that we catch it when something changes in the + // underlying type. + std::string DirectiveType = GetDirectiveType(DirLang); + OS << "static_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n"; + + OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName + << "[][" << MaxLeafCount + 2 << "] = {\n"; + for (size_t I = 0, E = Directives.size(); I != E; ++I) { + auto &Leaves = LeafTable[Ordering[I]]; + OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]); + OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),"; + for (size_t I = 2, E = Leaves.size(); I != E; ++I) { + int Idx = Leaves[I]; + if (Idx >= 0) + OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ','; + else + OS << " static_cast<" << DirectiveType << ">(-1),"; } - OS << " };\n"; - ListNames.insert(std::make_pair(R, std::move(ListName))); - } - - if (!ListNames.empty()) OS << '\n'; - OS << " switch (Dir) {\n"; - for (Record *R : DirLang.getDirectives()) { - auto F = ListNames.find(R); - if (F == ListNames.end()) - continue; - - Directive Dir{R}; - OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n"; - OS << " return " << F->second << ";\n"; } - OS << " default:\n"; - OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n"; - OS << " } // switch (Dir)\n"; - OS << "}\n"; + OS << "};\n\n"; + + // Emit the auxiliary index table: it's the inverse of the `Ordering` + // table above. + OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n"; + OS << " "; + std::vector<int> Reverse(Ordering.size()); + for (int I = 0, E = Ordering.size(); I != E; ++I) + Reverse[Ordering[I]] = I; + for (int Idx : Reverse) + OS << ' ' << Idx << ','; + OS << "\n};\n"; } static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang, @@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang, // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(DirLang, OS); - // getLeafConstructs(Directive D) - GenerateGetLeafConstructs(DirLang, OS); - // getDirectiveAssociation(Directive D) GenerateGetDirectiveAssociation(DirLang, OS); + + // Leaf table for getLeafConstructs, etc. + EmitLeafTable(DirLang, OS, "LeafConstructTable"); } // Generate the implemenation section for the enumeration in the directive _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits