Created
December 20, 2018 16:48
-
-
Save cchacin/af90d76a7e8e2f5db5a9564be60b02d5 to your computer and use it in GitHub Desktop.
Http2Client => OkHttp Call.Factory
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cronus.core; | |
import java.io.IOException; | |
import java.text.SimpleDateFormat; | |
import java.util.Arrays; | |
import java.util.Date; | |
import okhttp3.Call; | |
import okhttp3.Callback; | |
import okhttp3.Handshake; | |
import okhttp3.Headers; | |
import okhttp3.HttpUrl; | |
import okhttp3.Request; | |
import okhttp3.Response; | |
import okhttp3.WebSocket; | |
import okhttp3.mockwebserver.MockResponse; | |
import okhttp3.mockwebserver.MockWebServer; | |
import okhttp3.mockwebserver.RecordedRequest; | |
import org.assertj.core.api.Assertions; | |
import org.junit.Rule; | |
import org.junit.Test; | |
public final class CallTest extends Assertions { | |
@Rule | |
public final MockWebServer server = new MockWebServer(); | |
@Rule | |
public final MockWebServer server2 = new MockWebServer(); | |
private final Call.Factory client = new Http2Client(); | |
@Test | |
public void get() throws Exception { | |
this.server.enqueue(new MockResponse() | |
.setBody("abc") | |
.clearHeaders() | |
.addHeader("content-type: text/plain") | |
.addHeader("content-length", "3")); | |
final long sentAt = System.currentTimeMillis(); | |
final RecordedResponse recordedResponse = this.executeSynchronously("/", "User-Agent", "SyncApiTest"); | |
final long receivedAt = System.currentTimeMillis(); | |
recordedResponse.assertCode(200) | |
.assertSuccessful() | |
.assertHeaders(new Headers.Builder() | |
.add("content-type", "text/plain") | |
.add("content-length", "3") | |
.build()) | |
.assertBody("abc"); | |
// .assertSentRequestAtMillis(sentAt, receivedAt) | |
// .assertReceivedResponseAtMillis(sentAt, receivedAt); | |
final RecordedRequest recordedRequest = this.server.takeRequest(); | |
assertThat(recordedRequest.getMethod()).isEqualTo("GET"); | |
assertThat(recordedRequest.getHeader("User-Agent")).isEqualTo("SyncApiTest"); | |
assertThat(recordedRequest.getBody().size()).isEqualTo(0); | |
assertThat(recordedRequest.getHeader("Content-Length")).isEqualTo("0"); | |
} | |
private RecordedResponse executeSynchronously(final String path, final String... headers) throws IOException { | |
final Request.Builder builder = new Request.Builder(); | |
builder.url(this.server.url(path)); | |
for (int i = 0, size = headers.length; i < size; i += 2) { | |
builder.addHeader(headers[i], headers[i + 1]); | |
} | |
return this.executeSynchronously(builder.build()); | |
} | |
private RecordedResponse executeSynchronously(final Request request) throws IOException { | |
final Call call = this.client.newCall(request); | |
try { | |
final Response[] r = new Response[1]; | |
call.enqueue(new Callback() { | |
@Override | |
public void onFailure(final Call call, final IOException e) { | |
assertThat(true).isFalse(); | |
} | |
@Override | |
public void onResponse(final Call call, final Response response) throws IOException { | |
r[0] = response; | |
} | |
}); | |
try { | |
Thread.sleep(200); | |
} | |
catch (final InterruptedException e) { | |
e.printStackTrace(); | |
} | |
final String bodyString = r[0].body().string(); | |
return new RecordedResponse(request, r[0], null, bodyString, null); | |
} | |
catch (final IOException e) { | |
return new RecordedResponse(request, null, null, null, e); | |
} | |
} | |
public static final class RecordedResponse extends Assertions { | |
public final Request request; | |
public final Response response; | |
public final WebSocket webSocket; | |
public final String body; | |
public final IOException failure; | |
public RecordedResponse(final Request request, | |
final Response response, | |
final WebSocket webSocket, | |
final String body, | |
final IOException failure) { | |
this.request = request; | |
this.response = response; | |
this.webSocket = webSocket; | |
this.body = body; | |
this.failure = failure; | |
} | |
public RecordedResponse assertRequestUrl(final HttpUrl url) { | |
assertThat(this.request.url()).isEqualTo(url); | |
return this; | |
} | |
public RecordedResponse assertRequestMethod(final String method) { | |
assertThat(this.request.method()).isEqualTo(method); | |
return this; | |
} | |
public RecordedResponse assertRequestHeader(final String name, final String... values) { | |
assertThat(this.request.headers(name)).containsExactly(values); | |
return this; | |
} | |
public RecordedResponse assertCode(final int expectedCode) { | |
assertThat(this.response.code()).isEqualTo(expectedCode); | |
return this; | |
} | |
public RecordedResponse assertSuccessful() { | |
assertThat(this.response.isSuccessful()).isTrue(); | |
return this; | |
} | |
public RecordedResponse assertNotSuccessful() { | |
assertThat(this.response.isSuccessful()).isFalse(); | |
return this; | |
} | |
public RecordedResponse assertHeader(final String name, final String... values) { | |
assertThat(this.response.headers(name)).containsExactly(values); | |
return this; | |
} | |
public RecordedResponse assertHeaders(final Headers headers) { | |
assertThat(this.response.headers().toMultimap()).isEqualTo(headers.toMultimap()); | |
return this; | |
} | |
public RecordedResponse assertBody(final String expectedBody) { | |
assertThat(this.body).isEqualTo(expectedBody); | |
return this; | |
} | |
public RecordedResponse assertHandshake() { | |
final Handshake handshake = this.response.handshake(); | |
assertThat(handshake.tlsVersion()).isNotNull(); | |
assertThat(handshake.cipherSuite()).isNotNull(); | |
assertThat(handshake.peerPrincipal()).isNotNull(); | |
assertThat(handshake.peerCertificates()).hasSize(1); | |
assertThat(handshake.localPrincipal()).isNull(); | |
assertThat(handshake.localCertificates()).hasSize(0); | |
return this; | |
} | |
/** | |
* Asserts that the current response was redirected and returns the prior response. | |
*/ | |
public RecordedResponse priorResponse() { | |
final Response priorResponse = this.response.priorResponse(); | |
assertThat(priorResponse).isNotNull(); | |
assertThat(priorResponse.body()).isNull(); | |
return new RecordedResponse(priorResponse.request(), priorResponse, null, null, null); | |
} | |
/** | |
* Asserts that the current response used the network and returns the network response. | |
*/ | |
public RecordedResponse networkResponse() { | |
final Response networkResponse = this.response.networkResponse(); | |
assertThat(networkResponse).isNotNull(); | |
assertThat(networkResponse.body()).isNull(); | |
return new RecordedResponse(networkResponse.request(), networkResponse, null, null, null); | |
} | |
/** | |
* Asserts that the current response didn't use the network. | |
*/ | |
public RecordedResponse assertNoNetworkResponse() { | |
assertThat(this.response.networkResponse()).isNull(); | |
return this; | |
} | |
/** | |
* Asserts that the current response didn't use the cache. | |
*/ | |
public RecordedResponse assertNoCacheResponse() { | |
assertThat(this.response.cacheResponse()).isNull(); | |
return this; | |
} | |
/** | |
* Asserts that the current response used the cache and returns the cache response. | |
*/ | |
public RecordedResponse cacheResponse() { | |
final Response cacheResponse = this.response.cacheResponse(); | |
assertThat(cacheResponse).isNotNull(); | |
assertThat(cacheResponse.body()).isNull(); | |
return new RecordedResponse(cacheResponse.request(), cacheResponse, null, null, null); | |
} | |
public RecordedResponse assertFailure(final Class<?>... allowedExceptionTypes) { | |
boolean found = false; | |
for (final Class expectedClass : allowedExceptionTypes) { | |
if (expectedClass.isInstance(this.failure)) { | |
found = true; | |
break; | |
} | |
} | |
assertThat(found).as("Expected exception type among " + Arrays.toString(allowedExceptionTypes)).isTrue(); | |
return this; | |
} | |
public RecordedResponse assertFailure(final String... messages) { | |
assertThat(this.failure).isNotNull(); | |
assertThat(Arrays.asList(messages)).contains(this.failure.getMessage()); | |
return this; | |
} | |
public RecordedResponse assertFailureMatches(final String... patterns) { | |
assertThat(this.failure).isNotNull(); | |
for (final String pattern : patterns) { | |
if (this.failure.getMessage().matches(pattern)) { | |
return this; | |
} | |
} | |
throw new AssertionError(this.failure.getMessage()); | |
} | |
public RecordedResponse assertSentRequestAtMillis(final long minimum, final long maximum) { | |
this.assertDateInRange(minimum, this.response.sentRequestAtMillis(), maximum); | |
return this; | |
} | |
public RecordedResponse assertReceivedResponseAtMillis(final long minimum, final long maximum) { | |
this.assertDateInRange(minimum, this.response.receivedResponseAtMillis(), maximum); | |
return this; | |
} | |
private void assertDateInRange(final long minimum, final long actual, final long maximum) { | |
assertThat(actual).isGreaterThanOrEqualTo(minimum); | |
assertThat(actual).isLessThanOrEqualTo(maximum); | |
} | |
private String format(final long time) { | |
return new SimpleDateFormat("HH:mm:ss.SSS").format(new Date(time)); | |
} | |
public String getBody() { | |
return this.body; | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package cronus.core; | |
import java.io.IOException; | |
import java.io.InterruptedIOException; | |
import java.net.http.HttpClient; | |
import java.net.http.HttpRequest; | |
import java.net.http.HttpResponse; | |
import java.util.Arrays; | |
import java.util.Collection; | |
import java.util.Collections; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
import java.util.Objects; | |
import java.util.Set; | |
import java.util.TreeSet; | |
import java.util.concurrent.CompletableFuture; | |
import java.util.concurrent.Executor; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.ForkJoinPool; | |
import java.util.concurrent.TimeUnit; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import okhttp3.Call; | |
import okhttp3.Callback; | |
import okhttp3.Headers; | |
import okhttp3.Protocol; | |
import okhttp3.Request; | |
import okhttp3.Response; | |
import okhttp3.ResponseBody; | |
import okio.AsyncTimeout; | |
import okio.Timeout; | |
import static java.net.http.HttpClient.Redirect.ALWAYS; | |
import static java.net.http.HttpClient.Version.HTTP_2; | |
public class Http2Client implements Call.Factory { | |
private final HttpClient client; | |
private final ExecutorService executorService; | |
public Http2Client(final HttpClient client, final ExecutorService executorService) { | |
this.client = Objects.requireNonNull(client, "HttpClient can not be null"); | |
this.executorService = Objects.requireNonNull(executorService, "ExecutorService can not be null"); | |
} | |
public Http2Client() { | |
this(HttpClient.newBuilder() | |
.followRedirects(ALWAYS) | |
.version(HTTP_2) | |
.build(), | |
ForkJoinPool.commonPool()); | |
} | |
@Override | |
public Call newCall(final Request request) { | |
return new Http2Call(this.client, request, this.executorService); | |
} | |
static class Http2Call implements Call { | |
private final HttpClient client; | |
private final Request originalRequest; | |
private final Executor executor; | |
private final CancellableFuture cancellableFuture; | |
private final AsyncTimeout timeout; | |
private boolean executed; | |
public Http2Call(final HttpClient client, | |
final Request originalRequest, | |
final Executor executor) { | |
this.client = client; | |
this.originalRequest = originalRequest; | |
this.executor = executor; | |
this.cancellableFuture = new CancellableFuture(this); | |
this.timeout = new AsyncTimeout() { | |
@Override | |
protected void timedOut() { | |
Http2Call.this.cancel(); | |
} | |
}; | |
client.connectTimeout() | |
.ifPresent(duration -> this.timeout.timeout(duration.toMillis(), TimeUnit.MILLISECONDS)); | |
} | |
@Override | |
public Request request() { | |
return this.originalRequest; | |
} | |
@Override | |
public Response execute() throws IOException { | |
synchronized (this) { | |
if (this.executed) { | |
throw new IllegalStateException("Already Executed"); | |
} | |
this.executed = true; | |
} | |
this.timeout.enter(); | |
final HttpResponse<byte[]> httpResponse; | |
try { | |
httpResponse = this.client.send(this.toRequest(this.originalRequest), | |
HttpResponse.BodyHandlers.ofByteArray()); | |
} | |
catch (final InterruptedException e) { | |
throw new RuntimeException("Interrupted", e); | |
} | |
catch (final IOException e) { | |
throw this.timeoutExit(e); | |
} | |
return this.fromResponse(httpResponse); | |
} | |
IOException timeoutExit(final IOException cause) { | |
if (!this.timeout.exit()) { | |
return cause; | |
} | |
final InterruptedIOException e = new InterruptedIOException("timeout"); | |
if (cause != null) { | |
e.initCause(cause); | |
} | |
return e; | |
} | |
@Override | |
public void enqueue(final Callback responseCallback) { | |
synchronized (this) { | |
if (this.executed) { | |
throw new IllegalStateException("Already Executed"); | |
} | |
this.executed = true; | |
} | |
final CompletableFuture<HttpResponse<byte[]>> httpResponse = | |
this.client.sendAsync(this.toRequest(this.originalRequest), | |
HttpResponse.BodyHandlers.ofByteArray()); | |
httpResponse.handleAsync((response, throwable) -> { | |
if (throwable != null) { | |
responseCallback.onFailure(this, new IOException(throwable)); | |
} | |
return this.fromResponse(response); | |
}).thenAcceptAsync(r -> { | |
try { | |
responseCallback.onResponse(this, r); | |
} | |
catch (final IOException e) { | |
responseCallback.onFailure(this, e); | |
} | |
}, this.executor); | |
} | |
@Override | |
public void cancel() { | |
this.cancellableFuture.cancel(true); | |
} | |
@Override | |
public boolean isExecuted() { | |
return this.executed; | |
} | |
@Override | |
public boolean isCanceled() { | |
return this.cancellableFuture.isCancelled(); | |
} | |
@Override | |
public Timeout timeout() { | |
return null; | |
} | |
@Override | |
public Call clone() { | |
return new Http2Call(this.client, this.originalRequest, this.executor); | |
} | |
private HttpRequest toRequest(final Request request) { | |
final HttpRequest.BodyPublisher body; | |
if (request.body() == null) { | |
body = HttpRequest.BodyPublishers.noBody(); | |
} | |
else { | |
body = HttpRequest.BodyPublishers.ofByteArray(request.body().toString().getBytes()); | |
} | |
final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() | |
.uri(request.url().uri()) | |
.version(HTTP_2); | |
final Map<String, Collection<String>> headers = this.filterRestrictedHeaders(request.headers().toMultimap()); | |
if (!headers.isEmpty()) { | |
requestBuilder.headers(this.asString(headers)); | |
} | |
switch (request.method()) { | |
case "GET": | |
return requestBuilder.GET().build(); | |
case "POST": | |
return requestBuilder.POST(body).build(); | |
case "PUT": | |
return requestBuilder.PUT(body).build(); | |
case "DELETE": | |
return requestBuilder.DELETE().build(); | |
default: | |
// fall back scenario, http implementations may restrict some methods | |
return requestBuilder.method(request.method(), body).build(); | |
} | |
} | |
private Response fromResponse(final HttpResponse<byte[]> httpResponse) { | |
final Response.Builder builder = new Response.Builder(); | |
final Map<String, String> h = new HashMap<>(); | |
httpResponse.headers() | |
.map() | |
.forEach((key, value) -> value.forEach(v -> h.put(key, v))); | |
return builder | |
.body(ResponseBody.create(null, httpResponse.body())) | |
.request(this.originalRequest) | |
.code(httpResponse.statusCode()) | |
.protocol(Protocol.HTTP_2) | |
.headers(Headers.of(h)) | |
// .sentRequestAtMillis(httpResponse.headers()) | |
.message(httpResponse.headers().firstValue("Reason-Phrase").orElse("OK")) | |
.build(); | |
} | |
private static final Set<String> DISALLOWED_HEADERS_SET; | |
static { | |
// A case insensitive TreeSet of strings. | |
final TreeSet<String> treeSet = new TreeSet<>(String.CASE_INSENSITIVE_ORDER); | |
treeSet.addAll(Set.of("connection", "content-length", "date", "expect", "from", "host", | |
"origin", "referer", "upgrade", "via", "warning")); | |
DISALLOWED_HEADERS_SET = Collections.unmodifiableSet(treeSet); | |
} | |
private Map<String, Collection<String>> filterRestrictedHeaders(final Map<String, List<String>> headers) { | |
final Map<String, Collection<String>> filteredHeaders = | |
headers.keySet() | |
.stream() | |
.filter(headerName -> !DISALLOWED_HEADERS_SET.contains( | |
headerName)) | |
.collect(Collectors.toMap( | |
Function.identity(), | |
headers::get)); | |
filteredHeaders.computeIfAbsent("Accept", key -> List.of("*/*")); | |
return filteredHeaders; | |
} | |
private String[] asString(final Map<String, Collection<String>> headers) { | |
return headers.entrySet() | |
.stream() | |
.flatMap(entry -> entry.getValue() | |
.stream() | |
.map(value -> Arrays.asList(entry.getKey(), value)) | |
.flatMap(List::stream)).toArray(String[]::new); | |
} | |
} | |
static class CancellableFuture extends CompletableFuture<Response> { | |
private final Call call; | |
CancellableFuture(final Call call) { | |
this.call = call; | |
} | |
@Override | |
public boolean cancel(final boolean mayInterruptIfRunning) { | |
if (mayInterruptIfRunning && !this.isDone()) { | |
this.call.cancel(); | |
} | |
return super.cancel(mayInterruptIfRunning); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment