Created
February 2, 2019 04:45
-
-
Save ZakTaccardi/aa951827aa8c02c0754c9f4b649edeb3 to your computer and use it in GitHub Desktop.
Some Rx-style operators on ReceiveChannel<T>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:JvmName("RxChannelExtensions") | |
import kotlinx.coroutines.experimental.Dispatchers | |
import kotlinx.coroutines.experimental.GlobalScope | |
import kotlinx.coroutines.experimental.channels.ReceiveChannel | |
import kotlinx.coroutines.experimental.channels.consumeEach | |
import kotlinx.coroutines.experimental.channels.consumes | |
import kotlinx.coroutines.experimental.channels.consumesAll | |
import kotlinx.coroutines.experimental.channels.produce | |
import kotlinx.coroutines.experimental.launch | |
import kotlinx.coroutines.experimental.sync.Mutex | |
import kotlinx.coroutines.experimental.sync.withLock | |
import java.util.concurrent.atomic.AtomicInteger | |
import java.util.concurrent.atomic.AtomicReference | |
import kotlin.coroutines.experimental.CoroutineContext | |
import kotlin.coroutines.experimental.EmptyCoroutineContext | |
/** | |
* Execute a [sideEffect] when [E] emits. | |
* | |
* Equivalent to RxJava's `.doOnNext()` operator. | |
*/ | |
fun <E> ReceiveChannel<E>.doOnNext( | |
context: CoroutineContext = Dispatchers.Unconfined, | |
sideEffect: (E) -> Unit | |
): ReceiveChannel<E> = GlobalScope.produce(context, onCompletion = consumes()) { | |
consumeEach { | |
sideEffect(it) | |
send(it) | |
} | |
} | |
/** | |
* Merges multiple [ReceiveChannel]s of the same type [T] into a single [ReceiveChannel] | |
*/ | |
fun <T> merge( | |
vararg sources: ReceiveChannel<T>, | |
context: CoroutineContext = Dispatchers.Unconfined | |
): ReceiveChannel<T> = | |
GlobalScope.produce(context, onCompletion = consumesAll(*sources)) { | |
sources.forEach { source -> | |
launch { source.consumeEach { send(it) } } | |
} | |
} | |
/** | |
* Emit [E] only when it is not equal to the previous emission. The first emission will always be | |
* emitted. Use this to not emit the same value twice in a raw. | |
* | |
* `.equals()` equality comparison will be used. | |
* | |
* Equivalent to RxJava's `.distinctUntilChanged()` operator. | |
*/ | |
fun <E> ReceiveChannel<E>.distinctUntilChanged( | |
context: CoroutineContext = Dispatchers.Unconfined | |
): ReceiveChannel<E> = GlobalScope.produce(context, onCompletion = consumes()) { | |
val last = AtomicReference<E>() | |
var wasInitialized = false | |
consumeEach { emission -> | |
if (!wasInitialized) { | |
// first emission | |
last.set(emission) | |
wasInitialized = true | |
send(emission) | |
} else { | |
// we have a previous emission to compare to | |
if (emission != last.get()) { | |
// a distinct value has appeared | |
last.set(emission) | |
send(emission) | |
} | |
} | |
} | |
} | |
/** | |
* Suppress the first [skipCount] items emitted by [this] [ReceiveChannel]. | |
*/ | |
fun <E> ReceiveChannel<E>.skip( | |
skipCount: Int, | |
context: CoroutineContext = Dispatchers.Unconfined | |
): ReceiveChannel<E> = GlobalScope.produce(context, onCompletion = consumes()) { | |
// TODO replace with actor to guarentee thread safety | |
val skipped = AtomicInteger(0) | |
val mutex = Mutex() | |
consumeEach { emission: E -> | |
mutex.withLock { | |
if (skipped.get() >= skipCount) { | |
send(emission) | |
} else { | |
// emission was skipped | |
skipped.incrementAndGet() | |
} | |
} | |
} | |
} | |
/** | |
* A combine latest that takes in 3 sources and runs a [combineFunction] over their latest emissions | |
* to emit [R] | |
*/ | |
fun <A : Any?, B : Any?, C : Any?, R> combineLatest( | |
sourceA: ReceiveChannel<A>, | |
sourceB: ReceiveChannel<B>, | |
sourceC: ReceiveChannel<C>, | |
context: CoroutineContext = Dispatchers.Unconfined, | |
combineFunction: suspend (A, B, C) -> R | |
): ReceiveChannel<R> = GlobalScope.produce(context, onCompletion = consumesAll(sourceA, sourceB, sourceC)) { | |
val latestA = AtomicReference<A>() | |
val latestB = AtomicReference<B>() | |
val latestC = AtomicReference<C>() | |
var aInitialized = false | |
var bInitialized = false | |
var cInitialized = false | |
val mutex = Mutex() | |
suspend fun combineAndSendIfInitialized() { | |
if (aInitialized && bInitialized && cInitialized) { | |
send(combineFunction(latestA.get(), latestB.get(), latestC.get())) | |
} | |
} | |
launch(coroutineContext) { | |
sourceA.consumeEach { a -> | |
mutex.withLock { | |
latestA.set(a) | |
aInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
launch(coroutineContext) { | |
sourceB.consumeEach { b -> | |
mutex.withLock { | |
latestB.set(b) | |
bInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
launch(coroutineContext) { | |
sourceC.consumeEach { C -> | |
mutex.withLock { | |
latestC.set(C) | |
cInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
} | |
/** | |
* A combine latest that takes in 3 sources and runs a [combineFunction] over their latest emissions | |
* to emit [R] | |
* | |
* // TODO add tests | |
*/ | |
fun <A : Any?, B : Any?, C : Any?, D : Any?, R> combineLatest( | |
sourceA: ReceiveChannel<A>, | |
sourceB: ReceiveChannel<B>, | |
sourceC: ReceiveChannel<C>, | |
sourceD: ReceiveChannel<D>, | |
context: CoroutineContext = Dispatchers.Unconfined, | |
combineFunction: suspend (A, B, C, D) -> R | |
): ReceiveChannel<R> = GlobalScope.produce( | |
context, onCompletion = consumesAll(sourceA, sourceB, sourceC, sourceD) | |
) { | |
val latestA = AtomicReference<A>() | |
val latestB = AtomicReference<B>() | |
val latestC = AtomicReference<C>() | |
val latestD = AtomicReference<D>() | |
var aInitialized = false | |
var bInitialized = false | |
var cInitialized = false | |
var dInitialized = false | |
val mutex = Mutex() | |
suspend fun combineAndSendIfInitialized() { | |
if (aInitialized && bInitialized && cInitialized && dInitialized) { | |
send(combineFunction(latestA.get(), latestB.get(), latestC.get(), latestD.get())) | |
} | |
} | |
launch(coroutineContext) { | |
sourceA.consumeEach { a -> | |
mutex.withLock { | |
latestA.set(a) | |
aInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
launch(coroutineContext) { | |
sourceB.consumeEach { b -> | |
mutex.withLock { | |
latestB.set(b) | |
bInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
launch(coroutineContext) { | |
sourceC.consumeEach { C -> | |
mutex.withLock { | |
latestC.set(C) | |
cInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
launch(coroutineContext) { | |
sourceD.consumeEach { D -> | |
mutex.withLock { | |
latestD.set(D) | |
dInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
} | |
/** | |
* Execute a [combineFunction] and emit its result [R] over the latest value received by [A] and [B]. | |
* | |
* A combine latest that takes in 2 sources and runs a [combineFunction] over their latest emissions | |
* to emit [R] | |
* | |
* This will not emit until [A] and [B] have each emitted at least once. | |
* | |
* Equivalent to RxJava's `.combineLatest()` operator | |
*/ | |
fun <A : Any?, B : Any?, R> combineLatest( | |
sourceA: ReceiveChannel<A>, | |
sourceB: ReceiveChannel<B>, | |
context: CoroutineContext = Dispatchers.Unconfined, | |
combineFunction: suspend (A, B) -> R | |
): ReceiveChannel<R> = GlobalScope.produce(context, onCompletion = consumesAll(sourceA, sourceB)) { | |
val latestA = AtomicReference<A>() | |
val latestB = AtomicReference<B>() | |
var aInitialized = false | |
var bInitialized = false | |
val mutex = Mutex() | |
suspend fun combineAndSendIfInitialized() { | |
if (aInitialized && bInitialized) { | |
send(combineFunction(latestA.get(), latestB.get())) | |
} | |
} | |
launch(coroutineContext) { | |
sourceA.consumeEach { A -> | |
mutex.withLock { | |
latestA.set(A) | |
aInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
launch(coroutineContext) { | |
sourceB.consumeEach { b -> | |
mutex.withLock { | |
latestB.set(b) | |
bInitialized = true | |
combineAndSendIfInitialized() | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.nhaarman.mockitokotlin2.InOrderOnType | |
import com.nhaarman.mockitokotlin2.mock | |
import com.nhaarman.mockitokotlin2.times | |
import com.nhaarman.mockitokotlin2.verify | |
import com.nhaarman.mockitokotlin2.verifyNoMoreInteractions | |
import com.nhaarman.mockitokotlin2.verifyZeroInteractions | |
import kotlinx.coroutines.experimental.CoroutineScope | |
import kotlinx.coroutines.experimental.Dispatchers | |
import kotlinx.coroutines.experimental.channels.ArrayBroadcastChannel | |
import kotlinx.coroutines.experimental.channels.BroadcastChannel | |
import kotlinx.coroutines.experimental.channels.consumeEach | |
import kotlinx.coroutines.experimental.channels.sendBlocking | |
import kotlinx.coroutines.experimental.delay | |
import kotlinx.coroutines.experimental.launch | |
import kotlinx.coroutines.experimental.runBlocking | |
import org.junit.Before | |
import org.junit.Test | |
import kotlin.coroutines.experimental.CoroutineContext | |
class RxChannelExtensionsTest { | |
private lateinit var scope: CoroutineScope | |
@Before | |
fun setUp() { | |
scope = CoroutineScope(Dispatchers.Unconfined) | |
} | |
/** | |
* Tests [doOnNext] | |
*/ | |
@Test | |
fun operator_doOnNext() { | |
val mockObserver = mock<(Int) -> Unit> { } | |
val logger: (Int) -> Unit = { println(it) } | |
val source = BroadcastChannel<Int>(10) | |
fun verify(emission: Int, times: Int = 1) { | |
verify(mockObserver, times(times)).invoke(emission) | |
} | |
scope.launch { | |
source.openSubscription() | |
.doOnNext { | |
logger(it) | |
mockObserver(it) | |
} | |
.consumeEach {} | |
} | |
source.sendBlocking(1) | |
source.sendBlocking(2) | |
source.sendBlocking(3) | |
verify(1) | |
verify(2) | |
verify(3) | |
verifyNoMoreInteractions(mockObserver) | |
} | |
/** | |
* Tests [merge] | |
*/ | |
@Test | |
fun operator_merge() { | |
val mockObserver = mock<(Int) -> Unit> { } | |
val logger: (Int) -> Unit = { println(it) } | |
val source1 = BroadcastChannel<Int>(10) | |
val source2 = BroadcastChannel<Int>(10) | |
fun verify(emission: Int, times: Int = 1) { | |
verify(mockObserver, times(times)).invoke(emission) | |
} | |
scope.launch { | |
merge( | |
source1.openSubscription(), | |
source2.openSubscription() | |
) | |
.consumeEach { | |
mockObserver(it) | |
logger(it) | |
} | |
} | |
verifyZeroInteractions(mockObserver) | |
source1.sendBlocking(1) | |
verify(1) | |
source2.sendBlocking(2) | |
verify(2) | |
source1.sendBlocking(3) | |
verify(3) | |
source2.sendBlocking(4) | |
verify(4) | |
verifyNoMoreInteractions(mockObserver) | |
} | |
/** | |
* Test for [distinctUntilChanged] | |
*/ | |
@Test | |
fun operator_distinctUntilChanged() { | |
val mockObserver = mock<(Int) -> Unit> { } | |
val logger: (Int) -> Unit = { println(it) } | |
val source = BroadcastChannel<Int>(1) | |
fun verify(emission: Int, times: Int = 1) { | |
verify(mockObserver, times(times)).invoke(emission) | |
} | |
scope.launch { | |
source.openSubscription() | |
.distinctUntilChanged() | |
.consumeEach { | |
logger(it) | |
mockObserver(it) | |
} | |
} | |
source.sendBlocking(0) | |
source.sendBlocking(1) | |
source.sendBlocking(1) | |
source.sendBlocking(2) | |
verify(0) | |
verify(1) | |
verify(2) | |
verifyNoMoreInteractions(mockObserver) | |
source.sendBlocking(1) | |
verify(1, times = 2) | |
verifyZeroInteractions(mockObserver) | |
} | |
@Test | |
fun operator_skip() { | |
val mockObserver = mock<(Int) -> Unit> { } | |
val source = BroadcastChannel<Int>(5) | |
scope.launch { | |
source.openSubscription() | |
.skip(1) | |
.doOnNext { println(it) } | |
.consumeEach(mockObserver) | |
} | |
runBlocking { | |
source.send(1) | |
source.send(2) | |
source.send(3) | |
} | |
val inOrder = InOrderOnType(mockObserver) | |
runBlocking { | |
// first invocation is skipped | |
inOrder.verify(mockObserver, times(1)).invoke(2) | |
inOrder.verify(mockObserver, times(1)).invoke(3) | |
} | |
verifyNoMoreInteractions(mockObserver) | |
} | |
@Test | |
fun operator_combineLatest() { | |
fun runTest(context: CoroutineContext) { | |
println("Running test for $context") | |
val mockObserver = mock<(Pair<String, Int>) -> Unit> { } | |
val logger: (Pair<String, Int>) -> Unit = { println(it) } | |
val sourceNames = ArrayBroadcastChannel<String>(1000) | |
val sourceAges = ArrayBroadcastChannel<Int>(1000) | |
val sourceAChannel = sourceNames.openSubscription() | |
val sourceBChannel = sourceAges.openSubscription() | |
scope.launch { | |
combineLatest( | |
sourceAChannel, | |
sourceBChannel | |
) { name, age -> Pair(name, age) } | |
.consumeEach { | |
logger(it) | |
mockObserver(it) | |
} | |
} | |
val job = scope.launch { | |
sourceNames.send("Zak") | |
delay(10) | |
sourceNames.send("Grace") | |
delay(10) | |
sourceAges.send(24) | |
delay(10) | |
sourceNames.send("Kelly") | |
delay(10) | |
sourceAges.send(25) | |
delay(10) | |
sourceNames.send("Jack") | |
delay(10) | |
sourceAges.send(27) | |
delay(10) | |
sourceAges.send(28) // happy birthday | |
delay(10) | |
} | |
runBlocking { | |
job.join() | |
} | |
val inOrder = InOrderOnType(mockObserver) | |
inOrder.verify(mockObserver, times(1)).invoke(Pair("Grace", 24)) | |
inOrder.verify(mockObserver, times(1)).invoke(Pair("Kelly", 24)) | |
inOrder.verify(mockObserver, times(1)).invoke(Pair("Kelly", 25)) | |
inOrder.verify(mockObserver, times(1)).invoke(Pair("Jack", 25)) | |
inOrder.verify(mockObserver, times(1)).invoke(Pair("Jack", 27)) | |
inOrder.verify(mockObserver, times(1)).invoke(Pair("Jack", 28)) | |
verifyNoMoreInteractions(mockObserver) | |
println("Test passed for $context") | |
} | |
runTest(Dispatchers.Unconfined) | |
} | |
@Test | |
fun operator_combineLatest3() { | |
fun runTest(context: CoroutineContext) { | |
println("Running test for $context") | |
val mockObserver = mock<(Triple<Int, String, Boolean>) -> Unit> { } | |
val logger: (Triple<Int, String, Boolean>) -> Unit = { println(it) } | |
val sourceInts = ArrayBroadcastChannel<Int>(1000) | |
val sourceStrings = ArrayBroadcastChannel<String>(1000) | |
val sourceBooleans = ArrayBroadcastChannel<Boolean>(1000) | |
val sourceAChannel = sourceInts.openSubscription() | |
val sourceBChannel = sourceStrings.openSubscription() | |
val sourceCChannel = sourceBooleans.openSubscription() | |
scope.launch { | |
combineLatest( | |
sourceAChannel, | |
sourceBChannel, | |
sourceCChannel | |
) { int, string, boolean -> Triple(int, string, boolean) } | |
.consumeEach { | |
logger(it) | |
mockObserver(it) | |
} | |
} | |
val job = scope.launch { | |
sourceInts.sendBlocking(0) | |
delay(10) | |
sourceStrings.sendBlocking("0") | |
delay(10) | |
sourceBooleans.sendBlocking(false) | |
delay(10) | |
sourceInts.sendBlocking(1) | |
delay(10) | |
sourceStrings.sendBlocking("1") | |
delay(10) | |
sourceBooleans.sendBlocking(true) | |
delay(10) | |
} | |
runBlocking { | |
job.join() | |
} | |
val inOrder = InOrderOnType(mockObserver) | |
fun verify(int: Int, string: String, boolean: Boolean) { | |
inOrder.verify(mockObserver, times(1)).invoke(Triple(int, string, boolean)) | |
} | |
verify(0, "0", false) | |
verify(1, "0", false) | |
verify(1, "1", false) | |
verify(1, "1", true) | |
verifyNoMoreInteractions(mockObserver) | |
println("Test passed for $context") | |
} | |
runTest(Dispatchers.Unconfined) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment