Last active
March 27, 2021 21:51
-
-
Save ExFed/6f0ffd139d2474697c0910b7176c1808 to your computer and use it in GitHub Desktop.
Variant Types in Java (using Lombok and Vavr)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.columnzero.util.function; | |
import io.vavr.Function1; | |
import io.vavr.PartialFunction; | |
import io.vavr.collection.Seq; | |
import io.vavr.collection.Stream; | |
import lombok.AllArgsConstructor; | |
import lombok.NonNull; | |
import java.util.function.Predicate; | |
/** | |
* Describes a sequential variant. Used to replace if/else ladders with a type-safe, declarative | |
* functional expression. Note that, just like with an if/else ladder, each is checked sequentially, | |
* so it will ultimately evaluate the first expression branch whose condition evaluates to {@code | |
* true}. | |
* | |
* @param <U> type of the input variant | |
* @param <V> type of the resulting value | |
*/ | |
@AllArgsConstructor | |
public class SeqVariant<U, V> implements PartialFunction<U, V> { | |
public static <U, V> PartialFunction<U, V> branch(Predicate<U> condition, | |
Function1<U, V> expression) { | |
return expression.partial(condition); | |
} | |
@SafeVarargs | |
public static <U, V> SeqVariant<U, V> of(PartialFunction<U, V>... branches) { | |
return new SeqVariant<>(Stream.of(branches)); | |
} | |
private final @NonNull Seq<PartialFunction<U, V>> branches; | |
@Override | |
public V apply(U value) { | |
return branches.filter(e -> e.isDefinedAt(value)) | |
.map(e -> e.apply(value)) | |
.getOrElseThrow(IllegalArgumentException::new); | |
} | |
@Override | |
public boolean isDefinedAt(U value) { | |
return branches.exists(e -> e.isDefinedAt(value)); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.columnzero.util.function; | |
import org.junit.jupiter.api.BeforeEach; | |
import org.junit.jupiter.api.Test; | |
import static com.columnzero.util.function.SeqVariant.branch; | |
import static com.google.common.truth.Truth.assertThat; | |
class SeqVariantTest { | |
SeqVariant<String, Integer> ordinals; | |
@BeforeEach | |
void setUp() { | |
ordinals = SeqVariant.of( | |
branch(s -> s.startsWith("one"), s -> 1), | |
branch(s -> s.startsWith("two"), s -> 2), | |
branch(s -> s.startsWith("three"), s -> 3), | |
branch(s -> s.startsWith("four"), s -> 4) | |
); | |
} | |
@Test | |
void apply() { | |
assertThat(ordinals.apply("one")).isEqualTo(1); | |
assertThat(ordinals.apply("one!")).isEqualTo(1); | |
assertThat(ordinals.apply("two")).isEqualTo(2); | |
assertThat(ordinals.apply("three")).isEqualTo(3); | |
assertThat(ordinals.apply("four")).isEqualTo(4); | |
} | |
@Test | |
void isDefinedAt() { | |
assertThat(ordinals.isDefinedAt("zero")).isEqualTo(false); | |
assertThat(ordinals.isDefinedAt("one")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("one!")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("two")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("three")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("four")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("five")).isEqualTo(false); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.columnzero.util.function; | |
import io.vavr.Function1; | |
import io.vavr.PartialFunction; | |
import io.vavr.collection.Stream; | |
import lombok.AccessLevel; | |
import lombok.AllArgsConstructor; | |
import lombok.Getter; | |
import lombok.NonNull; | |
/** | |
* Describes a tagged union. Used to replace switch statements and if/else ladders with a type-safe, | |
* declarative functional expression. | |
* | |
* @param <U> Type of the input value | |
* @param <V> Type of the resulting value | |
*/ | |
public class TaggedUnion<U, V> implements PartialFunction<U, V> { | |
public static <T, U, V> Branch<T, U, V> branch(T tag, | |
Function1<? super U, ? extends V> expression) { | |
return new Branch<>(tag, expression); | |
} | |
@SafeVarargs | |
public static <T, U, V> TaggedUnion<U, V> of(Function1<U, T> discriminator, | |
Branch<T, U, V>... branches) { | |
return new TaggedUnion<>(discriminator, | |
Stream.of(branches).toMap(Branch::getTag, Branch::getExpression)); | |
} | |
@SuppressWarnings("unchecked") | |
public <T> TaggedUnion( | |
@NonNull Function1<U, ? extends T> discriminator, | |
@NonNull PartialFunction<? super T, Function1<? super U, ? extends V>> expressionMap) { | |
this.discriminator = discriminator; | |
this.expressionMap = | |
(PartialFunction<Object, Function1<? super U, ? extends V>>) expressionMap; | |
} | |
// we can elide tag types because they are enforced during construction | |
private final @NonNull Function1<U, ?> discriminator; | |
private final @NonNull PartialFunction<Object, Function1<? super U, ? extends V>> expressionMap; | |
@Override | |
public V apply(U unionValue) { | |
return expressionMap.apply(discriminator.apply(unionValue)).apply(unionValue); | |
} | |
@Override | |
public boolean isDefinedAt(U unionValue) { | |
return expressionMap.isDefinedAt(discriminator.apply(unionValue)); | |
} | |
@AllArgsConstructor(access = AccessLevel.PRIVATE) | |
@Getter | |
public static final class Branch<T, U, V> { | |
private final @NonNull T tag; | |
private final @NonNull Function1<? super U, ? extends V> expression; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.columnzero.util.function; | |
import org.junit.jupiter.api.BeforeEach; | |
import org.junit.jupiter.api.Test; | |
import static com.columnzero.util.function.TaggedUnion.branch; | |
import static com.google.common.truth.Truth.assertThat; | |
class TaggedUnionTest { | |
TaggedUnion<String, Integer> ordinals; | |
@BeforeEach | |
void setUp() { | |
ordinals = TaggedUnion.of( | |
s -> s.substring(0, 3), | |
branch("one", s -> 1), | |
branch("two", s -> 2), | |
branch("thr", s -> 3), | |
branch("fou", s -> 4) | |
); | |
} | |
@Test | |
void apply() { | |
assertThat(ordinals.apply("one")).isEqualTo(1); | |
assertThat(ordinals.apply("one!")).isEqualTo(1); | |
assertThat(ordinals.apply("two")).isEqualTo(2); | |
assertThat(ordinals.apply("three")).isEqualTo(3); | |
assertThat(ordinals.apply("four")).isEqualTo(4); | |
} | |
@Test | |
void isDefinedAt() { | |
assertThat(ordinals.isDefinedAt("zero")).isEqualTo(false); | |
assertThat(ordinals.isDefinedAt("one")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("one!")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("two")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("three")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("four")).isEqualTo(true); | |
assertThat(ordinals.isDefinedAt("five")).isEqualTo(false); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment