Skip to content

Instantly share code, notes, and snippets.

@plebioda
Last active June 3, 2025 09:04
Show Gist options
  • Save plebioda/244b09aa24d474dd36048b4db7c9bd83 to your computer and use it in GitHub Desktop.
Save plebioda/244b09aa24d474dd36048b4db7c9bd83 to your computer and use it in GitHub Desktop.
pymongo with OIDC example
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions
from selenium.webdriver.common.by import By
import traceback
class KeycloakAuthenticator:
@classmethod
def check(cls, url, driver):
return "keycloak" in url or "keycloak" in driver.page_source.lower()
@classmethod
def run(cls, driver, username: str, password: str, wait_step: int):
username_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='username']")
)
)
username_input_box.send_keys(username)
password_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='password']")
)
)
password_input_box.send_keys(password)
verify_button = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//button[@name='login']")
)
)
verify_button.click()
class OktaAuthenticator:
@classmethod
def check(cls, url, driver):
return "okta" in url
@classmethod
def run(cls, driver, username: str, password: str, wait_step: int):
username_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='username']")
)
)
username_input_box.send_keys(username)
password_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='password']")
)
)
password_input_box.send_keys(password)
verify_button = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@class='button button-primary'][@value='Sign In']")
)
)
verify_button.click()
class PingIdentityAuthenticator:
@classmethod
def check(cls, url, driver):
return "pingone.eu" in url or "pingone.eu" in driver.page_source.lower()
@classmethod
def run(cls, driver, username: str, password: str, wait_step: int):
username_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='username']")
)
)
username_input_box.send_keys(username)
password_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='password']")
)
)
password_input_box.send_keys(password)
verify_button = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(
By.XPATH,
"//button[@type='submit']",
)
)
)
verify_button.click()
class MSEntraAuthenticator:
@classmethod
def check(cls, url, driver):
return "microsoftonline" in url
@classmethod
def run(cls, driver, username: str, password: str, wait_step: int):
username_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='loginfmt']")
)
)
username_input_box.send_keys(username)
next_button = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(
By.XPATH,
"//input[@type='submit'][@value='Next']",
)
)
)
next_button.click()
password_input_box = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(By.XPATH, "//input[@name='passwd']")
)
)
password_input_box.send_keys(password)
signin_button = WebDriverWait(driver, wait_step).until(
expected_conditions.element_to_be_clickable(
(
By.XPATH,
"//input[@type='submit'][@value='Sign in']",
)
)
)
signin_button.click()
class OIDCAuthenticator:
def __init__(self):
self.driver = webdriver.Chrome(
service=Service(ChromeDriverManager().install()),
options=self._get_chrome_options(),
)
self.auth_methods = {
"okta": OktaAuthenticator,
"keycloak": KeycloakAuthenticator,
"ping": PingIdentityAuthenticator,
"entra": MSEntraAuthenticator,
}
def _get_chrome_options(self):
options = Options()
options.add_argument("--headless")
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--disable-gpu")
options.add_argument("--ignore-certificate-errors")
return options
def _get_authenticator(self, url, driver):
for name, auth_class in self.auth_methods.items():
if auth_class.check(url, driver):
return auth_class
else:
return None
def authenticate(self, url, username, password):
self.driver.get(url)
try:
authenticator = self._get_authenticator(url, self.driver)
if authenticator is None:
raise Exception("No valid authenticator found for the URL.")
authenticator.run(self.driver, username, password, 5)
landing_header = WebDriverWait(self.driver, 30).until(
expected_conditions.presence_of_element_located(
(By.XPATH, "//h1[contains(text(), 'Authentication Successful')]")
)
)
assert landing_header is not None
except Exception as e:
print("Error: ", e)
print("Traceback: ", traceback.format_exc())
print("Success")
from pymongo import MongoClient
from pymongo.auth_oidc_shared import (
OIDCCallback,
OIDCCallbackContext,
OIDCCallbackResult,
)
import webbrowser
import secrets
import hashlib
import base64
from urllib.parse import urlencode, urlparse, parse_qs
from http.server import BaseHTTPRequestHandler, HTTPServer
from concurrent.futures import Future
import threading
import requests
from requests.auth import HTTPBasicAuth
import json
import code
import readline
import rlcompleter
import shutil
import subprocess
import getpass
import traceback
from oidc_authenticator import OIDCAuthenticator
def tamper_jwt(jwt_token: str, extra_claims: dict) -> str:
def b64url_decode(data):
return base64.urlsafe_b64decode(data + "=" * (-len(data) % 4))
def b64url_encode(data):
return base64.urlsafe_b64encode(data).rstrip(b"=")
header_b64, payload_b64, signature_b64 = jwt_token.split(".")
# Decode payload
payload = json.loads(b64url_decode(payload_b64).decode())
# Add / modify claims
payload.update(extra_claims)
# Encode back
new_payload_b64 = b64url_encode(
json.dumps(payload, separators=(",", ":")).encode()
).decode()
# Reconstruct token with original signature (invalid!)
tampered_token = f"{header_b64}.{new_payload_b64}.{signature_b64}"
return tampered_token
def HTMLResponse(query_params, full_path, headers):
HTML_TEMPLATE = """
<html>
<head>
<title>Authentication Complete</title>
<style>
body {{
font-family: Arial, sans-serif;
margin: 20px;
background-color: #f9f9f9;
}}
h1 {{
color: #333;
}}
table {{
border-collapse: collapse;
width: 100%;
margin-top: 20px;
}}
th, td {{
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}}
th {{
background-color: #f2f2f2;
color: #333;
}}
tr:nth-child(even) {{
background-color: #f9f9f9;
}}
tr:hover {{
background-color: #f1f1f1;
}}
pre {{
background-color: #f4f4f4;
padding: 10px;
border: 1px solid #ddd;
overflow: auto;
}}
</style>
<script>
let secondsRemaining = 5;
let countdownInterval;
function updateCountdown() {{
const countdownElement = document.getElementById('countdown');
countdownElement.textContent = secondsRemaining;
if (secondsRemaining > 0) {{
secondsRemaining--;
}} else {{
clearInterval(countdownInterval);
window.close();
}}
}}
function stopCountdown() {{
clearInterval(countdownInterval);
const countdownDivElement = document.getElementById('countdown_div');
countdownDivElement.style.display = 'none';
}}
window.onload = () => {{
countdownInterval = setInterval(updateCountdown, 1000);
}};
</script>
</head>
<body>
<h1>Authentication Successful</h1>
<p>Received the following query parameters:</p>
<table>
<tr><th>Key</th><th>Value</th></tr>
{table_rows}
</table>
<h2>Request Details</h2>
<p><strong>Full Path:</strong> {full_path}</p>
<h3>Headers:</h3>
<pre>{headers}</pre>
<div id="countdown_div">
<p>This tab will close automatically in <span id="countdown">5</span> seconds.</p>
<button onclick="stopCountdown()">Stop Countdown</button>
</div>
</body>
</html>
"""
table_rows = "".join(
f"<tr><td>{key}</td><td>{', '.join(value)}</td></tr>"
for key, value in query_params.items()
)
return HTML_TEMPLATE.format(
table_rows=table_rows,
full_path=full_path,
headers=headers,
)
class CallbackHandler(BaseHTTPRequestHandler):
def do_GET(self):
parsed = urlparse(self.path)
if parsed.path != self.server.allowed_path:
self.send_response(404)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"404 Not Found")
return
query_params = parse_qs(parsed.query)
full_path = self.path
headers = "\n".join(f"{key}: {value}" for key, value in self.headers.items())
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(HTMLResponse(query_params, full_path, headers).encode())
if not self.server.future.done():
self.server.future.set_result(parse_qs(parsed.query))
class OneShotHTTPServer(HTTPServer):
def __init__(self, server_address, RequestHandlerClass, future, allowed_path):
super().__init__(server_address, RequestHandlerClass)
self.future = future
self.allowed_path = allowed_path # Store the allowed path
def start_server(server_address, timeout: int, allowed_path: str) -> Future:
future = Future()
server = OneShotHTTPServer(server_address, CallbackHandler, future, allowed_path)
server.timeout = timeout
def serve():
server.handle_request()
server.server_close()
if not future.done():
future.set_result(None)
thread = threading.Thread(target=serve, daemon=True)
thread.start()
return future
def decode_jwt(token):
try:
header, payload, signature = token.split(".")
def decode_part(part):
padded = part + "=" * (-len(part) % 4) # Fix padding
decoded_bytes = base64.urlsafe_b64decode(padded)
return json.loads(decoded_bytes)
decoded_header = decode_part(header)
decoded_payload = decode_part(payload)
return {
"header": decoded_header,
"payload": decoded_payload,
"signature": signature,
}
except:
return token
def encode_jwt(jwt):
def encode_part(part):
json_bytes = json.dumps(part, separators=(",", ":")).encode()
return base64.urlsafe_b64encode(json_bytes).decode().rstrip("=")
return ".".join(
[
encode_part(jwt["header"]),
encode_part(jwt["payload"]),
jwt["signature"], # Use the original signature without modification
]
)
def start_authenticator(uri, username, password):
auth = OIDCAuthenticator()
auth.authenticate(uri, username, password)
return
def start_browser(uri):
print(
"You will be redirected to your browser to log in. Please complete the authentication process."
)
input("Press Enter to continue...")
# Check if Chrome exists and use it; otherwise, fall back to the default browser
chrome_path = shutil.which("google-chrome")
if chrome_path:
print("Launching Chrome in incognito mode...")
subprocess.Popen([chrome_path, "--new-window", "--incognito", uri])
else:
print("Chrome not found. Falling back to the default browser.")
webbrowser.open(uri)
def read_file(path):
with open(path, "r") as file:
return file.read().strip()
def write_file(path, content):
with open(path, "w") as file:
file.write(content)
class MyCallback(OIDCCallback):
def __init__(
self,
redirect_uri: str,
credentials: tuple = None,
use_id_token=False,
tamper_jwt=False,
jwt_path: str = None,
client_auth_config: str = None,
):
super().__init__()
self.credentials = credentials
self.tamper_jwt = tamper_jwt
self.use_id_token = use_id_token
self.uri = urlparse(redirect_uri)
if client_auth_config:
self.jwt = self.client_auth(client_auth_config)
if jwt_path:
write_file(jwt_path, self.jwt)
elif jwt_path:
self.jwt = read_file(jwt_path)
else:
self.jwt = None
def client_auth(self, client_auth_config: str):
data = json.load(open(client_auth_config, "r"))
issuer = data.get("issuer")
client_id = data.get("client_id")
client_secret = data.get("client_secret")
scope = data.get("scope")
if not issuer or not client_id or not client_secret:
raise ValueError("Invalid client authentication configuration")
oath_metadata = requests.get(
f"{issuer}/.well-known/openid-configuration", verify=False
).json()
response = requests.post(
oath_metadata["token_endpoint"],
data={
"grant_type": "client_credentials",
"scope": scope,
},
auth=HTTPBasicAuth(client_id, client_secret),
)
response.raise_for_status() # Raise an error if the request fails
token_result = response.json()
access_token = decode_jwt(token_result["access_token"])
print("Access Token:", json.dumps(access_token, indent=1))
return (
token_result["access_token"]
if not self.use_id_token
else token_result["id_token"]
)
def get_uri(self, uri, clientId, code_verifier, nonce, scopes) -> dict:
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.rstrip(b"=")
.decode()
)
scope = "openid profile offline_access " + " ".join(scopes)
print("Requesting scopes:", scope)
params = {
"client_id": clientId,
"response_type": "code",
"scope": scope,
"redirect_uri": self.uri.geturl(),
"state": "somestate",
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"nonce": nonce,
}
return f"{uri}?{urlencode(params)}"
def run_code_flow(self, auth_uri, context, code_verifier, nonce):
future = start_server(
(self.uri.hostname, self.uri.port), context.timeout_seconds, self.uri.path
)
uri = self.get_uri(
auth_uri,
context.idp_info.clientId,
code_verifier,
nonce,
scopes=context.idp_info.requestScopes or [],
)
if self.credentials:
start_authenticator(uri, self.credentials[0], self.credentials[1])
else:
start_browser(uri)
code_result = future.result()
if code_result is None:
raise Exception("Timeout waiting for callback")
if "error" in code_result:
raise Exception(f"Error in callback: {code_result['error'][0]}")
if "code" not in code_result:
raise Exception("Authorization code not found in callback")
if "state" not in code_result:
raise Exception("State not found in callback")
return code_result
def get_oath_metadata(self, issuer):
response = requests.get(
f"{issuer}/.well-known/openid-configuration", verify=False
)
response.raise_for_status()
return response.json()
def get_jwks(self, uri):
response = requests.get(uri, verify=False)
response.raise_for_status()
return response.json()
def get_tokens(self, uri, context, code, code_verifier):
response = requests.post(
uri,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.uri.geturl(),
"client_id": context.idp_info.clientId,
"code_verifier": code_verifier,
},
verify=False,
)
response.raise_for_status() # Raise an error if the request fails
token_result = response.json()
print(token_result)
id_token = decode_jwt(token_result["id_token"])
access_token = decode_jwt(token_result["access_token"])
refresh_token = token_result.get("refresh_token")
expires_in = token_result.get("expires_in")
return (id_token, access_token, refresh_token, expires_in, token_result)
def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
if self.jwt is not None:
return OIDCCallbackResult(
access_token=self.jwt,
expires_in_seconds=None,
refresh_token=None,
)
code_verifier = secrets.token_urlsafe(64)
print("Code verifier:", code_verifier)
nonce = secrets.token_urlsafe(64)
print("Nonce:", nonce)
print(context)
print(context.idp_info)
oath_metadata = self.get_oath_metadata(context.idp_info.issuer)
print("OAuth Metadata:", json.dumps(oath_metadata, indent=1))
jwks = self.get_jwks(oath_metadata["jwks_uri"])
print("JWKS:", json.dumps(jwks, indent=1))
code_result = self.run_code_flow(
oath_metadata["authorization_endpoint"], context, code_verifier, nonce
)
print("Authorization code:", code_result["code"][0])
print("State:", code_result["state"][0])
id_token, access_token, refresh_token, expires_in, token_result = (
self.get_tokens(
oath_metadata["token_endpoint"],
context,
code_result["code"][0],
code_verifier,
)
)
print("Access Token:", json.dumps(access_token, indent=1))
print("Refresh Token:", json.dumps(refresh_token, indent=1))
print("ID Token:", json.dumps(id_token, indent=1))
print("Expires In:", expires_in)
token_jwt = (
token_result["access_token"]
if not self.use_id_token
else token_result["id_token"]
)
if self.tamper_jwt:
token_jwt = tamper_jwt(
token_jwt,
{"auth_claim": ["root", "admin"], "MyClaim": ["root", "admin"]},
)
print(
"Tampered JWT: ",
json.dumps(decode_jwt(token_jwt), indent=1),
)
print("Token exchange successful")
return OIDCCallbackResult(
access_token=token_jwt,
expires_in_seconds=expires_in,
refresh_token=refresh_token,
)
def build_parser():
import argparse
parser = argparse.ArgumentParser(
description="MongoDB OIDC Authentication Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--host",
default="localhost",
help="MongoDB host",
)
parser.add_argument(
"--port",
type=int,
default=27017,
help="MongoDB port",
)
parser.add_argument(
"--username",
type=str,
default=None,
help="Username for MongoDB authentication mechanism",
)
parser.add_argument(
"--login",
type=str,
default=None,
help="<username>:<password> for IdP login",
)
parser.add_argument(
"--redirect_uri",
default="http://localhost:27097/redirect",
help="Redirect URI for OIDC callback",
)
parser.add_argument(
"--tls",
type=bool,
default=False,
help="Use TLS for MongoDB connection",
)
parser.add_argument(
"-i",
"--interactive",
action="store_true",
help="Run in interactive mode",
)
parser.add_argument(
"--tamper-jwt",
action="store_true",
help="Tamper JWT to add some additional claims",
)
parser.add_argument(
"--use-id-token",
action="store_true",
help="Use ID Token instead of access token",
)
parser.add_argument(
"--jwt-path",
type=str,
help="Path to JWT token for non-human authentication flows",
)
parser.add_argument(
"--client-auth-config",
type=str,
help="Path to client authentication configuration file",
)
return parser
SHELL_BANNER = """
Running in interactive mode. Type 'exit()' to quit.
You can use the 'client' object to interact with MongoDB.
Tab completion is enabled.
"""
def parse_login(login):
if login is None:
return None
try:
if ":" not in login:
username = login
password = getpass.getpass(f"Enter password for '{username}': ")
else:
username, password = login.split(":")
except ValueError:
raise ValueError("Login must be in the format <username>:<password>")
return username, password
def main():
parser = build_parser()
args = parser.parse_args()
print("MongoDB OIDC Authentication Example")
print("Starting MongoDB client...")
use_human_flows = args.jwt_path is None and args.client_auth_config is None
callback = MyCallback(
redirect_uri=args.redirect_uri,
credentials=parse_login(args.login),
use_id_token=args.use_id_token,
tamper_jwt=args.tamper_jwt,
jwt_path=args.jwt_path,
client_auth_config=args.client_auth_config,
)
authMechanismProperties = {}
if use_human_flows:
authMechanismProperties["OIDC_HUMAN_CALLBACK"] = callback
else:
authMechanismProperties["OIDC_CALLBACK"] = callback
client = MongoClient(
host=args.host,
port=args.port,
connect=True,
tls=args.tls,
authMechanism="MONGODB-OIDC",
username=args.username,
authMechanismProperties=authMechanismProperties,
)
try:
status = client.admin.command({"connectionStatus": 1})
print(f"Connection status: {json.dumps(status, indent=1)}")
except Exception as e:
print("Connection failed:", e)
tb = traceback.format_exc()
print("Traceback:\n", tb)
exit(0)
print("Connection successful")
if args.interactive:
readline.set_completer(rlcompleter.Completer(globals()).complete)
readline.parse_and_bind("tab: complete")
variables = globals().copy()
variables.update(locals())
code.interact(banner=SHELL_BANNER, local=variables)
print("Exiting...")
client.close()
if __name__ == "__main__":
main()
pymongo
requests
selenium
webdriver-manager
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment