CyclicBarrier example: a parallel sort algorithm (ctd)

Having seen the general CyclicBarrier pattern for performing a task such as a parallel sort, on the next couple of pages, we'll fill in the details and actually implement the sort.

To implement the sort, I'd firstly suggest creating a "sorter" object that outside callers create and call into. The sorter object will hold the current state of the sort, including the intermediary "buckets", and parameters such as the number of threads. If some of the other variables aren't clear, they hopefully will be by the end of this explanation.

public class ParallelSorter<E extends Comparable<E>> {
  private final int noThreads =
    Runtime.getRuntime().availableProcessors();
  private final int noSamplesPerThread = 16;
  private final AtomicLong randSeed =
    new AtomicLong(System.nanoTime());
  private volatile int stageNo = 0;
  private final int dataSize;
  private final List<E> data;
  private final List<E> splitPoints =
    new ArrayList<E>(noSamplesPerThread * noThreads);
  private final List<List<E>> bucketsToSort;
  private final ReadWriteLock dataLock;

  private final CyclicBarrier barrier =
    new CyclicBarrier(noThreads + 1, new Runnable() {
      public void run() {
        sortStageComplete();
      }
    });

  public ParallelSorter(List data, ReadWriteLock dataLock) {
    if (!(data instanceof RandomAccess))
      throw new IllegalArgumentException("List must be random access");
    this.data = data;
    this.dataLock = dataLock;
    this.dataSize = data.size();
    List<List<E>> tempList = new ArrayList<List<E>>(noThreads);
    for (int i = 0; i < noThreads; i++) {
      tempList.add(new ArrayList(dataSize / noThreads));
    }
    bucketsToSort = Collections.unmodifiableList(tempList);
  }
}

Safe data access from multiple threads

An important element to note is the ReadWriteLock passed to the constructor, which we'll use to guard the data list. Unless we want to take a copy of it (bearing in mind the whole point of a parallel sort is that the data list could be quite large), we need some way of allowing multiple threads to access the data list safely, and for one of those threads to be the calling thread. Using a ReadWriteLock (as opposed to just any old Lock) gives us an advantage because there are several places where we need concurrent reads and we're not expecting concurrent writes. (In this case, a regular synchronized block, for example, would force threads to read one thread at a time where this serialisation of reads would be unnecessary.)

The bucketsToSort field is a slightly nasty list of a list: we essentially want one list per thread, with which list we use indexed on thread number. (Having a list within a list, rather than an array of lists, makes the syntax a bit easier for dealing with generics.) Once we've constructed a given thread's list, the actual list object won't change, but the contents of the list will. In other words, the outer List will only ever be read once it is constructed, and storing the reference to the list in a final field is enough to give thread-safe access. (The Collections.unmodifiableList() wrapper is just to ensure we don't accidentally try and modify this list; it isn't really what adds the thread-safety.)

A given thread's individual list (i.e. one of the elements of bucketsToSort) will be both read and written, and all accesses to one of these "inner" lists requires synchronization on the individual list.

Sort worker thread

The ParallelSorter will then include the method sortStageComplete() previously mentioned, plus the inner worker class; here, we simply fill in a few details:

  private class SorterThread extends Thread {
    private final int threadNo;
    private volatile Throwable error;
    SorterThread(int no) {
      this.threadNo = no;
    }
    public void run() {
      try {
        double div = (double) dataSize / noThreads;
        int startPos = (int) (div * threadNo),
            endPos = (int) (div * (threadNo + 1));

        gatherSplitPointSample(data, startPos, endPos);
        barrier.await();        
        assignItemsToBuckets(data, threadNo, startPos, endPos);
        barrier.await();
        sortMyBucket();
        barrier.await();
      } catch (InterruptedException e) {
      } catch (BrokenBarrierException e) {
      } catch (Throwable t) {
        this.error = t;
        Thread.currentThread().interrupt();
        try {
          barrier.await();
        } catch (Exception e) {}
      }
    }
    private void sortMyBucket() {
      List<E> l = bucketsToSort.get(threadNo);
      synchronized (l) {
        Collections.sort(l);
      }
    }
  }

  private void sortStageComplete() {
    try {
      switch (stageNo) {
      case 0 : amalgamateSplitPointData(); break;
      case 1 : clearData(); break;
      case 2 : combineBuckets(); break;
      default :
        throw new RuntimeException("Don't expect to be "
          + " called at stage " + stageNo);
      stageNo++;
    } catch (RuntimeException rte) {
      completionStageError = rte;
      throw rte;
    }
  }
  private volatile RuntimeException completionStageError;

Notice the wrapper to save any RuntimeException occurring during the sortStageComplete() method, as discussed in the section on error handling with CyclicBarrier.

Gathering split points

The first method we need is gatherSplitPointSample() from the SorterThread's run() method. Recall that this must select noSamplesPerThread items at random from the given portion of the data. For this, we'll simply select noSamplesPerThread random numbers within our allocated index range, and accept that we could generate the same index twice. Normally for random sample selection, this possibility of duplicate indexes wouldn't be acceptable, and to select a given number of elements at random, we'd use the correct technique of random sampling. Correct random sampling has the disadvantage that it requires us to generate one random number per item in the list. But if the number of samples is very small compared to the list size, then the chance of duplicate index is negligible, and given our purpose here the occasional duplicate would not matter. To avoid too much contention on the shared splitPoints array, we initially add our values to a local list and then add that local list to the shared one at the end. (Arguably, amalgamating the per-thread list of split points should be carried out in the amalgamateSplitPointData() method; we do it here simply because it's a bit less complicated and probably doesn't make much overall different performance-wise.) Each thread has its own Random object, just to avoid contention on a shared generator; we use our own seed, incremented each time, just because we know in this case that we'll create several Random instances from different threads in succession. In most applications, the weak guarantee that the Random class offers of a different seed per instance is good enough, and we wouldn't go to the trouble of managing our own seed generation.

In the end, the gatherSplitPointSample method ends up as follows:

private void gatherSplitPointSample(List data, int startPos, int endPos) {
  Random rand = new Random(randSeed.getAndAdd(17));
  List sample = new ArrayList(noSamplesPerThread);
  Lock l = dataLock.readLock();
  l.lock();
  try {
    for (int i = 0; i < noSamplesPerThread; i++) {
      int n = rand.nextInt(endPos - startPos) + startPos;
      sample.add(data.get(n));
    }
  } finally {
    l.unlock();
  }
  synchronized (splitPoints) {
    splitPoints.addAll(sample);
  }
}

Notice the call to dataLock.readLock() to fetch the read part of the lock, and then the lock() and unlock() calls. At this stage, we know that all access to the data are reads (and so multiple threads holding read locks won't stall one another). So we happily hold on to the lock over the entire loop rather than just around the call to data.get() which is where it is strictly necessary.

Amalgamating the split point data— which you'll recall is executing in a single thread after each thread has gathered its sample— is a simple question of sorting the sample values, and then taking the 16th, 32nd etc values (because noSamplesPerThread is 16 in our case). In the following implementation, we copy the sorted sample into a temporary list, then put just the required samples back into the splitPoints list. The code's pretty much what you'd expect. Remember, we must still synchronize for data visibility reasons, because other threads have been and will be accessing splitPoints:

private void amalgamateSplitPointData() {
  synchronized (splitPoints) {
    List spl = new ArrayList(splitPoints);
    Collections.sort(spl);
    splitPoints.clear();
    for (int i = 1; i < noThreads; i++) {
      splitPoints.add(spl.get(i * noSamplesPerThread));
    }
  }    
}

Next: stages 2 and 3 of the sort

On the next page, we look at implementing stages 2 and 3 of the parallel sort.


If you enjoy this Java programming article, please share with friends and colleagues. Follow the author on Twitter for the latest news and rants.

Editorial page content written by Neil Coffey. Copyright © Javamex UK 2021. All rights reserved.