Created
February 25, 2026 14:38
-
-
Save Jagdeep1/c51b301c5a2c34e6b62d7c1c3815fa6a to your computer and use it in GitHub Desktop.
anthropic_bedrock_model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Custom Strands model provider using Anthropic SDK with AWS Bedrock.""" | |
| import base64 | |
| import json | |
| import logging | |
| import mimetypes | |
| from collections.abc import AsyncGenerator | |
| from typing import Any, TypedDict, TypeVar, cast | |
| import anthropic | |
| from anthropic import AsyncAnthropicBedrock | |
| from pydantic import BaseModel | |
| from typing_extensions import Required, Unpack, override | |
| from strands.event_loop.streaming import process_stream | |
| from strands.models._validation import validate_config_keys | |
| from strands.models.model import Model | |
| from strands.tools.structured_output.structured_output_utils import convert_pydantic_to_tool_spec | |
| from strands.types.content import ContentBlock, Messages | |
| from strands.types.exceptions import ContextWindowOverflowException, ModelThrottledException | |
| from strands.types.streaming import StreamEvent | |
| from strands.types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T", bound=BaseModel) | |
| class AnthropicBedrockModel(Model): | |
| """Anthropic Bedrock model provider using AsyncAnthropicBedrock from the Anthropic SDK.""" | |
| EVENT_TYPES = { | |
| "message_start", | |
| "content_block_start", | |
| "content_block_delta", | |
| "content_block_stop", | |
| "message_stop", | |
| } | |
| OVERFLOW_MESSAGES = { | |
| "prompt is too long:", | |
| "input is too long", | |
| "input length exceeds context window", | |
| "input and output tokens exceed your context limit", | |
| } | |
| class BedrockConfig(TypedDict, total=False): | |
| """Configuration options for Anthropic Bedrock models. | |
| Attributes: | |
| model_id: Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0"). | |
| max_tokens: Maximum number of tokens to generate. | |
| params: Additional model parameters (e.g., temperature). | |
| """ | |
| model_id: Required[str] | |
| max_tokens: Required[int] | |
| params: dict[str, Any] | None | |
| def __init__( | |
| self, | |
| *, | |
| aws_region: str | None = None, | |
| aws_profile: str | None = None, | |
| aws_access_key: str | None = None, | |
| aws_secret_key: str | None = None, | |
| aws_session_token: str | None = None, | |
| **model_config: Unpack[BedrockConfig], | |
| ): | |
| """Initialize the Anthropic Bedrock model provider. | |
| Args: | |
| aws_region: AWS region for Bedrock (e.g., "us-east-1"). | |
| aws_profile: AWS profile name from ~/.aws/credentials. | |
| aws_access_key: AWS access key ID. | |
| aws_secret_key: AWS secret access key. | |
| aws_session_token: AWS session token for temporary credentials. | |
| **model_config: Configuration options for the model (model_id, max_tokens, params). | |
| """ | |
| validate_config_keys(model_config, self.BedrockConfig) | |
| self.config = AnthropicBedrockModel.BedrockConfig(**model_config) | |
| logger.debug("config=<%s> | initializing", self.config) | |
| client_kwargs: dict[str, Any] = {} | |
| if aws_region: | |
| client_kwargs["aws_region"] = aws_region | |
| if aws_profile: | |
| client_kwargs["aws_profile"] = aws_profile | |
| if aws_access_key: | |
| client_kwargs["aws_access_key"] = aws_access_key | |
| if aws_secret_key: | |
| client_kwargs["aws_secret_key"] = aws_secret_key | |
| if aws_session_token: | |
| client_kwargs["aws_session_token"] = aws_session_token | |
| self.client = AsyncAnthropicBedrock(**client_kwargs) | |
| @override | |
| def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore[override] | |
| """Update the model configuration. | |
| Args: | |
| **model_config: Configuration overrides. | |
| """ | |
| validate_config_keys(model_config, self.BedrockConfig) | |
| self.config.update(model_config) | |
| @override | |
| def get_config(self) -> BedrockConfig: | |
| """Get the model configuration. | |
| Returns: | |
| The model configuration. | |
| """ | |
| return self.config | |
| def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: | |
| """Convert a Strands ContentBlock to Anthropic API content format. | |
| Args: | |
| content: Strands message content block. | |
| Returns: | |
| Anthropic-formatted content block. | |
| Raises: | |
| TypeError: If the content block type is unsupported. | |
| """ | |
| if "document" in content: | |
| mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") | |
| return { | |
| "source": { | |
| "data": ( | |
| content["document"]["source"]["bytes"].decode("utf-8") | |
| if mime_type == "text/plain" | |
| else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") | |
| ), | |
| "media_type": mime_type, | |
| "type": "text" if mime_type == "text/plain" else "base64", | |
| }, | |
| "title": content["document"]["name"], | |
| "type": "document", | |
| } | |
| if "image" in content: | |
| return { | |
| "source": { | |
| "data": base64.b64encode(content["image"]["source"]["bytes"]).decode("utf-8"), | |
| "media_type": mimetypes.types_map.get( | |
| f".{content['image']['format']}", "application/octet-stream" | |
| ), | |
| "type": "base64", | |
| }, | |
| "type": "image", | |
| } | |
| if "reasoningContent" in content: | |
| return { | |
| "signature": content["reasoningContent"]["reasoningText"]["signature"], | |
| "thinking": content["reasoningContent"]["reasoningText"]["text"], | |
| "type": "thinking", | |
| } | |
| if "text" in content: | |
| return {"text": content["text"], "type": "text"} | |
| if "toolUse" in content: | |
| return { | |
| "id": content["toolUse"]["toolUseId"], | |
| "input": content["toolUse"]["input"], | |
| "name": content["toolUse"]["name"], | |
| "type": "tool_use", | |
| } | |
| if "toolResult" in content: | |
| return { | |
| "content": [ | |
| self._format_request_message_content( | |
| {"text": json.dumps(tool_result_content["json"])} | |
| if "json" in tool_result_content | |
| else cast(ContentBlock, tool_result_content) | |
| ) | |
| for tool_result_content in content["toolResult"]["content"] | |
| ], | |
| "is_error": content["toolResult"]["status"] == "error", | |
| "tool_use_id": content["toolResult"]["toolUseId"], | |
| "type": "tool_result", | |
| } | |
| raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") | |
| def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: | |
| """Convert Strands Messages list to Anthropic messages format. | |
| Args: | |
| messages: List of Strands message objects. | |
| Returns: | |
| Anthropic-formatted messages list. | |
| """ | |
| formatted_messages = [] | |
| for message in messages: | |
| formatted_contents: list[dict[str, Any]] = [] | |
| for content in message["content"]: | |
| if "cachePoint" in content: | |
| formatted_contents[-1]["cache_control"] = {"type": "ephemeral"} | |
| continue | |
| formatted_contents.append(self._format_request_message_content(content)) | |
| if formatted_contents: | |
| formatted_messages.append({"content": formatted_contents, "role": message["role"]}) | |
| return formatted_messages | |
| def format_request( | |
| self, | |
| messages: Messages, | |
| tool_specs: list[ToolSpec] | None = None, | |
| system_prompt: str | None = None, | |
| tool_choice: ToolChoice | None = None, | |
| ) -> dict[str, Any]: | |
| """Assemble the full Anthropic API request dict. | |
| Args: | |
| messages: List of message objects to be processed by the model. | |
| tool_specs: List of tool specifications to make available to the model. | |
| system_prompt: System prompt to provide context to the model. | |
| tool_choice: Selection strategy for tool invocation. | |
| Returns: | |
| Anthropic API request dictionary. | |
| """ | |
| return { | |
| "max_tokens": self.config["max_tokens"], | |
| "messages": self._format_request_messages(messages), | |
| "model": self.config["model_id"], | |
| "tools": [ | |
| { | |
| "name": tool_spec["name"], | |
| "description": tool_spec["description"], | |
| "input_schema": tool_spec["inputSchema"]["json"], | |
| } | |
| for tool_spec in tool_specs or [] | |
| ], | |
| **(self._format_tool_choice(tool_choice)), | |
| **({"system": system_prompt} if system_prompt else {}), | |
| **(self.config.get("params") or {}), | |
| } | |
| @staticmethod | |
| def _format_tool_choice(tool_choice: ToolChoice | None) -> dict: | |
| if tool_choice is None: | |
| return {} | |
| if "any" in tool_choice: | |
| return {"tool_choice": {"type": "any"}} | |
| elif "auto" in tool_choice: | |
| return {"tool_choice": {"type": "auto"}} | |
| elif "tool" in tool_choice: | |
| return {"tool_choice": {"type": "tool", "name": cast(ToolChoiceToolDict, tool_choice)["tool"]["name"]}} | |
| else: | |
| return {} | |
| def format_chunk(self, event: dict[str, Any]) -> StreamEvent: | |
| """Convert Anthropic streaming events to Strands StreamEvent dicts. | |
| Args: | |
| event: A response event from the Anthropic API. | |
| Returns: | |
| Strands-formatted stream event. | |
| Raises: | |
| RuntimeError: If the event type is not recognized. | |
| """ | |
| match event["type"]: | |
| case "message_start": | |
| return {"messageStart": {"role": "assistant"}} | |
| case "content_block_start": | |
| content = event["content_block"] | |
| if content["type"] == "tool_use": | |
| return { | |
| "contentBlockStart": { | |
| "contentBlockIndex": event["index"], | |
| "start": { | |
| "toolUse": { | |
| "name": content["name"], | |
| "toolUseId": content["id"], | |
| } | |
| }, | |
| } | |
| } | |
| return {"contentBlockStart": {"contentBlockIndex": event["index"], "start": {}}} | |
| case "content_block_delta": | |
| delta = event["delta"] | |
| match delta["type"]: | |
| case "signature_delta": | |
| return { | |
| "contentBlockDelta": { | |
| "contentBlockIndex": event["index"], | |
| "delta": { | |
| "reasoningContent": { | |
| "signature": delta["signature"], | |
| }, | |
| }, | |
| }, | |
| } | |
| case "thinking_delta": | |
| return { | |
| "contentBlockDelta": { | |
| "contentBlockIndex": event["index"], | |
| "delta": { | |
| "reasoningContent": { | |
| "text": delta["thinking"], | |
| }, | |
| }, | |
| }, | |
| } | |
| case "input_json_delta": | |
| return { | |
| "contentBlockDelta": { | |
| "contentBlockIndex": event["index"], | |
| "delta": { | |
| "toolUse": { | |
| "input": delta["partial_json"], | |
| }, | |
| }, | |
| }, | |
| } | |
| case "text_delta": | |
| return { | |
| "contentBlockDelta": { | |
| "contentBlockIndex": event["index"], | |
| "delta": { | |
| "text": delta["text"], | |
| }, | |
| }, | |
| } | |
| case _: | |
| raise RuntimeError( | |
| f"event_type=<content_block_delta>, delta_type=<{delta['type']}> | unknown type" | |
| ) | |
| case "content_block_stop": | |
| return {"contentBlockStop": {"contentBlockIndex": event["index"]}} | |
| case "message_stop": | |
| message = event["message"] | |
| return {"messageStop": {"stopReason": message["stop_reason"]}} | |
| case "metadata": | |
| usage = event["usage"] | |
| return { | |
| "metadata": { | |
| "usage": { | |
| "inputTokens": usage["input_tokens"], | |
| "outputTokens": usage["output_tokens"], | |
| "totalTokens": usage["input_tokens"] + usage["output_tokens"], | |
| }, | |
| "metrics": { | |
| "latencyMs": 0, | |
| }, | |
| } | |
| } | |
| case _: | |
| raise RuntimeError(f"event_type=<{event['type']} | unknown type") | |
| @override | |
| async def stream( | |
| self, | |
| messages: Messages, | |
| tool_specs: list[ToolSpec] | None = None, | |
| system_prompt: str | None = None, | |
| *, | |
| tool_choice: ToolChoice | None = None, | |
| **kwargs: Any, | |
| ) -> AsyncGenerator[StreamEvent, None]: | |
| """Stream conversation with Claude via Bedrock. | |
| Args: | |
| messages: List of message objects to be processed by the model. | |
| tool_specs: List of tool specifications to make available to the model. | |
| system_prompt: System prompt to provide context to the model. | |
| tool_choice: Selection strategy for tool invocation. | |
| **kwargs: Additional keyword arguments. | |
| Yields: | |
| Formatted stream events from the model. | |
| Raises: | |
| ContextWindowOverflowException: If the input exceeds the model's context window. | |
| ModelThrottledException: If the request is throttled. | |
| """ | |
| logger.debug("formatting request") | |
| request = self.format_request(messages, tool_specs, system_prompt, tool_choice) | |
| logger.debug("request=<%s>", request) | |
| logger.debug("invoking model") | |
| try: | |
| async with self.client.messages.stream(**request) as stream: | |
| logger.debug("got response from model") | |
| async for event in stream: | |
| if event.type in AnthropicBedrockModel.EVENT_TYPES: | |
| yield self.format_chunk(event.model_dump()) | |
| usage = event.message.usage # type: ignore | |
| yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()}) | |
| except anthropic.RateLimitError as error: | |
| raise ModelThrottledException(str(error)) from error | |
| except anthropic.BadRequestError as error: | |
| if any( | |
| overflow_message in str(error).lower() for overflow_message in AnthropicBedrockModel.OVERFLOW_MESSAGES | |
| ): | |
| raise ContextWindowOverflowException(str(error)) from error | |
| raise error | |
| logger.debug("finished streaming response from model") | |
| @override | |
| async def structured_output( | |
| self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any | |
| ) -> AsyncGenerator[dict[str, T | Any], None]: | |
| """Get structured output from the model. | |
| Args: | |
| output_model: Pydantic model class for the expected output. | |
| prompt: The prompt messages. | |
| system_prompt: System prompt to provide context to the model. | |
| **kwargs: Additional keyword arguments. | |
| Yields: | |
| Model events with the last being the structured output. | |
| """ | |
| tool_spec = convert_pydantic_to_tool_spec(output_model) | |
| response = self.stream( | |
| messages=prompt, | |
| tool_specs=[tool_spec], | |
| system_prompt=system_prompt, | |
| tool_choice=cast(ToolChoice, {"any": {}}), | |
| **kwargs, | |
| ) | |
| async for event in process_stream(response): | |
| yield event | |
| stop_reason, messages, _, _ = event["stop"] | |
| if stop_reason != "tool_use": | |
| raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".') | |
| content = messages["content"] | |
| output_response: dict[str, Any] | None = None | |
| for block in content: | |
| if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]: | |
| output_response = block["toolUse"]["input"] | |
| else: | |
| continue | |
| if output_response is None: | |
| raise ValueError("No valid tool use or tool use input was found in the response.") | |
| yield {"output": output_model(**output_response)} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment