libstdc++: [_Hashtable] Do not reuse untrusted cached hash code

On merge reuse merged node cached hash code only if we are on the same type of hash and this hash is stateless. Usage of function pointers or std::function as
hash functor will prevent this optimization.

libstdc++-v3/ChangeLog

    * include/bits/hashtable_policy.h
    (_Hash_code_base::_M_hash_code(const _Hash&, const _Hash_node_value<>&)): Remove.     (_Hash_code_base::_M_hash_code<_H2>(const _H2&, const _Hash_node_value<>&)): Remove.
    * include/bits/hashtable.h
    (_M_src_hash_code<_H2>(const _H2&, const key_type&, const __node_value_type&)): New.
    (_M_merge_unique<>, _M_merge_multi<>): Use latter.
    * testsuite/23_containers/unordered_map/modifiers/merge.cc
    (test04, test05, test06): New test cases.

Tested under Linux x86_64, ok to commit ?

François

diff --git a/libstdc++-v3/include/bits/hashtable.h b/libstdc++-v3/include/bits/hashtable.h
index 4c12dc895b2..f69acfe5213 100644
--- a/libstdc++-v3/include/bits/hashtable.h
+++ b/libstdc++-v3/include/bits/hashtable.h
@@ -1109,6 +1109,20 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	return { __n, this->_M_node_allocator() };
       }
 
+      // Check and if needed compute hash code using _Hash as __n _M_hash_code,
+      // if present, was computed using _H2.
+      template<typename _H2>
+	__hash_code
+	_M_src_hash_code(const _H2&, const key_type& __k,
+			 const __node_value_type& __src_n) const
+	{
+	  if constexpr (std::is_same_v<_H2, _Hash>)
+	    if constexpr (std::is_empty_v<_Hash>)
+	      return this->_M_hash_code(__src_n);
+
+	  return this->_M_hash_code(__k);
+	}
+
     public:
       // Extract a node.
       node_type
@@ -1146,7 +1160,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	      auto __pos = __i++;
 	      const key_type& __k = _ExtractKey{}(*__pos);
 	      __hash_code __code
-		= this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+		= _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
 	      size_type __bkt = _M_bucket_index(__code);
 	      if (_M_find_node(__bkt, __k, __code) == nullptr)
 		{
@@ -1174,8 +1188,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
 	  for (auto __i = __src.cbegin(), __end = __src.cend(); __i != __end;)
 	    {
 	      auto __pos = __i++;
+	      const key_type& __k = _ExtractKey{}(*__pos);
 	      __hash_code __code
-		= this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+		= _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
 	      auto __nh = __src.extract(__pos);
 	      __hint = _M_insert_multi_node(__hint, __code, __nh._M_ptr)._M_cur;
 	      __nh._M_ptr = nullptr;
diff --git a/libstdc++-v3/include/bits/hashtable_policy.h b/libstdc++-v3/include/bits/hashtable_policy.h
index 86b32fb15f2..5d162463dc3 100644
--- a/libstdc++-v3/include/bits/hashtable_policy.h
+++ b/libstdc++-v3/include/bits/hashtable_policy.h
@@ -1319,19 +1319,6 @@ namespace __detail
 	  return _M_hash()(__k);
 	}
 
-      __hash_code
-      _M_hash_code(const _Hash&,
-		   const _Hash_node_value<_Value, true>& __n) const
-      { return __n._M_hash_code; }
-
-      // Compute hash code using _Hash as __n _M_hash_code, if present, was
-      // computed using _H2.
-      template<typename _H2>
-	__hash_code
-	_M_hash_code(const _H2&,
-		const _Hash_node_value<_Value, __cache_hash_code>& __n) const
-	{ return _M_hash_code(_ExtractKey{}(__n._M_v())); }
-
       __hash_code
       _M_hash_code(const _Hash_node_value<_Value, false>& __n) const
       { return _M_hash_code(_ExtractKey{}(__n._M_v())); }
diff --git a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
index b140ce452aa..c051b58137a 100644
--- a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
+++ b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
@@ -17,15 +17,29 @@
 
 // { dg-do run { target c++17 } }
 
+#include <string>
+#include <functional>
 #include <unordered_map>
 #include <algorithm>
 #include <testsuite_hooks.h>
 
 using test_type = std::unordered_map<int, int>;
 
-struct hash {
-  auto operator()(int i) const noexcept { return ~std::hash<int>()(i); }
-};
+template<typename T>
+  struct xhash
+  {
+    auto operator()(const T& i) const noexcept
+    { return ~std::hash<T>()(i); }
+  };
+
+
+namespace std
+{
+  template<typename T>
+    struct __is_fast_hash<xhash<T>> : __is_fast_hash<std::hash<T>>
+    { };
+}
+
 struct equal : std::equal_to<> { };
 
 template<typename C1, typename C2>
@@ -64,7 +78,7 @@ test02()
 {
   const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
   test_type c1 = c0;
-  std::unordered_map<int, int, hash, equal> c2( c0.begin(), c0.end() );
+  std::unordered_map<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
 
   c1.merge(c2);
   VERIFY( c1 == c0 );
@@ -89,7 +103,7 @@ test03()
 {
   const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
   test_type c1 = c0;
-  std::unordered_multimap<int, int, hash, equal> c2( c0.begin(), c0.end() );
+  std::unordered_multimap<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
   c1.merge(c2);
   VERIFY( c1 == c0 );
   VERIFY( equal_elements(c2, c0) );
@@ -125,10 +139,164 @@ test03()
   VERIFY( c2.empty() );
 }
 
+void
+test04()
+{
+  const std::unordered_map<std::string, int> c0
+    { {"one", 10}, {"two", 20}, {"three", 30} };
+
+  std::unordered_map<std::string, int> c1 = c0;
+  std::unordered_multimap<std::string, int> c2( c0.begin(), c0.end() );
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1 = c0;
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2.size() == (2 * c0.size()) );
+  VERIFY( c2.count("one") == 2 );
+  VERIFY( c2.count("two") == 2 );
+  VERIFY( c2.count("three") == 2 );
+
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+}
+
+void
+test05()
+{
+  const std::unordered_map<std::string, int> c0
+    { {"one", 10}, {"two", 20}, {"three", 30} };
+
+  std::unordered_map<std::string, int> c1 = c0;
+  std::unordered_multimap<std::string, int, xhash<std::string>, equal> c2( c0.begin(), c0.end() );
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1 = c0;
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2.size() == (2 * c0.size()) );
+  VERIFY( c2.count("one") == 2 );
+  VERIFY( c2.count("two") == 2 );
+  VERIFY( c2.count("three") == 2 );
+
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+}
+
+template<typename T>
+  using hash_f =
+    std::function<std::size_t(const T&)>;
+
+std::size_t
+hash_func(const std::string& str)
+{ return std::hash<std::string>{}(str);  }
+
+std::size_t
+xhash_func(const std::string& str)
+{ return xhash<std::string>{}(str); }
+
+namespace std
+{
+  template<typename T>
+    struct __is_fast_hash<hash_f<T>> : __is_fast_hash<std::hash<T>>
+    { };
+}
+
+void
+test06()
+{
+  const std::unordered_map<std::string, int, hash_f<std::string>, equal>
+    c0({ {"one", 10}, {"two", 20}, {"three", 30} }, 3, &hash_func);
+
+  std::unordered_map<std::string, int, hash_f<std::string>, equal>
+    c1(3, &hash_func);
+  c1 = c0;
+  std::unordered_multimap<std::string, int, hash_f<std::string>, equal>
+    c2(c0.begin(), c0.end(), 3, &xhash_func);
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1 = c0;
+  c2.merge(c1);
+  VERIFY( c1.empty() );
+  VERIFY( c2.size() == (2 * c0.size()) );
+  VERIFY( c2.count("one") == 2 );
+  VERIFY( c2.count("two") == 2 );
+  VERIFY( c2.count("three") == 2 );
+
+  c1.merge(c2);
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( equal_elements(c2, c0) );
+
+  c1.clear();
+  c1.merge(std::move(c2));
+  VERIFY( c1 == c0 );
+  VERIFY( c2.empty() );
+}
+
 int
 main()
 {
   test01();
   test02();
   test03();
+  test04();
+  test05();
+  test06();
 }

Reply via email to