balanced_quicksort.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 
00003 // Copyright (C) 2007, 2008, 2009 Free Software Foundation, Inc.
00004 //
00005 // This file is part of the GNU ISO C++ Library.  This library is free
00006 // software; you can redistribute it and/or modify it under the terms
00007 // of the GNU General Public License as published by the Free Software
00008 // Foundation; either version 3, or (at your option) any later
00009 // version.
00010 
00011 // This library is distributed in the hope that it will be useful, but
00012 // WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00014 // General Public License for more details.
00015 
00016 // Under Section 7 of GPL version 3, you are granted additional
00017 // permissions described in the GCC Runtime Library Exception, version
00018 // 3.1, as published by the Free Software Foundation.
00019 
00020 // You should have received a copy of the GNU General Public License and
00021 // a copy of the GCC Runtime Library Exception along with this program;
00022 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
00023 // <http://www.gnu.org/licenses/>.
00024 
00025 /** @file parallel/balanced_quicksort.h
00026  *  @brief Implementation of a dynamically load-balanced parallel quicksort.
00027  *
00028  *  It works in-place and needs only logarithmic extra memory.
00029  *  The algorithm is similar to the one proposed in
00030  *
00031  *  P. Tsigas and Y. Zhang.
00032  *  A simple, fast parallel implementation of quicksort and
00033  *  its performance evaluation on SUN enterprise 10000.
00034  *  In 11th Euromicro Conference on Parallel, Distributed and
00035  *  Network-Based Processing, page 372, 2003.
00036  *
00037  *  This file is a GNU parallel extension to the Standard C++ Library.
00038  */
00039 
00040 // Written by Johannes Singler.
00041 
00042 #ifndef _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H
00043 #define _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H 1
00044 
00045 #include <parallel/basic_iterator.h>
00046 #include <bits/stl_algo.h>
00047 
00048 #include <parallel/settings.h>
00049 #include <parallel/partition.h>
00050 #include <parallel/random_number.h>
00051 #include <parallel/queue.h>
00052 #include <functional>
00053 
00054 #if _GLIBCXX_ASSERTIONS
00055 #include <parallel/checkers.h>
00056 #endif
00057 
00058 namespace __gnu_parallel
00059 {
00060 /** @brief Information local to one thread in the parallel quicksort run. */
00061 template<typename RandomAccessIterator>
00062   struct QSBThreadLocal
00063   {
00064     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00065     typedef typename traits_type::difference_type difference_type;
00066 
00067     /** @brief Continuous part of the sequence, described by an
00068     iterator pair. */
00069     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00070 
00071     /** @brief Initial piece to work on. */
00072     Piece initial;
00073 
00074     /** @brief Work-stealing queue. */
00075     RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
00076 
00077     /** @brief Number of threads involved in this algorithm. */
00078     thread_index_t num_threads;
00079 
00080     /** @brief Pointer to a counter of elements left over to sort. */
00081     volatile difference_type* elements_leftover;
00082 
00083     /** @brief The complete sequence to sort. */
00084     Piece global;
00085 
00086     /** @brief Constructor.
00087      *  @param queue_size Size of the work-stealing queue. */
00088     QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
00089   };
00090 
00091 /** @brief Balanced quicksort divide step.
00092   *  @param begin Begin iterator of subsequence.
00093   *  @param end End iterator of subsequence.
00094   *  @param comp Comparator.
00095   *  @param num_threads Number of threads that are allowed to work on
00096   *  this part.
00097   *  @pre @c (end-begin)>=1 */
00098 template<typename RandomAccessIterator, typename Comparator>
00099   typename std::iterator_traits<RandomAccessIterator>::difference_type
00100   qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
00101              Comparator comp, thread_index_t num_threads)
00102   {
00103     _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
00104 
00105     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00106     typedef typename traits_type::value_type value_type;
00107     typedef typename traits_type::difference_type difference_type;
00108 
00109     RandomAccessIterator pivot_pos =
00110       median_of_three_iterators(begin, begin + (end - begin) / 2,
00111                 end  - 1, comp);
00112 
00113 #if defined(_GLIBCXX_ASSERTIONS)
00114     // Must be in between somewhere.
00115     difference_type n = end - begin;
00116 
00117     _GLIBCXX_PARALLEL_ASSERT(
00118            (!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
00119         || (!comp(*pivot_pos, *begin) && !comp(*(end - 1), *pivot_pos))
00120         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
00121         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*(end - 1), *pivot_pos))
00122         || (!comp(*pivot_pos, *(end - 1)) && !comp(*begin, *pivot_pos))
00123         || (!comp(*pivot_pos, *(end - 1)) && !comp(*(begin + n / 2), *pivot_pos)));
00124 #endif
00125 
00126     // Swap pivot value to end.
00127     if (pivot_pos != (end - 1))
00128       std::swap(*pivot_pos, *(end - 1));
00129     pivot_pos = end - 1;
00130 
00131     __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
00132         pred(comp, *pivot_pos);
00133 
00134     // Divide, returning end - begin - 1 in the worst case.
00135     difference_type split_pos = parallel_partition(
00136         begin, end - 1, pred, num_threads);
00137 
00138     // Swap back pivot to middle.
00139     std::swap(*(begin + split_pos), *pivot_pos);
00140     pivot_pos = begin + split_pos;
00141 
00142 #if _GLIBCXX_ASSERTIONS
00143     RandomAccessIterator r;
00144     for (r = begin; r != pivot_pos; ++r)
00145       _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
00146     for (; r != end; ++r)
00147       _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
00148 #endif
00149 
00150     return split_pos;
00151   }
00152 
00153 /** @brief Quicksort conquer step.
00154   *  @param tls Array of thread-local storages.
00155   *  @param begin Begin iterator of subsequence.
00156   *  @param end End iterator of subsequence.
00157   *  @param comp Comparator.
00158   *  @param iam Number of the thread processing this function.
00159   *  @param num_threads
00160   *          Number of threads that are allowed to work on this part. */
00161 template<typename RandomAccessIterator, typename Comparator>
00162   void
00163   qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
00164               RandomAccessIterator begin, RandomAccessIterator end,
00165               Comparator comp,
00166               thread_index_t iam, thread_index_t num_threads,
00167               bool parent_wait)
00168   {
00169     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00170     typedef typename traits_type::value_type value_type;
00171     typedef typename traits_type::difference_type difference_type;
00172 
00173     difference_type n = end - begin;
00174 
00175     if (num_threads <= 1 || n <= 1)
00176       {
00177         tls[iam]->initial.first  = begin;
00178         tls[iam]->initial.second = end;
00179 
00180         qsb_local_sort_with_helping(tls, comp, iam, parent_wait);
00181 
00182         return;
00183       }
00184 
00185     // Divide step.
00186     difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
00187 
00188 #if _GLIBCXX_ASSERTIONS
00189     _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
00190 #endif
00191 
00192     thread_index_t num_threads_leftside =
00193         std::max<thread_index_t>(1, std::min<thread_index_t>(
00194                           num_threads - 1, split_pos * num_threads / n));
00195 
00196 #   pragma omp atomic
00197     *tls[iam]->elements_leftover -= (difference_type)1;
00198 
00199     // Conquer step.
00200 #   pragma omp parallel num_threads(2)
00201     {
00202       bool wait;
00203       if(omp_get_num_threads() < 2)
00204         wait = false;
00205       else
00206         wait = parent_wait;
00207 
00208 #     pragma omp sections
00209         {
00210 #         pragma omp section
00211             {
00212               qsb_conquer(tls, begin, begin + split_pos, comp,
00213                           iam,
00214                           num_threads_leftside,
00215                           wait);
00216               wait = parent_wait;
00217             }
00218           // The pivot_pos is left in place, to ensure termination.
00219 #         pragma omp section
00220             {
00221               qsb_conquer(tls, begin + split_pos + 1, end, comp,
00222                           iam + num_threads_leftside,
00223                           num_threads - num_threads_leftside,
00224                           wait);
00225               wait = parent_wait;
00226             }
00227         }
00228     }
00229   }
00230 
00231 /**
00232   *  @brief Quicksort step doing load-balanced local sort.
00233   *  @param tls Array of thread-local storages.
00234   *  @param comp Comparator.
00235   *  @param iam Number of the thread processing this function.
00236   */
00237 template<typename RandomAccessIterator, typename Comparator>
00238   void
00239   qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
00240                               Comparator& comp, int iam, bool wait)
00241   {
00242     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00243     typedef typename traits_type::value_type value_type;
00244     typedef typename traits_type::difference_type difference_type;
00245     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00246 
00247     QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
00248 
00249     difference_type base_case_n =
00250         _Settings::get().sort_qsb_base_case_maximal_n;
00251     if (base_case_n < 2)
00252       base_case_n = 2;
00253     thread_index_t num_threads = tl.num_threads;
00254 
00255     // Every thread has its own random number generator.
00256     random_number rng(iam + 1);
00257 
00258     Piece current = tl.initial;
00259 
00260     difference_type elements_done = 0;
00261 #if _GLIBCXX_ASSERTIONS
00262     difference_type total_elements_done = 0;
00263 #endif
00264 
00265     for (;;)
00266       {
00267         // Invariant: current must be a valid (maybe empty) range.
00268         RandomAccessIterator begin = current.first, end = current.second;
00269         difference_type n = end - begin;
00270 
00271         if (n > base_case_n)
00272           {
00273             // Divide.
00274             RandomAccessIterator pivot_pos = begin +  rng(n);
00275 
00276             // Swap pivot_pos value to end.
00277             if (pivot_pos != (end - 1))
00278               std::swap(*pivot_pos, *(end - 1));
00279             pivot_pos = end - 1;
00280 
00281             __gnu_parallel::binder2nd
00282                 <Comparator, value_type, value_type, bool>
00283                 pred(comp, *pivot_pos);
00284 
00285             // Divide, leave pivot unchanged in last place.
00286             RandomAccessIterator split_pos1, split_pos2;
00287             split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
00288 
00289             // Left side: < pivot_pos; right side: >= pivot_pos.
00290 #if _GLIBCXX_ASSERTIONS
00291             _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
00292 #endif
00293             // Swap pivot back to middle.
00294             if (split_pos1 != pivot_pos)
00295               std::swap(*split_pos1, *pivot_pos);
00296             pivot_pos = split_pos1;
00297 
00298             // In case all elements are equal, split_pos1 == 0.
00299             if ((split_pos1 + 1 - begin) < (n >> 7)
00300             || (end - split_pos1) < (n >> 7))
00301               {
00302                 // Very unequal split, one part smaller than one 128th
00303                 // elements not strictly larger than the pivot.
00304                 __gnu_parallel::unary_negate<__gnu_parallel::binder1st
00305           <Comparator, value_type, value_type, bool>, value_type>
00306           pred(__gnu_parallel::binder1st
00307                <Comparator, value_type, value_type, bool>(comp,
00308                                   *pivot_pos));
00309 
00310                 // Find other end of pivot-equal range.
00311                 split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
00312                              end, pred);
00313               }
00314             else
00315               // Only skip the pivot.
00316               split_pos2 = split_pos1 + 1;
00317 
00318             // Elements equal to pivot are done.
00319             elements_done += (split_pos2 - split_pos1);
00320 #if _GLIBCXX_ASSERTIONS
00321             total_elements_done += (split_pos2 - split_pos1);
00322 #endif
00323             // Always push larger part onto stack.
00324             if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
00325               {
00326                 // Right side larger.
00327                 if ((split_pos2) != end)
00328                   tl.leftover_parts.push_front(std::make_pair(split_pos2,
00329                                   end));
00330 
00331                 //current.first = begin;    //already set anyway
00332                 current.second = split_pos1;
00333                 continue;
00334               }
00335             else
00336               {
00337                 // Left side larger.
00338                 if (begin != split_pos1)
00339                   tl.leftover_parts.push_front(std::make_pair(begin,
00340                                   split_pos1));
00341 
00342                 current.first = split_pos2;
00343                 //current.second = end; //already set anyway
00344                 continue;
00345               }
00346           }
00347         else
00348           {
00349             __gnu_sequential::sort(begin, end, comp);
00350             elements_done += n;
00351 #if _GLIBCXX_ASSERTIONS
00352             total_elements_done += n;
00353 #endif
00354 
00355             // Prefer own stack, small pieces.
00356             if (tl.leftover_parts.pop_front(current))
00357               continue;
00358 
00359 #           pragma omp atomic
00360             *tl.elements_leftover -= elements_done;
00361 
00362             elements_done = 0;
00363 
00364 #if _GLIBCXX_ASSERTIONS
00365             double search_start = omp_get_wtime();
00366 #endif
00367 
00368             // Look for new work.
00369             bool successfully_stolen = false;
00370             while (wait && *tl.elements_leftover > 0 && !successfully_stolen
00371 #if _GLIBCXX_ASSERTIONS
00372               // Possible dead-lock.
00373               && (omp_get_wtime() < (search_start + 1.0))
00374 #endif
00375               )
00376               {
00377                 thread_index_t victim;
00378                 victim = rng(num_threads);
00379 
00380                 // Large pieces.
00381                 successfully_stolen = (victim != iam)
00382                     && tls[victim]->leftover_parts.pop_back(current);
00383                 if (!successfully_stolen)
00384                   yield();
00385 #if !defined(__ICC) && !defined(__ECC)
00386 #               pragma omp flush
00387 #endif
00388               }
00389 
00390 #if _GLIBCXX_ASSERTIONS
00391             if (omp_get_wtime() >= (search_start + 1.0))
00392               {
00393                 sleep(1);
00394                 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime()
00395                      < (search_start + 1.0));
00396               }
00397 #endif
00398             if (!successfully_stolen)
00399               {
00400 #if _GLIBCXX_ASSERTIONS
00401                 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
00402 #endif
00403                 return;
00404               }
00405           }
00406       }
00407   }
00408 
00409 /** @brief Top-level quicksort routine.
00410   *  @param begin Begin iterator of sequence.
00411   *  @param end End iterator of sequence.
00412   *  @param comp Comparator.
00413   *  @param num_threads Number of threads that are allowed to work on
00414   *  this part.
00415   */
00416 template<typename RandomAccessIterator, typename Comparator>
00417   void
00418   parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
00419                     Comparator comp,
00420                     thread_index_t num_threads)
00421   {
00422     _GLIBCXX_CALL(end - begin)
00423 
00424     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00425     typedef typename traits_type::value_type value_type;
00426     typedef typename traits_type::difference_type difference_type;
00427     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00428 
00429     typedef QSBThreadLocal<RandomAccessIterator> tls_type;
00430 
00431     difference_type n = end - begin;
00432 
00433     if (n <= 1)
00434       return;
00435 
00436     // At least one element per processor.
00437     if (num_threads > n)
00438       num_threads = static_cast<thread_index_t>(n);
00439 
00440     // Initialize thread local storage
00441     tls_type** tls = new tls_type*[num_threads];
00442     difference_type queue_size = num_threads * (thread_index_t)(log2(n) + 1);
00443     for (thread_index_t t = 0; t < num_threads; ++t)
00444       tls[t] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
00445 
00446     // There can never be more than ceil(log2(n)) ranges on the stack, because
00447     // 1. Only one processor pushes onto the stack
00448     // 2. The largest range has at most length n
00449     // 3. Each range is larger than half of the range remaining
00450     volatile difference_type elements_leftover = n;
00451     for (int i = 0; i < num_threads; ++i)
00452       {
00453         tls[i]->elements_leftover = &elements_leftover;
00454         tls[i]->num_threads = num_threads;
00455         tls[i]->global = std::make_pair(begin, end);
00456 
00457         // Just in case nothing is left to assign.
00458         tls[i]->initial = std::make_pair(end, end);
00459       }
00460 
00461     // Main recursion call.
00462     qsb_conquer(tls, begin, begin + n, comp, 0, num_threads, true);
00463 
00464 #if _GLIBCXX_ASSERTIONS
00465     // All stack must be empty.
00466     Piece dummy;
00467     for (int i = 1; i < num_threads; ++i)
00468       _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
00469 #endif
00470 
00471     for (int i = 0; i < num_threads; ++i)
00472       delete tls[i];
00473     delete[] tls;
00474   }
00475 } // namespace __gnu_parallel
00476 
00477 #endif /* _GLIBCXX_PARALLEL_BALANCED_QUICKSORT_H */

Generated on 19 Jun 2018 for libstdc++ by  doxygen 1.6.1