Last active
June 3, 2025 09:04
-
-
Save plebioda/244b09aa24d474dd36048b4db7c9bd83 to your computer and use it in GitHub Desktop.
pymongo with OIDC example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
venv/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from selenium import webdriver | |
from selenium.webdriver.chrome.service import Service | |
from selenium.webdriver.chrome.options import Options | |
from webdriver_manager.chrome import ChromeDriverManager | |
from selenium.webdriver.support.ui import WebDriverWait | |
from selenium.webdriver.support import expected_conditions | |
from selenium.webdriver.common.by import By | |
import traceback | |
class KeycloakAuthenticator: | |
@classmethod | |
def check(cls, url, driver): | |
return "keycloak" in url or "keycloak" in driver.page_source.lower() | |
@classmethod | |
def run(cls, driver, username: str, password: str, wait_step: int): | |
username_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='username']") | |
) | |
) | |
username_input_box.send_keys(username) | |
password_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='password']") | |
) | |
) | |
password_input_box.send_keys(password) | |
verify_button = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//button[@name='login']") | |
) | |
) | |
verify_button.click() | |
class OktaAuthenticator: | |
@classmethod | |
def check(cls, url, driver): | |
return "okta" in url | |
@classmethod | |
def run(cls, driver, username: str, password: str, wait_step: int): | |
username_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='username']") | |
) | |
) | |
username_input_box.send_keys(username) | |
password_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='password']") | |
) | |
) | |
password_input_box.send_keys(password) | |
verify_button = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@class='button button-primary'][@value='Sign In']") | |
) | |
) | |
verify_button.click() | |
class PingIdentityAuthenticator: | |
@classmethod | |
def check(cls, url, driver): | |
return "pingone.eu" in url or "pingone.eu" in driver.page_source.lower() | |
@classmethod | |
def run(cls, driver, username: str, password: str, wait_step: int): | |
username_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='username']") | |
) | |
) | |
username_input_box.send_keys(username) | |
password_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='password']") | |
) | |
) | |
password_input_box.send_keys(password) | |
verify_button = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
( | |
By.XPATH, | |
"//button[@type='submit']", | |
) | |
) | |
) | |
verify_button.click() | |
class MSEntraAuthenticator: | |
@classmethod | |
def check(cls, url, driver): | |
return "microsoftonline" in url | |
@classmethod | |
def run(cls, driver, username: str, password: str, wait_step: int): | |
username_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='loginfmt']") | |
) | |
) | |
username_input_box.send_keys(username) | |
next_button = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
( | |
By.XPATH, | |
"//input[@type='submit'][@value='Next']", | |
) | |
) | |
) | |
next_button.click() | |
password_input_box = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
(By.XPATH, "//input[@name='passwd']") | |
) | |
) | |
password_input_box.send_keys(password) | |
signin_button = WebDriverWait(driver, wait_step).until( | |
expected_conditions.element_to_be_clickable( | |
( | |
By.XPATH, | |
"//input[@type='submit'][@value='Sign in']", | |
) | |
) | |
) | |
signin_button.click() | |
class OIDCAuthenticator: | |
def __init__(self): | |
self.driver = webdriver.Chrome( | |
service=Service(ChromeDriverManager().install()), | |
options=self._get_chrome_options(), | |
) | |
self.auth_methods = { | |
"okta": OktaAuthenticator, | |
"keycloak": KeycloakAuthenticator, | |
"ping": PingIdentityAuthenticator, | |
"entra": MSEntraAuthenticator, | |
} | |
def _get_chrome_options(self): | |
options = Options() | |
options.add_argument("--headless") | |
options.add_argument("--no-sandbox") | |
options.add_argument("--disable-dev-shm-usage") | |
options.add_argument("--disable-gpu") | |
options.add_argument("--ignore-certificate-errors") | |
return options | |
def _get_authenticator(self, url, driver): | |
for name, auth_class in self.auth_methods.items(): | |
if auth_class.check(url, driver): | |
return auth_class | |
else: | |
return None | |
def authenticate(self, url, username, password): | |
self.driver.get(url) | |
try: | |
authenticator = self._get_authenticator(url, self.driver) | |
if authenticator is None: | |
raise Exception("No valid authenticator found for the URL.") | |
authenticator.run(self.driver, username, password, 5) | |
landing_header = WebDriverWait(self.driver, 30).until( | |
expected_conditions.presence_of_element_located( | |
(By.XPATH, "//h1[contains(text(), 'Authentication Successful')]") | |
) | |
) | |
assert landing_header is not None | |
except Exception as e: | |
print("Error: ", e) | |
print("Traceback: ", traceback.format_exc()) | |
print("Success") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pymongo import MongoClient | |
from pymongo.auth_oidc_shared import ( | |
OIDCCallback, | |
OIDCCallbackContext, | |
OIDCCallbackResult, | |
) | |
import webbrowser | |
import secrets | |
import hashlib | |
import base64 | |
from urllib.parse import urlencode, urlparse, parse_qs | |
from http.server import BaseHTTPRequestHandler, HTTPServer | |
from concurrent.futures import Future | |
import threading | |
import requests | |
from requests.auth import HTTPBasicAuth | |
import json | |
import code | |
import readline | |
import rlcompleter | |
import shutil | |
import subprocess | |
import getpass | |
import traceback | |
from oidc_authenticator import OIDCAuthenticator | |
def tamper_jwt(jwt_token: str, extra_claims: dict) -> str: | |
def b64url_decode(data): | |
return base64.urlsafe_b64decode(data + "=" * (-len(data) % 4)) | |
def b64url_encode(data): | |
return base64.urlsafe_b64encode(data).rstrip(b"=") | |
header_b64, payload_b64, signature_b64 = jwt_token.split(".") | |
# Decode payload | |
payload = json.loads(b64url_decode(payload_b64).decode()) | |
# Add / modify claims | |
payload.update(extra_claims) | |
# Encode back | |
new_payload_b64 = b64url_encode( | |
json.dumps(payload, separators=(",", ":")).encode() | |
).decode() | |
# Reconstruct token with original signature (invalid!) | |
tampered_token = f"{header_b64}.{new_payload_b64}.{signature_b64}" | |
return tampered_token | |
def HTMLResponse(query_params, full_path, headers): | |
HTML_TEMPLATE = """ | |
<html> | |
<head> | |
<title>Authentication Complete</title> | |
<style> | |
body {{ | |
font-family: Arial, sans-serif; | |
margin: 20px; | |
background-color: #f9f9f9; | |
}} | |
h1 {{ | |
color: #333; | |
}} | |
table {{ | |
border-collapse: collapse; | |
width: 100%; | |
margin-top: 20px; | |
}} | |
th, td {{ | |
border: 1px solid #ddd; | |
padding: 8px; | |
text-align: left; | |
}} | |
th {{ | |
background-color: #f2f2f2; | |
color: #333; | |
}} | |
tr:nth-child(even) {{ | |
background-color: #f9f9f9; | |
}} | |
tr:hover {{ | |
background-color: #f1f1f1; | |
}} | |
pre {{ | |
background-color: #f4f4f4; | |
padding: 10px; | |
border: 1px solid #ddd; | |
overflow: auto; | |
}} | |
</style> | |
<script> | |
let secondsRemaining = 5; | |
let countdownInterval; | |
function updateCountdown() {{ | |
const countdownElement = document.getElementById('countdown'); | |
countdownElement.textContent = secondsRemaining; | |
if (secondsRemaining > 0) {{ | |
secondsRemaining--; | |
}} else {{ | |
clearInterval(countdownInterval); | |
window.close(); | |
}} | |
}} | |
function stopCountdown() {{ | |
clearInterval(countdownInterval); | |
const countdownDivElement = document.getElementById('countdown_div'); | |
countdownDivElement.style.display = 'none'; | |
}} | |
window.onload = () => {{ | |
countdownInterval = setInterval(updateCountdown, 1000); | |
}}; | |
</script> | |
</head> | |
<body> | |
<h1>Authentication Successful</h1> | |
<p>Received the following query parameters:</p> | |
<table> | |
<tr><th>Key</th><th>Value</th></tr> | |
{table_rows} | |
</table> | |
<h2>Request Details</h2> | |
<p><strong>Full Path:</strong> {full_path}</p> | |
<h3>Headers:</h3> | |
<pre>{headers}</pre> | |
<div id="countdown_div"> | |
<p>This tab will close automatically in <span id="countdown">5</span> seconds.</p> | |
<button onclick="stopCountdown()">Stop Countdown</button> | |
</div> | |
</body> | |
</html> | |
""" | |
table_rows = "".join( | |
f"<tr><td>{key}</td><td>{', '.join(value)}</td></tr>" | |
for key, value in query_params.items() | |
) | |
return HTML_TEMPLATE.format( | |
table_rows=table_rows, | |
full_path=full_path, | |
headers=headers, | |
) | |
class CallbackHandler(BaseHTTPRequestHandler): | |
def do_GET(self): | |
parsed = urlparse(self.path) | |
if parsed.path != self.server.allowed_path: | |
self.send_response(404) | |
self.send_header("Content-type", "text/plain") | |
self.end_headers() | |
self.wfile.write(b"404 Not Found") | |
return | |
query_params = parse_qs(parsed.query) | |
full_path = self.path | |
headers = "\n".join(f"{key}: {value}" for key, value in self.headers.items()) | |
self.send_response(200) | |
self.send_header("Content-type", "text/html") | |
self.end_headers() | |
self.wfile.write(HTMLResponse(query_params, full_path, headers).encode()) | |
if not self.server.future.done(): | |
self.server.future.set_result(parse_qs(parsed.query)) | |
class OneShotHTTPServer(HTTPServer): | |
def __init__(self, server_address, RequestHandlerClass, future, allowed_path): | |
super().__init__(server_address, RequestHandlerClass) | |
self.future = future | |
self.allowed_path = allowed_path # Store the allowed path | |
def start_server(server_address, timeout: int, allowed_path: str) -> Future: | |
future = Future() | |
server = OneShotHTTPServer(server_address, CallbackHandler, future, allowed_path) | |
server.timeout = timeout | |
def serve(): | |
server.handle_request() | |
server.server_close() | |
if not future.done(): | |
future.set_result(None) | |
thread = threading.Thread(target=serve, daemon=True) | |
thread.start() | |
return future | |
def decode_jwt(token): | |
try: | |
header, payload, signature = token.split(".") | |
def decode_part(part): | |
padded = part + "=" * (-len(part) % 4) # Fix padding | |
decoded_bytes = base64.urlsafe_b64decode(padded) | |
return json.loads(decoded_bytes) | |
decoded_header = decode_part(header) | |
decoded_payload = decode_part(payload) | |
return { | |
"header": decoded_header, | |
"payload": decoded_payload, | |
"signature": signature, | |
} | |
except: | |
return token | |
def encode_jwt(jwt): | |
def encode_part(part): | |
json_bytes = json.dumps(part, separators=(",", ":")).encode() | |
return base64.urlsafe_b64encode(json_bytes).decode().rstrip("=") | |
return ".".join( | |
[ | |
encode_part(jwt["header"]), | |
encode_part(jwt["payload"]), | |
jwt["signature"], # Use the original signature without modification | |
] | |
) | |
def start_authenticator(uri, username, password): | |
auth = OIDCAuthenticator() | |
auth.authenticate(uri, username, password) | |
return | |
def start_browser(uri): | |
print( | |
"You will be redirected to your browser to log in. Please complete the authentication process." | |
) | |
input("Press Enter to continue...") | |
# Check if Chrome exists and use it; otherwise, fall back to the default browser | |
chrome_path = shutil.which("google-chrome") | |
if chrome_path: | |
print("Launching Chrome in incognito mode...") | |
subprocess.Popen([chrome_path, "--new-window", "--incognito", uri]) | |
else: | |
print("Chrome not found. Falling back to the default browser.") | |
webbrowser.open(uri) | |
def read_file(path): | |
with open(path, "r") as file: | |
return file.read().strip() | |
def write_file(path, content): | |
with open(path, "w") as file: | |
file.write(content) | |
class MyCallback(OIDCCallback): | |
def __init__( | |
self, | |
redirect_uri: str, | |
credentials: tuple = None, | |
use_id_token=False, | |
tamper_jwt=False, | |
jwt_path: str = None, | |
client_auth_config: str = None, | |
): | |
super().__init__() | |
self.credentials = credentials | |
self.tamper_jwt = tamper_jwt | |
self.use_id_token = use_id_token | |
self.uri = urlparse(redirect_uri) | |
if client_auth_config: | |
self.jwt = self.client_auth(client_auth_config) | |
if jwt_path: | |
write_file(jwt_path, self.jwt) | |
elif jwt_path: | |
self.jwt = read_file(jwt_path) | |
else: | |
self.jwt = None | |
def client_auth(self, client_auth_config: str): | |
data = json.load(open(client_auth_config, "r")) | |
issuer = data.get("issuer") | |
client_id = data.get("client_id") | |
client_secret = data.get("client_secret") | |
scope = data.get("scope") | |
if not issuer or not client_id or not client_secret: | |
raise ValueError("Invalid client authentication configuration") | |
oath_metadata = requests.get( | |
f"{issuer}/.well-known/openid-configuration", verify=False | |
).json() | |
response = requests.post( | |
oath_metadata["token_endpoint"], | |
data={ | |
"grant_type": "client_credentials", | |
"scope": scope, | |
}, | |
auth=HTTPBasicAuth(client_id, client_secret), | |
) | |
response.raise_for_status() # Raise an error if the request fails | |
token_result = response.json() | |
access_token = decode_jwt(token_result["access_token"]) | |
print("Access Token:", json.dumps(access_token, indent=1)) | |
return ( | |
token_result["access_token"] | |
if not self.use_id_token | |
else token_result["id_token"] | |
) | |
def get_uri(self, uri, clientId, code_verifier, nonce, scopes) -> dict: | |
code_challenge = ( | |
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) | |
.rstrip(b"=") | |
.decode() | |
) | |
scope = "openid profile offline_access " + " ".join(scopes) | |
print("Requesting scopes:", scope) | |
params = { | |
"client_id": clientId, | |
"response_type": "code", | |
"scope": scope, | |
"redirect_uri": self.uri.geturl(), | |
"state": "somestate", | |
"code_challenge": code_challenge, | |
"code_challenge_method": "S256", | |
"nonce": nonce, | |
} | |
return f"{uri}?{urlencode(params)}" | |
def run_code_flow(self, auth_uri, context, code_verifier, nonce): | |
future = start_server( | |
(self.uri.hostname, self.uri.port), context.timeout_seconds, self.uri.path | |
) | |
uri = self.get_uri( | |
auth_uri, | |
context.idp_info.clientId, | |
code_verifier, | |
nonce, | |
scopes=context.idp_info.requestScopes or [], | |
) | |
if self.credentials: | |
start_authenticator(uri, self.credentials[0], self.credentials[1]) | |
else: | |
start_browser(uri) | |
code_result = future.result() | |
if code_result is None: | |
raise Exception("Timeout waiting for callback") | |
if "error" in code_result: | |
raise Exception(f"Error in callback: {code_result['error'][0]}") | |
if "code" not in code_result: | |
raise Exception("Authorization code not found in callback") | |
if "state" not in code_result: | |
raise Exception("State not found in callback") | |
return code_result | |
def get_oath_metadata(self, issuer): | |
response = requests.get( | |
f"{issuer}/.well-known/openid-configuration", verify=False | |
) | |
response.raise_for_status() | |
return response.json() | |
def get_jwks(self, uri): | |
response = requests.get(uri, verify=False) | |
response.raise_for_status() | |
return response.json() | |
def get_tokens(self, uri, context, code, code_verifier): | |
response = requests.post( | |
uri, | |
data={ | |
"grant_type": "authorization_code", | |
"code": code, | |
"redirect_uri": self.uri.geturl(), | |
"client_id": context.idp_info.clientId, | |
"code_verifier": code_verifier, | |
}, | |
verify=False, | |
) | |
response.raise_for_status() # Raise an error if the request fails | |
token_result = response.json() | |
print(token_result) | |
id_token = decode_jwt(token_result["id_token"]) | |
access_token = decode_jwt(token_result["access_token"]) | |
refresh_token = token_result.get("refresh_token") | |
expires_in = token_result.get("expires_in") | |
return (id_token, access_token, refresh_token, expires_in, token_result) | |
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: | |
if self.jwt is not None: | |
return OIDCCallbackResult( | |
access_token=self.jwt, | |
expires_in_seconds=None, | |
refresh_token=None, | |
) | |
code_verifier = secrets.token_urlsafe(64) | |
print("Code verifier:", code_verifier) | |
nonce = secrets.token_urlsafe(64) | |
print("Nonce:", nonce) | |
print(context) | |
print(context.idp_info) | |
oath_metadata = self.get_oath_metadata(context.idp_info.issuer) | |
print("OAuth Metadata:", json.dumps(oath_metadata, indent=1)) | |
jwks = self.get_jwks(oath_metadata["jwks_uri"]) | |
print("JWKS:", json.dumps(jwks, indent=1)) | |
code_result = self.run_code_flow( | |
oath_metadata["authorization_endpoint"], context, code_verifier, nonce | |
) | |
print("Authorization code:", code_result["code"][0]) | |
print("State:", code_result["state"][0]) | |
id_token, access_token, refresh_token, expires_in, token_result = ( | |
self.get_tokens( | |
oath_metadata["token_endpoint"], | |
context, | |
code_result["code"][0], | |
code_verifier, | |
) | |
) | |
print("Access Token:", json.dumps(access_token, indent=1)) | |
print("Refresh Token:", json.dumps(refresh_token, indent=1)) | |
print("ID Token:", json.dumps(id_token, indent=1)) | |
print("Expires In:", expires_in) | |
token_jwt = ( | |
token_result["access_token"] | |
if not self.use_id_token | |
else token_result["id_token"] | |
) | |
if self.tamper_jwt: | |
token_jwt = tamper_jwt( | |
token_jwt, | |
{"auth_claim": ["root", "admin"], "MyClaim": ["root", "admin"]}, | |
) | |
print( | |
"Tampered JWT: ", | |
json.dumps(decode_jwt(token_jwt), indent=1), | |
) | |
print("Token exchange successful") | |
return OIDCCallbackResult( | |
access_token=token_jwt, | |
expires_in_seconds=expires_in, | |
refresh_token=refresh_token, | |
) | |
def build_parser(): | |
import argparse | |
parser = argparse.ArgumentParser( | |
description="MongoDB OIDC Authentication Example", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--host", | |
default="localhost", | |
help="MongoDB host", | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
default=27017, | |
help="MongoDB port", | |
) | |
parser.add_argument( | |
"--username", | |
type=str, | |
default=None, | |
help="Username for MongoDB authentication mechanism", | |
) | |
parser.add_argument( | |
"--login", | |
type=str, | |
default=None, | |
help="<username>:<password> for IdP login", | |
) | |
parser.add_argument( | |
"--redirect_uri", | |
default="http://localhost:27097/redirect", | |
help="Redirect URI for OIDC callback", | |
) | |
parser.add_argument( | |
"--tls", | |
type=bool, | |
default=False, | |
help="Use TLS for MongoDB connection", | |
) | |
parser.add_argument( | |
"-i", | |
"--interactive", | |
action="store_true", | |
help="Run in interactive mode", | |
) | |
parser.add_argument( | |
"--tamper-jwt", | |
action="store_true", | |
help="Tamper JWT to add some additional claims", | |
) | |
parser.add_argument( | |
"--use-id-token", | |
action="store_true", | |
help="Use ID Token instead of access token", | |
) | |
parser.add_argument( | |
"--jwt-path", | |
type=str, | |
help="Path to JWT token for non-human authentication flows", | |
) | |
parser.add_argument( | |
"--client-auth-config", | |
type=str, | |
help="Path to client authentication configuration file", | |
) | |
return parser | |
SHELL_BANNER = """ | |
Running in interactive mode. Type 'exit()' to quit. | |
You can use the 'client' object to interact with MongoDB. | |
Tab completion is enabled. | |
""" | |
def parse_login(login): | |
if login is None: | |
return None | |
try: | |
if ":" not in login: | |
username = login | |
password = getpass.getpass(f"Enter password for '{username}': ") | |
else: | |
username, password = login.split(":") | |
except ValueError: | |
raise ValueError("Login must be in the format <username>:<password>") | |
return username, password | |
def main(): | |
parser = build_parser() | |
args = parser.parse_args() | |
print("MongoDB OIDC Authentication Example") | |
print("Starting MongoDB client...") | |
use_human_flows = args.jwt_path is None and args.client_auth_config is None | |
callback = MyCallback( | |
redirect_uri=args.redirect_uri, | |
credentials=parse_login(args.login), | |
use_id_token=args.use_id_token, | |
tamper_jwt=args.tamper_jwt, | |
jwt_path=args.jwt_path, | |
client_auth_config=args.client_auth_config, | |
) | |
authMechanismProperties = {} | |
if use_human_flows: | |
authMechanismProperties["OIDC_HUMAN_CALLBACK"] = callback | |
else: | |
authMechanismProperties["OIDC_CALLBACK"] = callback | |
client = MongoClient( | |
host=args.host, | |
port=args.port, | |
connect=True, | |
tls=args.tls, | |
authMechanism="MONGODB-OIDC", | |
username=args.username, | |
authMechanismProperties=authMechanismProperties, | |
) | |
try: | |
status = client.admin.command({"connectionStatus": 1}) | |
print(f"Connection status: {json.dumps(status, indent=1)}") | |
except Exception as e: | |
print("Connection failed:", e) | |
tb = traceback.format_exc() | |
print("Traceback:\n", tb) | |
exit(0) | |
print("Connection successful") | |
if args.interactive: | |
readline.set_completer(rlcompleter.Completer(globals()).complete) | |
readline.parse_and_bind("tab: complete") | |
variables = globals().copy() | |
variables.update(locals()) | |
code.interact(banner=SHELL_BANNER, local=variables) | |
print("Exiting...") | |
client.close() | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pymongo | |
requests | |
selenium | |
webdriver-manager |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment