Hi klimek,

Moved to storing a pair of void* and NodeBaseType which will allow us to safely 
determine at runtime whether a given node (stored as void*) can be cast to the 
type the user requests when calling getNodeAs. This logic is hidden inside a 
new internal class, BoundNodesMap. The result of this is that we can add 
support for binding any base node type with very little effort. This change 
also adds support for binding QualType (in BoundNodes).

http://llvm-reviews.chandlerc.com/D25

Files:
  
google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchers.h
  
google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
  
google3/third_party/llvm/llvm/tools/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
Index: google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchers.h
===================================================================
--- google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchers.h
+++ google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchers.h
@@ -67,35 +67,29 @@
   /// \brief Returns the AST node bound to \c ID.
   /// Returns NULL if there was no node bound to \c ID or if there is a node but
   /// it cannot be converted to the specified type.
-  /// FIXME: We'll need one of those for every base type.
+  template <typename T>
+  const T *getNodeAs(StringRef ID) const {
+    return MyBoundNodes.getNodeAs<T>(ID);
+  }
+
+  /// \brief Deprecated. Please use \c getNodeAs instead.
   /// @{
   template <typename T>
   const T *getDeclAs(StringRef ID) const {
-    return getNodeAs<T>(DeclBindings, ID);
+    return getNodeAs<T>(ID);
   }
   template <typename T>
   const T *getStmtAs(StringRef ID) const {
-    return getNodeAs<T>(StmtBindings, ID);
+    return getNodeAs<T>(ID);
   }
   /// @}
 
 private:
   /// \brief Create BoundNodes from a pre-filled map of bindings.
-  BoundNodes(const std::map<std::string, const Decl*> &DeclBindings,
-             const std::map<std::string, const Stmt*> &StmtBindings)
-      : DeclBindings(DeclBindings), StmtBindings(StmtBindings) {}
-
-  template <typename T, typename MapT>
-  const T *getNodeAs(const MapT &Bindings, StringRef ID) const {
-    typename MapT::const_iterator It = Bindings.find(ID);
-    if (It == Bindings.end()) {
-      return NULL;
-    }
-    return llvm::dyn_cast<T>(It->second);
-  }
+  BoundNodes(internal::BoundNodesMap &MyBoundNodes)
+      : MyBoundNodes(MyBoundNodes) {}
 
-  std::map<std::string, const Decl*> DeclBindings;
-  std::map<std::string, const Stmt*> StmtBindings;
+  internal::BoundNodesMap MyBoundNodes;
 
   friend class internal::BoundNodesTree;
 };
Index: google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
===================================================================
--- google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
+++ google3/third_party/llvm/llvm/tools/clang/include/clang/ASTMatchers/ASTMatchersInternal.h
@@ -39,7 +39,9 @@
 #include "clang/AST/DeclCXX.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/Stmt.h"
+#include "clang/AST/Type.h"
 #include "llvm/ADT/VariadicFunction.h"
+#include "llvm/Support/type_traits.h"
 #include <map>
 #include <string>
 #include <vector>
@@ -58,6 +60,122 @@
 
 class BoundNodesTreeBuilder;
 
+/// \brief Macro for adding a base type to get_base_type.
+/// Note that you'll need to add ">::type" just before " type;" as well.
+#define GET_BASE_TYPE(BaseType) \
+  typename llvm::conditional<llvm::is_base_of<BaseType, T>::value, BaseType
+
+/// \brief Meta-template to get base class of a node.
+template <typename T>
+struct get_base_type {
+    typedef GET_BASE_TYPE(Decl), GET_BASE_TYPE(Stmt), GET_BASE_TYPE(QualType),
+      void>::type>::type>::type type;
+};
+
+/// \brief Indicates the base type of a bound AST node.
+/// Used for storing nodes as void*.
+enum NodeBaseType {
+  NT_Decl,
+  NT_Stmt,
+  NT_QualType,
+  NT_Unknown
+};
+
+/// \brief Utility to manipulate nodes of a given base type.
+/// We use template specialization on the node base type to enable us to
+/// get at the appropriate NodeBaseType objects and do approrpiate static_casts.
+template <typename BaseType>
+struct NodeBaseTypeUtil {
+  /// \brief Returns the NodeBaseType corresponding to \c BaseType.
+  static NodeBaseType getNodeBaseType() {
+    return NT_Unknown;
+  }
+
+  /// \brief Casts \c Node to \c T if \c ActualBaseType matches \c BaseType.
+  /// Otherwise, NULL is returned.
+  template <typename T>
+  static const T* castNode(const NodeBaseType &ActualBaseType,
+                           const void *Node) {
+    return NULL;
+  }
+};
+
+/// \brief Macro for adding a template specialization of NodeBaseTypeUtil.
+#define NODE_BASE_TYPE_UTIL(BaseType) \
+template <>                                                           \
+struct NodeBaseTypeUtil<BaseType> {                                   \
+  static NodeBaseType getNodeBaseType() {                             \
+    return NT_##BaseType;                                             \
+  }                                                                   \
+                                                                      \
+  template <typename T>                                               \
+  static const T *castNode(const NodeBaseType &ActualBaseType,        \
+                           const void *Node) {                        \
+    if (ActualBaseType == NT_##BaseType) {                            \
+      return llvm::dyn_cast<T>(static_cast<const BaseType*>(Node));   \
+    } else {                                                          \
+      return NULL;                                                    \
+    }                                                                 \
+  }                                                                   \
+}
+
+/// \brief Template specialization of NodeBaseTypeUtil. See main definition.
+/// @{
+NODE_BASE_TYPE_UTIL(Decl);
+NODE_BASE_TYPE_UTIL(Stmt);
+NODE_BASE_TYPE_UTIL(QualType);
+/// @}
+
+/// \brief Gets the NodeBaseType of a node with type \c T.
+template <typename T>
+static NodeBaseType getBaseType(const T*) {
+  return NodeBaseTypeUtil<typename get_base_type<T>::type>::getNodeBaseType();
+}
+
+/// \brief Internal version of BoundNodes. Holds all the bound nodes.
+class BoundNodesMap {
+public:
+  BoundNodesMap();
+
+  BoundNodesMap(const BoundNodesMap &Other);
+
+  /// \brief Adds \c Node to the map with key \c ID.
+  /// The node's base type should be in NodeBaseType or it will be unaccessible.
+  template <typename T>
+  void addNode(StringRef ID, const T* Node) {
+    NodeMap[ID] = std::make_pair(getBaseType(Node), Node);
+  }
+
+  /// \brief Returns the AST node bound to \c ID.
+  /// Returns NULL if there was no node bound to \c ID or if there is a node but
+  /// it cannot be converted to the specified type.
+  template <typename T>
+  const T *getNodeAs(StringRef ID) const {
+    IDToNodeMap::const_iterator It = NodeMap.find(ID);
+    if (It == NodeMap.end()) {
+      return NULL;
+    }
+
+    return NodeBaseTypeUtil<typename get_base_type<T>::type>::
+        template castNode<T>(It->second.first, It->second.second);
+  }
+
+  /// \brief Copies all ID/Node pairs to BoundNodesTreeBuilder \c Builder.
+  void copyTo(BoundNodesTreeBuilder *Builder) const;
+
+  /// \brief Copies all ID/Node pairs to BoundNodesMap \c Other.
+  void copyTo(BoundNodesMap *Other) const;
+
+private:
+  /// \brief A node and its base type.
+  typedef std::pair<NodeBaseType, const void*> NodeTypePair;
+
+  /// \brief A map from IDs to node/type pairs.
+  typedef std::map<std::string, NodeTypePair> IDToNodeMap;
+
+  IDToNodeMap NodeMap;
+};
+
 /// \brief A tree of bound nodes in match results.
 ///
 /// If a match can contain multiple matches on the same node with different
@@ -84,11 +202,10 @@
   BoundNodesTree();
 
   /// \brief Create a BoundNodesTree from pre-filled maps of bindings.
-  BoundNodesTree(const std::map<std::string, const Decl*>& DeclBindings,
-                 const std::map<std::string, const Stmt*>& StmtBindings,
+  BoundNodesTree(const BoundNodesMap& Bindings,
                  const std::vector<BoundNodesTree> RecursiveBindings);
 
-  /// \brief Adds all bound nodes to bound_nodes_builder.
+  /// \brief Adds all bound nodes to \c Builder.
   void copyTo(BoundNodesTreeBuilder* Builder) const;
 
   /// \brief Visits all matches that this BoundNodesTree represents.
@@ -99,17 +216,12 @@
 private:
   void visitMatchesRecursively(
       Visitor* ResultVistior,
-      std::map<std::string, const Decl*> DeclBindings,
-      std::map<std::string, const Stmt*> StmtBindings);
-
-  template <typename T>
-  void copyBindingsTo(const T& bindings, BoundNodesTreeBuilder* Builder) const;
+      BoundNodesMap *AggregatedBindings);
 
   // FIXME: Find out whether we want to use different data structures here -
   // first benchmarks indicate that it doesn't matter though.
 
-  std::map<std::string, const Decl*> DeclBindings;
-  std::map<std::string, const Stmt*> StmtBindings;
+  BoundNodesMap Bindings;
 
   std::vector<BoundNodesTree> RecursiveBindings;
 };
@@ -123,12 +235,10 @@
   BoundNodesTreeBuilder();
 
   /// \brief Add a binding from an id to a node.
-  ///
-  /// FIXME: Add overloads for all AST base types.
-  /// @{
-  void setBinding(const std::string &Id, const Decl *Node);
-  void setBinding(const std::string &Id, const Stmt *Node);
-  /// @}
+  template <typename T>
+  void setBinding(const std::string &Id, const T *Node) {
+    Bindings.addNode(Id, Node);
+  }
 
   /// \brief Adds a branch in the tree.
   void addMatch(const BoundNodesTree& Bindings);
@@ -140,8 +250,7 @@
   BoundNodesTreeBuilder(const BoundNodesTreeBuilder&);  // DO NOT IMPLEMENT
   void operator=(const BoundNodesTreeBuilder&);  // DO NOT IMPLEMENT
 
-  std::map<std::string, const Decl*> DeclBindings;
-  std::map<std::string, const Stmt*> StmtBindings;
+  BoundNodesMap Bindings;
 
   std::vector<BoundNodesTree> RecursiveBindings;
 };
Index: google3/third_party/llvm/llvm/tools/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
===================================================================
--- google3/third_party/llvm/llvm/tools/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
+++ google3/third_party/llvm/llvm/tools/clang/lib/ASTMatchers/ASTMatchersInternal.cpp
@@ -18,83 +18,71 @@
 namespace ast_matchers {
 namespace internal {
 
+BoundNodesMap::BoundNodesMap() {}
+
+BoundNodesMap::BoundNodesMap(const BoundNodesMap &Other) {
+  NodeMap = Other.NodeMap;
+}
+
+void BoundNodesMap::copyTo(BoundNodesTreeBuilder *Builder) const {
+  for (IDToNodeMap::const_iterator It = NodeMap.begin();
+       It != NodeMap.end();
+       ++It) {
+    Builder->setBinding(It->first, It->second.second);
+  }
+}
+
+void BoundNodesMap::copyTo(BoundNodesMap *Other) const {
+  copy(NodeMap.begin(), NodeMap.end(),
+       inserter(Other->NodeMap, Other->NodeMap.begin()));
+}
+
+
 BoundNodesTree::BoundNodesTree() {}
 
 BoundNodesTree::BoundNodesTree(
-  const std::map<std::string, const Decl*>& DeclBindings,
-  const std::map<std::string, const Stmt*>& StmtBindings,
+  const BoundNodesMap& Bindings,
   const std::vector<BoundNodesTree> RecursiveBindings)
-  : DeclBindings(DeclBindings), StmtBindings(StmtBindings),
+  : Bindings(Bindings),
     RecursiveBindings(RecursiveBindings) {}
 
 void BoundNodesTree::copyTo(BoundNodesTreeBuilder* Builder) const {
-  copyBindingsTo(DeclBindings, Builder);
-  copyBindingsTo(StmtBindings, Builder);
+  Bindings.copyTo(Builder);
   for (std::vector<BoundNodesTree>::const_iterator
          I = RecursiveBindings.begin(),
          E = RecursiveBindings.end();
        I != E; ++I) {
     Builder->addMatch(*I);
   }
 }
 
-template <typename T>
-void BoundNodesTree::copyBindingsTo(
-    const T& Bindings, BoundNodesTreeBuilder* Builder) const {
-  for (typename T::const_iterator I = Bindings.begin(),
-                                  E = Bindings.end();
-       I != E; ++I) {
-    Builder->setBinding(I->first, I->second);
-  }
-}
-
 void BoundNodesTree::visitMatches(Visitor* ResultVisitor) {
-  std::map<std::string, const Decl*> AggregatedDeclBindings;
-  std::map<std::string, const Stmt*> AggregatedStmtBindings;
-  visitMatchesRecursively(ResultVisitor, AggregatedDeclBindings,
-                          AggregatedStmtBindings);
+  BoundNodesMap AggregatedBindings;
+  visitMatchesRecursively(ResultVisitor, &AggregatedBindings);
 }
 
 void BoundNodesTree::
 visitMatchesRecursively(Visitor* ResultVisitor,
-                        std::map<std::string, const Decl*>
-                          AggregatedDeclBindings,
-                        std::map<std::string, const Stmt*>
-                          AggregatedStmtBindings) {
-  copy(DeclBindings.begin(), DeclBindings.end(),
-       inserter(AggregatedDeclBindings, AggregatedDeclBindings.begin()));
-  copy(StmtBindings.begin(), StmtBindings.end(),
-       inserter(AggregatedStmtBindings, AggregatedStmtBindings.begin()));
+                        BoundNodesMap* AggregatedBindings) {
+  Bindings.copyTo(AggregatedBindings);
   if (RecursiveBindings.empty()) {
-    ResultVisitor->visitMatch(BoundNodes(AggregatedDeclBindings,
-                                         AggregatedStmtBindings));
+    ResultVisitor->visitMatch(BoundNodes(*AggregatedBindings));
   } else {
     for (unsigned I = 0; I < RecursiveBindings.size(); ++I) {
       RecursiveBindings[I].visitMatchesRecursively(ResultVisitor,
-                                                   AggregatedDeclBindings,
-                                                   AggregatedStmtBindings);
+                                                   AggregatedBindings);
     }
   }
 }
 
 BoundNodesTreeBuilder::BoundNodesTreeBuilder() {}
 
-void BoundNodesTreeBuilder::setBinding(const std::string &Id,
-                                       const Decl *Node) {
-  DeclBindings[Id] = Node;
-}
-
-void BoundNodesTreeBuilder::setBinding(const std::string &Id,
-                                       const Stmt *Node) {
-  StmtBindings[Id] = Node;
-}
-
 void BoundNodesTreeBuilder::addMatch(const BoundNodesTree& Bindings) {
   RecursiveBindings.push_back(Bindings);
 }
 
 BoundNodesTree BoundNodesTreeBuilder::build() const {
-  return BoundNodesTree(DeclBindings, StmtBindings, RecursiveBindings);
+  return BoundNodesTree(Bindings, RecursiveBindings);
 }
 
 } // end namespace internal
_______________________________________________
cfe-commits mailing list
[email protected]
http://lists.cs.uiuc.edu/mailman/listinfo/cfe-commits

Reply via email to