Skip to content

Instantly share code, notes, and snippets.

@hrj
Forked from akihiro4chawon/parasort.scala
Last active December 24, 2015 07:09
Show Gist options
  • Save hrj/6761561 to your computer and use it in GitHub Desktop.
Save hrj/6761561 to your computer and use it in GitHub Desktop.
import java.util.Arrays
abstract class Sorter {
def sorted(a: Array[Int]): Array[Int]
}
object SimpleSorter extends Sorter {
def sorted(a: Array[Int]) = {
Arrays.sort(a)
a
}
}
object DivideAndMergeParallelSorter extends Sorter {
def sorted(a: Array[Int]) = {
require(a.length >= 2)
import scala.annotation.tailrec
import scala.concurrent.ops._
val len = a.length
val half = len / 2
par(Arrays.sort(a, 0, half), Arrays.sort(a, half, len))
val ret = new Array[Int](a.length)
@tailrec
def merge(i: Int, j: Int, k: Int) {
if (a(j) <= a(k)) {
ret(i) = a(j)
if (j < half - 1) merge(i + 1, j + 1, k)
else System.arraycopy(a, k, ret, i + 1, len - k)
} else {
ret(i) = a(k)
if (k < len - 1) merge(i + 1, j, k + 1)
else System.arraycopy(a, j, ret, i + 1, half - j)
}
}
merge(0, 0, half)
ret
}
}
object DivideAndMergeParallelSorter2 extends Sorter {
def sorted(a: Array[Int]) = {
require(a.length >= scala.collection.parallel.availableProcessors)
import scala.annotation.tailrec
import scala.concurrent.ops._
val nDiv = collection.parallel.availableProcessors
val len = a.length
val pslices = (0 until nDiv).par map {i => Arrays.copyOfRange(a, i * len / nDiv, (i + 1) * len / nDiv)}
pslices foreach (Arrays.sort _)
def merge(a: Array[Int], b: Array[Int]): Array[Int] = {
val alen = a.length
val blen = b.length
val ret = new Array[Int](alen + blen);
@tailrec def rec(i: Int, j: Int, k: Int) {
if (a(j) <= b(k)) {
ret(i) = a(j)
if (j < alen - 1) rec(i + 1, j + 1, k)
else System.arraycopy(b, k, ret, i + 1, blen - k)
} else {
ret(i) = b(k)
if (k < blen - 1) rec(i + 1, j, k + 1)
else System.arraycopy(a, j, ret, i + 1, alen - j)
}
}
rec(0, 0, 0)
ret
}
pslices reduce merge
}
}
object ShellParallelSorter extends Sorter {
def sorted(a: Array[Int]) = {
val hInit = (Iterator.iterate(1)(_ * 3 + 1) find (a.length <) get) / 3
for (h <- Iterator.iterate(hInit)(_ / 3) takeWhile (1 <=)) {
for (k <- 0 until h par) {
var i = k - h; while ({i += h; i < a.size}) {
// for (i <- k until a.size by h) {
val v = a(i)
var j = i
while (j >= h && (a(j - h) > v)) {
a(j) = a(j - h)
j -= h
}
a(j) = v
}
}
}
a
}
}
object ShellParallelSorterOpt extends Sorter {
def sorted(a: Array[Int]) = {
val len = a.length
var h = 1; while (h < len) {h *= 3; h += 1}
while ({h /= 3; h >= 1}) {
for (k <- 0 until h par) {
var i = k - h; while ({i += h; i < len}) {
val v = a(i)
var j = i
while (j >= h && (a(j - h) > v)) {
a(j) = a(j - h)
j -= h
}
a(j) = v
}
}
}
a
}
}
object ShellParallelSorterOpt2 extends Sorter {
def sorted(a: Array[Int]) = {
val len = a.length
var h = 1; while (h < len) {h *= 3; h += 1}
h/=3;h/=3
while ({h /= 3; h > 1}) {
for (k <- 0 until h par) {
var i = k - h; while ({i += h; i < len}) {
val v = a(i)
var j = i
while (j >= h && (a(j - h) > v)) {
a(j) = a(j - h)
j -= h
}
a(j) = v
}
}
}
var i = 0; while ({i += 1; i < len}) {
val v = a(i)
var j = i; while (j >= h && (a(j - h) > v)) {
a(j) = a(j - 1)
j -= 1
}
a(j) = v
}
a
}
}
object ShellParallelSorterThread extends Sorter with Runnable {
val count = new java.util.concurrent.atomic.AtomicInteger
var h: Int = 1
var a: Array[Int] = null
def run() {
var i: Int = 0
val len = this.a.length
val a = this.a
val h = this.h
while ({i = count.getAndDecrement; i >= 0}) {
i -= h; while ({i += h; i < len}) {
val v = a(i)
var j = i
while (j >= h && (a(j - h) > v)) {
a(j) = a(j - h)
j -= h
}
a(j) = v
}
}
}
def sorted(a: Array[Int]) = {
this.a = a
val len = a.length
while (h < len) {h *= 3; h += 1}
h /= 3;
while ({h /= 3; h > 1}) {
count.set(h - 1)
val nThreads = (Runtime.getRuntime.availableProcessors - 1) min (h - 1)
val pool = Array.fill(nThreads)(new Thread(this))
pool foreach {_.start}
run
pool foreach {_.join}
}
var i = 0; while ({i += 1; i < len}) {
val v = a(i)
var j = i; while (j >= h && (a(j - h) > v)) {
a(j) = a(j - 1)
j -= 1
}
a(j) = v
}
a
}
}
object Main extends App {
import scala.collection.mutable.WrappedArray
import scala.util.Random
object RandomSource {
private val src: WrappedArray[Int] = Array.range(0, 1000000)
def getShuffled = scala.util.Random.shuffle(src).toArray
}
val impls = Seq(
SimpleSorter,
DivideAndMergeParallelSorter,
DivideAndMergeParallelSorter2,
ShellParallelSorter,
ShellParallelSorterOpt,
ShellParallelSorterOpt2,
ShellParallelSorterThread)
check(impls.head, impls.tail :_*)
impls foreach { sorter =>
System.gc
val times = for (i <- 1 to 10) yield benchmark(sorter, i)
val noOutliers = times.sorted.drop(2).reverse.drop(2)
println("avg. " + noOutliers.sum / noOutliers.size / 1000000 + "[ms]")
}
def benchmark(sorter: Sorter, i: Int) = {
val r = RandomSource.getShuffled
val t1 = System.nanoTime
sorter.sorted(r)
val t = System.nanoTime - t1
println(sorter.getClass.getName+" #"+i+": "+((t)/1000000)+"[ms]")
t
}
def check(refSorter: Sorter, sorters: Sorter*) {
val r = RandomSource.getShuffled
val a = refSorter.sorted(r.clone)
sorters foreach {s => assert(Arrays.equals(s.sorted(r.clone), a))}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment