libstdc++: [_Hashtable] Do not reuse untrusted cached hash code
On merge reuse merged node cached hash code only if we are on the same
type of
hash and this hash is stateless. Usage of function pointers or
std::function as
hash functor will prevent this optimization.
libstdc++-v3/ChangeLog
* include/bits/hashtable_policy.h
(_Hash_code_base::_M_hash_code(const _Hash&, const
_Hash_node_value<>&)): Remove.
(_Hash_code_base::_M_hash_code<_H2>(const _H2&, const
_Hash_node_value<>&)): Remove.
* include/bits/hashtable.h
(_M_src_hash_code<_H2>(const _H2&, const key_type&, const
__node_value_type&)): New.
(_M_merge_unique<>, _M_merge_multi<>): Use latter.
* testsuite/23_containers/unordered_map/modifiers/merge.cc
(test04, test05, test06): New test cases.
Tested under Linux x86_64, ok to commit ?
François
diff --git a/libstdc++-v3/include/bits/hashtable.h b/libstdc++-v3/include/bits/hashtable.h
index 4c12dc895b2..f69acfe5213 100644
--- a/libstdc++-v3/include/bits/hashtable.h
+++ b/libstdc++-v3/include/bits/hashtable.h
@@ -1109,6 +1109,20 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
return { __n, this->_M_node_allocator() };
}
+ // Check and if needed compute hash code using _Hash as __n _M_hash_code,
+ // if present, was computed using _H2.
+ template<typename _H2>
+ __hash_code
+ _M_src_hash_code(const _H2&, const key_type& __k,
+ const __node_value_type& __src_n) const
+ {
+ if constexpr (std::is_same_v<_H2, _Hash>)
+ if constexpr (std::is_empty_v<_Hash>)
+ return this->_M_hash_code(__src_n);
+
+ return this->_M_hash_code(__k);
+ }
+
public:
// Extract a node.
node_type
@@ -1146,7 +1160,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
auto __pos = __i++;
const key_type& __k = _ExtractKey{}(*__pos);
__hash_code __code
- = this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+ = _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
size_type __bkt = _M_bucket_index(__code);
if (_M_find_node(__bkt, __k, __code) == nullptr)
{
@@ -1174,8 +1188,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
for (auto __i = __src.cbegin(), __end = __src.cend(); __i != __end;)
{
auto __pos = __i++;
+ const key_type& __k = _ExtractKey{}(*__pos);
__hash_code __code
- = this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+ = _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
auto __nh = __src.extract(__pos);
__hint = _M_insert_multi_node(__hint, __code, __nh._M_ptr)._M_cur;
__nh._M_ptr = nullptr;
diff --git a/libstdc++-v3/include/bits/hashtable_policy.h b/libstdc++-v3/include/bits/hashtable_policy.h
index 86b32fb15f2..5d162463dc3 100644
--- a/libstdc++-v3/include/bits/hashtable_policy.h
+++ b/libstdc++-v3/include/bits/hashtable_policy.h
@@ -1319,19 +1319,6 @@ namespace __detail
return _M_hash()(__k);
}
- __hash_code
- _M_hash_code(const _Hash&,
- const _Hash_node_value<_Value, true>& __n) const
- { return __n._M_hash_code; }
-
- // Compute hash code using _Hash as __n _M_hash_code, if present, was
- // computed using _H2.
- template<typename _H2>
- __hash_code
- _M_hash_code(const _H2&,
- const _Hash_node_value<_Value, __cache_hash_code>& __n) const
- { return _M_hash_code(_ExtractKey{}(__n._M_v())); }
-
__hash_code
_M_hash_code(const _Hash_node_value<_Value, false>& __n) const
{ return _M_hash_code(_ExtractKey{}(__n._M_v())); }
diff --git a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
index b140ce452aa..c051b58137a 100644
--- a/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
+++ b/libstdc++-v3/testsuite/23_containers/unordered_map/modifiers/merge.cc
@@ -17,15 +17,29 @@
// { dg-do run { target c++17 } }
+#include <string>
+#include <functional>
#include <unordered_map>
#include <algorithm>
#include <testsuite_hooks.h>
using test_type = std::unordered_map<int, int>;
-struct hash {
- auto operator()(int i) const noexcept { return ~std::hash<int>()(i); }
-};
+template<typename T>
+ struct xhash
+ {
+ auto operator()(const T& i) const noexcept
+ { return ~std::hash<T>()(i); }
+ };
+
+
+namespace std
+{
+ template<typename T>
+ struct __is_fast_hash<xhash<T>> : __is_fast_hash<std::hash<T>>
+ { };
+}
+
struct equal : std::equal_to<> { };
template<typename C1, typename C2>
@@ -64,7 +78,7 @@ test02()
{
const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
test_type c1 = c0;
- std::unordered_map<int, int, hash, equal> c2( c0.begin(), c0.end() );
+ std::unordered_map<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
c1.merge(c2);
VERIFY( c1 == c0 );
@@ -89,7 +103,7 @@ test03()
{
const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
test_type c1 = c0;
- std::unordered_multimap<int, int, hash, equal> c2( c0.begin(), c0.end() );
+ std::unordered_multimap<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
c1.merge(c2);
VERIFY( c1 == c0 );
VERIFY( equal_elements(c2, c0) );
@@ -125,10 +139,164 @@ test03()
VERIFY( c2.empty() );
}
+void
+test04()
+{
+ const std::unordered_map<std::string, int> c0
+ { {"one", 10}, {"two", 20}, {"three", 30} };
+
+ std::unordered_map<std::string, int> c1 = c0;
+ std::unordered_multimap<std::string, int> c2( c0.begin(), c0.end() );
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1 = c0;
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( c2.size() == (2 * c0.size()) );
+ VERIFY( c2.count("one") == 2 );
+ VERIFY( c2.count("two") == 2 );
+ VERIFY( c2.count("three") == 2 );
+
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+}
+
+void
+test05()
+{
+ const std::unordered_map<std::string, int> c0
+ { {"one", 10}, {"two", 20}, {"three", 30} };
+
+ std::unordered_map<std::string, int> c1 = c0;
+ std::unordered_multimap<std::string, int, xhash<std::string>, equal> c2( c0.begin(), c0.end() );
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1 = c0;
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( c2.size() == (2 * c0.size()) );
+ VERIFY( c2.count("one") == 2 );
+ VERIFY( c2.count("two") == 2 );
+ VERIFY( c2.count("three") == 2 );
+
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+}
+
+template<typename T>
+ using hash_f =
+ std::function<std::size_t(const T&)>;
+
+std::size_t
+hash_func(const std::string& str)
+{ return std::hash<std::string>{}(str); }
+
+std::size_t
+xhash_func(const std::string& str)
+{ return xhash<std::string>{}(str); }
+
+namespace std
+{
+ template<typename T>
+ struct __is_fast_hash<hash_f<T>> : __is_fast_hash<std::hash<T>>
+ { };
+}
+
+void
+test06()
+{
+ const std::unordered_map<std::string, int, hash_f<std::string>, equal>
+ c0({ {"one", 10}, {"two", 20}, {"three", 30} }, 3, &hash_func);
+
+ std::unordered_map<std::string, int, hash_f<std::string>, equal>
+ c1(3, &hash_func);
+ c1 = c0;
+ std::unordered_multimap<std::string, int, hash_f<std::string>, equal>
+ c2(c0.begin(), c0.end(), 3, &xhash_func);
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1 = c0;
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( c2.size() == (2 * c0.size()) );
+ VERIFY( c2.count("one") == 2 );
+ VERIFY( c2.count("two") == 2 );
+ VERIFY( c2.count("three") == 2 );
+
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+}
+
int
main()
{
test01();
test02();
test03();
+ test04();
+ test05();
+ test06();
}