Last active
May 6, 2023 02:04
-
-
Save harry248/69c29348fa22f1c3972f7bee95026f70 to your computer and use it in GitHub Desktop.
Custom Http Cache plugin for Ktor Client
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Licensed under the Apache License, Version 2.0 (the "License") | |
package dev.haraldhalbig | |
import io.ktor.client.* | |
import io.ktor.client.call.* | |
import io.ktor.client.plugins.* | |
import io.ktor.client.request.* | |
import io.ktor.client.statement.* | |
import io.ktor.client.utils.* | |
import io.ktor.http.* | |
import io.ktor.util.* | |
import io.ktor.util.date.* | |
import io.ktor.utils.io.* | |
import kotlinx.datetime.Clock | |
internal class CachePlugin( | |
val findCacheEntry: suspend (Url) -> CacheEntry?, | |
val storeCacheEntry: suspend (Url, CacheEntry) -> Unit, | |
val deleteCacheEntry: suspend (Url) -> Unit | |
) { | |
@kotlinx.serialization.Serializable | |
data class CacheEntry( | |
val key: String, | |
val eTag: String, | |
val maxAge: Long, | |
val requestTime: Long, | |
val noCache: Boolean, | |
val headers: Map<String, String>, | |
val content: String | |
) | |
class Config() { | |
var findCacheEntry: suspend (Url) -> CacheEntry? = { null } | |
var storeCacheEntry: suspend (Url, CacheEntry) -> Unit = { _, _ -> } | |
var deleteCacheEntry: suspend (Url) -> Unit = {} | |
} | |
companion object Plugin : HttpClientPlugin<Config, CachePlugin> { | |
override val key: AttributeKey<CachePlugin> = AttributeKey("cache-plugin") | |
override fun prepare(block: Config.() -> Unit): CachePlugin { | |
val config = Config().apply(block) | |
with(config) { | |
return CachePlugin( | |
findCacheEntry = findCacheEntry, | |
storeCacheEntry = storeCacheEntry, | |
deleteCacheEntry = deleteCacheEntry | |
) | |
} | |
} | |
@OptIn(InternalAPI::class) | |
override fun install(plugin: CachePlugin, scope: HttpClient) { | |
scope.plugin(HttpSend).intercept { httpRequestBuilder -> | |
@Suppress("NAME_SHADOWING") val httpRequestBuilder = httpRequestBuilder.withNewExecutionContext() | |
if (httpRequestBuilder.url.protocol != URLProtocol.HTTP && httpRequestBuilder.url.protocol != URLProtocol.HTTPS) { | |
return@intercept execute(httpRequestBuilder) | |
} | |
val url = httpRequestBuilder.url.build() | |
val cacheEntry = plugin.findCacheEntry(url) | |
when (httpRequestBuilder.method) { | |
HttpMethod.Post, | |
HttpMethod.Put, | |
HttpMethod.Patch, | |
HttpMethod.Delete -> { | |
val call = execute(httpRequestBuilder) | |
if (cacheEntry != null && call.response.status.isSuccess()) { | |
plugin.deleteCacheEntry(url) | |
} | |
return@intercept call | |
} | |
HttpMethod.Get -> {} | |
else -> return@intercept execute(httpRequestBuilder) | |
} | |
val requestCacheConfig = httpRequestBuilder.headers.build().cacheConfig() | |
// If request contains only-if-cached cache control directive continue with cache entry or 504 | |
if (requestCacheConfig?.onlyIfCached == true) { | |
return@intercept if (cacheEntry == null) { | |
createHttpClientCall(scope, httpRequestBuilder, HttpStatusCode.GatewayTimeout, Headers.Empty, "") | |
} else { | |
createHttpClientCall(scope, httpRequestBuilder, HttpStatusCode.OK, cacheEntry.createHeaders(), cacheEntry.content) | |
} | |
} | |
if (cacheEntry != null) { | |
if (cacheEntry.shouldValidate() || requestCacheConfig?.noCache == true) { | |
val call = execute(httpRequestBuilder.apply { | |
headers { set("If-None-Match", cacheEntry.eTag) } | |
}).save() | |
// If no cache-control given or cache-control contains no-store directive, delete cache entry | |
val responseCacheConfig = call.response.headers.cacheConfig() | |
if (responseCacheConfig == null || responseCacheConfig.noStore) { | |
plugin.deleteCacheEntry(url) | |
return@intercept call | |
} | |
// If server responds with 304 continue with updated cache entry | |
if (call.response.status == HttpStatusCode.NotModified) { | |
val updatedCacheEntry = cacheEntry.copy( | |
maxAge = responseCacheConfig.maxAge * 1000L, | |
requestTime = call.response.requestTime.timestamp, | |
noCache = responseCacheConfig.noCache | |
) | |
plugin.storeCacheEntry(url, updatedCacheEntry) | |
return@intercept createHttpClientCall( | |
client = scope, | |
origin = httpRequestBuilder.withNewExecutionContext(), | |
statusCode = HttpStatusCode.OK, | |
headers = updatedCacheEntry.createHeaders(), | |
body = updatedCacheEntry.content | |
) | |
} | |
// If server responds with 200..299 update cache entry | |
if (call.response.status.isSuccess()) { | |
plugin.storeCacheEntry( | |
url, CacheEntry( | |
key = url.toString(), | |
eTag = responseCacheConfig.eTag, | |
maxAge = responseCacheConfig.maxAge * 1000L, | |
requestTime = call.response.requestTime.timestamp, | |
noCache = responseCacheConfig.noCache, | |
headers = call.response.headers.toMap(), | |
content = call.response.bodyAsText() | |
) | |
) | |
} | |
return@intercept call | |
} | |
// If no validation is required continue with cache entry | |
return@intercept createHttpClientCall( | |
scope, | |
httpRequestBuilder, | |
HttpStatusCode.OK, | |
cacheEntry.createHeaders(), | |
cacheEntry.content | |
) | |
} | |
val call = execute(httpRequestBuilder).save() | |
val responseCacheConfig = call.response.headers.cacheConfig() | |
if (responseCacheConfig != null) { | |
if (call.response.status.isSuccess() && !responseCacheConfig.noStore) { | |
plugin.storeCacheEntry( | |
url, CacheEntry( | |
key = url.toString(), | |
eTag = responseCacheConfig.eTag, | |
maxAge = responseCacheConfig.maxAge * 1000L, | |
requestTime = call.response.requestTime.timestamp, | |
noCache = responseCacheConfig.noCache, | |
headers = call.response.headers.toMap(), | |
content = call.response.bodyAsText() | |
) | |
) | |
} | |
} else { | |
plugin.deleteCacheEntry(url) | |
} | |
call | |
} | |
} | |
@OptIn(InternalAPI::class) | |
private fun createHttpClientCall( | |
client: HttpClient, | |
origin: HttpRequestBuilder, | |
statusCode: HttpStatusCode, | |
headers: Headers, | |
body: String | |
): HttpClientCall { | |
return HttpClientCall( | |
client = client, | |
requestData = HttpRequestData( | |
url = origin.url.build(), | |
method = origin.method, | |
headers = origin.headers.build(), | |
body = EmptyContent, | |
executionContext = origin.executionContext, | |
attributes = origin.attributes | |
), | |
responseData = HttpResponseData( | |
statusCode = statusCode, | |
requestTime = GMTDate(), | |
headers = headers, | |
version = HttpProtocolVersion.HTTP_1_1, | |
body = ByteReadChannel(body), | |
callContext = origin.executionContext | |
) | |
) | |
} | |
} | |
} | |
private data class CacheConfig( | |
val noCache: Boolean, | |
val noStore: Boolean, | |
val onlyIfCached: Boolean, | |
val maxAge: Int, | |
val eTag: String | |
) | |
private fun Headers.cacheConfig(): CacheConfig? { | |
val cacheControl = get(HttpHeaders.CacheControl)?.split(",")?.associate { | |
val keyValue = it.split("=") | |
keyValue[0].trim().lowercase() to keyValue.getOrNull(1)?.trim() | |
} ?: return null | |
return CacheConfig( | |
noCache = cacheControl.containsKey("no-cache"), | |
noStore = cacheControl.containsKey("no-store"), | |
onlyIfCached = cacheControl.containsKey("only-if-cached"), | |
maxAge = cacheControl["max-age"]?.toInt() ?: 0, | |
eTag = get("etag") ?: "" | |
) | |
} | |
private fun Headers.toMap(): Map<String, String> { | |
return entries() | |
.filter { it.value.isNotEmpty() } | |
.associate { it.key to it.value.first() } | |
} | |
private fun HttpRequestBuilder.withNewExecutionContext(): HttpRequestBuilder { | |
return HttpRequestBuilder().takeFrom(this) | |
} | |
private fun CachePlugin.CacheEntry.createHeaders(): Headers { | |
return Headers.build { | |
headers.forEach { (key, value) -> | |
append(key, value) | |
} | |
} | |
} | |
private fun CachePlugin.CacheEntry.isExpired(): Boolean { | |
return Clock.System.now().toEpochMilliseconds() > (requestTime + maxAge) | |
} | |
private fun CachePlugin.CacheEntry.shouldValidate(): Boolean { | |
return noCache || isExpired() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This could work really well as a starting point for my purposes. Will you please mark it with an indication of the license that you're releasing it under?