Skip to content

Instantly share code, notes, and snippets.

@Jagdeep1
Created February 25, 2026 14:38
Show Gist options
  • Select an option

  • Save Jagdeep1/c51b301c5a2c34e6b62d7c1c3815fa6a to your computer and use it in GitHub Desktop.

Select an option

Save Jagdeep1/c51b301c5a2c34e6b62d7c1c3815fa6a to your computer and use it in GitHub Desktop.
anthropic_bedrock_model.py
"""Custom Strands model provider using Anthropic SDK with AWS Bedrock."""
import base64
import json
import logging
import mimetypes
from collections.abc import AsyncGenerator
from typing import Any, TypedDict, TypeVar, cast
import anthropic
from anthropic import AsyncAnthropicBedrock
from pydantic import BaseModel
from typing_extensions import Required, Unpack, override
from strands.event_loop.streaming import process_stream
from strands.models._validation import validate_config_keys
from strands.models.model import Model
from strands.tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec
from strands.types.content import ContentBlock, Messages
from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException
from strands.types.streaming import StreamEvent
from strands.types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=BaseModel)
class AnthropicBedrockModel(Model):
"""Anthropic Bedrock model provider using AsyncAnthropicBedrock from the Anthropic SDK."""
EVENT_TYPES = {
"message_start",
"content_block_start",
"content_block_delta",
"content_block_stop",
"message_stop",
}
OVERFLOW_MESSAGES = {
"prompt is too long:",
"input is too long",
"input length exceeds context window",
"input and output tokens exceed your context limit",
}
class BedrockConfig(TypedDict, total=False):
"""Configuration options for Anthropic Bedrock models.
Attributes:
model_id: Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0").
max_tokens: Maximum number of tokens to generate.
params: Additional model parameters (e.g., temperature).
"""
model_id: Required[str]
max_tokens: Required[int]
params: dict[str, Any] | None
def __init__(
self,
*,
aws_region: str | None = None,
aws_profile: str | None = None,
aws_access_key: str | None = None,
aws_secret_key: str | None = None,
aws_session_token: str | None = None,
**model_config: Unpack[BedrockConfig],
):
"""Initialize the Anthropic Bedrock model provider.
Args:
aws_region: AWS region for Bedrock (e.g., "us-east-1").
aws_profile: AWS profile name from ~/.aws/credentials.
aws_access_key: AWS access key ID.
aws_secret_key: AWS secret access key.
aws_session_token: AWS session token for temporary credentials.
**model_config: Configuration options for the model (model_id, max_tokens, params).
"""
validate_config_keys(model_config, self.BedrockConfig)
self.config = AnthropicBedrockModel.BedrockConfig(**model_config)
logger.debug("config=<%s> | initializing", self.config)
client_kwargs: dict[str, Any] = {}
if aws_region:
client_kwargs["aws_region"] = aws_region
if aws_profile:
client_kwargs["aws_profile"] = aws_profile
if aws_access_key:
client_kwargs["aws_access_key"] = aws_access_key
if aws_secret_key:
client_kwargs["aws_secret_key"] = aws_secret_key
if aws_session_token:
client_kwargs["aws_session_token"] = aws_session_token
self.client = AsyncAnthropicBedrock(**client_kwargs)
@override
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore[override]
"""Update the model configuration.
Args:
**model_config: Configuration overrides.
"""
validate_config_keys(model_config, self.BedrockConfig)
self.config.update(model_config)
@override
def get_config(self) -> BedrockConfig:
"""Get the model configuration.
Returns:
The model configuration.
"""
return self.config
def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
"""Convert a Strands ContentBlock to Anthropic API content format.
Args:
content: Strands message content block.
Returns:
Anthropic-formatted content block.
Raises:
TypeError: If the content block type is unsupported.
"""
if "document" in content:
mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream")
return {
"source": {
"data": (
content["document"]["source"]["bytes"].decode("utf-8")
if mime_type == "text/plain"
else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8")
),
"media_type": mime_type,
"type": "text" if mime_type == "text/plain" else "base64",
},
"title": content["document"]["name"],
"type": "document",
}
if "image" in content:
return {
"source": {
"data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"),
"media_type": mimetypes.types_map.get(
f".{content['image']['format']}", "application/octet-stream"
),
"type": "base64",
},
"type": "image",
}
if "reasoningContent" in content:
return {
"signature": content["reasoningContent"]["reasoningText"]["signature"],
"thinking": content["reasoningContent"]["reasoningText"]["text"],
"type": "thinking",
}
if "text" in content:
return {"text": content["text"], "type": "text"}
if "toolUse" in content:
return {
"id": content["toolUse"]["toolUseId"],
"input": content["toolUse"]["input"],
"name": content["toolUse"]["name"],
"type": "tool_use",
}
if "toolResult" in content:
return {
"content": [
self._format_request_message_content(
{"text": json.dumps(tool_result_content["json"])}
if "json" in tool_result_content
else cast(ContentBlock, tool_result_content)
)
for tool_result_content in content["toolResult"]["content"]
],
"is_error": content["toolResult"]["status"] == "error",
"tool_use_id": content["toolResult"]["toolUseId"],
"type": "tool_result",
}
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]:
"""Convert Strands Messages list to Anthropic messages format.
Args:
messages: List of Strands message objects.
Returns:
Anthropic-formatted messages list.
"""
formatted_messages = []
for message in messages:
formatted_contents: list[dict[str, Any]] = []
for content in message["content"]:
if "cachePoint" in content:
formatted_contents[-1]["cache_control"] = {"type": "ephemeral"}
continue
formatted_contents.append(self._format_request_message_content(content))
if formatted_contents:
formatted_messages.append({"content": formatted_contents, "role": message["role"]})
return formatted_messages
def format_request(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
tool_choice: ToolChoice | None = None,
) -> dict[str, Any]:
"""Assemble the full Anthropic API request dict.
Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
Returns:
Anthropic API request dictionary.
"""
return {
"max_tokens": self.config["max_tokens"],
"messages": self._format_request_messages(messages),
"model": self.config["model_id"],
"tools": [
{
"name": tool_spec["name"],
"description": tool_spec["description"],
"input_schema": tool_spec["inputSchema"]["json"],
}
for tool_spec in tool_specs or []
],
**(self._format_tool_choice(tool_choice)),
**({"system": system_prompt} if system_prompt else {}),
**(self.config.get("params") or {}),
}
@staticmethod
def _format_tool_choice(tool_choice: ToolChoice | None) -> dict:
if tool_choice is None:
return {}
if "any" in tool_choice:
return {"tool_choice": {"type": "any"}}
elif "auto" in tool_choice:
return {"tool_choice": {"type": "auto"}}
elif "tool" in tool_choice:
return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}}
else:
return {}
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""Convert Anthropic streaming events to Strands StreamEvent dicts.
Args:
event: A response event from the Anthropic API.
Returns:
Strands-formatted stream event.
Raises:
RuntimeError: If the event type is not recognized.
"""
match event["type"]:
case "message_start":
return {"messageStart": {"role": "assistant"}}
case "content_block_start":
content = event["content_block"]
if content["type"] == "tool_use":
return {
"contentBlockStart": {
"contentBlockIndex": event["index"],
"start": {
"toolUse": {
"name": content["name"],
"toolUseId": content["id"],
}
},
}
}
return {"contentBlockStart": {"contentBlockIndex": event["index"], "start": {}}}
case "content_block_delta":
delta = event["delta"]
match delta["type"]:
case "signature_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"reasoningContent": {
"signature": delta["signature"],
},
},
},
}
case "thinking_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"reasoningContent": {
"text": delta["thinking"],
},
},
},
}
case "input_json_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"toolUse": {
"input": delta["partial_json"],
},
},
},
}
case "text_delta":
return {
"contentBlockDelta": {
"contentBlockIndex": event["index"],
"delta": {
"text": delta["text"],
},
},
}
case _:
raise RuntimeError(
f"event_type=<content_block_delta>, delta_type=<{delta['type']}> | unknown type"
)
case "content_block_stop":
return {"contentBlockStop": {"contentBlockIndex": event["index"]}}
case "message_stop":
message = event["message"]
return {"messageStop": {"stopReason": message["stop_reason"]}}
case "metadata":
usage = event["usage"]
return {
"metadata": {
"usage": {
"inputTokens": usage["input_tokens"],
"outputTokens": usage["output_tokens"],
"totalTokens": usage["input_tokens"] + usage["output_tokens"],
},
"metrics": {
"latencyMs": 0,
},
}
}
case _:
raise RuntimeError(f"event_type=<{event['type']} | unknown type")
@override
async def stream(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
*,
tool_choice: ToolChoice | None = None,
**kwargs: Any,
) -> AsyncGenerator[StreamEvent, None]:
"""Stream conversation with Claude via Bedrock.
Args:
messages: List of message objects to be processed by the model.
tool_specs: List of tool specifications to make available to the model.
system_prompt: System prompt to provide context to the model.
tool_choice: Selection strategy for tool invocation.
**kwargs: Additional keyword arguments.
Yields:
Formatted stream events from the model.
Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the request is throttled.
"""
logger.debug("formatting request")
request = self.format_request(messages, tool_specs, system_prompt, tool_choice)
logger.debug("request=<%s>", request)
logger.debug("invoking model")
try:
async with self.client.messages.stream(**request) as stream:
logger.debug("got response from model")
async for event in stream:
if event.type in AnthropicBedrockModel.EVENT_TYPES:
yield self.format_chunk(event.model_dump())
usage = event.message.usage # type: ignore
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
except anthropic.BadRequestError as error:
if any(
overflow_message in str(error).lower() for overflow_message in AnthropicBedrockModel.OVERFLOW_MESSAGES
):
raise ContextWindowOverflowException(str(error)) from error
raise error
logger.debug("finished streaming response from model")
@override
async def structured_output(
self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any
) -> AsyncGenerator[dict[str, T | Any], None]:
"""Get structured output from the model.
Args:
output_model: Pydantic model class for the expected output.
prompt: The prompt messages.
system_prompt: System prompt to provide context to the model.
**kwargs: Additional keyword arguments.
Yields:
Model events with the last being the structured output.
"""
tool_spec = convert_pydantic_to_tool_spec(output_model)
response = self.stream(
messages=prompt,
tool_specs=[tool_spec],
system_prompt=system_prompt,
tool_choice=cast(ToolChoice, {"any": {}}),
**kwargs,
)
async for event in process_stream(response):
yield event
stop_reason, messages, _, _ = event["stop"]
if stop_reason != "tool_use":
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
content = messages["content"]
output_response: dict[str, Any] | None = None
for block in content:
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
output_response = block["toolUse"]["input"]
else:
continue
if output_response is None:
raise ValueError("No valid tool use or tool use input was found in the response.")
yield {"output": output_model(**output_response)}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment