Skip to content

Instantly share code, notes, and snippets.

@dexhunter
Last active March 19, 2025 04:22
Show Gist options
  • Save dexhunter/38c1ce0080717f13f03464cf56511e2c to your computer and use it in GitHub Desktop.
Save dexhunter/38c1ce0080717f13f03464cf56511e2c to your computer and use it in GitHub Desktop.
Full example for fastapi integration with mcp
"""
Simple example of using MCP to add an MCP server to a FastAPI app.
"""
import inspect
import json
import re
from typing import List, Optional
import click
import uvicorn
from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
import mcp.types as types
from mcp.server.lowlevel import Server
# Create a simple FastAPI app
app = FastAPI(
title="Example API",
description="A simple example API with integrated MCP server",
version="0.1.0",
)
# Define some models
class Item(BaseModel):
id: int
name: str
description: Optional[str] = None
price: float
tags: List[str] = []
# In-memory database
items_db: dict[int, Item] = {}
# Define some endpoints
@app.get("/items/", response_model=List[Item], tags=["items"])
async def list_items(skip: int = 0, limit: int = 10):
"""
List all items in the database.
Returns a list of items, with pagination support.
"""
return list(items_db.values())[skip : skip + limit]
@app.get("/items/{item_id}", response_model=Item, tags=["items"])
async def read_item(item_id: int):
"""
Get a specific item by its ID.
Raises a 404 error if the item does not exist.
"""
if item_id not in items_db:
raise HTTPException(status_code=404, detail="Item not found")
return items_db[item_id]
@app.post("/items/", response_model=Item, tags=["items"])
async def create_item(item: Item):
"""
Create a new item in the database.
Returns the created item with its assigned ID.
"""
items_db[item.id] = item
return item
@app.put("/items/{item_id}", response_model=Item, tags=["items"])
async def update_item(item_id: int, item: Item):
"""
Update an existing item.
Raises a 404 error if the item does not exist.
"""
if item_id not in items_db:
raise HTTPException(status_code=404, detail="Item not found")
item.id = item_id
items_db[item_id] = item
return item
@app.delete("/items/{item_id}", tags=["items"])
async def delete_item(item_id: int):
"""
Delete an item from the database.
Raises a 404 error if the item does not exist.
"""
if item_id not in items_db:
raise HTTPException(status_code=404, detail="Item not found")
del items_db[item_id]
return {"message": "Item deleted successfully"}
@app.get("/items/search/", response_model=List[Item], tags=["search"])
async def search_items(
q: Optional[str] = Query(None, description="Search query string"),
min_price: Optional[float] = Query(None, description="Minimum price"),
max_price: Optional[float] = Query(None, description="Maximum price"),
tags: List[str] = Query([], description="Filter by tags"),
):
"""
Search for items with various filters.
Returns a list of items that match the search criteria.
"""
results = list(items_db.values())
# Filter by search query
if q:
q = q.lower()
results = [
item for item in results if q in item.name.lower() or (item.description and q in item.description.lower())
]
# Filter by price range
if min_price is not None:
results = [item for item in results if item.price >= min_price]
if max_price is not None:
results = [item for item in results if item.price <= max_price]
# Filter by tags
if tags:
results = [item for item in results if all(tag in item.tags for tag in tags)]
return results
# Add sample data
sample_items = [
Item(id=1, name="Hammer", description="A tool for hammering nails", price=9.99, tags=["tool", "hardware"]),
Item(id=2, name="Screwdriver", description="A tool for driving screws", price=7.99, tags=["tool", "hardware"]),
Item(id=3, name="Wrench", description="A tool for tightening bolts", price=12.99, tags=["tool", "hardware"]),
Item(id=4, name="Saw", description="A tool for cutting wood", price=19.99, tags=["tool", "hardware", "cutting"]),
Item(id=5, name="Drill", description="A tool for drilling holes", price=49.99, tags=["tool", "hardware", "power"]),
]
for item in sample_items:
items_db[item.id] = item
import anyio
import click
import httpx
import mcp.types as types
from mcp.server.lowlevel import Server
import asyncio
@click.command()
@click.option("--port", default=8000, help="Port to listen on for SSE")
@click.option(
"--transport",
type=click.Choice(["stdio", "sse"]),
default="stdio",
help="Transport type",
)
def main(port: int, transport: str) -> int:
# Create MCP server
mcp_app = Server("mcp tool server")
def convert_endpoint_to_tool(fastapi_app):
"""
Convert all endpoints in a FastAPI app to MCP tools.
Automatically converts routes to tools based on their metadata.
"""
tool_list = []
# Helper function to determine field type and properties
def get_field_properties(field_name, field_type):
if field_type == int:
return {"type": "integer", "description": f"{field_name} field"}
elif field_type == float:
return {"type": "number", "description": f"{field_name} field"}
elif field_type == str:
return {"type": "string", "description": f"{field_name} field"}
elif str(field_type).startswith("typing.List") or str(field_type).startswith("list"):
return {"type": "array", "items": {"type": "string"}, "description": f"{field_name} field"}
else:
return {"type": "string", "description": f"{field_name} field"}
for route in fastapi_app.routes:
# Skip non-endpoint routes and documentation endpoints
if not hasattr(route, "endpoint") or not hasattr(route, "path"):
continue
if route.path.startswith("/docs") or route.path.startswith("/openapi") or route.path.startswith("/mcp"):
continue
# Get the original function
endpoint_func = route.endpoint
# Skip internal FastAPI endpoints
if endpoint_func.__name__ in ["openapi", "get_openapi", "swagger_ui_html", "swagger_ui_redirect", "redoc_html", "handle_mcp_connection"]:
continue
# Get path and HTTP method
path = route.path
http_method = None
if hasattr(route, "methods") and route.methods:
# Get the first method from the available methods instead of hardcoding order
http_methods = list(route.methods)
if http_methods:
http_method = http_methods[0]
# Skip if we couldn't determine the method
if not http_method:
print(f"Skipping route {path} - could not determine HTTP method")
continue
# Create tool name and description
tool_name = endpoint_func.__name__
tool_desc = endpoint_func.__doc__ or f"API endpoint: {route.path}"
# Examine function signature to build inputSchema
sig = inspect.signature(endpoint_func)
properties = {}
required = []
# Handle different HTTP methods appropriately
if http_method == "POST":
# For POST endpoints that create resources
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
if hasattr(param.annotation, "model_fields"):
# This is a Pydantic model parameter
model_fields = param.annotation.model_fields
for field_name, field_info in model_fields.items():
field_type = field_info.annotation
properties[field_name] = get_field_properties(field_name, field_type)
# Check if required
if field_info.is_required():
required.append(field_name)
elif http_method == "PUT":
# For PUT endpoints that update resources
path_params = re.findall(r"\{([^}]+)\}", path)
for param in path_params:
properties[param] = {"type": "integer", "description": f"{param} path parameter"}
required.append(param)
# Include model fields for the body
for param_name, param in sig.parameters.items():
if param_name == "self" or param_name in path_params:
continue
if hasattr(param.annotation, "model_fields"):
# This is a Pydantic model parameter
model_fields = param.annotation.model_fields
for field_name, field_info in model_fields.items():
field_type = field_info.annotation
properties[field_name] = get_field_properties(field_name, field_type)
# Check if required
if field_info.is_required():
required.append(field_name)
elif http_method == "DELETE" or (http_method == "GET" and "{" in path):
# For DELETE endpoints or GETs with path parameters
path_params = re.findall(r"\{([^}]+)\}", path)
for param in path_params:
# Look up the parameter in the function signature
param_obj = sig.parameters.get(param)
field_type = int # Default to int for path params if not specified
if param_obj and param_obj.annotation != inspect.Parameter.empty:
field_type = param_obj.annotation
properties[param] = get_field_properties(param, field_type)
# Update description to indicate it's a path parameter
properties[param]["description"] = f"{param} path parameter"
# Path parameters are typically required
required.append(param)
elif http_method == "GET":
# For GET endpoints with query parameters
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
field_type = param.annotation if param.annotation != inspect.Parameter.empty else str
properties[param_name] = get_field_properties(param_name, field_type)
# Replace the description to indicate it's a query parameter
properties[param_name]["description"] = f"{param_name} query parameter"
# Check if parameter is required (no default value)
if param.default == inspect.Parameter.empty:
required.append(param_name)
# Create the tool
tool_list.append({
"name": tool_name,
"description": tool_desc,
"method": http_method,
"path": path,
"inputSchema": {
"type": "object",
"properties": properties,
"required": required
}
})
return tool_list
# Create a function to register all tools from FastAPI to MCP server
def register_api_tools(mcp_server, fastapi_app):
"""Register all FastAPI endpoints as MCP tools"""
@mcp_server.list_tools()
async def list_tools():
"""List all available API endpoints as tools."""
try:
tools = convert_endpoint_to_tool(fastapi_app)
print(f"list_tools handler called - found {len(tools)} tools")
result = [
types.Tool(
name=tool["name"],
description=tool["description"],
inputSchema=tool["inputSchema"]
)
for tool in tools
]
print(f"Returning {len(result)} tools to client")
return result
except Exception as e:
import traceback
print(f"ERROR in list_tools handler: {e}")
traceback.print_exc()
# Return an empty list to prevent client errors
return []
@mcp_server.call_tool()
async def call_tool(name: str, arguments: dict):
"""Call an API endpoint as a tool."""
tools = convert_endpoint_to_tool(fastapi_app)
print(f"Tool called: {name} with arguments: {arguments}")
# Find the matching tool
tool = next((t for t in tools if t["name"] == name), None)
if not tool:
raise ValueError(f"Unknown tool: {name}")
# Get the endpoint function and path
endpoint_func = None
for route in fastapi_app.routes:
if hasattr(route, "endpoint") and route.endpoint.__name__ == name:
endpoint_func = route.endpoint
break
if not endpoint_func:
raise ValueError(f"Unknown tool: {name}")
method = tool["method"]
path = tool["path"]
# Handle different ways of passing arguments
path_params = {}
query_params = {}
body_params = {}
# Extract path parameter names correctly
path_param_names = re.findall(r"\{([^}]+)\}", path)
print(f"DEBUG: Path parameters from URL: {path_param_names}")
# Handle case where a value is provided without a key
if arguments and "" in arguments and len(path_param_names) == 1:
arguments = {path_param_names[0]: arguments[""]}
# Map ID parameter to resource_id parameter if needed
if "id" in arguments and len(path_param_names) == 1:
param = path_param_names[0]
if param.endswith("_id"):
arguments[param] = arguments.pop("id")
# Separate arguments into appropriate categories
for key, value in arguments.items():
if key in path_param_names:
path_params[key] = value
elif any(param.startswith(key + "_") or param.endswith("_" + key) for param in path_param_names):
for param in path_param_names:
if param.startswith(key + "_") or param.endswith("_" + key):
path_params[param] = value
break
else:
query_params[key] = value
body_params[key] = value
# If we have a value but no path param set, map it to the path param
if len(path_param_names) == 1 and len(path_params) == 0 and len(arguments) == 1:
path_params[path_param_names[0]] = next(iter(arguments.values()))
try:
# Handle different HTTP methods
if method == "POST":
# For POST endpoints, construct a model object if needed
sig = inspect.signature(endpoint_func)
for param_name, param in sig.parameters.items():
if param_name == "self":
continue
if hasattr(param.annotation, "model_fields"):
# This is a Pydantic model parameter
model_class = param.annotation
model_fields = model_class.model_fields
# Filter body params to only include model fields
model_data = {k: v for k, v in body_params.items() if k in model_fields}
# Create the model instance
model_instance = model_class(**model_data)
# Call with both path parameters and the model
result = await endpoint_func(**path_params, **{param_name: model_instance})
break
else:
# If no model parameter was found
result = await endpoint_func(**path_params, **body_params)
elif method == "PUT":
# For PUT endpoints, need both path parameters and model object
sig = inspect.signature(endpoint_func)
for param_name, param in sig.parameters.items():
if param_name == "self" or param_name in path_params:
continue
if hasattr(param.annotation, "model_fields"):
# This is a Pydantic model parameter
model_class = param.annotation
model_fields = model_class.model_fields
# Filter body params to only include model fields
model_data = {k: v for k, v in body_params.items() if k in model_fields}
# Create the model instance
model_instance = model_class(**model_data)
# Call with both path parameters and the model
result = await endpoint_func(**path_params, **{param_name: model_instance})
break
else:
# If no model parameter was found
result = await endpoint_func(**path_params, **body_params)
elif method == "DELETE" or (method == "GET" and "{" in path):
# DELETE or GET with path parameter - only need path parameters
result = await endpoint_func(**path_params)
else:
# For other GET endpoints, use query parameters
result = await endpoint_func(**query_params)
# Convert the result to the appropriate MCP content type
if isinstance(result, str):
return [types.TextContent(type="text", text=result)]
else:
# For Pydantic models or lists of models
if hasattr(result, "model_dump"):
result_json = json.dumps(result.model_dump())
elif isinstance(result, list) and all(hasattr(item, "model_dump") for item in result):
result_json = json.dumps([item.model_dump() for item in result])
else:
# Use regular JSON serialization
result_json = json.dumps(result, default=lambda obj: obj.__dict__)
return [types.TextContent(type="text", text=result_json)]
except Exception as e:
# Return the error as text
return [types.TextContent(
type="text",
text=f"Error calling endpoint {path}: {str(e)}"
)]
return mcp_server
# Register API tools from the FastAPI app to the MCP server
register_api_tools(mcp_app, app)
# Add startup event to check tools after app initialization
@app.on_event("startup")
async def startup_event():
tools = convert_endpoint_to_tool(app)
print(f"App startup: Found {len(tools)} API endpoints to be exposed as tools")
# Print details about all FastAPI routes to diagnose the issue
print("\nDEBUG - All registered FastAPI routes:")
for route in app.routes:
if hasattr(route, "endpoint") and hasattr(route, "path"):
func_name = route.endpoint.__name__ if hasattr(route.endpoint, "__name__") else "unknown"
path = getattr(route, "path", "unknown")
methods = getattr(route, "methods", ["unknown"])
print(f" Route: {path} | Methods: {methods} | Function: {func_name}")
else:
print(f" Other route: {route}")
if transport == "sse":
from mcp.server.sse import SseServerTransport
from fastapi import Request
sse = SseServerTransport(f"/mcp/messages/")
# Define MCP connection handler
async def handle_mcp_connection(request: Request):
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await mcp_app.run(
streams[0],
streams[1],
mcp_app.create_initialization_options(),
)
# Mount the MCP connection handler
app.get("/mcp")(handle_mcp_connection)
app.mount(f"/mcp/messages/", app=sse.handle_post_message)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=port)
else:
from mcp.server.stdio import stdio_server
async def arun():
async with stdio_server() as streams:
await mcp_app.run(
streams[0], streams[1], mcp_app.create_initialization_options()
)
anyio.run(arun)
return 0
if __name__ == "__main__":
main()
@dexhunter
Copy link
Author

dexhunter commented Mar 19, 2025

run the script with

python simple_server.py --transport sse --port 8000

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment