Skip to content

Instantly share code, notes, and snippets.

@plebioda
Last active April 3, 2025 12:53
Show Gist options
  • Save plebioda/244b09aa24d474dd36048b4db7c9bd83 to your computer and use it in GitHub Desktop.
Save plebioda/244b09aa24d474dd36048b4db7c9bd83 to your computer and use it in GitHub Desktop.
pymongo with OIDC example
from pymongo import MongoClient
from pymongo.auth_oidc_shared import (
OIDCCallback,
OIDCCallbackContext,
OIDCCallbackResult,
)
import webbrowser
import secrets
import hashlib
import base64
from urllib.parse import urlencode, urlparse, parse_qs
from http.server import BaseHTTPRequestHandler, HTTPServer
from concurrent.futures import Future
import threading
import requests
import json
import code
import readline
import rlcompleter
import shutil
import subprocess
def HTMLResponse(query_params, full_path, headers):
HTML_TEMPLATE = """
<html>
<head>
<title>Authentication Complete</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 20px;
background-color: #f9f9f9;
}}
h1 {{
color: #333;
}}
table {{
border-collapse: collapse;
width: 100%;
margin-top: 20px;
}}
th, td {{
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}}
th {{
background-color: #f2f2f2;
color: #333;
}}
tr:nth-child(even) {{
background-color: #f9f9f9;
}}
tr:hover {{
background-color: #f1f1f1;
}}
pre {{
background-color: #f4f4f4;
padding: 10px;
border: 1px solid #ddd;
overflow: auto;
}}
</style>
<script>
let secondsRemaining = 5;
let countdownInterval;
function updateCountdown() {{
const countdownElement = document.getElementById('countdown');
countdownElement.textContent = secondsRemaining;
if (secondsRemaining > 0) {{
secondsRemaining--;
}} else {{
clearInterval(countdownInterval);
window.close();
}}
}}
function stopCountdown() {{
clearInterval(countdownInterval);
const countdownDivElement = document.getElementById('countdown_div');
countdownDivElement.style.display = 'none';
}}
window.onload = () => {{
countdownInterval = setInterval(updateCountdown, 1000);
}};
</script>
</head>
<body>
<h1>Authentication Successful</h1>
<p>Received the following query parameters:</p>
<table>
<tr><th>Key</th><th>Value</th></tr>
{table_rows}
</table>
<h2>Request Details</h2>
<p><strong>Full Path:</strong> {full_path}</p>
<h3>Headers:</h3>
<pre>{headers}</pre>
<div id="countdown_div">
<p>This tab will close automatically in <span id="countdown">5</span> seconds.</p>
<button onclick="stopCountdown()">Stop Countdown</button>
</div>
</body>
</html>
"""
table_rows = "".join(
f"<tr><td>{key}</td><td>{', '.join(value)}</td></tr>"
for key, value in query_params.items()
)
return HTML_TEMPLATE.format(
table_rows=table_rows,
full_path=full_path,
headers=headers,
)
class CallbackHandler(BaseHTTPRequestHandler):
def do_GET(self):
parsed = urlparse(self.path)
if parsed.path != self.server.allowed_path:
self.send_response(404)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"404 Not Found")
return
query_params = parse_qs(parsed.query)
full_path = self.path
headers = "\n".join(f"{key}: {value}" for key, value in self.headers.items())
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(HTMLResponse(query_params, full_path, headers).encode())
if not self.server.future.done():
self.server.future.set_result(parse_qs(parsed.query))
class OneShotHTTPServer(HTTPServer):
def __init__(self, server_address, RequestHandlerClass, future, allowed_path):
super().__init__(server_address, RequestHandlerClass)
self.future = future
self.allowed_path = allowed_path # Store the allowed path
def start_server(server_address, timeout: int, allowed_path: str) -> Future:
future = Future()
server = OneShotHTTPServer(server_address, CallbackHandler, future, allowed_path)
server.timeout = timeout
def serve():
server.handle_request()
server.server_close()
if not future.done():
future.set_result(None)
thread = threading.Thread(target=serve, daemon=True)
thread.start()
return future
def decode_jwt(token):
header, payload, signature = token.split(".")
def decode_part(part):
padded = part + "=" * (-len(part) % 4) # Fix padding
decoded_bytes = base64.urlsafe_b64decode(padded)
return json.loads(decoded_bytes)
decoded_header = decode_part(header)
decoded_payload = decode_part(payload)
return {
"header": decoded_header,
"payload": decoded_payload,
"signature": signature,
}
def encode_jwt(jwt):
def encode_part(part):
json_bytes = json.dumps(part, separators=(",", ":")).encode()
return base64.urlsafe_b64encode(json_bytes).decode().rstrip("=")
return ".".join(
[
encode_part(jwt["header"]),
encode_part(jwt["payload"]),
jwt["signature"], # Use the original signature without modification
]
)
def start_browser(uri):
# Check if Chrome exists and use it; otherwise, fall back to the default browser
chrome_path = shutil.which("google-chrome")
if chrome_path:
print("Launching Chrome in incognito mode...")
subprocess.Popen([chrome_path, "--new-window", "--incognito", uri])
else:
print("Chrome not found. Falling back to the default browser.")
webbrowser.open(uri)
class MyCallback(OIDCCallback):
def __init__(self, redirect_uri: str):
super().__init__()
self.uri = urlparse(redirect_uri)
def get_uri(self, uri, clientId, code_verifier, nonce) -> dict:
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.rstrip(b"=")
.decode()
)
params = {
"client_id": clientId,
"response_type": "code",
"scope": "openid profile offline_access",
"redirect_uri": self.uri.geturl(),
"state": "somestate",
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"nonce": nonce,
}
return f"{uri}?{urlencode(params)}"
def run_code_flow(self, auth_uri, context, code_verifier, nonce):
print(
"You will be redirected to your browser to log in. Please complete the authentication process."
)
input("Press Enter to continue...")
future = start_server(
(self.uri.hostname, self.uri.port), context.timeout_seconds, self.uri.path
)
uri = self.get_uri(
auth_uri, context.idp_info.clientId, code_verifier, nonce
)
start_browser(uri)
code_result = future.result()
if code_result is None:
raise Exception("Timeout waiting for callback")
if "error" in code_result:
raise Exception(f"Error in callback: {code_result['error'][0]}")
if "code" not in code_result:
raise Exception("Authorization code not found in callback")
if "state" not in code_result:
raise Exception("State not found in callback")
return code_result
def get_oath_metadata(self, issuer):
response = requests.get(f"{issuer}/.well-known/oauth-authorization-server", verify=False)
response.raise_for_status()
return response.json()
def get_jwks(self, uri):
response = requests.get(uri)
response.raise_for_status()
return response.json()
def get_tokens(self, uri, context, code, code_verifier):
response = requests.post(
uri,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.uri.geturl(),
"client_id": context.idp_info.clientId,
"code_verifier": code_verifier,
},
)
response.raise_for_status() # Raise an error if the request fails
token_result = response.json()
id_token = decode_jwt(token_result["id_token"])
access_token = decode_jwt(token_result["access_token"])
refresh_token = token_result.get("refresh_token")
expires_in = token_result.get("expires_in")
return (id_token, access_token, refresh_token, expires_in, token_result)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
code_verifier = secrets.token_urlsafe(64)
print("Code verifier:", code_verifier)
nonce = secrets.token_urlsafe(64)
print("Nonce:", nonce)
oath_metadata = self.get_oath_metadata(context.idp_info.issuer)
print("OAuth Metadata:", json.dumps(oath_metadata, indent=1))
jwks = self.get_jwks(oath_metadata["jwks_uri"])
print("JWKS:", json.dumps(jwks, indent=1))
code_result = self.run_code_flow(oath_metadata["authorization_endpoint"], context, code_verifier, nonce)
print("Authorization code:", code_result["code"][0])
print("State:", code_result["state"][0])
id_token, access_token, refresh_token, expires_in, token_result = self.get_tokens(
oath_metadata["token_endpoint"],
context, code_result["code"][0], code_verifier
)
print("Access Token:", json.dumps(access_token, indent=1))
print("Refresh Token:", json.dumps(refresh_token, indent=1))
print("ID Token:", json.dumps(id_token, indent=1))
print("Expires In:", expires_in)
# TODO validate tokens
print("Token exchange successful")
return OIDCCallbackResult(
access_token=token_result["access_token"],
expires_in_seconds=expires_in,
refresh_token=refresh_token,
)
def build_parser():
import argparse
parser = argparse.ArgumentParser(
description="MongoDB OIDC Authentication Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--host",
default="localhost",
help="MongoDB host",
)
parser.add_argument(
"--port",
type=int,
default=27017,
help="MongoDB port",
)
parser.add_argument(
"--username",
type=str,
default=None,
help="Username for MongoDB authentication",
)
parser.add_argument(
"--redirect_uri",
default="http://localhost:27097/redirect",
help="Redirect URI for OIDC callback",
)
parser.add_argument(
"--tls",
type=bool,
default=False,
help="Use TLS for MongoDB connection",
)
parser.add_argument(
"-i",
"--interactive",
action="store_true",
help="Run in interactive mode",
)
return parser
SHELL_BANNER = """
Running in interactive mode. Type 'exit()' to quit.
You can use the 'client' object to interact with MongoDB.
Tab completion is enabled.
"""
def main():
parser = build_parser()
args = parser.parse_args()
print("MongoDB OIDC Authentication Example")
print("Starting MongoDB client...")
client = MongoClient(
host=args.host,
port=args.port,
connect=True,
tls=args.tls,
authMechanism="MONGODB-OIDC",
username=args.username,
authMechanismProperties={
"OIDC_HUMAN_CALLBACK": MyCallback(args.redirect_uri),
},
)
try:
client.list_database_names()
except Exception as e:
print("Connection failed:", e)
exit(0)
print("Connection successful")
if args.interactive:
readline.set_completer(rlcompleter.Completer(globals()).complete)
readline.parse_and_bind("tab: complete")
variables = globals().copy()
variables.update(locals())
code.interact(banner=SHELL_BANNER, local=variables)
print("Exiting...")
client.close()
if __name__ == "__main__":
main()
pymongo
requests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment