Created
July 10, 2023 20:06
-
-
Save drewgillson/546ef7f308e7d1cf9e4b185f157ffa36 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from sklearn.cluster import DBSCAN | |
import numpy as np | |
import statistics | |
from collections import Counter | |
def group_by_visual_row(data, eps=5, min_samples=1): | |
# Cluster report rows using the DBSCAN clustering algorithm to group OCR lines | |
# DBSCAN (Density-Based Spatial Clustering of Applications with Noise) can help | |
# group points that are packed closely together (points with nearby points i.e. "rows") | |
orig_min_y_values = np.array([item['min_y'] for item in data]).reshape(-1, 1) | |
clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(orig_min_y_values) | |
for item, label in zip(data, clustering.labels_): | |
cluster_points = [data[i]['min_y'] for i, cluster_label in enumerate(clustering.labels_) if cluster_label == label] | |
item['adj_min_y'] = round(statistics.median(cluster_points)) | |
# You could also count the number of lines with the same min_y, this can be | |
# useful for post-processing algorithms that expect a certain number of items | |
# per column, just uncomment the following: | |
""" | |
min_y_counts = Counter([item['adj_min_y'] for item in data]) | |
for item in data: | |
item['line_count'] = min_y_counts[item['adj_min_y']] | |
""" | |
return data | |
def sort_lines(page): | |
# Sort entities from top to bottom and left to right so we can step through them in a predictable reading order | |
lines = [] | |
for line in page.lines: | |
line_text = layout_to_text(line.layout, document_response.text).strip() | |
x_values = [vertex.x for vertex in line.layout.bounding_poly.vertices] | |
min_x, max_x = min(x_values), max(x_values) | |
y_values = [vertex.y for vertex in line.layout.bounding_poly.vertices] | |
min_y, max_y = min(y_values), max(y_values) | |
start_index = [segment.start_index for segment in line.layout.text_anchor.text_segments][0] | |
end_index = [segment.end_index for segment in line.layout.text_anchor.text_segments][0] | |
lines.append({'text': line_text, 'min_x': min_x, 'min_y': min(y_values), 'max_x': max(x_values), 'max_y': max(y_values), 'confidence': line.layout.confidence, 'start_index': start_index, 'end_index': end_index}) | |
items = group_by_visual_row(lines) | |
sorted_items = sorted(items, key=lambda item: (item['adj_min_y'], item['min_x'])) | |
return sorted_items | |
# Where document is a response from a Google Cloud Document AI processor | |
for page in document.pages: | |
lines = sort_lines(page) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment