Last active
April 3, 2025 12:53
-
-
Save plebioda/244b09aa24d474dd36048b4db7c9bd83 to your computer and use it in GitHub Desktop.
pymongo with OIDC example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
venv/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pymongo import MongoClient | |
from pymongo.auth_oidc_shared import ( | |
OIDCCallback, | |
OIDCCallbackContext, | |
OIDCCallbackResult, | |
) | |
import webbrowser | |
import secrets | |
import hashlib | |
import base64 | |
from urllib.parse import urlencode, urlparse, parse_qs | |
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from concurrent.futures import Future | |
import threading | |
import requests | |
import json | |
import code | |
import readline | |
import rlcompleter | |
import shutil | |
import subprocess | |
def HTMLResponse(query_params, full_path, headers): | |
HTML_TEMPLATE = """ | |
<html> | |
<head> | |
<title>Authentication Complete</title> | |
<style> | |
body {{ | |
font-family: Arial, sans-serif; | |
margin: 20px; | |
background-color: #f9f9f9; | |
}} | |
h1 {{ | |
color: #333; | |
}} | |
table {{ | |
border-collapse: collapse; | |
width: 100%; | |
margin-top: 20px; | |
}} | |
th, td {{ | |
border: 1px solid #ddd; | |
padding: 8px; | |
text-align: left; | |
}} | |
th {{ | |
background-color: #f2f2f2; | |
color: #333; | |
}} | |
tr:nth-child(even) {{ | |
background-color: #f9f9f9; | |
}} | |
tr:hover {{ | |
background-color: #f1f1f1; | |
}} | |
pre {{ | |
background-color: #f4f4f4; | |
padding: 10px; | |
border: 1px solid #ddd; | |
overflow: auto; | |
}} | |
</style> | |
<script> | |
let secondsRemaining = 5; | |
let countdownInterval; | |
function updateCountdown() {{ | |
const countdownElement = document.getElementById('countdown'); | |
countdownElement.textContent = secondsRemaining; | |
if (secondsRemaining > 0) {{ | |
secondsRemaining--; | |
}} else {{ | |
clearInterval(countdownInterval); | |
window.close(); | |
}} | |
}} | |
function stopCountdown() {{ | |
clearInterval(countdownInterval); | |
const countdownDivElement = document.getElementById('countdown_div'); | |
countdownDivElement.style.display = 'none'; | |
}} | |
window.onload = () => {{ | |
countdownInterval = setInterval(updateCountdown, 1000); | |
}}; | |
</script> | |
</head> | |
<body> | |
<h1>Authentication Successful</h1> | |
<p>Received the following query parameters:</p> | |
<table> | |
<tr><th>Key</th><th>Value</th></tr> | |
{table_rows} | |
</table> | |
<h2>Request Details</h2> | |
<p><strong>Full Path:</strong> {full_path}</p> | |
<h3>Headers:</h3> | |
<pre>{headers}</pre> | |
<div id="countdown_div"> | |
<p>This tab will close automatically in <span id="countdown">5</span> seconds.</p> | |
<button onclick="stopCountdown()">Stop Countdown</button> | |
</div> | |
</body> | |
</html> | |
""" | |
table_rows = "".join( | |
f"<tr><td>{key}</td><td>{', '.join(value)}</td></tr>" | |
for key, value in query_params.items() | |
) | |
return HTML_TEMPLATE.format( | |
table_rows=table_rows, | |
full_path=full_path, | |
headers=headers, | |
) | |
class CallbackHandler(BaseHTTPRequestHandler): | |
def do_GET(self): | |
parsed = urlparse(self.path) | |
if parsed.path != self.server.allowed_path: | |
self.send_response(404) | |
self.send_header("Content-type", "text/plain") | |
self.end_headers() | |
self.wfile.write(b"404 Not Found") | |
return | |
query_params = parse_qs(parsed.query) | |
full_path = self.path | |
headers = "\n".join(f"{key}: {value}" for key, value in self.headers.items()) | |
self.send_response(200) | |
self.send_header("Content-type", "text/html") | |
self.end_headers() | |
self.wfile.write(HTMLResponse(query_params, full_path, headers).encode()) | |
if not self.server.future.done(): | |
self.server.future.set_result(parse_qs(parsed.query)) | |
class OneShotHTTPServer(HTTPServer): | |
def __init__(self, server_address, RequestHandlerClass, future, allowed_path): | |
super().__init__(server_address, RequestHandlerClass) | |
self.future = future | |
self.allowed_path = allowed_path # Store the allowed path | |
def start_server(server_address, timeout: int, allowed_path: str) -> Future: | |
future = Future() | |
server = OneShotHTTPServer(server_address, CallbackHandler, future, allowed_path) | |
server.timeout = timeout | |
def serve(): | |
server.handle_request() | |
server.server_close() | |
if not future.done(): | |
future.set_result(None) | |
thread = threading.Thread(target=serve, daemon=True) | |
thread.start() | |
return future | |
def decode_jwt(token): | |
header, payload, signature = token.split(".") | |
def decode_part(part): | |
padded = part + "=" * (-len(part) % 4) # Fix padding | |
decoded_bytes = base64.urlsafe_b64decode(padded) | |
return json.loads(decoded_bytes) | |
decoded_header = decode_part(header) | |
decoded_payload = decode_part(payload) | |
return { | |
"header": decoded_header, | |
"payload": decoded_payload, | |
"signature": signature, | |
} | |
def encode_jwt(jwt): | |
def encode_part(part): | |
json_bytes = json.dumps(part, separators=(",", ":")).encode() | |
return base64.urlsafe_b64encode(json_bytes).decode().rstrip("=") | |
return ".".join( | |
[ | |
encode_part(jwt["header"]), | |
encode_part(jwt["payload"]), | |
jwt["signature"], # Use the original signature without modification | |
] | |
) | |
def start_browser(uri): | |
# Check if Chrome exists and use it; otherwise, fall back to the default browser | |
chrome_path = shutil.which("google-chrome") | |
if chrome_path: | |
print("Launching Chrome in incognito mode...") | |
subprocess.Popen([chrome_path, "--new-window", "--incognito", uri]) | |
else: | |
print("Chrome not found. Falling back to the default browser.") | |
webbrowser.open(uri) | |
class MyCallback(OIDCCallback): | |
def __init__(self, redirect_uri: str): | |
super().__init__() | |
self.uri = urlparse(redirect_uri) | |
def get_uri(self, uri, clientId, code_verifier, nonce) -> dict: | |
code_challenge = ( | |
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) | |
.rstrip(b"=") | |
.decode() | |
) | |
params = { | |
"client_id": clientId, | |
"response_type": "code", | |
"scope": "openid profile offline_access", | |
"redirect_uri": self.uri.geturl(), | |
"state": "somestate", | |
"code_challenge": code_challenge, | |
"code_challenge_method": "S256", | |
"nonce": nonce, | |
} | |
return f"{uri}?{urlencode(params)}" | |
def run_code_flow(self, auth_uri, context, code_verifier, nonce): | |
print( | |
"You will be redirected to your browser to log in. Please complete the authentication process." | |
) | |
input("Press Enter to continue...") | |
future = start_server( | |
(self.uri.hostname, self.uri.port), context.timeout_seconds, self.uri.path | |
) | |
uri = self.get_uri( | |
auth_uri, context.idp_info.clientId, code_verifier, nonce | |
) | |
start_browser(uri) | |
code_result = future.result() | |
if code_result is None: | |
raise Exception("Timeout waiting for callback") | |
if "error" in code_result: | |
raise Exception(f"Error in callback: {code_result['error'][0]}") | |
if "code" not in code_result: | |
raise Exception("Authorization code not found in callback") | |
if "state" not in code_result: | |
raise Exception("State not found in callback") | |
return code_result | |
def get_oath_metadata(self, issuer): | |
response = requests.get(f"{issuer}/.well-known/oauth-authorization-server", verify=False) | |
response.raise_for_status() | |
return response.json() | |
def get_jwks(self, uri): | |
response = requests.get(uri) | |
response.raise_for_status() | |
return response.json() | |
def get_tokens(self, uri, context, code, code_verifier): | |
response = requests.post( | |
uri, | |
data={ | |
"grant_type": "authorization_code", | |
"code": code, | |
"redirect_uri": self.uri.geturl(), | |
"client_id": context.idp_info.clientId, | |
"code_verifier": code_verifier, | |
}, | |
) | |
response.raise_for_status() # Raise an error if the request fails | |
token_result = response.json() | |
id_token = decode_jwt(token_result["id_token"]) | |
access_token = decode_jwt(token_result["access_token"]) | |
refresh_token = token_result.get("refresh_token") | |
expires_in = token_result.get("expires_in") | |
return (id_token, access_token, refresh_token, expires_in, token_result) | |
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: | |
code_verifier = secrets.token_urlsafe(64) | |
print("Code verifier:", code_verifier) | |
nonce = secrets.token_urlsafe(64) | |
print("Nonce:", nonce) | |
oath_metadata = self.get_oath_metadata(context.idp_info.issuer) | |
print("OAuth Metadata:", json.dumps(oath_metadata, indent=1)) | |
jwks = self.get_jwks(oath_metadata["jwks_uri"]) | |
print("JWKS:", json.dumps(jwks, indent=1)) | |
code_result = self.run_code_flow(oath_metadata["authorization_endpoint"], context, code_verifier, nonce) | |
print("Authorization code:", code_result["code"][0]) | |
print("State:", code_result["state"][0]) | |
id_token, access_token, refresh_token, expires_in, token_result = self.get_tokens( | |
oath_metadata["token_endpoint"], | |
context, code_result["code"][0], code_verifier | |
) | |
print("Access Token:", json.dumps(access_token, indent=1)) | |
print("Refresh Token:", json.dumps(refresh_token, indent=1)) | |
print("ID Token:", json.dumps(id_token, indent=1)) | |
print("Expires In:", expires_in) | |
# TODO validate tokens | |
print("Token exchange successful") | |
return OIDCCallbackResult( | |
access_token=token_result["access_token"], | |
expires_in_seconds=expires_in, | |
refresh_token=refresh_token, | |
) | |
def build_parser(): | |
import argparse | |
parser = argparse.ArgumentParser( | |
description="MongoDB OIDC Authentication Example", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--host", | |
default="localhost", | |
help="MongoDB host", | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
default=27017, | |
help="MongoDB port", | |
) | |
parser.add_argument( | |
"--username", | |
type=str, | |
default=None, | |
help="Username for MongoDB authentication", | |
) | |
parser.add_argument( | |
"--redirect_uri", | |
default="http://localhost:27097/redirect", | |
help="Redirect URI for OIDC callback", | |
) | |
parser.add_argument( | |
"--tls", | |
type=bool, | |
default=False, | |
help="Use TLS for MongoDB connection", | |
) | |
parser.add_argument( | |
"-i", | |
"--interactive", | |
action="store_true", | |
help="Run in interactive mode", | |
) | |
return parser | |
SHELL_BANNER = """ | |
Running in interactive mode. Type 'exit()' to quit. | |
You can use the 'client' object to interact with MongoDB. | |
Tab completion is enabled. | |
""" | |
def main(): | |
parser = build_parser() | |
args = parser.parse_args() | |
print("MongoDB OIDC Authentication Example") | |
print("Starting MongoDB client...") | |
client = MongoClient( | |
host=args.host, | |
port=args.port, | |
connect=True, | |
tls=args.tls, | |
authMechanism="MONGODB-OIDC", | |
username=args.username, | |
authMechanismProperties={ | |
"OIDC_HUMAN_CALLBACK": MyCallback(args.redirect_uri), | |
}, | |
) | |
try: | |
client.list_database_names() | |
except Exception as e: | |
print("Connection failed:", e) | |
exit(0) | |
print("Connection successful") | |
if args.interactive: | |
readline.set_completer(rlcompleter.Completer(globals()).complete) | |
readline.parse_and_bind("tab: complete") | |
variables = globals().copy() | |
variables.update(locals()) | |
code.interact(banner=SHELL_BANNER, local=variables) | |
print("Exiting...") | |
client.close() | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pymongo | |
requests |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment