PR libstdc++/100795 libstdc++-v3/ChangeLog:
* include/bits/ranges_algo.h (__sample_fn::operator()): Reimplement the forward_iterator branch directly. * testsuite/25_algorithms/sample/constrained.cc (test02): New test. --- libstdc++-v3/include/bits/ranges_algo.h | 70 +++++++++++++++++-- .../25_algorithms/sample/constrained.cc | 28 ++++++++ 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/libstdc++-v3/include/bits/ranges_algo.h b/libstdc++-v3/include/bits/ranges_algo.h index b12da2af1263..672a0ebce0de 100644 --- a/libstdc++-v3/include/bits/ranges_algo.h +++ b/libstdc++-v3/include/bits/ranges_algo.h @@ -1839,14 +1839,70 @@ namespace ranges operator()(_Iter __first, _Sent __last, _Out __out, iter_difference_t<_Iter> __n, _Gen&& __g) const { + // FIXME: Correctly handle integer-class difference types. if constexpr (forward_iterator<_Iter>) { - // FIXME: Forwarding to std::sample here requires computing __lasti - // which may take linear time. - auto __lasti = ranges::next(__first, __last); - return _GLIBCXX_STD_A:: - sample(std::move(__first), std::move(__lasti), std::move(__out), - __n, std::forward<_Gen>(__g)); + using _Size = iter_difference_t<_Iter>; + using __distrib_type = uniform_int_distribution<_Size>; + using __param_type = typename __distrib_type::param_type; + using _USize = __detail::__make_unsigned_like_t<_Size>; + using __uc_type + = common_type_t<typename remove_reference_t<_Gen>::result_type, _USize>; + + if (__first == __last) + return __out; + + __distrib_type __d{}; + _Size __unsampled_sz = ranges::distance(__first, __last); + __n = std::min(__n, __unsampled_sz); + + // If possible, we use __gen_two_uniform_ints to efficiently produce + // two random numbers using a single distribution invocation: + + const __uc_type __urngrange = __g.max() - __g.min(); + if (__urngrange / __uc_type(__unsampled_sz) >= __uc_type(__unsampled_sz)) + // I.e. (__urngrange >= __unsampled_sz * __unsampled_sz) but without + // wrapping issues. + { + while (__n != 0 && __unsampled_sz >= 2) + { + const pair<_Size, _Size> __p = + __gen_two_uniform_ints(__unsampled_sz, __unsampled_sz - 1, __g); + + --__unsampled_sz; + if (__p.first < __n) + { + *__out = *__first; + ++__out; + --__n; + } + + ++__first; + + if (__n == 0) break; + + --__unsampled_sz; + if (__p.second < __n) + { + *__out = *__first; + ++__out; + --__n; + } + + ++__first; + } + } + + // The loop above is otherwise equivalent to this one-at-a-time version: + + for (; __n != 0; ++__first) + if (__d(__g, __param_type{0, --__unsampled_sz}) < __n) + { + *__out = *__first; + ++__out; + --__n; + } + return __out; } else { @@ -1867,7 +1923,7 @@ namespace ranges if (__k < __n) __out[__k] = *__first; } - return __out + __sample_sz; + return __out + iter_difference_t<_Out>(__sample_sz); } } diff --git a/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc b/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc index b9945b164903..150e2d2036e0 100644 --- a/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc +++ b/libstdc++-v3/testsuite/25_algorithms/sample/constrained.cc @@ -20,6 +20,7 @@ #include <algorithm> #include <random> +#include <ranges> #include <testsuite_hooks.h> #include <testsuite_iterators.h> @@ -59,9 +60,36 @@ test01() } } +void +test02() +{ + // PR libstdc++/100795 - ranges::sample should not use std::sample +#if 0 // FIXME: ranges::sample rejects integer-class difference types. +#if __SIZEOF_INT128__ + auto v = std::views::iota(__int128(0), __int128(20)); +#else + auto v = std::views::iota(0ll, 20ll); +#endif +#else + auto v = std::views::iota(0, 20); +#endif + + int storage[20] = {2,5,4,3,1,6,7,9,10,8,11,14,12,13,15,16,18,0,19,17}; + auto w = v | std::views::transform([&](auto i) -> int& { return storage[i]; }); + using type = decltype(w); + using cat = std::iterator_traits<std::ranges::iterator_t<type>>::iterator_category; + static_assert( std::same_as<cat, std::output_iterator_tag> ); + static_assert( std::ranges::random_access_range<type> ); + + ranges::sample(v, w.begin(), 20, rng); + ranges::sort(w); + VERIFY( ranges::equal(w, v) ); +} + int main() { test01<forward_iterator_wrapper, output_iterator_wrapper>(); test01<input_iterator_wrapper, random_access_iterator_wrapper>(); + test02(); } -- 2.50.0.131.gcf6f63ea6b