Skip to content

Instantly share code, notes, and snippets.

@Streamweaver
Created November 14, 2024 14:52
Show Gist options
  • Save Streamweaver/52b61f4e08364851dd3c6ca4334f13a9 to your computer and use it in GitHub Desktop.
Save Streamweaver/52b61f4e08364851dd3c6ca4334f13a9 to your computer and use it in GitHub Desktop.
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.runnables.utils import AddableDict
import xml.etree.ElementTree as ET
from typing import Union, Optional, Any, Iterator, AsyncIterator, Literal
from bs4 import BeautifulSoup
import logging
import re
import contextlib
from xml.etree.ElementTree import TreeBuilder
class RobustStreamingParser:
"""Internal streaming parser with error handling."""
def __init__(self, parser: Literal["defusedxml", "xml"] = "defusedxml",
error_handler: Optional[callable] = None) -> None:
"""Initialize the streaming parser with error handling."""
self.logger = logging.getLogger(__name__)
self._setup_parser(parser)
self.error_handler = error_handler
self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
self.current_path: list[str] = []
self.current_path_has_children = False
self.buffer = ""
self.xml_started = False
self.partial_tag_buffer = ""
self.in_cdata = False
def _setup_parser(self, parser_type: str) -> None:
"""Set up the XML parser with error handling."""
if parser_type == "defusedxml":
try:
import defusedxml.ElementTree
_parser = defusedxml.ElementTree.DefusedXMLParser(target=TreeBuilder())
except ImportError as e:
self.logger.warning("defusedxml not installed, falling back to standard xml")
_parser = None
else:
_parser = None
self.pull_parser = ET.XMLPullParser(["start", "end"], parser=_parser)
def _clean_chunk(self, chunk: str) -> str:
"""Clean individual chunks of XML data."""
# Remove invalid XML characters
chunk = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', chunk)
return chunk
def _handle_incomplete_tag(self, chunk: str) -> str:
"""Handle incomplete XML tags in the stream."""
# Check for incomplete opening tags
open_count = chunk.count('<')
close_count = chunk.count('>')
if open_count > close_count:
# We have an incomplete tag, store it in buffer
last_open = chunk.rindex('<')
self.partial_tag_buffer = chunk[last_open:]
return chunk[:last_open]
elif self.partial_tag_buffer:
# Complete the previous partial tag
chunk = self.partial_tag_buffer + chunk
self.partial_tag_buffer = ""
return chunk
def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]:
"""Parse a chunk of XML with error handling."""
try:
if isinstance(chunk, BaseMessage):
chunk_content = chunk.content
if not isinstance(chunk_content, str):
return
chunk = chunk_content
# Clean the chunk
chunk = self._clean_chunk(chunk)
# Handle CDATA sections
if "![CDATA[" in chunk:
self.in_cdata = True
if self.in_cdata and "]]>" in chunk:
self.in_cdata = False
if self.in_cdata:
self.buffer += chunk
return
# Handle incomplete tags
chunk = self._handle_incomplete_tag(chunk)
if not chunk:
return
# Add chunk to buffer
self.buffer += chunk
# Process buffer if XML has started
if not self.xml_started:
if match := self.xml_start_re.search(self.buffer):
self.buffer = self.buffer[match.start():]
self.xml_started = True
else:
return
# Feed buffer to parser
try:
self.pull_parser.feed(self.buffer)
self.buffer = ""
except ET.ParseError as e:
self.logger.debug(f"Partial XML parse error (expected in streaming): {e}")
return
# Process events
for event, elem in self.pull_parser.read_events():
try:
if event == "start":
self.current_path.append(elem.tag)
self.current_path_has_children = False
elif event == "end":
self.current_path.pop()
if not self.current_path_has_children:
yield self._create_nested_element(self.current_path, elem)
if self.current_path:
self.current_path_has_children = True
else:
self.xml_started = False
except Exception as e:
self.logger.error(f"Error processing XML element: {e}")
if self.error_handler:
self.error_handler(e, elem)
except Exception as e:
self.logger.error(f"Error in stream parsing: {e}")
if self.error_handler:
self.error_handler(e, chunk)
def _create_nested_element(self, path: list[str], elem: ET.Element) -> AddableDict:
"""Create nested dictionary structure from XML element."""
try:
if not path:
return AddableDict({elem.tag: elem.text or ""})
return AddableDict({path[0]: [self._create_nested_element(path[1:], elem)]})
except Exception as e:
self.logger.error(f"Error creating nested element: {e}")
return AddableDict({path[0] if path else "error": str(e)})
def close(self) -> None:
"""Close the parser and process any remaining buffer."""
with contextlib.suppress(ET.ParseError):
if self.buffer:
self.pull_parser.feed(self.buffer)
self.pull_parser.close()
class RobustStreamingXMLOutputParser(BaseTransformOutputParser):
"""XML Output Parser with robust streaming support."""
def __init__(self,
tags: Optional[list[str]] = None,
parser: Literal["defusedxml", "xml"] = "defusedxml",
error_handler: Optional[callable] = None):
super().__init__()
self.tags = tags
self.parser = parser
self.error_handler = error_handler
self.logger = logging.getLogger(__name__)
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[AddableDict]:
"""Transform streaming input to parsed XML."""
streaming_parser = RobustStreamingParser(self.parser, self.error_handler)
for chunk in input:
yield from streaming_parser.parse(chunk)
streaming_parser.close()
async def _atransform(self, input: AsyncIterator[Union[str, BaseMessage]]) -> AsyncIterator[AddableDict]:
"""Transform async streaming input to parsed XML."""
streaming_parser = RobustStreamingParser(self.parser, self.error_handler)
async for chunk in input:
for output in streaming_parser.parse(chunk):
yield output
streaming_parser.close()
def parse(self, text: str) -> dict[str, Union[str, list[Any]]]:
"""Parse complete XML text."""
match = re.search(r"```(?:xml)?(.*)```", text, re.DOTALL)
if match:
text = match.group(1)
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
try:
if self.parser == "defusedxml":
from defusedxml import ElementTree as DefusedET
root = DefusedET.fromstring(text)
else:
root = ET.fromstring(text)
return self._root_to_dict(root)
except Exception as e:
self.logger.error(f"Failed to parse XML: {e}")
if self.error_handler:
return self.error_handler(e, text)
raise OutputParserException(f"Failed to parse XML: {e}", llm_output=text)
def _root_to_dict(self, root: ET.Element) -> dict[str, Union[str, list[Any]]]:
"""Convert XML tree to dictionary."""
try:
if root.text and bool(re.search(r"\S", root.text)):
return {root.tag: root.text}
result: dict = {root.tag: []}
for child in root:
if len(child) == 0:
result[root.tag].append({child.tag: child.text or ""})
else:
result[root.tag].append(self._root_to_dict(child))
return result
except Exception as e:
self.logger.error(f"Error converting XML to dict: {e}")
return {root.tag: f"Error: {str(e)}"}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment