Skip to content

Instantly share code, notes, and snippets.

@xhanin
Created March 10, 2025 16:04
Show Gist options
  • Save xhanin/644b0ead55d66ed74370aba836fb239a to your computer and use it in GitHub Desktop.
Save xhanin/644b0ead55d66ed74370aba836fb239a to your computer and use it in GitHub Desktop.
package playground
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousChannelGroup
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.channels.CompletionHandler
import java.nio.charset.StandardCharsets.UTF_8
import java.time.Clock
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.function.Consumer
val executor = Executors.newScheduledThreadPool(2, NamedThreadFactory("ASYNC-"))
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// async socket server
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// open and bind an async server
// to test it:
// 1) launch this server
// 2) use telnet to connect
/*
$ telnet localhost 8888
Trying ::1...
Connected to localhost.
Escape character is '^]'.
hello
server responds: hello
Connection closed by foreign host.
*/
class AsyncSocketServer {
fun start(port: Int) {
log("Starting server on port $port")
val srv = AsynchronousServerSocketChannel.open(
AsynchronousChannelGroup.withFixedThreadPool(2, NamedThreadFactory("NIO-"))
).bind(InetSocketAddress(port));
srv.accept(null, SocketAcceptor(srv))
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// socket acceptor
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// accepts incoming connections
// and delegates handling to a connection handler
class SocketAcceptor(private val srv: AsynchronousServerSocketChannel) :
CompletionHandler<AsynchronousSocketChannel, Any?> {
val counter = AtomicInteger(0)
// called when incoming connection is accepted
override fun completed(result: AsynchronousSocketChannel, attachment: Any?) {
val connection = counter.incrementAndGet()
log("accepted: $result - $connection")
// accept next connection
srv.accept(null, this)
// handle this connection
ConnectionHandler(connection, result).handle()
}
override fun failed(exc: Throwable?, attachment: Any?) {
log("error: $exc")
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// connection handler
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// handles one connection on a channel
// it basically echoes what the client writes, and closes the connection
class ConnectionHandler(val connection: Int, val channel: AsynchronousSocketChannel) {
val buffer = ByteBuffer.allocateDirect(1024)
val wByteBuffer = ByteBuffer.allocateDirect(1024)
fun handle() {
log("reading from channel ${connection}")
buffer.clear()
channel.read(buffer, 1, TimeUnit.MINUTES) {
onRead(it)
}
log("after read call on ${connection}")
}
fun onRead(result: Int?) {
log("read from channel ${connection} - $result bytes")
val txt = UTF_8.decode(buffer.flip()).toString()
log("server received: " + txt)
val resp: String = doHeavyWorkAndMydatabase(txt)
async(1, TimeUnit.SECONDS) {
wByteBuffer.clear()
wByteBuffer.asCharBuffer().put(resp)
log("writing to channel ${connection}")
channel.write(wByteBuffer, 1, TimeUnit.MINUTES) {
onWrite(result)
}
}
}
private fun doHeavyWorkAndMydatabase(txt: String): String {
log("heavy work")
if (txt.startsWith("hello")) {
Thread.sleep(30000)
}
log("did heavy work")
return "server responds: $txt"
}
fun onWrite(result: Int?) {
log("wrote to channel ${connection} - $result bytes")
async {
log("closing channel ${connection}")
channel.close()
}
}
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// helpers
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
fun log(msg: String) {
val now = Clock.systemDefaultZone().instant()
val thread = Thread.currentThread().name
println("$now - $thread - $msg")
}
fun async(time: Int = 1, timeUnit: TimeUnit = TimeUnit.SECONDS, block: Runnable) {
executor.schedule(block, time.toLong(), timeUnit)
}
fun AsynchronousSocketChannel.read(dst: ByteBuffer,
timeout: Long,
unit: TimeUnit,
block: Consumer<Int>
) = read(dst, timeout, unit, null, object : CompletionHandler<Int?, Any?> {
override fun completed(result: Int?, attachment: Any?) {
block.accept(result!!)
}
override fun failed(exc: Throwable?, attachment: Any?) {
log("error reading from socket: $exc")
}
})
fun AsynchronousSocketChannel.write(src: ByteBuffer,
timeout: Long,
unit: TimeUnit,
block: Consumer<Int>
) = write(src, timeout, unit, null, object : CompletionHandler<Int?, Any?> {
override fun completed(result: Int?, attachment: Any?) {
block.accept(result!!)
}
override fun failed(exc: Throwable?, attachment: Any?) {
log("error writing to socket: $exc")
}
})
class NamedThreadFactory(val prefix: String) : ThreadFactory {
val c = AtomicInteger()
override fun newThread(r: Runnable) = Thread(r, prefix + c.incrementAndGet())
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// main
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
fun main() {
AsyncSocketServer().start(8888)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment