Created
February 12, 2015 07:06
-
-
Save ryantenney/8353eb10e689a3da9849 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.enernoc.cost.common.util.concurrent; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.AbstractQueue; | |
import java.util.Collection; | |
import java.util.HashSet; | |
import java.util.Iterator; | |
import java.util.LinkedList; | |
import java.util.Queue; | |
import java.util.Set; | |
import java.util.concurrent.BlockingQueue; | |
import java.util.concurrent.Semaphore; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.concurrent.locks.Condition; | |
import java.util.concurrent.locks.ReentrantLock; | |
import javax.annotation.concurrent.GuardedBy; | |
public final class CompletelyFairWorkQueue extends AbstractQueue<Runnable> implements BlockingQueue<Runnable> { | |
private static final Logger LOG = LoggerFactory.getLogger(CompletelyFairWorkQueue.class); | |
private static final long serialVersionUID = 1087694850619920832L; | |
private final int leasesPerThread; | |
private final int maxLeasesPerThread; | |
private final ThreadLocal<Handle> handles; | |
private final ReentrantLock takeLock = new ReentrantLock(); | |
private final Condition notEmpty = takeLock.newCondition(); | |
//private final ReentrantLock putLock = new ReentrantLock(); | |
private final ReentrantLock putLock = takeLock; | |
private final AtomicInteger count = new AtomicInteger(); | |
private final LinkedList<Handle> subqueue = new LinkedList<>(); | |
private final Set<Handle> activeHandles = new HashSet<>(); | |
public CompletelyFairWorkQueue(int leasesPerThread) { | |
this(leasesPerThread, leasesPerThread); | |
} | |
public CompletelyFairWorkQueue(int leasesPerThread, int maxLeasesPerThread) { | |
if (leasesPerThread <= 0 || maxLeasesPerThread <= 0 || maxLeasesPerThread < leasesPerThread) { | |
throw new IllegalArgumentException(); | |
} | |
this.leasesPerThread = leasesPerThread; | |
this.maxLeasesPerThread = maxLeasesPerThread; | |
this.handles = ThreadLocal.withInitial(Handle::new); | |
} | |
@Override | |
public Runnable poll() { | |
final AtomicInteger count = this.count; | |
final ReentrantLock takeLock = this.takeLock; | |
Runnable runnable = null; | |
takeLock.lock(); | |
try { | |
runnable = dequeue(); | |
if (runnable != null) { | |
int c = count.getAndDecrement(); | |
if (c > 1) { | |
notEmpty.signal(); | |
} | |
} | |
} | |
finally { | |
takeLock.unlock(); | |
} | |
return runnable; | |
} | |
@Override | |
public Runnable poll(long timeout, TimeUnit unit) throws InterruptedException { | |
final AtomicInteger count = this.count; | |
final ReentrantLock takeLock = this.takeLock; | |
long nanos = unit.toNanos(timeout); | |
Runnable runnable = null; | |
takeLock.lockInterruptibly(); | |
try { | |
while (count.get() == 0) { | |
if (nanos <= 0) { | |
return null; | |
} | |
nanos = notEmpty.awaitNanos(nanos); | |
} | |
runnable = dequeue(); | |
if (runnable != null) { | |
int c = count.getAndDecrement(); | |
if (c > 1) { | |
notEmpty.signal(); | |
} | |
} | |
} | |
finally { | |
takeLock.unlock(); | |
} | |
return runnable; | |
} | |
@Override | |
public Runnable take() throws InterruptedException { | |
final AtomicInteger count = this.count; | |
final ReentrantLock takeLock = this.takeLock; | |
Runnable runnable = null; | |
takeLock.lockInterruptibly(); | |
try { | |
while (count.get() == 0) { | |
notEmpty.await(); | |
} | |
runnable = dequeue(); | |
if (runnable != null) { | |
int c = count.getAndDecrement(); | |
if (c > 1) { | |
notEmpty.signal(); | |
} | |
} | |
} | |
finally { | |
takeLock.unlock(); | |
} | |
return runnable; | |
} | |
private Runnable dequeue() { | |
return dequeue(false); | |
} | |
private Runnable dequeue(boolean allowSurplus) { | |
if (subqueue.isEmpty()) { | |
return null; | |
} | |
Runnable runnable = null; | |
Iterator<Handle> iterator = null; | |
do { | |
// Common case is likely that the first item on the queue is valid, this way we avoid creating an iterator every time | |
// As the saying goes: "Premature optimization is awesome and has never caused any problems ever." | |
Handle handle = (iterator == null) ? subqueue.peek() : iterator.next(); | |
handle.lock(); | |
try { | |
// Though the handle shouldn't be empty here, test if it is and discard. | |
if (handle.isEmpty()) { | |
if (iterator == null) { | |
subqueue.remove(handle); | |
// In this case we need to create the iterator here, otherwise | |
// we call subqueue.listIterator(1) futher down, and we skip a handle | |
if (!subqueue.isEmpty()) { | |
iterator = subqueue.listIterator(); | |
} | |
} | |
else { | |
iterator.remove(); | |
} | |
continue; | |
} | |
// Is there capacity to do the work? | |
if (handle.tryAcquire(allowSurplus)) { | |
runnable = handle.poll(); | |
// Remove the handle from the queue. | |
if (iterator == null) { | |
subqueue.remove(handle); | |
} | |
else { | |
iterator.remove(); | |
} | |
// Is there more work to be done after this? | |
if (handle.isEmpty()) { | |
// This handle is no longer in active rotation. | |
activeHandles.remove(handle); | |
} | |
else { | |
// Put it back at the end of the queue. | |
subqueue.offer(handle); | |
} | |
break; | |
} | |
// Now create the iterator | |
if (iterator == null && subqueue.size() > 1) { | |
iterator = subqueue.listIterator(1); | |
} | |
} | |
finally { | |
handle.unlock(); | |
} | |
} | |
while (iterator != null && iterator.hasNext()); | |
if (runnable == null && !allowSurplus) { | |
runnable = dequeue(true); | |
} | |
return runnable; | |
} | |
@Override | |
public Runnable peek() { | |
throw new UnsupportedOperationException("peek"); | |
} | |
@Override | |
public boolean offer(Runnable runnable) { | |
if (runnable == null) { | |
throw new NullPointerException(); | |
} | |
final ReentrantLock putLock = this.putLock; | |
final AtomicInteger count = this.count; | |
int c = -1; | |
putLock.lock(); | |
try { | |
enqueue(runnable); | |
c = count.getAndIncrement(); | |
} | |
finally { | |
putLock.unlock(); | |
} | |
if (c == 0) { | |
signalNotEmpty(); | |
} | |
return c >= 0; | |
} | |
@Override | |
public boolean offer(Runnable runnable, long timeout, TimeUnit unit) throws InterruptedException { | |
return _put(runnable); | |
} | |
@Override | |
public void put(Runnable runnable) throws InterruptedException { | |
_put(runnable); | |
} | |
private boolean _put(Runnable runnable) throws InterruptedException { | |
if (runnable == null) { | |
throw new NullPointerException(); | |
} | |
final ReentrantLock putLock = this.putLock; | |
final AtomicInteger count = this.count; | |
int c = -1; | |
putLock.lockInterruptibly(); | |
try { | |
enqueue(runnable); | |
c = count.getAndIncrement(); | |
} | |
finally { | |
putLock.unlock(); | |
} | |
if (c == 0) { | |
signalNotEmpty(); | |
} | |
return c >= 0; | |
} | |
private boolean enqueue(Runnable runnable) { | |
Handle handle = handles.get(); | |
handle.lock(); | |
try { | |
if (activeHandles.add(handle)) { | |
subqueue.offer(handle); | |
} | |
handle.offer(runnable); | |
} | |
finally { | |
handle.unlock(); | |
} | |
return true; | |
} | |
@Override | |
public int size() { | |
return count.get(); | |
} | |
@Override | |
public int remainingCapacity() { | |
return Integer.MAX_VALUE; | |
} | |
@Override | |
public int drainTo(Collection<? super Runnable> c) { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public int drainTo(Collection<? super Runnable> c, int maxElements) { | |
throw new UnsupportedOperationException(); | |
} | |
@Override | |
public Iterator<Runnable> iterator() { | |
throw new UnsupportedOperationException(); | |
} | |
private final void signalNotEmpty() { | |
final ReentrantLock takeLock = this.takeLock; | |
takeLock.lock(); | |
try { | |
notEmpty.signal(); | |
} | |
finally { | |
takeLock.unlock(); | |
} | |
} | |
private final class Handle { | |
private final Queue<Runnable> queue = new LinkedList<>(); | |
private final ReentrantLock lock = new ReentrantLock(); | |
private final Semaphore leaseSemaphore = new Semaphore(leasesPerThread); | |
private final Semaphore surplusLeaseSemaphore = new Semaphore(maxLeasesPerThread - leasesPerThread); | |
private final AtomicInteger surplusLeasesCount = new AtomicInteger(); | |
public boolean isEmpty() { | |
return queue.isEmpty(); | |
} | |
public void lock() { | |
lock.lock(); | |
} | |
public void unlock() { | |
lock.unlock(); | |
} | |
public boolean tryAcquire() { | |
return tryAcquire(false); | |
} | |
public boolean tryAcquire(boolean allowSurplus) { | |
// Called used when trying to obtain a permit to execute | |
// head of this queue, so we don't care about fairness. | |
// (tryAcquire() is non-fair, tryAcquire(int, TimeUnit) is) | |
boolean acquired = leaseSemaphore.tryAcquire(); | |
if (!acquired && allowSurplus) { | |
acquired = surplusLeaseSemaphore.tryAcquire(); | |
if (acquired) { | |
surplusLeasesCount.incrementAndGet(); | |
} | |
} | |
return acquired; | |
} | |
private void release() { | |
int surplusLeases = this.surplusLeasesCount.get(); | |
if (surplusLeases > 0) { | |
for (;;) { | |
if (this.surplusLeasesCount.compareAndSet(surplusLeases, surplusLeases - 1)) { | |
surplusLeaseSemaphore.release(); | |
return; | |
} | |
surplusLeases = this.surplusLeasesCount.get(); | |
} | |
} | |
leaseSemaphore.release(); | |
} | |
@GuardedBy("lock") | |
public Runnable poll() { | |
final Runnable runnable = queue.poll(); | |
return new Runnable() { | |
@Override | |
public void run() { | |
try { | |
runnable.run(); | |
} | |
finally { | |
release(); | |
} | |
} | |
}; | |
} | |
@GuardedBy("lock") | |
public void offer(Runnable runnable) { | |
queue.offer(runnable); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment