Created
July 9, 2023 10:49
-
-
Save th-yoo/603e8e85c5ae58a3d2f7127e33bedd32 to your computer and use it in GitHub Desktop.
Bokeh with FastAPI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from pprint import pprint | |
from var_dump import var_dump | |
from fastapi import FastAPI, WebSocket, Request, HTTPException, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse, PlainTextResponse | |
from fastapi.staticfiles import StaticFiles | |
app = FastAPI() | |
#app.mount('/static', StaticFiles(directory='bokeh/server/static'), name='static') | |
app.mount('/static', StaticFiles(directory='../bokeh/bokehjs/build'), name='static') | |
from bokeh.settings import settings | |
from bokeh.core.types import ID | |
from bokeh.util.token import ( | |
check_token_signature, | |
generate_jwt_token, | |
generate_session_id, | |
get_session_id, | |
get_token_payload | |
) | |
from bokeh.application.application import Application, SessionContext | |
from bokeh.document import Document | |
import weakref | |
class BokashiSessionContext(SessionContext): | |
#_session: ServerSession | None | |
#_request: _RequestProxy | None | |
_token: str | None | |
def __init__(self | |
, session_id: ID | |
#, server_context: ServerContext | |
, document: Document | |
, logout_url: str | None = None) -> None: | |
self._doc = document | |
#self._session = None | |
self._logout_url = logout_url | |
#super().__init__(server_context, session_id) | |
super().__init__(None, session_id) | |
self._request = None | |
self._token = None | |
def _set_session(self, session: ServerSession) -> None: | |
#self._session = session | |
pass | |
@property | |
def destroyed(self) -> bool: | |
# TODO | |
# server session | |
return False | |
async def with_locked_document(self | |
, func: Callable[[Document], Awaitable[None]]) -> None: | |
await func(self._doc) | |
class HTTPError(Exception): | |
def __init__(self, status: int, message: str = ''): | |
self.status = status | |
self.message = message | |
# all the keys should be lower cased. | |
def get_session(headers: dict[str,str], cookies: dict[str,str], qs: dict[str,str]) -> SessionContext: | |
token = qs.get('bokeh-token', None) | |
session_id: ID | None = qs.get('bokeh-session-id', None) | |
if 'bokeh-session-id' in headers: | |
if seesion_id: | |
raise HTTPError(403, 'session ID was provided as an argument and header') | |
session_id = headers.get('bokeh-session-id') | |
if token: | |
if session_id: | |
raise HTTPError(403, 'Both token and session ID were provided') | |
session_id = get_session_id(token) | |
elif not session_id: | |
session_id = generate_session_id( | |
settings.secret_key_bytes(), | |
settings.sign_sessions() | |
) | |
if not token: | |
if cookies and 'cookie' in headers: | |
del headers['cookie'] | |
payload = {'headers': headers, 'cookies': cookies, 'arguments': qs} | |
token = generate_jwt_token( | |
session_id, | |
secret_key=settings.secret_key_bytes(), | |
signed=settings.sign_sessions(), | |
expiration=300, | |
extra_payload=payload | |
) | |
if not check_token_signature( | |
token, | |
secret_key=settings.secret_key_bytes(), | |
signed=settings.sign_sessions() | |
): | |
# error('Session id had invalid signature: %r', session_id) | |
raise HTTPError(403, 'Invalid token or session ID') | |
doc = Document() | |
session_ctx = BokashiSessionContext(session_id, doc) | |
session_ctx._token = token | |
doc._session_context = weakref.ref(session_ctx) | |
return session_ctx | |
from main import bkapp | |
from bokeh.application.handlers.function import FunctionHandler | |
bkapp = Application(FunctionHandler(bkapp)) | |
def html_page_for_session(ctx: SessionContext, root_url: str): | |
from bokeh.embed.util import RenderItem | |
render_item = RenderItem( | |
token=ctx._token, | |
roots=ctx._doc.roots, | |
use_for_title=True | |
) | |
from bokeh.resources import Resources | |
resources = Resources(mode='server', root_url=root_url) | |
from bokeh.embed.bundle import bundle_for_objs_and_resources | |
bundle = bundle_for_objs_and_resources(None, resources) | |
from bokeh.embed.elements import html_page_for_render_items | |
return html_page_for_render_items( | |
bundle, | |
{}, | |
[render_item], | |
ctx._doc.title, | |
template=ctx._doc.template, | |
template_variables=ctx._doc.template_variables or {} | |
) | |
session_ctx: SessionContext | None = None | |
@app.get("/") | |
async def get(req: Request): | |
url = req.url | |
global session_ctx | |
try: | |
session_ctx = get_session(*map(lambda x: dict(x), (req.headers, req.cookies, req.query_params))) | |
await bkapp.on_session_created(session_ctx) | |
bkapp.initialize_document(session_ctx._doc) | |
# FIXME: url | |
html = html_page_for_session(session_ctx, ''.join((url.scheme, '://', url.netloc))) | |
#pprint(html) | |
except HTTPError as e: | |
raise HTTPException(status_code=e.status, detail=e.message) | |
return HTMLResponse(html) | |
import asyncio | |
class TornadoWSAdapter: | |
def __init__(self, ws: WebSocket): | |
self._sock = ws | |
def write_message(self, msg: str | bytes, binary: bool = False) -> None: | |
if isinstance(msg, str): | |
co = self._sock.send_text(msg) | |
elif isinstance(msg, bytes): | |
co = self._sock.send_bytes(msg) | |
#elif isinstance(msg, dict): | |
# return await self._sock.send_json(msg) | |
asyncio.create_task(co) | |
async def read_message(self, callback: Callable[..., None] | None = None) -> Awaitable[None | str | bytes]: | |
msg = await self._sock.receive() | |
rv = msg.get('text', msg.get('bytes', None)) | |
if callback is not None: | |
await asyncio.ensure_future(callback(rv)) | |
return rv | |
#def close(self, code, reason): | |
def close(self, *args): | |
asyncio.create_task(self._sock.close(*args)) | |
class TornadoLockAdapter: | |
def __init__(self): | |
self._lck = asyncio.Lock() | |
async def acquire(self): | |
await self._lck.acquire() | |
return self | |
def __enter__(self): | |
pass | |
def __exit__(self, exc_type, exc, tb): | |
self._lck.release() | |
from bokeh.client.websocket import WebSocketClientConnectionWrapper | |
class WSConnAdapter(WebSocketClientConnectionWrapper): | |
def __init__(self, socket: WebSocket) -> None: | |
self._socket = TornadoWSAdapter(socket) | |
self.write_lock = TornadoLockAdapter() | |
from bokeh.protocol import Protocol | |
from bokeh.protocol import messages as msg | |
from bokeh.protocol.exceptions import MessageError, ProtocolError, ValidationError | |
from bokeh.protocol.message import Message | |
from bokeh.protocol.receiver import Receiver | |
from bokeh.document.events import DocumentPatchedEvent | |
from typing import ( | |
TYPE_CHECKING, | |
cast, | |
Any, | |
Optional, | |
Dict, | |
Union, | |
List, | |
Awaitable, | |
Callable, | |
Tuple, | |
Type, | |
) | |
# server/connection.py | |
class ServerConnection: | |
def __init__(self, proto: Protocol, sock: WSConnAdapter): | |
self._protocol = proto | |
self._sock = sock | |
def ok(self, message: Message[Any]) -> msg.ok: | |
return self.protocol.create('OK', message.header['msgid']) | |
def error(self, message: Message[Any], text: str) -> msg.error: | |
return self.protocol.create('ERROR', message.header['msgid'], text) | |
def send_patch_document(self, event: DocumentPatchedEvent) -> Awaitable[None]: | |
msg = self.protocol.create('PATCH-DOC', [event]) | |
return msg.send(self._sock) | |
@property | |
def protocol(self) -> Protocol: | |
return self._protocol | |
# bokeh/server/protocol_handler.py | |
# TODO: document lock | |
class ProtocolHandler: | |
_handlers: dict[str, Callable[..., Any]] | |
def __init__(self, doc: Document, ss: ServerSession) -> None: | |
self._doc = doc | |
self._ss = ss | |
self._handlers = {} | |
self._handlers['PULL-DOC-REQ'] = self.pull | |
self._handlers['PUSH-DOC'] = self.push | |
self._handlers['PATCH-DOC'] = self.patch | |
self._handlers['SERVER-INFO-REQ'] = self.server_info | |
async def pull(self, msg: msg.pull_doc_req, conn: ServerConnection) -> msg.pull_doc_reply: | |
return conn.protocol.create('PULL-DOC-REPLY', msg.header['msgid'], self._doc) | |
async def push(self, msg: msg.push_doc, conn: ServerConnection) -> msg.ok: | |
msg.push_to_document(self._doc) | |
return conn.ok(msg) | |
async def patch(self, msg: msg.patch_doc, conn: ServerConnection) -> msg.ok: | |
msg.apply_to_document(self._doc, self._ss) | |
return conn.ok(msg) | |
async def server_info(self, msg: msg.server_info_req, conn: ServerConnection) -> msg.server_info_reply: | |
return conn.protocol.create('SERVER-INFO-REPLY', msg.header['msgid']) | |
async def handle(self, message, conn): | |
handler = self._handlers.get(message.msgtype) | |
if handler is None: | |
handler = self._handlers.get(message.msgtype) | |
if handler is None: | |
raise ProtocolError(f"{message} not expected on server") | |
try: | |
work = await handler(message, conn) | |
except Exception as e: | |
#log.error("error handling message\n message: %r \n error: %r", | |
# message, e, exc_info=True) | |
# FIXME | |
return conn.error(message, repr(e)) | |
#work = e | |
return work | |
# server/session.py | |
class ServerSession: | |
def __init__(self, session_id: ID, doc: Document, conn: ServerConnection): | |
self._id = session_id | |
self._doc = doc | |
self._conn = conn | |
self._doc.callbacks.on_change_dispatch_to(self) | |
def _document_patched(self, event: DocumentPatchedEvent) -> None: | |
#may_suppress = event.setter is self | |
# TODO: broadcast all the ServerConnection instants (ws connections?) | |
asyncio.create_task(self._conn.send_patch_document(event)) | |
import calendar | |
import datetime as dt | |
@app.websocket('/ws') | |
async def ws(ws: WebSocket): | |
sp = ws.get('subprotocols') | |
if not sp or len(sp) < 2: | |
raise HTTPException(status_code=403, detail='Invalid subprotocols') | |
sub_proto, token = sp | |
if sub_proto != 'bokeh' or not token: | |
raise HTTPException(status_code=403, detail='Invalid subprotocols') | |
payload = get_token_payload(token) | |
now = calendar.timegm(dt.datetime.utcnow().utctimetuple()) | |
if 'session_expiry' not in payload: | |
raise HTTPException(status_code=403, detail='Session expiry has not been provided') | |
elif now >= payload['session_expiry']: | |
raise HTTPException(status_code=403, detail='Token is expired') | |
elif not check_token_signature( | |
token, | |
secret_key=settings.secret_key_bytes(), | |
signed=settings.sign_sessions() | |
): | |
#session_id = get_session_id(token) | |
raise HTTPException(status_code=403, detail='Invalid token signature') | |
session_id = get_session_id(token) | |
proto = Protocol() | |
receiver = Receiver(proto) | |
conn = WSConnAdapter(ws) | |
sconn = ServerConnection(proto, conn) | |
ss = ServerSession(session_id, session_ctx._doc, sconn) | |
handler = ProtocolHandler(session_ctx._doc, ss) | |
await ws.accept(subprotocol='bokeh') | |
m = proto.create('ACK') | |
await m.send(conn) | |
while True: | |
data = await ws.receive() | |
if data['type'] == 'websocket.disconnect': | |
raise WebSocketDisconnect(data['code']) | |
# FIXME: what if empty string? | |
frag = data.get('text') or data.get['bytes'] | |
try: | |
msg = await receiver.consume(frag) | |
if msg: | |
work = await handler.handle(msg, sconn) | |
if isinstance(work, Message): | |
await work.send(conn) | |
except Exception as e: | |
pprint(e) | |
# TODO: ping/pong | |
import uvicorn | |
if __name__ == '__main__': | |
uvicorn.run(app, host='0.0.0.0', port=5050) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment