I thought it would be an optimization to use _M_find_tr(k) != end()
for the unique associative containers, but as the PR points out the
heterogeneous version of count() can find multiple matches even in a
unique container. We need to use _M_count_tr(k)  to find all matches.

        PR libstdc++/78273
        * include/bits/stl_map.h (map::count<_Kt>(const _Kt&)): Don't assume
        the heterogeneous comparison can only find one match.
        * include/bits/stl_set.h (set::count<_Kt>(const _Kt&)): Likewise.
        * testsuite/23_containers/map/operations/2.cc: Test count works with
        comparison function that just partitions rather than sorting.
        * testsuite/23_containers/set/operations/2.cc: Likewise.

Tested powerpc64le-linux, committed to trunk.

I'll also backport this to the branches.

commit 8e5f512a435d4ccc7b592a7c4872418943a5c9c7
Author: Jonathan Wakely <jwak...@redhat.com>
Date:   Wed Jan 11 13:49:02 2017 +0000

    PR78273 fix count to work with partitioning function
    
        PR libstdc++/78273
        * include/bits/stl_map.h (map::count<_Kt>(const _Kt&)): Don't assume
        the heterogeneous comparison can only find one match.
        * include/bits/stl_set.h (set::count<_Kt>(const _Kt&)): Likewise.
        * testsuite/23_containers/map/operations/2.cc: Test count works with
        comparison function that just partitions rather than sorting.
        * testsuite/23_containers/set/operations/2.cc: Likewise.

diff --git a/libstdc++-v3/include/bits/stl_map.h 
b/libstdc++-v3/include/bits/stl_map.h
index f2a0ffa..91b80d9 100644
--- a/libstdc++-v3/include/bits/stl_map.h
+++ b/libstdc++-v3/include/bits/stl_map.h
@@ -1194,7 +1194,7 @@ _GLIBCXX_BEGIN_NAMESPACE_CONTAINER
       template<typename _Kt>
        auto
        count(const _Kt& __x) const -> decltype(_M_t._M_count_tr(__x))
-       { return _M_t._M_find_tr(__x) == _M_t.end() ? 0 : 1; }
+       { return _M_t._M_count_tr(__x); }
 #endif
       //@}
 
diff --git a/libstdc++-v3/include/bits/stl_set.h 
b/libstdc++-v3/include/bits/stl_set.h
index 66560a7..ab960f1 100644
--- a/libstdc++-v3/include/bits/stl_set.h
+++ b/libstdc++-v3/include/bits/stl_set.h
@@ -739,7 +739,7 @@ _GLIBCXX_BEGIN_NAMESPACE_CONTAINER
        auto
        count(const _Kt& __x) const
        -> decltype(_M_t._M_count_tr(__x))
-       { return _M_t._M_find_tr(__x) == _M_t.end() ? 0 : 1; }
+       { return _M_t._M_count_tr(__x); }
 #endif
       //@}
 
diff --git a/libstdc++-v3/testsuite/23_containers/map/operations/2.cc 
b/libstdc++-v3/testsuite/23_containers/map/operations/2.cc
index 6509084..ef4e76b 100644
--- a/libstdc++-v3/testsuite/23_containers/map/operations/2.cc
+++ b/libstdc++-v3/testsuite/23_containers/map/operations/2.cc
@@ -133,6 +133,27 @@ test05()
   VERIFY( Cmp::count == 0);
 }
 
+void
+test06()
+{
+  // PR libstdc++/78273
+
+  struct C {
+    bool operator()(int l, int r) const { return l < r; }
+
+    struct Partition { };
+
+    bool operator()(int l, Partition) const { return l < 2; }
+    bool operator()(Partition, int r) const { return 4 < r; }
+
+    using is_transparent = void;
+  };
+
+  std::map<int, int, C> m{ {1,0}, {2,0}, {3,0}, {4, 0}, {5, 0} };
+
+  auto n = m.count(C::Partition{});
+  VERIFY( n == 3 );
+}
 
 int
 main()
@@ -142,4 +163,5 @@ main()
   test03();
   test04();
   test05();
+  test06();
 }
diff --git a/libstdc++-v3/testsuite/23_containers/set/operations/2.cc 
b/libstdc++-v3/testsuite/23_containers/set/operations/2.cc
index aa71ae5..aef808d 100644
--- a/libstdc++-v3/testsuite/23_containers/set/operations/2.cc
+++ b/libstdc++-v3/testsuite/23_containers/set/operations/2.cc
@@ -150,6 +150,28 @@ test06()
   s.find(i);
 }
 
+void
+test07()
+{
+  // PR libstdc++/78273
+
+  struct C {
+    bool operator()(int l, int r) const { return l < r; }
+
+    struct Partition { };
+
+    bool operator()(int l, Partition) const { return l < 2; }
+    bool operator()(Partition, int r) const { return 4 < r; }
+
+    using is_transparent = void;
+  };
+
+  std::set<int, C> s{ 1, 2, 3, 4, 5 };
+
+  auto n = s.count(C::Partition{});
+  VERIFY( n == 3 );
+}
+
 int
 main()
 {
@@ -159,4 +181,5 @@ main()
   test04();
   test05();
   test06();
+  test07();
 }

Reply via email to