Last active
February 1, 2023 11:09
-
-
Save meitinger/3fb863005df2750500c168497fb171ca to your computer and use it in GitHub Desktop.
Helper utility to perform OCR and translation on images in the cloud.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
# | |
# Copyright (C) 2023, Manuel Meitinger | |
# | |
# This program is free software: you can redistribute it and/or modify | |
# it under the terms of the GNU General Public License as published by | |
# the Free Software Foundation, either version 2 of the License, or | |
# (at your option) any later version. | |
# | |
# This program is distributed in the hope that it will be useful, | |
# but WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
# GNU General Public License for more details. | |
# | |
# You should have received a copy of the GNU General Public License | |
# along with this program. If not, see <http://www.gnu.org/licenses/>. | |
# Prerequisites: | |
# 1. AWS account | |
# 2. Azure account | |
# 3. accepted Microsoft's Responsible AI terms and conditions | |
from __future__ import annotations | |
import abc | |
import argparse | |
import collections | |
import datetime | |
import enum | |
import functools | |
import hashlib | |
import inspect | |
import io | |
import ipaddress | |
import json | |
import math | |
import os | |
import re | |
import socket | |
import string | |
import subprocess | |
import sys | |
import threading | |
import time | |
import types | |
import typing | |
import urllib.parse | |
import urllib.request | |
if typing.TYPE_CHECKING: | |
import mypy_boto3_rekognition.type_defs | |
import mypy_boto3_s3.service_resource | |
print("Python version:", sys.version) | |
try: | |
with open(__file__, "r") as _file: | |
print("Script version:", hashlib.sha1("\n".join(_file.readlines()).encode("utf8")).hexdigest()) | |
except OSError: | |
pass | |
class Settings(typing.NamedTuple): | |
language: str = "en" | |
"""The target language's short code.""" | |
threads: int = 10 | |
"""The number of worker threads to spawn.""" | |
statistics: bool = False | |
"""If set, statistics about deployment and execution will be written to file.""" | |
preserve_order: bool = False | |
"""If set, the order of the regions will not be optimized by geolocation.""" | |
presign_expiration: int = 300 | |
"""Time in seconds after which a pre-signed URI should expire.""" | |
connect_timeout: float = 1.5 | |
"""Time in seconds the script will wait to connect to a cloud service.""" | |
ocr_timeout: float = 4.0 | |
"""Time in seconds the script will wait for the OCR to complete.""" | |
translation_timeout: float = 2.0 | |
"""Time in seconds the script will wait for the translation to complete.""" | |
refresh_interval: float = 0.75 | |
"""Time in seconds after which the progress report should be refreshed.""" | |
retry_interval: float = 0.5 | |
"""Time in seconds after which all regions or resources should be retried.""" | |
max_attempts: int = 5 | |
"""Number of attempts after which to give up on an operation.""" | |
azure_resource_group: str = "daecc-project" | |
"""The name of the Azure resource group where to create cognitive accounts in.""" | |
azure_cognitive_account_name: str = "{resource_group}-{kind!l}-{location}" | |
"""The template of cognitive account names.""" | |
azure_use_free_tier: bool = False | |
"""If set, use free F0 tier instead of pay-as-you-go S1.""" | |
def dot() -> None: | |
print(".", end="", flush=True) | |
# region dependencies | |
print("Loading libraries", end="", flush=True) | |
try: | |
import azure.core.exceptions | |
dot() | |
import azure.identity | |
dot() | |
import azure.mgmt.cognitiveservices | |
dot() | |
import azure.mgmt.cognitiveservices.models | |
dot() | |
import azure.mgmt.resource | |
dot() | |
import azure.mgmt.resource.resources.models | |
dot() | |
import azure.mgmt.storage | |
dot() | |
import azure.mgmt.storage.models | |
dot() | |
import azure.mgmt.subscription | |
dot() | |
import azure.mgmt.subscription.models | |
dot() | |
import azure.storage.blob | |
dot() | |
import boto3 | |
dot() | |
import botocore | |
dot() | |
import botocore.client | |
dot() | |
import botocore.config | |
dot() | |
import botocore.exceptions | |
dot() | |
import msgraph.core | |
dot() | |
import pandas | |
dot() | |
import requests | |
dot() | |
except ModuleNotFoundError as e: | |
print(f": {e.name} missing, installing...") | |
subprocess.check_call([ | |
sys.executable, | |
"-m", | |
"pip", | |
"install", | |
"azure-core", | |
"azure-identity", | |
"azure-mgmt-cognitiveservices", | |
"azure-mgmt-resource", | |
"azure-mgmt-storage", | |
"azure-mgmt-subscription", | |
"azure-storage-blob", | |
"boto3", | |
"botocore", | |
"msgraph-core", | |
"pandas", | |
"requests", | |
]) | |
os.execl(sys.executable, sys.executable, *sys.argv) | |
else: | |
print(": ready") | |
# endregion | |
############################################################################### | |
# HELPERS # | |
############################################################################### | |
T = typing.TypeVar("T") | |
U = typing.TypeVar("U") | |
def full_path(file_name: str) -> str: | |
return os.path.join(os.path.dirname(__file__), file_name) | |
@typing.overload | |
def cached( | |
key: str, | |
type: typing.Type[T], | |
create: typing.Callable[[], T], | |
) -> T: | |
... | |
@typing.overload | |
def cached( | |
key: str, | |
type: typing.Type[U], | |
create: typing.Callable[[], T], | |
*, | |
from_json: typing.Callable[[U], T], | |
to_json: typing.Callable[[T], U], | |
) -> T: | |
... | |
def cached(key: str, type: typing.Type, create: typing.Callable[[], T], **kwargs) -> T: | |
from_json = kwargs.get("from_json", lambda v: v) | |
to_json = kwargs.get("to_json", lambda v: v) | |
try: | |
raw_value = cache[key] | |
if not isinstance(raw_value, type): | |
raise TypeError() | |
value = from_json(raw_value) | |
except Exception: | |
value = create() | |
cache[key] = to_json(value) | |
with open(os.path.join(os.path.dirname(__file__), ".cache"), "w") as file: | |
json.dump(cache, file) | |
return value | |
def locate_ip(ip: str) -> tuple[float, float]: | |
def locate() -> tuple[float, float]: | |
response = requests.get(f"https://ipapi.co/{ip}/json/").json() | |
if not isinstance(response, dict): | |
raise TypeError(f"invalid response type '{type(response).__name__}'") | |
if response.get("error", None): | |
raise ValueError(f"cannot locate IP '{ip}': {response.get('reason', None)}") | |
latitude = response.get("latitude", None) | |
longitude = response.get("longitude", None) | |
if latitude is None and longitude is None: | |
raise ValueError(f"location of IP '{ip}' is unknown") | |
if not isinstance(latitude, (int, float)): | |
raise TypeError(f"invalid latitude type '{type(latitude).__name__}'") | |
if not isinstance(longitude, (int, float)): | |
raise TypeError(f"invalid longitude type '{type(longitude).__name__}'") | |
return (float(latitude), float(longitude)) | |
def parse(value: list) -> tuple[float, float]: | |
latitude, longitude = value | |
if not (isinstance(latitude, float) and isinstance(longitude, float)): | |
raise TypeError() | |
return (latitude, longitude) | |
return cached( | |
f"location:{ip}", | |
list, | |
locate, | |
from_json=parse, | |
to_json=lambda v: list(v), | |
) | |
def resolve_host(hostname: str) -> str: | |
def resolve() -> str: | |
try: | |
# resolve to IPv4 address | |
addresses = socket.getaddrinfo(hostname, 0, family=socket.AF_INET) | |
except socket.gaierror as e: | |
# resolve to IPv6 address | |
try: | |
addresses = socket.getaddrinfo(hostname, 0, family=socket.AF_INET6) | |
except socket.gaierror: | |
raise e | |
return addresses[0][4][0] | |
try: | |
return str(ipaddress.ip_address(hostname)) | |
except ValueError: | |
return cached(f"ip:{hostname}", str, resolve) | |
def locate_host(hostname: str) -> tuple[float, float]: | |
try: | |
return locate_ip(resolve_host(hostname)) | |
except socket.gaierror as e: | |
if e.errno in (socket.EAI_NONAME, socket.EAI_AGAIN): | |
raise ValueError(f"host '{hostname}' doesn't exist") | |
raise | |
def locate_uri(uri: str) -> tuple[float, float]: | |
parts = urllib.parse.urlsplit(uri) | |
if not parts.hostname: | |
raise ValueError(f"URI '{uri}' does not contain a host name") | |
return locate_host(parts.hostname) | |
def locate_me() -> tuple[float, float]: | |
return locate_ip(cached( | |
"me", | |
str, | |
lambda: str(ipaddress.ip_address(requests.get("https://api.ipify.org").text)), | |
from_json=lambda v: str(ipaddress.ip_address(v)), | |
to_json=lambda v: v | |
)) | |
def exists_host(hostname: str) -> bool: | |
def exists() -> bool: | |
try: | |
return bool(resolve_host(hostname)) | |
except socket.gaierror as e: | |
if e.errno in (socket.EAI_NONAME, socket.EAI_AGAIN): | |
return False | |
raise | |
return cached(f"exists:{hostname}", bool, exists) | |
def distance(location1: tuple[float, float], location2: tuple[float, float]) -> float: | |
deg_to_rad = math.pi/180.0 | |
lat1, long1 = location1 | |
lat2, long2 = location2 | |
phi1 = (90.0 - lat1) * deg_to_rad | |
phi2 = (90.0 - lat2) * deg_to_rad | |
theta1 = long1 * deg_to_rad | |
theta2 = long2 * deg_to_rad | |
return math.acos( | |
math.sin(phi1) * math.sin(phi2) * math.cos(theta1 - theta2) + | |
math.cos(phi1) * math.cos(phi2) | |
) | |
def cast(type: typing.Type[T], value: typing.Any) -> T: | |
if not isinstance(value, type): | |
raise TypeError(f"type '{type.__name__}' expected") | |
return value | |
def ensure(value: T | None) -> T: | |
if value is None: | |
raise ValueError("value must not be None") | |
return value | |
def register_settings() -> None: | |
try: | |
settings_lines, _ = inspect.getsourcelines(Settings) | |
except Exception: | |
settings_lines = None | |
for index, setting in enumerate(Settings._fields): | |
if Settings._field_defaults[setting] is False: | |
parser.add_argument( | |
f"--{setting.replace('_', '-')}", | |
action='store_true', | |
help=settings_lines[(1 + index) * 2].strip().strip('"') if settings_lines else None, | |
) | |
else: | |
parser.add_argument( | |
f"--{setting.replace('_', '-')}", | |
required=False, | |
type=eval(Settings.__annotations__[setting].__forward_arg__), | |
metavar=Settings.__annotations__[setting].__forward_arg__.upper(), | |
default=Settings._field_defaults[setting], | |
help=settings_lines[(1 + index) * 2].strip().strip('".') + ", defaults to '%(default)s'." if settings_lines else None, | |
) | |
class Formatter(string.Formatter): | |
def convert_field(self, value: typing.Any, conversion: str) -> typing.Any: | |
if conversion == "u": | |
return str(value).upper() | |
elif conversion == "l": | |
return str(value).lower() | |
elif conversion == "c": | |
return str(value).capitalize() | |
elif conversion == "t": | |
return str(value).title() | |
else: | |
return super().convert_field(value, conversion) | |
############################################################################### | |
# ABSTRACT BASE CLASSES # | |
############################################################################### | |
class Store(abc.ABC): | |
def __init__(self, location: tuple[float, float]) -> None: | |
self.location = location | |
@staticmethod | |
def from_uri(uri: str) -> Store: | |
p = urllib.parse.urlsplit(uri) | |
if p.fragment: | |
raise ValueError("URI must not contain fragments") | |
if p.scheme == "s3": | |
if p.port is not None: | |
raise ValueError("S3 URI must not contain a port") | |
if p.query: | |
raise ValueError("S3 URI must not contain a query") | |
return S3BucketStore(s3_resource.Bucket(p.netloc), filter=p.path) | |
elif p.scheme == "https": | |
s3_suffix = ".s3.amazonaws.com" | |
netloc = p.netloc.lower() | |
if netloc.endswith(s3_suffix): | |
if p.port is not None and p.port != 443: | |
raise ValueError("bucket URI must have port 443") | |
if p.query: | |
raise ValueError("bucket URI must not contain a query") | |
return S3BucketStore(s3_resource.Bucket(p.netloc[: -len(s3_suffix)]), filter=p.path) | |
elif netloc.endswith(".blob.core.windows.net"): | |
return AzureContainerStore( | |
azure.storage.blob.ContainerClient.from_container_url(uri, credential=azure_credential) | |
) | |
else: | |
return WebDocumentStore(uri) | |
elif p.scheme == "file": | |
path = urllib.request.url2pathname(p.path) | |
if os.path.isdir(path): | |
return LocalDirectoryStore(path) | |
elif os.path.isfile(path): | |
return LocalDirectoryStore(os.path.dirname(path), file_name=os.path.basename(path)) | |
else: | |
raise ValueError(f"path '{path}' not found") | |
else: | |
raise ValueError(f"URI scheme {p.scheme} is not supported") | |
@abc.abstractmethod | |
def list_files(self) -> typing.Iterable[Document]: | |
raise NotImplementedError() | |
@abc.abstractmethod | |
def put_file(self, file_name: str, content: bytes) -> None: | |
raise NotImplementedError() | |
class CloudStore(Store): | |
def __init__(self, region: Region, is_public: bool) -> None: | |
super().__init__(region.location) | |
self.region = region | |
self.is_public = is_public | |
StoreT = typing.TypeVar("StoreT", bound=Store) | |
class Document(abc.ABC, typing.Generic[StoreT]): | |
def __init__(self, store: StoreT, name: str) -> None: | |
self.store = store | |
self.name = name | |
@abc.abstractproperty | |
def url(self) -> str: | |
raise NotImplementedError() | |
@abc.abstractproperty | |
def content(self) -> bytes: | |
raise NotImplementedError() | |
class PresignedUrlDocument(Document[StoreT]): | |
def __init__(self, store: StoreT, name: str) -> None: | |
super().__init__(store, name) | |
self.presigned_url: str | None = None | |
self.presign_time = 0.0 | |
@abc.abstractmethod | |
def presign_url(self) -> str: | |
raise NotImplementedError() | |
@property | |
def url(self) -> str: | |
now = time.time() | |
if self.presigned_url is None or now - self.presign_time > settings.presign_expiration * 0.8: | |
self.presigned_url = self.presign_url() | |
self.presign_time = now | |
return self.presigned_url | |
class ServiceInputError(Exception): | |
pass | |
class CloudService(abc.ABC): | |
def __init__(self, region: Region) -> None: | |
self.region = region | |
@abc.abstractproperty | |
def timeout(self) -> float: | |
raise NotImplementedError() | |
@abc.abstractproperty | |
def timed(self) -> TimedService: | |
raise NotImplementedError() | |
class OcrImageError(ServiceInputError): | |
pass | |
class OcrService(CloudService): | |
@abc.abstractmethod | |
def run(self, document: Document) -> str: | |
raise NotImplementedError() | |
@property | |
def timeout(self) -> float: | |
return settings.ocr_timeout | |
@property | |
def timed(self) -> TimedService: | |
return self.region.ocr | |
class TranslationTextError(ServiceInputError): | |
pass | |
class TranslationService(CloudService): | |
@abc.abstractmethod | |
def run(self, text: str, language: str) -> str: | |
raise NotImplementedError | |
@property | |
def timeout(self) -> float: | |
return settings.translation_timeout | |
@property | |
def timed(self) -> TimedService: | |
return self.region.translation | |
############################################################################### | |
# THROTTLING # | |
############################################################################### | |
class QuotaError(Exception): | |
def __init__(self, maximum: int, exceeded_by: int) -> None: | |
super().__init__(f"maximum of {maximum} exceeded by {exceeded_by}") | |
self.maximum = maximum | |
self.exceeded_by = exceeded_by | |
class ConcurrencyQuota: | |
def __init__(self, maximum: int) -> None: | |
if maximum < 0: | |
raise ValueError("concurrency quota must be non-negative") | |
self.lock = threading.Lock() | |
self.count = 0 | |
self.maximum = maximum | |
def __enter__(self) -> None: | |
with self.lock: | |
if self.count == self.maximum: | |
raise QuotaError(self.maximum, exceeded_by=1) | |
self.count += 1 | |
def __exit__( | |
self, | |
type: typing.Type[BaseException] | None, | |
value: BaseException | None, | |
traceback: types.TracebackType | None, | |
) -> None: | |
with self.lock: | |
self.count -= 1 | |
class SlidingWindowQuota: | |
def __init__(self, maximum: int, interval: datetime.timedelta) -> None: | |
if maximum < 0: | |
raise ValueError("sliding window quota must be non-negative") | |
if interval <= datetime.timedelta(): | |
raise ValueError("sliding window interval must be larger than zero") | |
self.lock = threading.Lock() | |
self.history: collections.deque[tuple[datetime.datetime, int]] = collections.deque() | |
self.count = 0 | |
self.maximum = maximum | |
self.interval = interval | |
def use(self, count: int = 1) -> None: | |
with self.lock: | |
now = datetime.datetime.utcnow() | |
while self.history and now - self.history[0][0] > self.interval: | |
self.count -= self.history.popleft()[1] | |
exceeded_by = (self.count + count) - self.maximum | |
if exceeded_by > 0: | |
raise QuotaError(self.maximum, exceeded_by) | |
self.history.append((now, count)) | |
self.count += count | |
############################################################################### | |
# SIMPLE FILE AND WEB DOCUMENTS # | |
############################################################################### | |
class LocalDirectoryStore(Store): | |
def __init__( | |
self, | |
path: str, | |
*, | |
file_name: str | None = None, | |
location: tuple[float, float] | None = None, | |
) -> None: | |
super().__init__(location or locate_me()) | |
self.path = path | |
self.file_name = file_name | |
def list_files(self) -> typing.Iterable[Document]: | |
if self.file_name is None: | |
for file_name in os.listdir(self.path): | |
path = os.path.join(self.path, file_name) | |
if os.path.isfile(path): | |
yield LocalFileDocument(self, path) | |
else: | |
yield LocalFileDocument(self, os.path.join(self.path, self.file_name)) | |
def put_file(self, file_name: str, content: bytes) -> None: | |
file_path = os.path.join(self.path, file_name) | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
with open(file_path, "wb") as file: | |
file.write(content) | |
def __repr__(self) -> str: | |
return f"<LocalDirectoryStore path='{self.path}' file_name='{self.file_name}' location={self.location}>" | |
class LocalFileDocument(Document[LocalDirectoryStore]): | |
def __init__(self, store: LocalDirectoryStore, path: str) -> None: | |
super().__init__(store, os.path.basename(path)) | |
self.path = path | |
@property | |
def url(self) -> str: | |
raise NotImplementedError() | |
@functools.cached_property | |
def content(self) -> bytes: | |
with open(self.path, "rb") as file: | |
return file.read() | |
class WebDocumentStore(Store): | |
def __init__(self, uri: str, *, location: tuple[float, float] | None = None) -> None: | |
super().__init__(location or locate_uri(uri)) | |
self.uri = uri | |
def list_files(self) -> typing.Iterable[Document]: | |
yield WebDocument(self, self.uri) | |
def put_file(self, file_name: str, content: bytes) -> None: | |
if file_name not in ("", self.uri): | |
raise ValueError(f"file name '{file_name}' not supported") | |
requests.put(self.uri, content).raise_for_status() | |
def __repr__(self) -> str: | |
return f"<WebDocumentStore uri='{self.uri}' location={self.location}>" | |
class WebDocument(Document[WebDocumentStore]): | |
def __init__(self, store: WebDocumentStore, uri: str) -> None: | |
super().__init__(store, uri) | |
self.uri = uri | |
@property | |
def url(self) -> str: | |
return self.uri | |
@functools.cached_property | |
def content(self) -> bytes: | |
return requests.get(self.uri).content | |
############################################################################### | |
# AMAZON WEB SERVICES # | |
############################################################################### | |
def register_aws_regions() -> None: | |
for region_desc in aws_session.client("ec2", "us-east-1").describe_regions()["Regions"]: | |
try: | |
region = Region( | |
provider=Provider.AWS, | |
name=region_desc["RegionName"], | |
location=locate_host(region_desc["Endpoint"]), | |
is_available=region_desc["OptInStatus"] in ("opted-in", "opt-in-not-required"), | |
supports_ocr=exists_host(f"rekognition.{region_desc['RegionName']}.amazonaws.com"), | |
supports_translation=exists_host(f"translate.{region_desc['RegionName']}.amazonaws.com"), | |
) | |
except KeyError: | |
raise ValueError("incomplete AWS region description") | |
Region.register(region) | |
class S3BucketStore(CloudStore): | |
@staticmethod | |
def is_public_bucket(bucket_name: str) -> bool: | |
def has_public_acl() -> bool: | |
# check if AllUsers have at least READ access | |
grants = s3_client.get_bucket_acl(Bucket=bucket_name)["Grants"] | |
return any( | |
True | |
for grant in grants | |
if grant.get("Grantee", {}).get("URI", None) | |
== "http://acs.amazonaws.com/groups/global/AllUsers" | |
and grant.get("Permission", None) in ("FULL_CONTROL", "READ") | |
) | |
def has_public_policy() -> bool: | |
# check if a public policy exists | |
try: | |
return s3_client.get_bucket_policy_status(Bucket=bucket_name)["PolicyStatus"]["IsPublic"] | |
except botocore.exceptions.ClientError as e: | |
if e.response.get("Error", {}).get("Code", None) != "NoSuchBucketPolicy": | |
raise | |
return False | |
try: | |
config = s3_client.get_public_access_block(Bucket=bucket_name)["PublicAccessBlockConfiguration"] | |
except botocore.exceptions.ClientError as e: | |
if e.response.get("Error", {}).get("Code", None) != "NoSuchPublicAccessBlockConfiguration": | |
raise | |
config = {} | |
return ( | |
( | |
not config.get("BlockPublicPolicy", False) | |
and not config.get("RestrictPublicBuckets", False) | |
and has_public_acl() | |
) | |
or ( | |
not config.get("BlockPublicAcls", False) | |
and not config.get("IgnorePublicAcls", False) | |
and has_public_policy() | |
) | |
) | |
def __init__(self, bucket: mypy_boto3_s3.service_resource.Bucket, filter: str) -> None: | |
super().__init__( | |
# return None means "us-east-1" (https://github.com/aws/aws-cli/issues/3864) | |
region=Region.lookup( | |
Provider.AWS, | |
s3_client.get_bucket_location(Bucket=bucket.name)["LocationConstraint"] or "us-east-1", | |
), | |
is_public=S3BucketStore.is_public_bucket(bucket.name), | |
) | |
self.bucket = bucket | |
self.filter = filter.strip("/") | |
if self.filter: | |
self.filter += "/" | |
def list_files(self) -> typing.Iterable[Document]: | |
for obj in self.bucket.objects.filter(Prefix=self.filter): | |
if obj.key[-1] != "/": | |
yield S3ObjectDocument(self, obj) | |
def put_file(self, file_name: str, content: bytes) -> None: | |
self.bucket.put_object( | |
Key=f"{self.filter}{file_name}", | |
Body=content, | |
) | |
def __repr__(self) -> str: | |
return f"<S3BucketStore bucket='{self.bucket.name}' filter='{self.filter}' region='{self.region.name}'>" | |
class S3ObjectDocument(PresignedUrlDocument[S3BucketStore]): | |
def __init__(self, store: S3BucketStore, obj: mypy_boto3_s3.service_resource.ObjectSummary) -> None: | |
super().__init__(store, obj.key[len(store.filter):]) | |
self.obj = obj | |
def presign_url(self) -> str: | |
params = { | |
"Bucket": self.obj.bucket_name, | |
"Key": self.obj.key, | |
} | |
if self.store.is_public: | |
return s3_client_unsigned.generate_presigned_url( | |
ClientMethod="get_object", | |
Params=params, | |
ExpiresIn=0, | |
) | |
else: | |
return s3_client.generate_presigned_url( | |
ClientMethod="get_object", | |
Params=params, | |
ExpiresIn=settings.presign_expiration, | |
) | |
@functools.cached_property | |
def content(self) -> bytes: | |
buffer = io.BytesIO() | |
s3_client.download_fileobj( | |
Bucket=self.obj.bucket_name, | |
Key=self.obj.key, | |
Fileobj=buffer, | |
) | |
return buffer.getvalue() | |
class AWSService(CloudService): | |
def __init__(self, service: botocore.client.BaseClient) -> None: | |
super().__init__(Region.lookup(Provider.AWS, service.meta.region_name)) | |
self.meta = service.meta | |
class AWSRekognitionService(AWSService, OcrService): | |
def __init__(self, region_name: str) -> None: | |
self.service = aws_session.client( | |
service_name="rekognition", | |
region_name=region_name, | |
config=botocore.config.Config( | |
retries={"total_max_attempts": 1}, | |
connect_timeout=settings.connect_timeout, | |
read_timeout=settings.ocr_timeout, | |
) | |
) | |
super().__init__(self.service) | |
is_fast = region_name in ("us-east-1", "us-west-2", "eu-west-1") | |
self.quota = SlidingWindowQuota(50 if is_fast else 5, datetime.timedelta(seconds=1)) | |
def run(self, document: Document) -> str: | |
# we either use a bucket reference for S3 objects or upload bytes | |
image: mypy_boto3_rekognition.type_defs.ImageTypeDef | |
if ( | |
isinstance(document, S3ObjectDocument) | |
and document.store.region is self.region | |
): | |
image = { | |
"S3Object": { | |
"Bucket": document.obj.bucket_name, | |
"Name": document.obj.key, | |
} | |
} | |
else: | |
try: | |
image = {"Bytes": document.content} | |
except Exception as e: | |
raise ServiceInputError(f"cannot download image ({e})") | |
self.quota.use() | |
try: | |
with self.timed: | |
result = self.service.detect_text(Image=image) | |
except self.service.exceptions.ImageTooLargeException: | |
raise OcrImageError("image is too large") | |
except self.service.exceptions.InvalidImageFormatException: | |
raise OcrImageError("image format is invalid") | |
return "\n".join( | |
line.get("DetectedText", "") | |
for line in result["TextDetections"] | |
if line.get("Type", None) == "LINE" | |
) | |
class AWSTranslateService(AWSService, TranslationService): | |
def __init__(self, region_name: str) -> None: | |
self.service = aws_session.client( | |
service_name="translate", | |
region_name=region_name, | |
config=botocore.config.Config( | |
retries={"total_max_attempts": 1}, | |
connect_timeout=settings.connect_timeout, | |
read_timeout=settings.translation_timeout, | |
) | |
) | |
super().__init__(self.service) | |
def run(self, text: str, language: str) -> str: | |
try: | |
with self.timed: | |
result = self.service.translate_text( | |
Text=text, | |
SourceLanguageCode="auto", | |
TargetLanguageCode=language, | |
) | |
except self.service.exceptions.DetectedLanguageLowConfidenceException: | |
raise TranslationTextError("failed to detect language") | |
except self.service.exceptions.TextSizeLimitExceededException: | |
raise TranslationTextError("text is too large") | |
except self.service.exceptions.UnsupportedLanguagePairException: | |
raise TranslationTextError(f"cannot translate to {language} from detected language") | |
return result["TranslatedText"] | |
############################################################################### | |
# AZURE # | |
############################################################################### | |
def get_azure_subscription() -> azure.mgmt.subscription.models.Subscription: | |
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", None) | |
if subscription_id is not None: | |
return azure_subscription_client.subscriptions.get(subscription_id) | |
if not azure_subscriptions: | |
raise Exception("no Azure subscription found") | |
if len(azure_subscriptions) == 1: | |
return azure_subscriptions[0] | |
else: | |
print() | |
for i, subscription in enumerate(azure_subscriptions): | |
print(f"{i+1}. {subscription.subscription_id}: {subscription.display_name} [{subscription.state}]") | |
while True: | |
try: | |
i = int(input("> ")) | |
except ValueError: | |
pass | |
else: | |
if 1 <= i and i <= len(azure_subscriptions): | |
return azure_subscriptions[i - 1] | |
def register_azure_regions() -> None: | |
for location in azure_subscription_client.subscriptions.list_locations(azure_subscription_id): | |
Region.register(Region( | |
provider=Provider.Azure, | |
name=cast(str, location.name), | |
location=(float(cast(str, location.latitude)), float(cast(str, location.longitude))), | |
is_available=True, | |
supports_ocr=exists_host(f"{location.name}.api.cognitive.microsoft.com"), | |
supports_translation=exists_host(f"{location.name}.api.cognitive.microsoft.com"), | |
)) | |
class AzureContainerStore(CloudStore): | |
@staticmethod | |
def get_storage_account(account_name: str) -> azure.mgmt.storage.models.StorageAccount: | |
# apparently there is no way to get the storage account just by name, so we need to enumerate | |
for subscription in azure_subscriptions: | |
storage_client = azure.mgmt.storage.StorageManagementClient( | |
credential=azure_credential, | |
subscription_id=cast(str, subscription.subscription_id), | |
) | |
for entry in storage_client.storage_accounts.list(): | |
storage_account = typing.cast(azure.mgmt.storage.models.StorageAccount, entry) | |
if cast(str, storage_account.name).lower() == account_name.lower(): | |
return storage_account | |
raise Exception(f"storage account '{account_name}' not found") | |
@staticmethod | |
def get_access_key(storage_account_id: str) -> str: | |
# we need to get the resource group name from the id | |
if not ( | |
m := re.match( | |
r"^/subscriptions/(?P<s>[^/]+)/resourceGroups/(?P<rg>[^/]+)/providers/Microsoft.Storage/storageAccounts/(?P<a>[^/]+)$", | |
storage_account_id, | |
) | |
): | |
raise Exception(f"invalid storage account id '{storage_account_id}'") | |
storage_client = azure.mgmt.storage.StorageManagementClient( | |
credential=azure_credential, | |
subscription_id=m.group("s"), | |
) | |
key_list = typing.cast( | |
azure.mgmt.storage.models.StorageAccountListKeysResult, | |
storage_client.storage_accounts.list_keys( | |
resource_group_name=m.group("rg"), | |
account_name=m.group("a"), | |
), | |
) | |
keys = cast(list, key_list.keys) | |
key = typing.cast(azure.mgmt.storage.models.StorageAccountKey, keys[0]) | |
return cast(str, key.value) | |
def __init__(self, container: azure.storage.blob.ContainerClient) -> None: | |
self.container = container | |
self.storage_account = AzureContainerStore.get_storage_account(cast(str, self.container.account_name)) | |
self.access_key: str | None = None | |
try: | |
policy = self.container.get_container_access_policy() | |
except azure.core.exceptions.ResourceNotFoundError: | |
self.access_key = AzureContainerStore.get_access_key(cast(str, self.storage_account.id)) | |
self.container = typing.cast( | |
azure.storage.blob.ContainerClient, | |
azure.storage.blob.ContainerClient.from_container_url(self.container.url, credential=self.access_key), | |
) | |
policy = self.container.get_container_access_policy() | |
endpoints = typing.cast(azure.mgmt.storage.models.Endpoints, self.storage_account.primary_endpoints) | |
self.blob_service_client = azure.storage.blob.BlobServiceClient(cast(str, endpoints.blob), credential=azure_credential) | |
super().__init__( | |
region=Region.lookup(Provider.Azure, self.storage_account.location), | |
is_public=policy["public_access"] is not None, | |
) | |
def list_files(self) -> typing.Iterable[Document]: | |
for blob in self.container.list_blobs(): | |
yield AzureBlobDocument(self, blob) | |
def put_file(self, file_name: str, content: bytes) -> None: | |
self.container.get_blob_client(file_name).upload_blob(content, overwrite=True) | |
def __repr__(self) -> str: | |
return f"<AzureContainerStore storage_account='{self.container.account_name}' container='{self.container.container_name}' region='{self.region.name}'>" | |
class AzureBlobDocument(PresignedUrlDocument[AzureContainerStore]): | |
def __init__(self, store: AzureContainerStore, blob: azure.storage.blob.BlobProperties) -> None: | |
super().__init__(store, cast(str, blob.name)) | |
self.blob = blob | |
def presign_url(self) -> str: | |
url = self.store.container.get_blob_client(blob=self.blob).url | |
if self.store.is_public or urllib.parse.urlsplit(url).query: | |
# the store is public or we already have a SAS, simply return the url | |
return url | |
else: | |
# we need a SAS | |
start = datetime.datetime.utcnow() | |
expiry = start + datetime.timedelta(seconds=settings.presign_expiration) | |
user_delegation_key = ( | |
self.store.blob_service_client.get_user_delegation_key( | |
key_start_time=start, | |
key_expiry_time=expiry, | |
) | |
if self.store.access_key is None else | |
None | |
) | |
# generate and append the SAS | |
sas = azure.storage.blob.generate_blob_sas( | |
account_name=cast(str, self.store.container.account_name), | |
container_name=self.store.container.container_name, | |
blob_name=cast(str, self.blob.name), | |
permission=azure.storage.blob.BlobSasPermissions(read=True), | |
account_key=self.store.access_key, | |
user_delegation_key=user_delegation_key, | |
start=start, | |
expiry=expiry, | |
) | |
return f"{url}?{sas}" | |
@functools.cached_property | |
def content(self) -> bytes: | |
return self.store.container.download_blob(blob=self.blob).readall() | |
class AzureServiceError(Exception): | |
def __init__(self, code: str | int, message: str) -> None: | |
super().__init__(message) | |
self.code = code | |
self.message = message | |
@staticmethod | |
def from_json(value: typing.Any) -> AzureServiceError: | |
if not isinstance(value, dict): | |
raise ValueError("value must be an object") | |
error = value.get("error", dict()) | |
if not isinstance(error, dict): | |
raise ValueError("inner error must be an object") | |
inner_error = error.get("innererror", dict()) | |
if not isinstance(inner_error, dict): | |
raise ValueError("inner error must be an object") | |
code = inner_error.get("code", error.get("code", 0)) | |
if not isinstance(code, (str, int)): | |
raise ValueError("error code must be a string or number") | |
message = inner_error.get("message", error.get("message", "An unknown error occurred.")) | |
if not isinstance(message, str): | |
raise ValueError("error message must be a string") | |
return AzureServiceError(code, message) | |
class AzureService(CloudService): | |
def __init__(self, *, kind: str, location: str) -> None: | |
super().__init__(Region.lookup(Provider.Azure, location)) | |
self.account, self.key = AzureService.get_or_create_account_and_key(kind, location) | |
self.endpoint = cast(str, ensure(self.account.properties).endpoint) | |
self.location = ensure(self.account.location) | |
@staticmethod | |
@functools.lru_cache | |
def get_resource_group() -> azure.mgmt.resource.resources.models.ResourceGroup: | |
region = min( | |
Region.index[Provider.Azure].values(), | |
key=lambda region: distance(region.location, me), | |
) | |
with Resource(settings.azure_resource_group, Provider.Azure) as res: | |
try: | |
result = azure_resource_client.resource_groups.get(settings.azure_resource_group) | |
res.existed = True | |
except azure.core.exceptions.ResourceNotFoundError: | |
result = azure_resource_client.resource_groups.create_or_update( | |
resource_group_name=settings.azure_resource_group, | |
parameters=azure.mgmt.resource.resources.models.ResourceGroup( | |
location=region.name | |
), # type: ignore | |
) | |
res.created = True | |
res.region_name = result.location | |
return typing.cast(azure.mgmt.resource.resources.models.ResourceGroup, result) | |
@staticmethod | |
def get_or_create_account_and_key(kind: str, location: str) -> tuple[azure.mgmt.cognitiveservices.models.Account, str]: | |
with Resource(kind, Provider.Azure, location) as res: | |
sku = "F0" if settings.azure_use_free_tier else "S1" | |
resource_group = AzureService.get_resource_group() | |
resource_group_name = cast(str, resource_group.name) | |
account_name = Formatter().format( | |
settings.azure_cognitive_account_name, | |
resource_group=resource_group_name, | |
location=location, | |
kind=kind, | |
sku=sku, | |
) | |
try: | |
# get any existing account with the given name | |
account = azure_cognitiveservices_management_client.accounts.get( | |
resource_group_name=resource_group_name, | |
account_name=account_name, | |
) | |
res.existed = True | |
# check if any setting is wrong | |
if ( | |
ensure(account.kind) != kind | |
or ensure(account.location) != location | |
or ensure(account.sku).name != sku | |
): | |
# delete and purge the incorrect account (we can't update these settings) | |
azure_cognitiveservices_management_client.accounts.begin_delete( | |
resource_group_name=resource_group_name, | |
account_name=account_name, | |
).result() | |
azure_cognitiveservices_management_client.deleted_accounts.begin_purge( | |
location=ensure(account.location), | |
resource_group_name=resource_group_name, | |
account_name=account_name, | |
) | |
raise azure.core.exceptions.ResourceNotFoundError() | |
except azure.core.exceptions.ResourceNotFoundError: | |
while True: | |
try: | |
# create a new account | |
account = azure_cognitiveservices_management_client.accounts.begin_create( | |
resource_group_name=resource_group_name, | |
account_name=account_name, | |
account=azure.mgmt.cognitiveservices.models.Account( | |
kind=kind, | |
sku=azure.mgmt.cognitiveservices.models.Sku(name=sku), | |
location=location, | |
), | |
).result() | |
res.created = True | |
break | |
except azure.core.exceptions.ResourceExistsError: | |
# purge the old account (but only if it's in our resource group) | |
for deleted_account in azure_cognitiveservices_management_client.deleted_accounts.list(): | |
if cast(str, deleted_account.name).lower() == account_name.lower(): | |
azure_cognitiveservices_management_client.deleted_accounts.begin_purge( | |
location=ensure(deleted_account.location), | |
resource_group_name=resource_group_name, | |
account_name=account_name, | |
).result() | |
break | |
else: | |
raise | |
key_list = azure_cognitiveservices_management_client.accounts.list_keys( | |
resource_group_name=resource_group_name, | |
account_name=account_name, | |
) | |
return (account, ensure(key_list.key1)) | |
def invoke( | |
self, | |
*, | |
path: str, | |
parse: typing.Callable[[typing.Any], T], | |
data: bytes | None = None, | |
json: typing.Any = None, | |
params: dict[str, str] | None = None, | |
) -> T: | |
assert not (data is not None and json is not None), "binary data and JSON are mutually exclusive" | |
with self.timed: | |
headers = { | |
"Ocp-Apim-Subscription-Key": self.key, | |
"Ocp-Apim-Subscription-Region": self.location, | |
} | |
if data is not None: | |
headers["Content-Type"] = "application/octet-stream" | |
response = requests.post( | |
url=f"{self.endpoint}{path}", | |
data=data, | |
json=json, | |
headers=headers, | |
params=params, | |
timeout=(settings.connect_timeout, self.timeout), | |
) | |
result = response.json() | |
if not response.ok: | |
raise AzureServiceError.from_json(result) | |
try: | |
return parse(result) | |
except Exception: | |
raise ValueError("invalid response data") | |
class AzureComputerVisionService(AzureService, OcrService): | |
def __init__(self, location: str) -> None: | |
super().__init__(kind="ComputerVision", location=location) | |
self.quota = ( | |
SlidingWindowQuota(20, datetime.timedelta(minutes=1)) | |
if settings.azure_use_free_tier else | |
SlidingWindowQuota(10, datetime.timedelta(seconds=1)) | |
) | |
def run(self, document: Document) -> str: | |
try: | |
data = None | |
json = {"url": document.url} | |
except NotImplementedError: | |
try: | |
data = document.content | |
json = None | |
except NotImplementedError: | |
raise ServiceInputError("document supports neither generating URI nor downloading") | |
except Exception as e: | |
raise ServiceInputError(f"cannot download image ({e})") | |
except Exception as e: | |
raise ServiceInputError(f"cannot generate image URI ({e})") | |
self.quota.use() | |
try: | |
return self.invoke( | |
path="/vision/v3.2/ocr", | |
data=data, | |
json=json, | |
parse=lambda result: "\n".join( | |
" ".join(word["text"] for word in line["words"]) | |
for region in result["regions"] | |
for line in region["lines"] | |
), | |
) | |
except AzureServiceError as e: | |
if e.code == "InvalidImageFormat": | |
raise ServiceInputError("image format is invalid") | |
elif e.code == "InvalidImageSize": | |
raise ServiceInputError("image size is invalid") | |
elif e.code == "NotSupportedImage": | |
raise ServiceInputError("image type is not supported") | |
elif e.code == "NotSupportedLanguage": | |
raise ServiceInputError(f"language is not supported") | |
else: | |
raise | |
class AzureTextTranslationService(AzureService, TranslationService): | |
def __init__(self, location: str) -> None: | |
super().__init__(kind="TextTranslation", location=location) | |
self.quota = ( | |
SlidingWindowQuota(2000000, datetime.timedelta(hours=1)) | |
if settings.azure_use_free_tier else | |
None | |
) | |
def run(self, text: str, language: str) -> str: | |
if self.quota is not None: | |
self.quota.use(len(text)) | |
try: | |
return self.invoke( | |
path="/translate", | |
params={ | |
"api-version": "3.0", | |
"to": language, | |
}, | |
json=[{"text": text}], | |
parse=lambda result: cast(str, result[0]["translations"][0]["text"]), | |
) | |
except AzureServiceError as e: | |
if e.code == 400019: | |
raise ServiceInputError(f"language '{language}' is not supported") | |
elif e.code == 400050: | |
raise ServiceInputError("text is too long") | |
else: | |
raise | |
############################################################################### | |
# REGISTRATION CLASSES # | |
############################################################################### | |
class Provider(enum.Enum): | |
create_ocr_service: typing.Callable[[str], OcrService] | |
create_translation_service: typing.Callable[[str], TranslationService] | |
AWS = (AWSRekognitionService, AWSTranslateService) | |
Azure = (AzureComputerVisionService, AzureTextTranslationService) | |
def __new__( | |
cls, | |
create_ocr_service: typing.Callable[[str], OcrService], | |
create_translation_service: typing.Callable[[str], TranslationService], | |
) -> Provider: | |
obj = object.__new__(cls) | |
obj._value_ = enum.auto() | |
obj.create_ocr_service = create_ocr_service | |
obj.create_translation_service = create_translation_service | |
return obj | |
CloudServiceT = typing.TypeVar("CloudServiceT", bound=CloudService) | |
class TimedService(typing.Generic[CloudServiceT]): | |
def __init__(self, is_supported: bool, region: Region, constructor: typing.Callable[[str], CloudServiceT]) -> None: | |
self.is_supported = is_supported | |
self.region = region | |
self.constructor = constructor | |
self.average_time = sys.maxsize | |
self.use_count = 0 | |
self.lock = threading.Lock() | |
self.local = threading.local() | |
@functools.cached_property | |
def instance(self) -> CloudServiceT: | |
result = self.constructor(self.region.name) | |
self.average_time = 0 | |
return result | |
def __enter__(self) -> None: | |
setattr(self.local, "start", time.time()) | |
def __exit__( | |
self, | |
type: typing.Type[BaseException] | None, | |
value: BaseException | None, | |
traceback: types.TracebackType | None, | |
) -> None: | |
if value is None: | |
runtime = time.time() - getattr(self.local, "start") | |
else: | |
runtime = settings.connect_timeout + self.instance.timeout | |
with self.lock: | |
self.average_time = (runtime + self.use_count * self.average_time) / (self.use_count + 1) | |
self.use_count += 1 | |
class Region: | |
index: typing.ClassVar[dict[Provider, dict[str, Region]]] = {} | |
def __init__( | |
self, | |
provider: Provider, | |
name: str, | |
location: tuple[float, float], | |
is_available: bool, | |
supports_ocr: bool, | |
supports_translation: bool, | |
) -> None: | |
self.provider = provider | |
self.name = name | |
self.location = location | |
self.is_available = is_available | |
self.ocr = TimedService(supports_ocr, self, provider.create_ocr_service) | |
self.translation = TimedService(supports_translation, self, provider.create_translation_service) | |
available_services: list[str] = [] | |
if is_available: | |
if supports_ocr: | |
available_services.append("ocr") | |
if supports_translation: | |
available_services.append("translation") | |
self.help = f"{name} ({', '.join(available_services)})" | |
@staticmethod | |
def register(region: Region) -> None: | |
provider = Region.index.setdefault(region.provider, {}) | |
if region.name in provider: | |
raise ValueError(f"region '{region.name}' already exists in provider '{region.provider}'") | |
provider[region.name] = region | |
@staticmethod | |
def parse(spec: str) -> Region: | |
parts = spec.split(":", maxsplit=2) | |
if len(parts) != 2: | |
raise ValueError("provider and region must be specified") | |
provider, region_name = parts | |
return Region.lookup(Provider[provider], region_name) | |
@staticmethod | |
def lookup(provider: Provider, name: str) -> Region: | |
try: | |
return Region.index[provider][name] | |
except KeyError: | |
raise ValueError(f"region '{name}' of provider '{provider}' not registered") | |
def __repr__(self) -> str: | |
return f"<Region provider='{self.provider.name}' name='{self.name}' location={self.location}>" | |
class Resource: | |
all: typing.ClassVar[list[Resource]] = [] | |
last_exc: typing.ClassVar[dict[tuple[str, str, str], tuple[float, BaseException]]] = {} | |
def __init__(self, name: str, provider: Provider, region_name: str = "") -> None: | |
self.name = name | |
self.provider_name = provider.name | |
self.region_name = region_name | |
self.existed = False | |
self.created = False | |
self.deployment_start = time.time() - epoche | |
self.deployment_end: float | None = None | |
def to_record(self) -> dict: | |
return self.__dict__ | |
def __enter__(self) -> Resource: | |
assert self.deployment_end is None, "resource already ended" | |
if last_exc := Resource.last_exc.get((self.name, self.provider_name, self.region_name), None): | |
event, exc = last_exc | |
if self.deployment_start < event + settings.refresh_interval: | |
raise exc | |
return self | |
def __exit__( | |
self, | |
type: typing.Type[BaseException] | None, | |
value: BaseException | None, | |
traceback: types.TracebackType | None, | |
) -> None: | |
self.deployment_end = time.time() - epoche | |
if value is None: | |
Resource.all.append(self) | |
else: | |
Resource.last_exc[(self.name, self.provider_name, self.region_name)] = (self.deployment_end, value) | |
############################################################################### | |
# INITIALIZATION # | |
############################################################################### | |
# region cache | |
print("Loading cache: ", end="", flush=True) | |
try: | |
with open(full_path(".cache"), "r") as _file: | |
cache = json.load(_file) | |
if not isinstance(cache, dict): | |
raise TypeError() | |
except Exception: | |
cache = {} | |
print(f"{len(cache)} entries") | |
# endregion | |
# region AWS | |
print("Signing in to AWS: ", end="", flush=True) | |
try: | |
aws_session = boto3.session.Session() | |
aws_sts = aws_session.client("sts") | |
aws_identity = aws_sts.get_caller_identity() | |
except botocore.exceptions.NoCredentialsError: | |
print() | |
aws_session = boto3.session.Session( | |
**{ | |
aws: input(f"{aws}: ") | |
for aws in ("aws_access_key_id", "aws_secret_access_key", "aws_session_token") | |
} | |
) | |
aws_sts = aws_session.client("sts") | |
aws_identity = aws_sts.get_caller_identity() | |
print(aws_identity["Arn"]) | |
s3_resource = aws_session.resource("s3") | |
s3_client = aws_session.client("s3") | |
s3_client_unsigned = aws_session.client("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)) | |
print("Initializing AWS regions: ", end="", flush=True) | |
register_aws_regions() | |
print(len(Region.index[Provider.AWS])) | |
# endregion | |
# region Azure | |
print("Signing in to Azure: ", end="", flush=True) | |
azure_credential = azure.identity.DefaultAzureCredential(exclude_interactive_browser_credential=False) | |
azure_graph_client = msgraph.core.GraphClient(credential=azure_credential) | |
azure_user = cast(dict, azure_graph_client.get("/me").json()) | |
print(azure_user.get("userPrincipalName", azure_user["id"])) | |
print("Using Azure subscription: ", end="", flush=True) | |
azure_subscription_client = azure.mgmt.subscription.SubscriptionClient(credential=azure_credential) | |
azure_subscriptions = list(azure_subscription_client.subscriptions.list()) | |
azure_subscription = get_azure_subscription() | |
azure_subscription_id = cast(str, azure_subscription.subscription_id) | |
print(f"'{azure_subscription.display_name}' ({azure_subscription_id})") | |
azure_resource_client = azure.mgmt.resource.ResourceManagementClient( | |
credential=azure_credential, | |
subscription_id=azure_subscription_id, | |
) | |
azure_cognitiveservices_management_client = azure.mgmt.cognitiveservices.CognitiveServicesManagementClient( | |
credential=azure_credential, | |
subscription_id=azure_subscription_id, | |
) | |
print("Initializing Azure regions: ", end="", flush=True) | |
register_azure_regions() | |
print(len(Region.index[Provider.Azure])) | |
# endregion | |
############################################################################### | |
# MAIN PROGRAM # | |
############################################################################### | |
# region build argument parser | |
parser = argparse.ArgumentParser( | |
prog="convert", | |
description="Performs OCR on images and translates the recognized text.", | |
epilog=( | |
"Known PROVIDER:\n " + | |
"\n ".join(sorted(provider.name for provider in Provider)) + | |
"\n\n" + | |
"\n\n".join( | |
f"Known {provider.name}:REGION: (available services)\n " + | |
"\n ".join(region.help for region in sorted(Region.index[provider].values(), key=lambda r: r.name)) | |
for provider in Provider | |
) | |
), | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
) | |
parser.add_argument( | |
"input", | |
type=Store.from_uri, | |
help="Input store containing the images.", | |
) | |
parser.add_argument( | |
"output", | |
type=Store.from_uri, | |
help="Output store receiving the translated text files.", | |
) | |
parser.add_argument( | |
"--ocr", | |
required=False, | |
metavar="PROVIDER:REGION", | |
type=Region.parse, | |
nargs='+', | |
default=[ | |
region for region in Region.index[Provider.Azure].values() | |
if region.is_available and region.ocr.is_supported | |
], | |
help="Specifies the regions that should be used for OCR, defaults to all available Azure regions." | |
) | |
parser.add_argument( | |
"--translation", | |
required=False, | |
metavar="PROVIDER:REGION", | |
type=Region.parse, | |
nargs='+', | |
default=[ | |
region for region in Region.index[Provider.AWS].values() | |
if region.is_available and region.translation.is_supported | |
], | |
help="Specifies the regions that should be used for translation, defaults to all available AWS regions." | |
) | |
register_settings() | |
# endregion | |
# region read the arguments | |
_args = parser.parse_args() | |
input_store: Store = _args.input | |
output_store: Store = _args.output | |
settings = Settings(**{ | |
setting: getattr(_args, setting) | |
for setting in Settings._fields | |
}) | |
me = locate_me() | |
ocr_regions: list[Region] = _args.ocr | |
translation_regions: list[Region] = _args.translation | |
if not settings.preserve_order: | |
ocr_regions = sorted( | |
ocr_regions, | |
key=lambda region: distance( | |
input_store.location if ( | |
region.provider is Provider.Azure | |
or isinstance(input_store, S3BucketStore) and region is input_store.region | |
) else me, | |
region.location | |
) | |
) | |
translation_regions = sorted( | |
translation_regions, | |
key=lambda region: distance(me, region.location) | |
) | |
ocr_services = [region.ocr for region in ocr_regions] | |
translation_services = [region.translation for region in translation_regions] | |
# endregion | |
class Task: | |
all: typing.ClassVar[list[Task]] = [] | |
lock: typing.ClassVar[threading.Lock] = threading.Lock() | |
iterator: typing.ClassVar[typing.Iterator[Document]] = iter(input_store.list_files()) | |
def __init__(self, document: Document) -> None: | |
self.document = document | |
self.state = TaskState.IDLE | |
self.error = "" | |
self.ocr = TaskStepStatistics() | |
self.translation = TaskStepStatistics() | |
self.write = TaskStepStatistics() | |
Task.all.append(self) | |
def to_record(self) -> dict: | |
return { | |
"name": self.document.name, | |
"error": self.error, | |
"ocr_start": self.ocr.start, | |
"ocr_end": self.ocr.end, | |
"ocr_provider_name": self.ocr.provider_name, | |
"ocr_region_name": self.ocr.region_name, | |
"translation_start": self.translation.start, | |
"translation_end": self.translation.end, | |
"translation_provider_name": self.translation.provider_name, | |
"translation_region_name": self.translation.region_name, | |
"write_start": self.write.start, | |
"write_end": self.write.end, | |
} | |
@staticmethod | |
def next() -> Task | None: | |
with Task.lock: | |
try: | |
return Task(next(Task.iterator)) | |
except StopIteration: | |
return None | |
def simple_step(self, state: TaskState, run: typing.Callable[[Document], T]) -> tuple[T, TaskStepStatistics]: | |
prev_state = self.state | |
try: | |
attempt = 1 | |
while True: | |
self.state = state | |
try: | |
start = time.time() - epoche | |
result = run(self.document) | |
end = time.time() - epoche | |
return (result, TaskStepStatistics(start=start, end=end)) | |
except Exception as e: | |
Terminal.exception(e) | |
attempt += 1 | |
if attempt > settings.max_attempts: | |
raise | |
self.state = TaskState.IDLE | |
time.sleep(settings.retry_interval) | |
finally: | |
self.state = prev_state | |
def timed_step( | |
self, | |
state: TaskState, | |
services: typing.Iterable[TimedService[CloudServiceT]], | |
run: typing.Callable[[CloudServiceT, Document], T], | |
) -> tuple[T, TaskStepStatistics]: | |
prev_state = self.state | |
try: | |
attempt = 1 | |
while True: | |
self.state = state | |
exceeded_quota = False | |
last_exc: Exception | None = None | |
for service in sorted(services, key=lambda service: service.average_time): | |
try: | |
start = time.time() - epoche | |
result = run(service.instance, self.document) | |
end = time.time() - epoche | |
except ServiceInputError: | |
raise | |
except QuotaError: | |
exceeded_quota = True | |
except Exception as e: | |
Terminal.exception(e) | |
last_exc = e | |
else: | |
return ( | |
result, | |
TaskStepStatistics( | |
start=start, | |
end=end, | |
provider_name=service.region.provider.name, | |
region_name=service.region.name, | |
) | |
) | |
if not exceeded_quota and last_exc is not None: | |
attempt += 1 | |
if attempt > settings.max_attempts: | |
raise last_exc | |
self.state = TaskState.THROTTLED if exceeded_quota else TaskState.IDLE | |
time.sleep(settings.retry_interval) | |
finally: | |
self.state = prev_state | |
class TaskState(enum.IntEnum): | |
IDLE = 0 | |
THROTTLED = 1 | |
OCR = 2 | |
TRANSLATE = 3 | |
WRITE = 4 | |
FAILED = 5 | |
DONE = 6 | |
class TaskStepStatistics(typing.NamedTuple): | |
start: float = 0 | |
end: float = 0 | |
provider_name: str = "" | |
region_name: str = "" | |
class Terminal: | |
errors: typing.ClassVar[set[str]] = set() | |
error_count: typing.ClassVar[int] = 0 | |
lock: typing.ClassVar[threading.Lock] = threading.Lock() | |
def __init__(self) -> None: | |
raise NotImplementedError() | |
@staticmethod | |
def max_width() -> int: | |
try: | |
return os.get_terminal_size().columns - 1 | |
except OSError: | |
return 80 | |
@staticmethod | |
def status(text: str) -> None: | |
width = Terminal.max_width() | |
line = f"{datetime.timedelta(seconds=time.time() - epoche)}: {text}"[0:width] | |
with Terminal.lock: | |
print("\r", " " * width, "\r", line, sep="", end="", flush=True) | |
@staticmethod | |
def exception(exception: Exception) -> None: | |
error = f"{type(exception).__name__}: {exception}" | |
width = Terminal.max_width() | |
with Terminal.lock: | |
Terminal.error_count += 1 | |
if not error in Terminal.errors: | |
Terminal.errors.add(error) | |
print("\r", " " * width, "\r", error, sep="", flush=True) | |
def worker() -> None: | |
while task := Task.next(): | |
try: | |
recognized, task.ocr = task.timed_step( | |
state=TaskState.OCR, | |
services=ocr_services, | |
run=lambda service, doc: service.run(doc), | |
) | |
translated, task.translation = task.timed_step( | |
state=TaskState.TRANSLATE, | |
services=translation_services, | |
run=lambda service, _: service.run(recognized, settings.language), | |
) | |
_, task.write = task.simple_step( | |
state=TaskState.WRITE, | |
run=lambda doc: output_store.put_file(os.path.splitext(doc.name)[0] + ".txt", translated.encode()), | |
) | |
except Exception as e: | |
task.state = TaskState.FAILED | |
task.error = str(e) | |
else: | |
task.state = TaskState.DONE | |
def main() -> None: | |
threads = [threading.Thread(target=worker, name=f"worker{index}", daemon=True) for index in range(settings.threads)] | |
for thread in threads: | |
thread.start() | |
while threads: | |
states = [0] * len(TaskState) | |
for task in Task.all: | |
states[task.state.value] += 1 | |
summary = " ".join(map(lambda state: f"{state.name}={states[state.value]}", TaskState)) | |
Terminal.status(f"THREADS={len(threads)} ERRORS={Terminal.error_count} {summary}") | |
time.sleep(settings.refresh_interval) | |
threads = [thread for thread in threads if thread.is_alive()] | |
Terminal.status(f"Processed {len(Task.all)} images, {sum(entry.state is TaskState.FAILED for entry in Task.all)} failed.") | |
if settings.statistics: | |
pandas.DataFrame.from_records(resource.to_record() for resource in Resource.all).to_csv(full_path("resources.csv"), index=False) | |
pandas.DataFrame.from_records(task.to_record() for task in Task.all).to_csv(full_path("documents.csv"), index=False) | |
epoche = time.time() | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment