Skip to content

Instantly share code, notes, and snippets.

@meitinger
Last active February 1, 2023 11:09
Show Gist options
  • Save meitinger/3fb863005df2750500c168497fb171ca to your computer and use it in GitHub Desktop.
Save meitinger/3fb863005df2750500c168497fb171ca to your computer and use it in GitHub Desktop.
Helper utility to perform OCR and translation on images in the cloud.
#!/usr/bin/python3
#
# Copyright (C) 2023, Manuel Meitinger
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Prerequisites:
# 1. AWS account
# 2. Azure account
# 3. accepted Microsoft's Responsible AI terms and conditions
from __future__ import annotations
import abc
import argparse
import collections
import datetime
import enum
import functools
import hashlib
import inspect
import io
import ipaddress
import json
import math
import os
import re
import socket
import string
import subprocess
import sys
import threading
import time
import types
import typing
import urllib.parse
import urllib.request
if typing.TYPE_CHECKING:
import mypy_boto3_rekognition.type_defs
import mypy_boto3_s3.service_resource
print("Python version:", sys.version)
try:
with open(__file__, "r") as _file:
print("Script version:", hashlib.sha1("\n".join(_file.readlines()).encode("utf8")).hexdigest())
except OSError:
pass
class Settings(typing.NamedTuple):
language: str = "en"
"""The target language's short code."""
threads: int = 10
"""The number of worker threads to spawn."""
statistics: bool = False
"""If set, statistics about deployment and execution will be written to file."""
preserve_order: bool = False
"""If set, the order of the regions will not be optimized by geolocation."""
presign_expiration: int = 300
"""Time in seconds after which a pre-signed URI should expire."""
connect_timeout: float = 1.5
"""Time in seconds the script will wait to connect to a cloud service."""
ocr_timeout: float = 4.0
"""Time in seconds the script will wait for the OCR to complete."""
translation_timeout: float = 2.0
"""Time in seconds the script will wait for the translation to complete."""
refresh_interval: float = 0.75
"""Time in seconds after which the progress report should be refreshed."""
retry_interval: float = 0.5
"""Time in seconds after which all regions or resources should be retried."""
max_attempts: int = 5
"""Number of attempts after which to give up on an operation."""
azure_resource_group: str = "daecc-project"
"""The name of the Azure resource group where to create cognitive accounts in."""
azure_cognitive_account_name: str = "{resource_group}-{kind!l}-{location}"
"""The template of cognitive account names."""
azure_use_free_tier: bool = False
"""If set, use free F0 tier instead of pay-as-you-go S1."""
def dot() -> None:
print(".", end="", flush=True)
# region dependencies
print("Loading libraries", end="", flush=True)
try:
import azure.core.exceptions
dot()
import azure.identity
dot()
import azure.mgmt.cognitiveservices
dot()
import azure.mgmt.cognitiveservices.models
dot()
import azure.mgmt.resource
dot()
import azure.mgmt.resource.resources.models
dot()
import azure.mgmt.storage
dot()
import azure.mgmt.storage.models
dot()
import azure.mgmt.subscription
dot()
import azure.mgmt.subscription.models
dot()
import azure.storage.blob
dot()
import boto3
dot()
import botocore
dot()
import botocore.client
dot()
import botocore.config
dot()
import botocore.exceptions
dot()
import msgraph.core
dot()
import pandas
dot()
import requests
dot()
except ModuleNotFoundError as e:
print(f": {e.name} missing, installing...")
subprocess.check_call([
sys.executable,
"-m",
"pip",
"install",
"azure-core",
"azure-identity",
"azure-mgmt-cognitiveservices",
"azure-mgmt-resource",
"azure-mgmt-storage",
"azure-mgmt-subscription",
"azure-storage-blob",
"boto3",
"botocore",
"msgraph-core",
"pandas",
"requests",
])
os.execl(sys.executable, sys.executable, *sys.argv)
else:
print(": ready")
# endregion
###############################################################################
# HELPERS #
###############################################################################
T = typing.TypeVar("T")
U = typing.TypeVar("U")
def full_path(file_name: str) -> str:
return os.path.join(os.path.dirname(__file__), file_name)
@typing.overload
def cached(
key: str,
type: typing.Type[T],
create: typing.Callable[[], T],
) -> T:
...
@typing.overload
def cached(
key: str,
type: typing.Type[U],
create: typing.Callable[[], T],
*,
from_json: typing.Callable[[U], T],
to_json: typing.Callable[[T], U],
) -> T:
...
def cached(key: str, type: typing.Type, create: typing.Callable[[], T], **kwargs) -> T:
from_json = kwargs.get("from_json", lambda v: v)
to_json = kwargs.get("to_json", lambda v: v)
try:
raw_value = cache[key]
if not isinstance(raw_value, type):
raise TypeError()
value = from_json(raw_value)
except Exception:
value = create()
cache[key] = to_json(value)
with open(os.path.join(os.path.dirname(__file__), ".cache"), "w") as file:
json.dump(cache, file)
return value
def locate_ip(ip: str) -> tuple[float, float]:
def locate() -> tuple[float, float]:
response = requests.get(f"https://ipapi.co/{ip}/json/").json()
if not isinstance(response, dict):
raise TypeError(f"invalid response type '{type(response).__name__}'")
if response.get("error", None):
raise ValueError(f"cannot locate IP '{ip}': {response.get('reason', None)}")
latitude = response.get("latitude", None)
longitude = response.get("longitude", None)
if latitude is None and longitude is None:
raise ValueError(f"location of IP '{ip}' is unknown")
if not isinstance(latitude, (int, float)):
raise TypeError(f"invalid latitude type '{type(latitude).__name__}'")
if not isinstance(longitude, (int, float)):
raise TypeError(f"invalid longitude type '{type(longitude).__name__}'")
return (float(latitude), float(longitude))
def parse(value: list) -> tuple[float, float]:
latitude, longitude = value
if not (isinstance(latitude, float) and isinstance(longitude, float)):
raise TypeError()
return (latitude, longitude)
return cached(
f"location:{ip}",
list,
locate,
from_json=parse,
to_json=lambda v: list(v),
)
def resolve_host(hostname: str) -> str:
def resolve() -> str:
try:
# resolve to IPv4 address
addresses = socket.getaddrinfo(hostname, 0, family=socket.AF_INET)
except socket.gaierror as e:
# resolve to IPv6 address
try:
addresses = socket.getaddrinfo(hostname, 0, family=socket.AF_INET6)
except socket.gaierror:
raise e
return addresses[0][4][0]
try:
return str(ipaddress.ip_address(hostname))
except ValueError:
return cached(f"ip:{hostname}", str, resolve)
def locate_host(hostname: str) -> tuple[float, float]:
try:
return locate_ip(resolve_host(hostname))
except socket.gaierror as e:
if e.errno in (socket.EAI_NONAME, socket.EAI_AGAIN):
raise ValueError(f"host '{hostname}' doesn't exist")
raise
def locate_uri(uri: str) -> tuple[float, float]:
parts = urllib.parse.urlsplit(uri)
if not parts.hostname:
raise ValueError(f"URI '{uri}' does not contain a host name")
return locate_host(parts.hostname)
def locate_me() -> tuple[float, float]:
return locate_ip(cached(
"me",
str,
lambda: str(ipaddress.ip_address(requests.get("https://api.ipify.org").text)),
from_json=lambda v: str(ipaddress.ip_address(v)),
to_json=lambda v: v
))
def exists_host(hostname: str) -> bool:
def exists() -> bool:
try:
return bool(resolve_host(hostname))
except socket.gaierror as e:
if e.errno in (socket.EAI_NONAME, socket.EAI_AGAIN):
return False
raise
return cached(f"exists:{hostname}", bool, exists)
def distance(location1: tuple[float, float], location2: tuple[float, float]) -> float:
deg_to_rad = math.pi/180.0
lat1, long1 = location1
lat2, long2 = location2
phi1 = (90.0 - lat1) * deg_to_rad
phi2 = (90.0 - lat2) * deg_to_rad
theta1 = long1 * deg_to_rad
theta2 = long2 * deg_to_rad
return math.acos(
math.sin(phi1) * math.sin(phi2) * math.cos(theta1 - theta2) +
math.cos(phi1) * math.cos(phi2)
)
def cast(type: typing.Type[T], value: typing.Any) -> T:
if not isinstance(value, type):
raise TypeError(f"type '{type.__name__}' expected")
return value
def ensure(value: T | None) -> T:
if value is None:
raise ValueError("value must not be None")
return value
def register_settings() -> None:
try:
settings_lines, _ = inspect.getsourcelines(Settings)
except Exception:
settings_lines = None
for index, setting in enumerate(Settings._fields):
if Settings._field_defaults[setting] is False:
parser.add_argument(
f"--{setting.replace('_', '-')}",
action='store_true',
help=settings_lines[(1 + index) * 2].strip().strip('"') if settings_lines else None,
)
else:
parser.add_argument(
f"--{setting.replace('_', '-')}",
required=False,
type=eval(Settings.__annotations__[setting].__forward_arg__),
metavar=Settings.__annotations__[setting].__forward_arg__.upper(),
default=Settings._field_defaults[setting],
help=settings_lines[(1 + index) * 2].strip().strip('".') + ", defaults to '%(default)s'." if settings_lines else None,
)
class Formatter(string.Formatter):
def convert_field(self, value: typing.Any, conversion: str) -> typing.Any:
if conversion == "u":
return str(value).upper()
elif conversion == "l":
return str(value).lower()
elif conversion == "c":
return str(value).capitalize()
elif conversion == "t":
return str(value).title()
else:
return super().convert_field(value, conversion)
###############################################################################
# ABSTRACT BASE CLASSES #
###############################################################################
class Store(abc.ABC):
def __init__(self, location: tuple[float, float]) -> None:
self.location = location
@staticmethod
def from_uri(uri: str) -> Store:
p = urllib.parse.urlsplit(uri)
if p.fragment:
raise ValueError("URI must not contain fragments")
if p.scheme == "s3":
if p.port is not None:
raise ValueError("S3 URI must not contain a port")
if p.query:
raise ValueError("S3 URI must not contain a query")
return S3BucketStore(s3_resource.Bucket(p.netloc), filter=p.path)
elif p.scheme == "https":
s3_suffix = ".s3.amazonaws.com"
netloc = p.netloc.lower()
if netloc.endswith(s3_suffix):
if p.port is not None and p.port != 443:
raise ValueError("bucket URI must have port 443")
if p.query:
raise ValueError("bucket URI must not contain a query")
return S3BucketStore(s3_resource.Bucket(p.netloc[: -len(s3_suffix)]), filter=p.path)
elif netloc.endswith(".blob.core.windows.net"):
return AzureContainerStore(
azure.storage.blob.ContainerClient.from_container_url(uri, credential=azure_credential)
)
else:
return WebDocumentStore(uri)
elif p.scheme == "file":
path = urllib.request.url2pathname(p.path)
if os.path.isdir(path):
return LocalDirectoryStore(path)
elif os.path.isfile(path):
return LocalDirectoryStore(os.path.dirname(path), file_name=os.path.basename(path))
else:
raise ValueError(f"path '{path}' not found")
else:
raise ValueError(f"URI scheme {p.scheme} is not supported")
@abc.abstractmethod
def list_files(self) -> typing.Iterable[Document]:
raise NotImplementedError()
@abc.abstractmethod
def put_file(self, file_name: str, content: bytes) -> None:
raise NotImplementedError()
class CloudStore(Store):
def __init__(self, region: Region, is_public: bool) -> None:
super().__init__(region.location)
self.region = region
self.is_public = is_public
StoreT = typing.TypeVar("StoreT", bound=Store)
class Document(abc.ABC, typing.Generic[StoreT]):
def __init__(self, store: StoreT, name: str) -> None:
self.store = store
self.name = name
@abc.abstractproperty
def url(self) -> str:
raise NotImplementedError()
@abc.abstractproperty
def content(self) -> bytes:
raise NotImplementedError()
class PresignedUrlDocument(Document[StoreT]):
def __init__(self, store: StoreT, name: str) -> None:
super().__init__(store, name)
self.presigned_url: str | None = None
self.presign_time = 0.0
@abc.abstractmethod
def presign_url(self) -> str:
raise NotImplementedError()
@property
def url(self) -> str:
now = time.time()
if self.presigned_url is None or now - self.presign_time > settings.presign_expiration * 0.8:
self.presigned_url = self.presign_url()
self.presign_time = now
return self.presigned_url
class ServiceInputError(Exception):
pass
class CloudService(abc.ABC):
def __init__(self, region: Region) -> None:
self.region = region
@abc.abstractproperty
def timeout(self) -> float:
raise NotImplementedError()
@abc.abstractproperty
def timed(self) -> TimedService:
raise NotImplementedError()
class OcrImageError(ServiceInputError):
pass
class OcrService(CloudService):
@abc.abstractmethod
def run(self, document: Document) -> str:
raise NotImplementedError()
@property
def timeout(self) -> float:
return settings.ocr_timeout
@property
def timed(self) -> TimedService:
return self.region.ocr
class TranslationTextError(ServiceInputError):
pass
class TranslationService(CloudService):
@abc.abstractmethod
def run(self, text: str, language: str) -> str:
raise NotImplementedError
@property
def timeout(self) -> float:
return settings.translation_timeout
@property
def timed(self) -> TimedService:
return self.region.translation
###############################################################################
# THROTTLING #
###############################################################################
class QuotaError(Exception):
def __init__(self, maximum: int, exceeded_by: int) -> None:
super().__init__(f"maximum of {maximum} exceeded by {exceeded_by}")
self.maximum = maximum
self.exceeded_by = exceeded_by
class ConcurrencyQuota:
def __init__(self, maximum: int) -> None:
if maximum < 0:
raise ValueError("concurrency quota must be non-negative")
self.lock = threading.Lock()
self.count = 0
self.maximum = maximum
def __enter__(self) -> None:
with self.lock:
if self.count == self.maximum:
raise QuotaError(self.maximum, exceeded_by=1)
self.count += 1
def __exit__(
self,
type: typing.Type[BaseException] | None,
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
with self.lock:
self.count -= 1
class SlidingWindowQuota:
def __init__(self, maximum: int, interval: datetime.timedelta) -> None:
if maximum < 0:
raise ValueError("sliding window quota must be non-negative")
if interval <= datetime.timedelta():
raise ValueError("sliding window interval must be larger than zero")
self.lock = threading.Lock()
self.history: collections.deque[tuple[datetime.datetime, int]] = collections.deque()
self.count = 0
self.maximum = maximum
self.interval = interval
def use(self, count: int = 1) -> None:
with self.lock:
now = datetime.datetime.utcnow()
while self.history and now - self.history[0][0] > self.interval:
self.count -= self.history.popleft()[1]
exceeded_by = (self.count + count) - self.maximum
if exceeded_by > 0:
raise QuotaError(self.maximum, exceeded_by)
self.history.append((now, count))
self.count += count
###############################################################################
# SIMPLE FILE AND WEB DOCUMENTS #
###############################################################################
class LocalDirectoryStore(Store):
def __init__(
self,
path: str,
*,
file_name: str | None = None,
location: tuple[float, float] | None = None,
) -> None:
super().__init__(location or locate_me())
self.path = path
self.file_name = file_name
def list_files(self) -> typing.Iterable[Document]:
if self.file_name is None:
for file_name in os.listdir(self.path):
path = os.path.join(self.path, file_name)
if os.path.isfile(path):
yield LocalFileDocument(self, path)
else:
yield LocalFileDocument(self, os.path.join(self.path, self.file_name))
def put_file(self, file_name: str, content: bytes) -> None:
file_path = os.path.join(self.path, file_name)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as file:
file.write(content)
def __repr__(self) -> str:
return f"<LocalDirectoryStore path='{self.path}' file_name='{self.file_name}' location={self.location}>"
class LocalFileDocument(Document[LocalDirectoryStore]):
def __init__(self, store: LocalDirectoryStore, path: str) -> None:
super().__init__(store, os.path.basename(path))
self.path = path
@property
def url(self) -> str:
raise NotImplementedError()
@functools.cached_property
def content(self) -> bytes:
with open(self.path, "rb") as file:
return file.read()
class WebDocumentStore(Store):
def __init__(self, uri: str, *, location: tuple[float, float] | None = None) -> None:
super().__init__(location or locate_uri(uri))
self.uri = uri
def list_files(self) -> typing.Iterable[Document]:
yield WebDocument(self, self.uri)
def put_file(self, file_name: str, content: bytes) -> None:
if file_name not in ("", self.uri):
raise ValueError(f"file name '{file_name}' not supported")
requests.put(self.uri, content).raise_for_status()
def __repr__(self) -> str:
return f"<WebDocumentStore uri='{self.uri}' location={self.location}>"
class WebDocument(Document[WebDocumentStore]):
def __init__(self, store: WebDocumentStore, uri: str) -> None:
super().__init__(store, uri)
self.uri = uri
@property
def url(self) -> str:
return self.uri
@functools.cached_property
def content(self) -> bytes:
return requests.get(self.uri).content
###############################################################################
# AMAZON WEB SERVICES #
###############################################################################
def register_aws_regions() -> None:
for region_desc in aws_session.client("ec2", "us-east-1").describe_regions()["Regions"]:
try:
region = Region(
provider=Provider.AWS,
name=region_desc["RegionName"],
location=locate_host(region_desc["Endpoint"]),
is_available=region_desc["OptInStatus"] in ("opted-in", "opt-in-not-required"),
supports_ocr=exists_host(f"rekognition.{region_desc['RegionName']}.amazonaws.com"),
supports_translation=exists_host(f"translate.{region_desc['RegionName']}.amazonaws.com"),
)
except KeyError:
raise ValueError("incomplete AWS region description")
Region.register(region)
class S3BucketStore(CloudStore):
@staticmethod
def is_public_bucket(bucket_name: str) -> bool:
def has_public_acl() -> bool:
# check if AllUsers have at least READ access
grants = s3_client.get_bucket_acl(Bucket=bucket_name)["Grants"]
return any(
True
for grant in grants
if grant.get("Grantee", {}).get("URI", None)
== "http://acs.amazonaws.com/groups/global/AllUsers"
and grant.get("Permission", None) in ("FULL_CONTROL", "READ")
)
def has_public_policy() -> bool:
# check if a public policy exists
try:
return s3_client.get_bucket_policy_status(Bucket=bucket_name)["PolicyStatus"]["IsPublic"]
except botocore.exceptions.ClientError as e:
if e.response.get("Error", {}).get("Code", None) != "NoSuchBucketPolicy":
raise
return False
try:
config = s3_client.get_public_access_block(Bucket=bucket_name)["PublicAccessBlockConfiguration"]
except botocore.exceptions.ClientError as e:
if e.response.get("Error", {}).get("Code", None) != "NoSuchPublicAccessBlockConfiguration":
raise
config = {}
return (
(
not config.get("BlockPublicPolicy", False)
and not config.get("RestrictPublicBuckets", False)
and has_public_acl()
)
or (
not config.get("BlockPublicAcls", False)
and not config.get("IgnorePublicAcls", False)
and has_public_policy()
)
)
def __init__(self, bucket: mypy_boto3_s3.service_resource.Bucket, filter: str) -> None:
super().__init__(
# return None means "us-east-1" (https://github.com/aws/aws-cli/issues/3864)
region=Region.lookup(
Provider.AWS,
s3_client.get_bucket_location(Bucket=bucket.name)["LocationConstraint"] or "us-east-1",
),
is_public=S3BucketStore.is_public_bucket(bucket.name),
)
self.bucket = bucket
self.filter = filter.strip("/")
if self.filter:
self.filter += "/"
def list_files(self) -> typing.Iterable[Document]:
for obj in self.bucket.objects.filter(Prefix=self.filter):
if obj.key[-1] != "/":
yield S3ObjectDocument(self, obj)
def put_file(self, file_name: str, content: bytes) -> None:
self.bucket.put_object(
Key=f"{self.filter}{file_name}",
Body=content,
)
def __repr__(self) -> str:
return f"<S3BucketStore bucket='{self.bucket.name}' filter='{self.filter}' region='{self.region.name}'>"
class S3ObjectDocument(PresignedUrlDocument[S3BucketStore]):
def __init__(self, store: S3BucketStore, obj: mypy_boto3_s3.service_resource.ObjectSummary) -> None:
super().__init__(store, obj.key[len(store.filter):])
self.obj = obj
def presign_url(self) -> str:
params = {
"Bucket": self.obj.bucket_name,
"Key": self.obj.key,
}
if self.store.is_public:
return s3_client_unsigned.generate_presigned_url(
ClientMethod="get_object",
Params=params,
ExpiresIn=0,
)
else:
return s3_client.generate_presigned_url(
ClientMethod="get_object",
Params=params,
ExpiresIn=settings.presign_expiration,
)
@functools.cached_property
def content(self) -> bytes:
buffer = io.BytesIO()
s3_client.download_fileobj(
Bucket=self.obj.bucket_name,
Key=self.obj.key,
Fileobj=buffer,
)
return buffer.getvalue()
class AWSService(CloudService):
def __init__(self, service: botocore.client.BaseClient) -> None:
super().__init__(Region.lookup(Provider.AWS, service.meta.region_name))
self.meta = service.meta
class AWSRekognitionService(AWSService, OcrService):
def __init__(self, region_name: str) -> None:
self.service = aws_session.client(
service_name="rekognition",
region_name=region_name,
config=botocore.config.Config(
retries={"total_max_attempts": 1},
connect_timeout=settings.connect_timeout,
read_timeout=settings.ocr_timeout,
)
)
super().__init__(self.service)
is_fast = region_name in ("us-east-1", "us-west-2", "eu-west-1")
self.quota = SlidingWindowQuota(50 if is_fast else 5, datetime.timedelta(seconds=1))
def run(self, document: Document) -> str:
# we either use a bucket reference for S3 objects or upload bytes
image: mypy_boto3_rekognition.type_defs.ImageTypeDef
if (
isinstance(document, S3ObjectDocument)
and document.store.region is self.region
):
image = {
"S3Object": {
"Bucket": document.obj.bucket_name,
"Name": document.obj.key,
}
}
else:
try:
image = {"Bytes": document.content}
except Exception as e:
raise ServiceInputError(f"cannot download image ({e})")
self.quota.use()
try:
with self.timed:
result = self.service.detect_text(Image=image)
except self.service.exceptions.ImageTooLargeException:
raise OcrImageError("image is too large")
except self.service.exceptions.InvalidImageFormatException:
raise OcrImageError("image format is invalid")
return "\n".join(
line.get("DetectedText", "")
for line in result["TextDetections"]
if line.get("Type", None) == "LINE"
)
class AWSTranslateService(AWSService, TranslationService):
def __init__(self, region_name: str) -> None:
self.service = aws_session.client(
service_name="translate",
region_name=region_name,
config=botocore.config.Config(
retries={"total_max_attempts": 1},
connect_timeout=settings.connect_timeout,
read_timeout=settings.translation_timeout,
)
)
super().__init__(self.service)
def run(self, text: str, language: str) -> str:
try:
with self.timed:
result = self.service.translate_text(
Text=text,
SourceLanguageCode="auto",
TargetLanguageCode=language,
)
except self.service.exceptions.DetectedLanguageLowConfidenceException:
raise TranslationTextError("failed to detect language")
except self.service.exceptions.TextSizeLimitExceededException:
raise TranslationTextError("text is too large")
except self.service.exceptions.UnsupportedLanguagePairException:
raise TranslationTextError(f"cannot translate to {language} from detected language")
return result["TranslatedText"]
###############################################################################
# AZURE #
###############################################################################
def get_azure_subscription() -> azure.mgmt.subscription.models.Subscription:
subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", None)
if subscription_id is not None:
return azure_subscription_client.subscriptions.get(subscription_id)
if not azure_subscriptions:
raise Exception("no Azure subscription found")
if len(azure_subscriptions) == 1:
return azure_subscriptions[0]
else:
print()
for i, subscription in enumerate(azure_subscriptions):
print(f"{i+1}. {subscription.subscription_id}: {subscription.display_name} [{subscription.state}]")
while True:
try:
i = int(input("> "))
except ValueError:
pass
else:
if 1 <= i and i <= len(azure_subscriptions):
return azure_subscriptions[i - 1]
def register_azure_regions() -> None:
for location in azure_subscription_client.subscriptions.list_locations(azure_subscription_id):
Region.register(Region(
provider=Provider.Azure,
name=cast(str, location.name),
location=(float(cast(str, location.latitude)), float(cast(str, location.longitude))),
is_available=True,
supports_ocr=exists_host(f"{location.name}.api.cognitive.microsoft.com"),
supports_translation=exists_host(f"{location.name}.api.cognitive.microsoft.com"),
))
class AzureContainerStore(CloudStore):
@staticmethod
def get_storage_account(account_name: str) -> azure.mgmt.storage.models.StorageAccount:
# apparently there is no way to get the storage account just by name, so we need to enumerate
for subscription in azure_subscriptions:
storage_client = azure.mgmt.storage.StorageManagementClient(
credential=azure_credential,
subscription_id=cast(str, subscription.subscription_id),
)
for entry in storage_client.storage_accounts.list():
storage_account = typing.cast(azure.mgmt.storage.models.StorageAccount, entry)
if cast(str, storage_account.name).lower() == account_name.lower():
return storage_account
raise Exception(f"storage account '{account_name}' not found")
@staticmethod
def get_access_key(storage_account_id: str) -> str:
# we need to get the resource group name from the id
if not (
m := re.match(
r"^/subscriptions/(?P<s>[^/]+)/resourceGroups/(?P<rg>[^/]+)/providers/Microsoft.Storage/storageAccounts/(?P<a>[^/]+)$",
storage_account_id,
)
):
raise Exception(f"invalid storage account id '{storage_account_id}'")
storage_client = azure.mgmt.storage.StorageManagementClient(
credential=azure_credential,
subscription_id=m.group("s"),
)
key_list = typing.cast(
azure.mgmt.storage.models.StorageAccountListKeysResult,
storage_client.storage_accounts.list_keys(
resource_group_name=m.group("rg"),
account_name=m.group("a"),
),
)
keys = cast(list, key_list.keys)
key = typing.cast(azure.mgmt.storage.models.StorageAccountKey, keys[0])
return cast(str, key.value)
def __init__(self, container: azure.storage.blob.ContainerClient) -> None:
self.container = container
self.storage_account = AzureContainerStore.get_storage_account(cast(str, self.container.account_name))
self.access_key: str | None = None
try:
policy = self.container.get_container_access_policy()
except azure.core.exceptions.ResourceNotFoundError:
self.access_key = AzureContainerStore.get_access_key(cast(str, self.storage_account.id))
self.container = typing.cast(
azure.storage.blob.ContainerClient,
azure.storage.blob.ContainerClient.from_container_url(self.container.url, credential=self.access_key),
)
policy = self.container.get_container_access_policy()
endpoints = typing.cast(azure.mgmt.storage.models.Endpoints, self.storage_account.primary_endpoints)
self.blob_service_client = azure.storage.blob.BlobServiceClient(cast(str, endpoints.blob), credential=azure_credential)
super().__init__(
region=Region.lookup(Provider.Azure, self.storage_account.location),
is_public=policy["public_access"] is not None,
)
def list_files(self) -> typing.Iterable[Document]:
for blob in self.container.list_blobs():
yield AzureBlobDocument(self, blob)
def put_file(self, file_name: str, content: bytes) -> None:
self.container.get_blob_client(file_name).upload_blob(content, overwrite=True)
def __repr__(self) -> str:
return f"<AzureContainerStore storage_account='{self.container.account_name}' container='{self.container.container_name}' region='{self.region.name}'>"
class AzureBlobDocument(PresignedUrlDocument[AzureContainerStore]):
def __init__(self, store: AzureContainerStore, blob: azure.storage.blob.BlobProperties) -> None:
super().__init__(store, cast(str, blob.name))
self.blob = blob
def presign_url(self) -> str:
url = self.store.container.get_blob_client(blob=self.blob).url
if self.store.is_public or urllib.parse.urlsplit(url).query:
# the store is public or we already have a SAS, simply return the url
return url
else:
# we need a SAS
start = datetime.datetime.utcnow()
expiry = start + datetime.timedelta(seconds=settings.presign_expiration)
user_delegation_key = (
self.store.blob_service_client.get_user_delegation_key(
key_start_time=start,
key_expiry_time=expiry,
)
if self.store.access_key is None else
None
)
# generate and append the SAS
sas = azure.storage.blob.generate_blob_sas(
account_name=cast(str, self.store.container.account_name),
container_name=self.store.container.container_name,
blob_name=cast(str, self.blob.name),
permission=azure.storage.blob.BlobSasPermissions(read=True),
account_key=self.store.access_key,
user_delegation_key=user_delegation_key,
start=start,
expiry=expiry,
)
return f"{url}?{sas}"
@functools.cached_property
def content(self) -> bytes:
return self.store.container.download_blob(blob=self.blob).readall()
class AzureServiceError(Exception):
def __init__(self, code: str | int, message: str) -> None:
super().__init__(message)
self.code = code
self.message = message
@staticmethod
def from_json(value: typing.Any) -> AzureServiceError:
if not isinstance(value, dict):
raise ValueError("value must be an object")
error = value.get("error", dict())
if not isinstance(error, dict):
raise ValueError("inner error must be an object")
inner_error = error.get("innererror", dict())
if not isinstance(inner_error, dict):
raise ValueError("inner error must be an object")
code = inner_error.get("code", error.get("code", 0))
if not isinstance(code, (str, int)):
raise ValueError("error code must be a string or number")
message = inner_error.get("message", error.get("message", "An unknown error occurred."))
if not isinstance(message, str):
raise ValueError("error message must be a string")
return AzureServiceError(code, message)
class AzureService(CloudService):
def __init__(self, *, kind: str, location: str) -> None:
super().__init__(Region.lookup(Provider.Azure, location))
self.account, self.key = AzureService.get_or_create_account_and_key(kind, location)
self.endpoint = cast(str, ensure(self.account.properties).endpoint)
self.location = ensure(self.account.location)
@staticmethod
@functools.lru_cache
def get_resource_group() -> azure.mgmt.resource.resources.models.ResourceGroup:
region = min(
Region.index[Provider.Azure].values(),
key=lambda region: distance(region.location, me),
)
with Resource(settings.azure_resource_group, Provider.Azure) as res:
try:
result = azure_resource_client.resource_groups.get(settings.azure_resource_group)
res.existed = True
except azure.core.exceptions.ResourceNotFoundError:
result = azure_resource_client.resource_groups.create_or_update(
resource_group_name=settings.azure_resource_group,
parameters=azure.mgmt.resource.resources.models.ResourceGroup(
location=region.name
), # type: ignore
)
res.created = True
res.region_name = result.location
return typing.cast(azure.mgmt.resource.resources.models.ResourceGroup, result)
@staticmethod
def get_or_create_account_and_key(kind: str, location: str) -> tuple[azure.mgmt.cognitiveservices.models.Account, str]:
with Resource(kind, Provider.Azure, location) as res:
sku = "F0" if settings.azure_use_free_tier else "S1"
resource_group = AzureService.get_resource_group()
resource_group_name = cast(str, resource_group.name)
account_name = Formatter().format(
settings.azure_cognitive_account_name,
resource_group=resource_group_name,
location=location,
kind=kind,
sku=sku,
)
try:
# get any existing account with the given name
account = azure_cognitiveservices_management_client.accounts.get(
resource_group_name=resource_group_name,
account_name=account_name,
)
res.existed = True
# check if any setting is wrong
if (
ensure(account.kind) != kind
or ensure(account.location) != location
or ensure(account.sku).name != sku
):
# delete and purge the incorrect account (we can't update these settings)
azure_cognitiveservices_management_client.accounts.begin_delete(
resource_group_name=resource_group_name,
account_name=account_name,
).result()
azure_cognitiveservices_management_client.deleted_accounts.begin_purge(
location=ensure(account.location),
resource_group_name=resource_group_name,
account_name=account_name,
)
raise azure.core.exceptions.ResourceNotFoundError()
except azure.core.exceptions.ResourceNotFoundError:
while True:
try:
# create a new account
account = azure_cognitiveservices_management_client.accounts.begin_create(
resource_group_name=resource_group_name,
account_name=account_name,
account=azure.mgmt.cognitiveservices.models.Account(
kind=kind,
sku=azure.mgmt.cognitiveservices.models.Sku(name=sku),
location=location,
),
).result()
res.created = True
break
except azure.core.exceptions.ResourceExistsError:
# purge the old account (but only if it's in our resource group)
for deleted_account in azure_cognitiveservices_management_client.deleted_accounts.list():
if cast(str, deleted_account.name).lower() == account_name.lower():
azure_cognitiveservices_management_client.deleted_accounts.begin_purge(
location=ensure(deleted_account.location),
resource_group_name=resource_group_name,
account_name=account_name,
).result()
break
else:
raise
key_list = azure_cognitiveservices_management_client.accounts.list_keys(
resource_group_name=resource_group_name,
account_name=account_name,
)
return (account, ensure(key_list.key1))
def invoke(
self,
*,
path: str,
parse: typing.Callable[[typing.Any], T],
data: bytes | None = None,
json: typing.Any = None,
params: dict[str, str] | None = None,
) -> T:
assert not (data is not None and json is not None), "binary data and JSON are mutually exclusive"
with self.timed:
headers = {
"Ocp-Apim-Subscription-Key": self.key,
"Ocp-Apim-Subscription-Region": self.location,
}
if data is not None:
headers["Content-Type"] = "application/octet-stream"
response = requests.post(
url=f"{self.endpoint}{path}",
data=data,
json=json,
headers=headers,
params=params,
timeout=(settings.connect_timeout, self.timeout),
)
result = response.json()
if not response.ok:
raise AzureServiceError.from_json(result)
try:
return parse(result)
except Exception:
raise ValueError("invalid response data")
class AzureComputerVisionService(AzureService, OcrService):
def __init__(self, location: str) -> None:
super().__init__(kind="ComputerVision", location=location)
self.quota = (
SlidingWindowQuota(20, datetime.timedelta(minutes=1))
if settings.azure_use_free_tier else
SlidingWindowQuota(10, datetime.timedelta(seconds=1))
)
def run(self, document: Document) -> str:
try:
data = None
json = {"url": document.url}
except NotImplementedError:
try:
data = document.content
json = None
except NotImplementedError:
raise ServiceInputError("document supports neither generating URI nor downloading")
except Exception as e:
raise ServiceInputError(f"cannot download image ({e})")
except Exception as e:
raise ServiceInputError(f"cannot generate image URI ({e})")
self.quota.use()
try:
return self.invoke(
path="/vision/v3.2/ocr",
data=data,
json=json,
parse=lambda result: "\n".join(
" ".join(word["text"] for word in line["words"])
for region in result["regions"]
for line in region["lines"]
),
)
except AzureServiceError as e:
if e.code == "InvalidImageFormat":
raise ServiceInputError("image format is invalid")
elif e.code == "InvalidImageSize":
raise ServiceInputError("image size is invalid")
elif e.code == "NotSupportedImage":
raise ServiceInputError("image type is not supported")
elif e.code == "NotSupportedLanguage":
raise ServiceInputError(f"language is not supported")
else:
raise
class AzureTextTranslationService(AzureService, TranslationService):
def __init__(self, location: str) -> None:
super().__init__(kind="TextTranslation", location=location)
self.quota = (
SlidingWindowQuota(2000000, datetime.timedelta(hours=1))
if settings.azure_use_free_tier else
None
)
def run(self, text: str, language: str) -> str:
if self.quota is not None:
self.quota.use(len(text))
try:
return self.invoke(
path="/translate",
params={
"api-version": "3.0",
"to": language,
},
json=[{"text": text}],
parse=lambda result: cast(str, result[0]["translations"][0]["text"]),
)
except AzureServiceError as e:
if e.code == 400019:
raise ServiceInputError(f"language '{language}' is not supported")
elif e.code == 400050:
raise ServiceInputError("text is too long")
else:
raise
###############################################################################
# REGISTRATION CLASSES #
###############################################################################
class Provider(enum.Enum):
create_ocr_service: typing.Callable[[str], OcrService]
create_translation_service: typing.Callable[[str], TranslationService]
AWS = (AWSRekognitionService, AWSTranslateService)
Azure = (AzureComputerVisionService, AzureTextTranslationService)
def __new__(
cls,
create_ocr_service: typing.Callable[[str], OcrService],
create_translation_service: typing.Callable[[str], TranslationService],
) -> Provider:
obj = object.__new__(cls)
obj._value_ = enum.auto()
obj.create_ocr_service = create_ocr_service
obj.create_translation_service = create_translation_service
return obj
CloudServiceT = typing.TypeVar("CloudServiceT", bound=CloudService)
class TimedService(typing.Generic[CloudServiceT]):
def __init__(self, is_supported: bool, region: Region, constructor: typing.Callable[[str], CloudServiceT]) -> None:
self.is_supported = is_supported
self.region = region
self.constructor = constructor
self.average_time = sys.maxsize
self.use_count = 0
self.lock = threading.Lock()
self.local = threading.local()
@functools.cached_property
def instance(self) -> CloudServiceT:
result = self.constructor(self.region.name)
self.average_time = 0
return result
def __enter__(self) -> None:
setattr(self.local, "start", time.time())
def __exit__(
self,
type: typing.Type[BaseException] | None,
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
if value is None:
runtime = time.time() - getattr(self.local, "start")
else:
runtime = settings.connect_timeout + self.instance.timeout
with self.lock:
self.average_time = (runtime + self.use_count * self.average_time) / (self.use_count + 1)
self.use_count += 1
class Region:
index: typing.ClassVar[dict[Provider, dict[str, Region]]] = {}
def __init__(
self,
provider: Provider,
name: str,
location: tuple[float, float],
is_available: bool,
supports_ocr: bool,
supports_translation: bool,
) -> None:
self.provider = provider
self.name = name
self.location = location
self.is_available = is_available
self.ocr = TimedService(supports_ocr, self, provider.create_ocr_service)
self.translation = TimedService(supports_translation, self, provider.create_translation_service)
available_services: list[str] = []
if is_available:
if supports_ocr:
available_services.append("ocr")
if supports_translation:
available_services.append("translation")
self.help = f"{name} ({', '.join(available_services)})"
@staticmethod
def register(region: Region) -> None:
provider = Region.index.setdefault(region.provider, {})
if region.name in provider:
raise ValueError(f"region '{region.name}' already exists in provider '{region.provider}'")
provider[region.name] = region
@staticmethod
def parse(spec: str) -> Region:
parts = spec.split(":", maxsplit=2)
if len(parts) != 2:
raise ValueError("provider and region must be specified")
provider, region_name = parts
return Region.lookup(Provider[provider], region_name)
@staticmethod
def lookup(provider: Provider, name: str) -> Region:
try:
return Region.index[provider][name]
except KeyError:
raise ValueError(f"region '{name}' of provider '{provider}' not registered")
def __repr__(self) -> str:
return f"<Region provider='{self.provider.name}' name='{self.name}' location={self.location}>"
class Resource:
all: typing.ClassVar[list[Resource]] = []
last_exc: typing.ClassVar[dict[tuple[str, str, str], tuple[float, BaseException]]] = {}
def __init__(self, name: str, provider: Provider, region_name: str = "") -> None:
self.name = name
self.provider_name = provider.name
self.region_name = region_name
self.existed = False
self.created = False
self.deployment_start = time.time() - epoche
self.deployment_end: float | None = None
def to_record(self) -> dict:
return self.__dict__
def __enter__(self) -> Resource:
assert self.deployment_end is None, "resource already ended"
if last_exc := Resource.last_exc.get((self.name, self.provider_name, self.region_name), None):
event, exc = last_exc
if self.deployment_start < event + settings.refresh_interval:
raise exc
return self
def __exit__(
self,
type: typing.Type[BaseException] | None,
value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
self.deployment_end = time.time() - epoche
if value is None:
Resource.all.append(self)
else:
Resource.last_exc[(self.name, self.provider_name, self.region_name)] = (self.deployment_end, value)
###############################################################################
# INITIALIZATION #
###############################################################################
# region cache
print("Loading cache: ", end="", flush=True)
try:
with open(full_path(".cache"), "r") as _file:
cache = json.load(_file)
if not isinstance(cache, dict):
raise TypeError()
except Exception:
cache = {}
print(f"{len(cache)} entries")
# endregion
# region AWS
print("Signing in to AWS: ", end="", flush=True)
try:
aws_session = boto3.session.Session()
aws_sts = aws_session.client("sts")
aws_identity = aws_sts.get_caller_identity()
except botocore.exceptions.NoCredentialsError:
print()
aws_session = boto3.session.Session(
**{
aws: input(f"{aws}: ")
for aws in ("aws_access_key_id", "aws_secret_access_key", "aws_session_token")
}
)
aws_sts = aws_session.client("sts")
aws_identity = aws_sts.get_caller_identity()
print(aws_identity["Arn"])
s3_resource = aws_session.resource("s3")
s3_client = aws_session.client("s3")
s3_client_unsigned = aws_session.client("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED))
print("Initializing AWS regions: ", end="", flush=True)
register_aws_regions()
print(len(Region.index[Provider.AWS]))
# endregion
# region Azure
print("Signing in to Azure: ", end="", flush=True)
azure_credential = azure.identity.DefaultAzureCredential(exclude_interactive_browser_credential=False)
azure_graph_client = msgraph.core.GraphClient(credential=azure_credential)
azure_user = cast(dict, azure_graph_client.get("/me").json())
print(azure_user.get("userPrincipalName", azure_user["id"]))
print("Using Azure subscription: ", end="", flush=True)
azure_subscription_client = azure.mgmt.subscription.SubscriptionClient(credential=azure_credential)
azure_subscriptions = list(azure_subscription_client.subscriptions.list())
azure_subscription = get_azure_subscription()
azure_subscription_id = cast(str, azure_subscription.subscription_id)
print(f"'{azure_subscription.display_name}' ({azure_subscription_id})")
azure_resource_client = azure.mgmt.resource.ResourceManagementClient(
credential=azure_credential,
subscription_id=azure_subscription_id,
)
azure_cognitiveservices_management_client = azure.mgmt.cognitiveservices.CognitiveServicesManagementClient(
credential=azure_credential,
subscription_id=azure_subscription_id,
)
print("Initializing Azure regions: ", end="", flush=True)
register_azure_regions()
print(len(Region.index[Provider.Azure]))
# endregion
###############################################################################
# MAIN PROGRAM #
###############################################################################
# region build argument parser
parser = argparse.ArgumentParser(
prog="convert",
description="Performs OCR on images and translates the recognized text.",
epilog=(
"Known PROVIDER:\n " +
"\n ".join(sorted(provider.name for provider in Provider)) +
"\n\n" +
"\n\n".join(
f"Known {provider.name}:REGION: (available services)\n " +
"\n ".join(region.help for region in sorted(Region.index[provider].values(), key=lambda r: r.name))
for provider in Provider
)
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"input",
type=Store.from_uri,
help="Input store containing the images.",
)
parser.add_argument(
"output",
type=Store.from_uri,
help="Output store receiving the translated text files.",
)
parser.add_argument(
"--ocr",
required=False,
metavar="PROVIDER:REGION",
type=Region.parse,
nargs='+',
default=[
region for region in Region.index[Provider.Azure].values()
if region.is_available and region.ocr.is_supported
],
help="Specifies the regions that should be used for OCR, defaults to all available Azure regions."
)
parser.add_argument(
"--translation",
required=False,
metavar="PROVIDER:REGION",
type=Region.parse,
nargs='+',
default=[
region for region in Region.index[Provider.AWS].values()
if region.is_available and region.translation.is_supported
],
help="Specifies the regions that should be used for translation, defaults to all available AWS regions."
)
register_settings()
# endregion
# region read the arguments
_args = parser.parse_args()
input_store: Store = _args.input
output_store: Store = _args.output
settings = Settings(**{
setting: getattr(_args, setting)
for setting in Settings._fields
})
me = locate_me()
ocr_regions: list[Region] = _args.ocr
translation_regions: list[Region] = _args.translation
if not settings.preserve_order:
ocr_regions = sorted(
ocr_regions,
key=lambda region: distance(
input_store.location if (
region.provider is Provider.Azure
or isinstance(input_store, S3BucketStore) and region is input_store.region
) else me,
region.location
)
)
translation_regions = sorted(
translation_regions,
key=lambda region: distance(me, region.location)
)
ocr_services = [region.ocr for region in ocr_regions]
translation_services = [region.translation for region in translation_regions]
# endregion
class Task:
all: typing.ClassVar[list[Task]] = []
lock: typing.ClassVar[threading.Lock] = threading.Lock()
iterator: typing.ClassVar[typing.Iterator[Document]] = iter(input_store.list_files())
def __init__(self, document: Document) -> None:
self.document = document
self.state = TaskState.IDLE
self.error = ""
self.ocr = TaskStepStatistics()
self.translation = TaskStepStatistics()
self.write = TaskStepStatistics()
Task.all.append(self)
def to_record(self) -> dict:
return {
"name": self.document.name,
"error": self.error,
"ocr_start": self.ocr.start,
"ocr_end": self.ocr.end,
"ocr_provider_name": self.ocr.provider_name,
"ocr_region_name": self.ocr.region_name,
"translation_start": self.translation.start,
"translation_end": self.translation.end,
"translation_provider_name": self.translation.provider_name,
"translation_region_name": self.translation.region_name,
"write_start": self.write.start,
"write_end": self.write.end,
}
@staticmethod
def next() -> Task | None:
with Task.lock:
try:
return Task(next(Task.iterator))
except StopIteration:
return None
def simple_step(self, state: TaskState, run: typing.Callable[[Document], T]) -> tuple[T, TaskStepStatistics]:
prev_state = self.state
try:
attempt = 1
while True:
self.state = state
try:
start = time.time() - epoche
result = run(self.document)
end = time.time() - epoche
return (result, TaskStepStatistics(start=start, end=end))
except Exception as e:
Terminal.exception(e)
attempt += 1
if attempt > settings.max_attempts:
raise
self.state = TaskState.IDLE
time.sleep(settings.retry_interval)
finally:
self.state = prev_state
def timed_step(
self,
state: TaskState,
services: typing.Iterable[TimedService[CloudServiceT]],
run: typing.Callable[[CloudServiceT, Document], T],
) -> tuple[T, TaskStepStatistics]:
prev_state = self.state
try:
attempt = 1
while True:
self.state = state
exceeded_quota = False
last_exc: Exception | None = None
for service in sorted(services, key=lambda service: service.average_time):
try:
start = time.time() - epoche
result = run(service.instance, self.document)
end = time.time() - epoche
except ServiceInputError:
raise
except QuotaError:
exceeded_quota = True
except Exception as e:
Terminal.exception(e)
last_exc = e
else:
return (
result,
TaskStepStatistics(
start=start,
end=end,
provider_name=service.region.provider.name,
region_name=service.region.name,
)
)
if not exceeded_quota and last_exc is not None:
attempt += 1
if attempt > settings.max_attempts:
raise last_exc
self.state = TaskState.THROTTLED if exceeded_quota else TaskState.IDLE
time.sleep(settings.retry_interval)
finally:
self.state = prev_state
class TaskState(enum.IntEnum):
IDLE = 0
THROTTLED = 1
OCR = 2
TRANSLATE = 3
WRITE = 4
FAILED = 5
DONE = 6
class TaskStepStatistics(typing.NamedTuple):
start: float = 0
end: float = 0
provider_name: str = ""
region_name: str = ""
class Terminal:
errors: typing.ClassVar[set[str]] = set()
error_count: typing.ClassVar[int] = 0
lock: typing.ClassVar[threading.Lock] = threading.Lock()
def __init__(self) -> None:
raise NotImplementedError()
@staticmethod
def max_width() -> int:
try:
return os.get_terminal_size().columns - 1
except OSError:
return 80
@staticmethod
def status(text: str) -> None:
width = Terminal.max_width()
line = f"{datetime.timedelta(seconds=time.time() - epoche)}: {text}"[0:width]
with Terminal.lock:
print("\r", " " * width, "\r", line, sep="", end="", flush=True)
@staticmethod
def exception(exception: Exception) -> None:
error = f"{type(exception).__name__}: {exception}"
width = Terminal.max_width()
with Terminal.lock:
Terminal.error_count += 1
if not error in Terminal.errors:
Terminal.errors.add(error)
print("\r", " " * width, "\r", error, sep="", flush=True)
def worker() -> None:
while task := Task.next():
try:
recognized, task.ocr = task.timed_step(
state=TaskState.OCR,
services=ocr_services,
run=lambda service, doc: service.run(doc),
)
translated, task.translation = task.timed_step(
state=TaskState.TRANSLATE,
services=translation_services,
run=lambda service, _: service.run(recognized, settings.language),
)
_, task.write = task.simple_step(
state=TaskState.WRITE,
run=lambda doc: output_store.put_file(os.path.splitext(doc.name)[0] + ".txt", translated.encode()),
)
except Exception as e:
task.state = TaskState.FAILED
task.error = str(e)
else:
task.state = TaskState.DONE
def main() -> None:
threads = [threading.Thread(target=worker, name=f"worker{index}", daemon=True) for index in range(settings.threads)]
for thread in threads:
thread.start()
while threads:
states = [0] * len(TaskState)
for task in Task.all:
states[task.state.value] += 1
summary = " ".join(map(lambda state: f"{state.name}={states[state.value]}", TaskState))
Terminal.status(f"THREADS={len(threads)} ERRORS={Terminal.error_count} {summary}")
time.sleep(settings.refresh_interval)
threads = [thread for thread in threads if thread.is_alive()]
Terminal.status(f"Processed {len(Task.all)} images, {sum(entry.state is TaskState.FAILED for entry in Task.all)} failed.")
if settings.statistics:
pandas.DataFrame.from_records(resource.to_record() for resource in Resource.all).to_csv(full_path("resources.csv"), index=False)
pandas.DataFrame.from_records(task.to_record() for task in Task.all).to_csv(full_path("documents.csv"), index=False)
epoche = time.time()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment