Created
November 14, 2024 14:52
-
-
Save Streamweaver/52b61f4e08364851dd3c6ca4334f13a9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain_core.output_parsers.transform import BaseTransformOutputParser | |
from langchain_core.exceptions import OutputParserException | |
from langchain_core.messages import BaseMessage | |
from langchain_core.runnables.utils import AddableDict | |
import xml.etree.ElementTree as ET | |
from typing import Union, Optional, Any, Iterator, AsyncIterator, Literal | |
from bs4 import BeautifulSoup | |
import logging | |
import re | |
import contextlib | |
from xml.etree.ElementTree import TreeBuilder | |
class RobustStreamingParser: | |
"""Internal streaming parser with error handling.""" | |
def __init__(self, parser: Literal["defusedxml", "xml"] = "defusedxml", | |
error_handler: Optional[callable] = None) -> None: | |
"""Initialize the streaming parser with error handling.""" | |
self.logger = logging.getLogger(__name__) | |
self._setup_parser(parser) | |
self.error_handler = error_handler | |
self.xml_start_re = re.compile(r"<[a-zA-Z:_]") | |
self.current_path: list[str] = [] | |
self.current_path_has_children = False | |
self.buffer = "" | |
self.xml_started = False | |
self.partial_tag_buffer = "" | |
self.in_cdata = False | |
def _setup_parser(self, parser_type: str) -> None: | |
"""Set up the XML parser with error handling.""" | |
if parser_type == "defusedxml": | |
try: | |
import defusedxml.ElementTree | |
_parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder()) | |
except ImportError as e: | |
self.logger.warning("defusedxml not installed, falling back to standard xml") | |
_parser = None | |
else: | |
_parser = None | |
self.pull_parser = ET.XMLPullParser(["start", "end"], parser=_parser) | |
def _clean_chunk(self, chunk: str) -> str: | |
"""Clean individual chunks of XML data.""" | |
# Remove invalid XML characters | |
chunk = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', chunk) | |
return chunk | |
def _handle_incomplete_tag(self, chunk: str) -> str: | |
"""Handle incomplete XML tags in the stream.""" | |
# Check for incomplete opening tags | |
open_count = chunk.count('<') | |
close_count = chunk.count('>') | |
if open_count > close_count: | |
# We have an incomplete tag, store it in buffer | |
last_open = chunk.rindex('<') | |
self.partial_tag_buffer = chunk[last_open:] | |
return chunk[:last_open] | |
elif self.partial_tag_buffer: | |
# Complete the previous partial tag | |
chunk = self.partial_tag_buffer + chunk | |
self.partial_tag_buffer = "" | |
return chunk | |
def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]: | |
"""Parse a chunk of XML with error handling.""" | |
try: | |
if isinstance(chunk, BaseMessage): | |
chunk_content = chunk.content | |
if not isinstance(chunk_content, str): | |
return | |
chunk = chunk_content | |
# Clean the chunk | |
chunk = self._clean_chunk(chunk) | |
# Handle CDATA sections | |
if "![CDATA[" in chunk: | |
self.in_cdata = True | |
if self.in_cdata and "]]>" in chunk: | |
self.in_cdata = False | |
if self.in_cdata: | |
self.buffer += chunk | |
return | |
# Handle incomplete tags | |
chunk = self._handle_incomplete_tag(chunk) | |
if not chunk: | |
return | |
# Add chunk to buffer | |
self.buffer += chunk | |
# Process buffer if XML has started | |
if not self.xml_started: | |
if match := self.xml_start_re.search(self.buffer): | |
self.buffer = self.buffer[match.start():] | |
self.xml_started = True | |
else: | |
return | |
# Feed buffer to parser | |
try: | |
self.pull_parser.feed(self.buffer) | |
self.buffer = "" | |
except ET.ParseError as e: | |
self.logger.debug(f"Partial XML parse error (expected in streaming): {e}") | |
return | |
# Process events | |
for event, elem in self.pull_parser.read_events(): | |
try: | |
if event == "start": | |
self.current_path.append(elem.tag) | |
self.current_path_has_children = False | |
elif event == "end": | |
self.current_path.pop() | |
if not self.current_path_has_children: | |
yield self._create_nested_element(self.current_path, elem) | |
if self.current_path: | |
self.current_path_has_children = True | |
else: | |
self.xml_started = False | |
except Exception as e: | |
self.logger.error(f"Error processing XML element: {e}") | |
if self.error_handler: | |
self.error_handler(e, elem) | |
except Exception as e: | |
self.logger.error(f"Error in stream parsing: {e}") | |
if self.error_handler: | |
self.error_handler(e, chunk) | |
def _create_nested_element(self, path: list[str], elem: ET.Element) -> AddableDict: | |
"""Create nested dictionary structure from XML element.""" | |
try: | |
if not path: | |
return AddableDict({elem.tag: elem.text or ""}) | |
return AddableDict({path[0]: [self._create_nested_element(path[1:], elem)]}) | |
except Exception as e: | |
self.logger.error(f"Error creating nested element: {e}") | |
return AddableDict({path[0] if path else "error": str(e)}) | |
def close(self) -> None: | |
"""Close the parser and process any remaining buffer.""" | |
with contextlib.suppress(ET.ParseError): | |
if self.buffer: | |
self.pull_parser.feed(self.buffer) | |
self.pull_parser.close() | |
class RobustStreamingXMLOutputParser(BaseTransformOutputParser): | |
"""XML Output Parser with robust streaming support.""" | |
def __init__(self, | |
tags: Optional[list[str]] = None, | |
parser: Literal["defusedxml", "xml"] = "defusedxml", | |
error_handler: Optional[callable] = None): | |
super().__init__() | |
self.tags = tags | |
self.parser = parser | |
self.error_handler = error_handler | |
self.logger = logging.getLogger(__name__) | |
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[AddableDict]: | |
"""Transform streaming input to parsed XML.""" | |
streaming_parser = RobustStreamingParser(self.parser, self.error_handler) | |
for chunk in input: | |
yield from streaming_parser.parse(chunk) | |
streaming_parser.close() | |
async def _atransform(self, input: AsyncIterator[Union[str, BaseMessage]]) -> AsyncIterator[AddableDict]: | |
"""Transform async streaming input to parsed XML.""" | |
streaming_parser = RobustStreamingParser(self.parser, self.error_handler) | |
async for chunk in input: | |
for output in streaming_parser.parse(chunk): | |
yield output | |
streaming_parser.close() | |
def parse(self, text: str) -> dict[str, Union[str, list[Any]]]: | |
"""Parse complete XML text.""" | |
match = re.search(r"```(?:xml)?(.*)```", text, re.DOTALL) | |
if match: | |
text = match.group(1) | |
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text) | |
try: | |
if self.parser == "defusedxml": | |
from defusedxml import ElementTree as DefusedET | |
root = DefusedET.fromstring(text) | |
else: | |
root = ET.fromstring(text) | |
return self._root_to_dict(root) | |
except Exception as e: | |
self.logger.error(f"Failed to parse XML: {e}") | |
if self.error_handler: | |
return self.error_handler(e, text) | |
raise OutputParserException(f"Failed to parse XML: {e}", llm_output=text) | |
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]: | |
"""Convert XML tree to dictionary.""" | |
try: | |
if root.text and bool(re.search(r"\S", root.text)): | |
return {root.tag: root.text} | |
result: dict = {root.tag: []} | |
for child in root: | |
if len(child) == 0: | |
result[root.tag].append({child.tag: child.text or ""}) | |
else: | |
result[root.tag].append(self._root_to_dict(child)) | |
return result | |
except Exception as e: | |
self.logger.error(f"Error converting XML to dict: {e}") | |
return {root.tag: f"Error: {str(e)}"} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment