Last active
March 19, 2025 04:22
-
-
Save dexhunter/38c1ce0080717f13f03464cf56511e2c to your computer and use it in GitHub Desktop.
Full example for fastapi integration with mcp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simple example of using MCP to add an MCP server to a FastAPI app. | |
""" | |
import inspect | |
import json | |
import re | |
from typing import List, Optional | |
import click | |
import uvicorn | |
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect | |
from pydantic import BaseModel | |
import mcp.types as types | |
from mcp.server.lowlevel import Server | |
# Create a simple FastAPI app | |
app = FastAPI( | |
title="Example API", | |
description="A simple example API with integrated MCP server", | |
version="0.1.0", | |
) | |
# Define some models | |
class Item(BaseModel): | |
id: int | |
name: str | |
description: Optional[str] = None | |
price: float | |
tags: List[str] = [] | |
# In-memory database | |
items_db: dict[int, Item] = {} | |
# Define some endpoints | |
@app.get("/items/", response_model=List[Item], tags=["items"]) | |
async def list_items(skip: int = 0, limit: int = 10): | |
""" | |
List all items in the database. | |
Returns a list of items, with pagination support. | |
""" | |
return list(items_db.values())[skip : skip + limit] | |
@app.get("/items/{item_id}", response_model=Item, tags=["items"]) | |
async def read_item(item_id: int): | |
""" | |
Get a specific item by its ID. | |
Raises a 404 error if the item does not exist. | |
""" | |
if item_id not in items_db: | |
raise HTTPException(status_code=404, detail="Item not found") | |
return items_db[item_id] | |
@app.post("/items/", response_model=Item, tags=["items"]) | |
async def create_item(item: Item): | |
""" | |
Create a new item in the database. | |
Returns the created item with its assigned ID. | |
""" | |
items_db[item.id] = item | |
return item | |
@app.put("/items/{item_id}", response_model=Item, tags=["items"]) | |
async def update_item(item_id: int, item: Item): | |
""" | |
Update an existing item. | |
Raises a 404 error if the item does not exist. | |
""" | |
if item_id not in items_db: | |
raise HTTPException(status_code=404, detail="Item not found") | |
item.id = item_id | |
items_db[item_id] = item | |
return item | |
@app.delete("/items/{item_id}", tags=["items"]) | |
async def delete_item(item_id: int): | |
""" | |
Delete an item from the database. | |
Raises a 404 error if the item does not exist. | |
""" | |
if item_id not in items_db: | |
raise HTTPException(status_code=404, detail="Item not found") | |
del items_db[item_id] | |
return {"message": "Item deleted successfully"} | |
@app.get("/items/search/", response_model=List[Item], tags=["search"]) | |
async def search_items( | |
q: Optional[str] = Query(None, description="Search query string"), | |
min_price: Optional[float] = Query(None, description="Minimum price"), | |
max_price: Optional[float] = Query(None, description="Maximum price"), | |
tags: List[str] = Query([], description="Filter by tags"), | |
): | |
""" | |
Search for items with various filters. | |
Returns a list of items that match the search criteria. | |
""" | |
results = list(items_db.values()) | |
# Filter by search query | |
if q: | |
q = q.lower() | |
results = [ | |
item for item in results if q in item.name.lower() or (item.description and q in item.description.lower()) | |
] | |
# Filter by price range | |
if min_price is not None: | |
results = [item for item in results if item.price >= min_price] | |
if max_price is not None: | |
results = [item for item in results if item.price <= max_price] | |
# Filter by tags | |
if tags: | |
results = [item for item in results if all(tag in item.tags for tag in tags)] | |
return results | |
# Add sample data | |
sample_items = [ | |
Item(id=1, name="Hammer", description="A tool for hammering nails", price=9.99, tags=["tool", "hardware"]), | |
Item(id=2, name="Screwdriver", description="A tool for driving screws", price=7.99, tags=["tool", "hardware"]), | |
Item(id=3, name="Wrench", description="A tool for tightening bolts", price=12.99, tags=["tool", "hardware"]), | |
Item(id=4, name="Saw", description="A tool for cutting wood", price=19.99, tags=["tool", "hardware", "cutting"]), | |
Item(id=5, name="Drill", description="A tool for drilling holes", price=49.99, tags=["tool", "hardware", "power"]), | |
] | |
for item in sample_items: | |
items_db[item.id] = item | |
import anyio | |
import click | |
import httpx | |
import mcp.types as types | |
from mcp.server.lowlevel import Server | |
import asyncio | |
@click.command() | |
@click.option("--port", default=8000, help="Port to listen on for SSE") | |
@click.option( | |
"--transport", | |
type=click.Choice(["stdio", "sse"]), | |
default="stdio", | |
help="Transport type", | |
) | |
def main(port: int, transport: str) -> int: | |
# Create MCP server | |
mcp_app = Server("mcp tool server") | |
def convert_endpoint_to_tool(fastapi_app): | |
""" | |
Convert all endpoints in a FastAPI app to MCP tools. | |
Automatically converts routes to tools based on their metadata. | |
""" | |
tool_list = [] | |
# Helper function to determine field type and properties | |
def get_field_properties(field_name, field_type): | |
if field_type == int: | |
return {"type": "integer", "description": f"{field_name} field"} | |
elif field_type == float: | |
return {"type": "number", "description": f"{field_name} field"} | |
elif field_type == str: | |
return {"type": "string", "description": f"{field_name} field"} | |
elif str(field_type).startswith("typing.List") or str(field_type).startswith("list"): | |
return {"type": "array", "items": {"type": "string"}, "description": f"{field_name} field"} | |
else: | |
return {"type": "string", "description": f"{field_name} field"} | |
for route in fastapi_app.routes: | |
# Skip non-endpoint routes and documentation endpoints | |
if not hasattr(route, "endpoint") or not hasattr(route, "path"): | |
continue | |
if route.path.startswith("/docs") or route.path.startswith("/openapi") or route.path.startswith("/mcp"): | |
continue | |
# Get the original function | |
endpoint_func = route.endpoint | |
# Skip internal FastAPI endpoints | |
if endpoint_func.__name__ in ["openapi", "get_openapi", "swagger_ui_html", "swagger_ui_redirect", "redoc_html", "handle_mcp_connection"]: | |
continue | |
# Get path and HTTP method | |
path = route.path | |
http_method = None | |
if hasattr(route, "methods") and route.methods: | |
# Get the first method from the available methods instead of hardcoding order | |
http_methods = list(route.methods) | |
if http_methods: | |
http_method = http_methods[0] | |
# Skip if we couldn't determine the method | |
if not http_method: | |
print(f"Skipping route {path} - could not determine HTTP method") | |
continue | |
# Create tool name and description | |
tool_name = endpoint_func.__name__ | |
tool_desc = endpoint_func.__doc__ or f"API endpoint: {route.path}" | |
# Examine function signature to build inputSchema | |
sig = inspect.signature(endpoint_func) | |
properties = {} | |
required = [] | |
# Handle different HTTP methods appropriately | |
if http_method == "POST": | |
# For POST endpoints that create resources | |
for param_name, param in sig.parameters.items(): | |
if param_name == "self": | |
continue | |
if hasattr(param.annotation, "model_fields"): | |
# This is a Pydantic model parameter | |
model_fields = param.annotation.model_fields | |
for field_name, field_info in model_fields.items(): | |
field_type = field_info.annotation | |
properties[field_name] = get_field_properties(field_name, field_type) | |
# Check if required | |
if field_info.is_required(): | |
required.append(field_name) | |
elif http_method == "PUT": | |
# For PUT endpoints that update resources | |
path_params = re.findall(r"\{([^}]+)\}", path) | |
for param in path_params: | |
properties[param] = {"type": "integer", "description": f"{param} path parameter"} | |
required.append(param) | |
# Include model fields for the body | |
for param_name, param in sig.parameters.items(): | |
if param_name == "self" or param_name in path_params: | |
continue | |
if hasattr(param.annotation, "model_fields"): | |
# This is a Pydantic model parameter | |
model_fields = param.annotation.model_fields | |
for field_name, field_info in model_fields.items(): | |
field_type = field_info.annotation | |
properties[field_name] = get_field_properties(field_name, field_type) | |
# Check if required | |
if field_info.is_required(): | |
required.append(field_name) | |
elif http_method == "DELETE" or (http_method == "GET" and "{" in path): | |
# For DELETE endpoints or GETs with path parameters | |
path_params = re.findall(r"\{([^}]+)\}", path) | |
for param in path_params: | |
# Look up the parameter in the function signature | |
param_obj = sig.parameters.get(param) | |
field_type = int # Default to int for path params if not specified | |
if param_obj and param_obj.annotation != inspect.Parameter.empty: | |
field_type = param_obj.annotation | |
properties[param] = get_field_properties(param, field_type) | |
# Update description to indicate it's a path parameter | |
properties[param]["description"] = f"{param} path parameter" | |
# Path parameters are typically required | |
required.append(param) | |
elif http_method == "GET": | |
# For GET endpoints with query parameters | |
for param_name, param in sig.parameters.items(): | |
if param_name == "self": | |
continue | |
field_type = param.annotation if param.annotation != inspect.Parameter.empty else str | |
properties[param_name] = get_field_properties(param_name, field_type) | |
# Replace the description to indicate it's a query parameter | |
properties[param_name]["description"] = f"{param_name} query parameter" | |
# Check if parameter is required (no default value) | |
if param.default == inspect.Parameter.empty: | |
required.append(param_name) | |
# Create the tool | |
tool_list.append({ | |
"name": tool_name, | |
"description": tool_desc, | |
"method": http_method, | |
"path": path, | |
"inputSchema": { | |
"type": "object", | |
"properties": properties, | |
"required": required | |
} | |
}) | |
return tool_list | |
# Create a function to register all tools from FastAPI to MCP server | |
def register_api_tools(mcp_server, fastapi_app): | |
"""Register all FastAPI endpoints as MCP tools""" | |
@mcp_server.list_tools() | |
async def list_tools(): | |
"""List all available API endpoints as tools.""" | |
try: | |
tools = convert_endpoint_to_tool(fastapi_app) | |
print(f"list_tools handler called - found {len(tools)} tools") | |
result = [ | |
types.Tool( | |
name=tool["name"], | |
description=tool["description"], | |
inputSchema=tool["inputSchema"] | |
) | |
for tool in tools | |
] | |
print(f"Returning {len(result)} tools to client") | |
return result | |
except Exception as e: | |
import traceback | |
print(f"ERROR in list_tools handler: {e}") | |
traceback.print_exc() | |
# Return an empty list to prevent client errors | |
return [] | |
@mcp_server.call_tool() | |
async def call_tool(name: str, arguments: dict): | |
"""Call an API endpoint as a tool.""" | |
tools = convert_endpoint_to_tool(fastapi_app) | |
print(f"Tool called: {name} with arguments: {arguments}") | |
# Find the matching tool | |
tool = next((t for t in tools if t["name"] == name), None) | |
if not tool: | |
raise ValueError(f"Unknown tool: {name}") | |
# Get the endpoint function and path | |
endpoint_func = None | |
for route in fastapi_app.routes: | |
if hasattr(route, "endpoint") and route.endpoint.__name__ == name: | |
endpoint_func = route.endpoint | |
break | |
if not endpoint_func: | |
raise ValueError(f"Unknown tool: {name}") | |
method = tool["method"] | |
path = tool["path"] | |
# Handle different ways of passing arguments | |
path_params = {} | |
query_params = {} | |
body_params = {} | |
# Extract path parameter names correctly | |
path_param_names = re.findall(r"\{([^}]+)\}", path) | |
print(f"DEBUG: Path parameters from URL: {path_param_names}") | |
# Handle case where a value is provided without a key | |
if arguments and "" in arguments and len(path_param_names) == 1: | |
arguments = {path_param_names[0]: arguments[""]} | |
# Map ID parameter to resource_id parameter if needed | |
if "id" in arguments and len(path_param_names) == 1: | |
param = path_param_names[0] | |
if param.endswith("_id"): | |
arguments[param] = arguments.pop("id") | |
# Separate arguments into appropriate categories | |
for key, value in arguments.items(): | |
if key in path_param_names: | |
path_params[key] = value | |
elif any(param.startswith(key + "_") or param.endswith("_" + key) for param in path_param_names): | |
for param in path_param_names: | |
if param.startswith(key + "_") or param.endswith("_" + key): | |
path_params[param] = value | |
break | |
else: | |
query_params[key] = value | |
body_params[key] = value | |
# If we have a value but no path param set, map it to the path param | |
if len(path_param_names) == 1 and len(path_params) == 0 and len(arguments) == 1: | |
path_params[path_param_names[0]] = next(iter(arguments.values())) | |
try: | |
# Handle different HTTP methods | |
if method == "POST": | |
# For POST endpoints, construct a model object if needed | |
sig = inspect.signature(endpoint_func) | |
for param_name, param in sig.parameters.items(): | |
if param_name == "self": | |
continue | |
if hasattr(param.annotation, "model_fields"): | |
# This is a Pydantic model parameter | |
model_class = param.annotation | |
model_fields = model_class.model_fields | |
# Filter body params to only include model fields | |
model_data = {k: v for k, v in body_params.items() if k in model_fields} | |
# Create the model instance | |
model_instance = model_class(**model_data) | |
# Call with both path parameters and the model | |
result = await endpoint_func(**path_params, **{param_name: model_instance}) | |
break | |
else: | |
# If no model parameter was found | |
result = await endpoint_func(**path_params, **body_params) | |
elif method == "PUT": | |
# For PUT endpoints, need both path parameters and model object | |
sig = inspect.signature(endpoint_func) | |
for param_name, param in sig.parameters.items(): | |
if param_name == "self" or param_name in path_params: | |
continue | |
if hasattr(param.annotation, "model_fields"): | |
# This is a Pydantic model parameter | |
model_class = param.annotation | |
model_fields = model_class.model_fields | |
# Filter body params to only include model fields | |
model_data = {k: v for k, v in body_params.items() if k in model_fields} | |
# Create the model instance | |
model_instance = model_class(**model_data) | |
# Call with both path parameters and the model | |
result = await endpoint_func(**path_params, **{param_name: model_instance}) | |
break | |
else: | |
# If no model parameter was found | |
result = await endpoint_func(**path_params, **body_params) | |
elif method == "DELETE" or (method == "GET" and "{" in path): | |
# DELETE or GET with path parameter - only need path parameters | |
result = await endpoint_func(**path_params) | |
else: | |
# For other GET endpoints, use query parameters | |
result = await endpoint_func(**query_params) | |
# Convert the result to the appropriate MCP content type | |
if isinstance(result, str): | |
return [types.TextContent(type="text", text=result)] | |
else: | |
# For Pydantic models or lists of models | |
if hasattr(result, "model_dump"): | |
result_json = json.dumps(result.model_dump()) | |
elif isinstance(result, list) and all(hasattr(item, "model_dump") for item in result): | |
result_json = json.dumps([item.model_dump() for item in result]) | |
else: | |
# Use regular JSON serialization | |
result_json = json.dumps(result, default=lambda obj: obj.__dict__) | |
return [types.TextContent(type="text", text=result_json)] | |
except Exception as e: | |
# Return the error as text | |
return [types.TextContent( | |
type="text", | |
text=f"Error calling endpoint {path}: {str(e)}" | |
)] | |
return mcp_server | |
# Register API tools from the FastAPI app to the MCP server | |
register_api_tools(mcp_app, app) | |
# Add startup event to check tools after app initialization | |
@app.on_event("startup") | |
async def startup_event(): | |
tools = convert_endpoint_to_tool(app) | |
print(f"App startup: Found {len(tools)} API endpoints to be exposed as tools") | |
# Print details about all FastAPI routes to diagnose the issue | |
print("\nDEBUG - All registered FastAPI routes:") | |
for route in app.routes: | |
if hasattr(route, "endpoint") and hasattr(route, "path"): | |
func_name = route.endpoint.__name__ if hasattr(route.endpoint, "__name__") else "unknown" | |
path = getattr(route, "path", "unknown") | |
methods = getattr(route, "methods", ["unknown"]) | |
print(f" Route: {path} | Methods: {methods} | Function: {func_name}") | |
else: | |
print(f" Other route: {route}") | |
if transport == "sse": | |
from mcp.server.sse import SseServerTransport | |
from fastapi import Request | |
sse = SseServerTransport(f"/mcp/messages/") | |
# Define MCP connection handler | |
async def handle_mcp_connection(request: Request): | |
async with sse.connect_sse(request.scope, request.receive, request._send) as streams: | |
await mcp_app.run( | |
streams[0], | |
streams[1], | |
mcp_app.create_initialization_options(), | |
) | |
# Mount the MCP connection handler | |
app.get("/mcp")(handle_mcp_connection) | |
app.mount(f"/mcp/messages/", app=sse.handle_post_message) | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=port) | |
else: | |
from mcp.server.stdio import stdio_server | |
async def arun(): | |
async with stdio_server() as streams: | |
await mcp_app.run( | |
streams[0], streams[1], mcp_app.create_initialization_options() | |
) | |
anyio.run(arun) | |
return 0 | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
run the script with