Last active
May 6, 2025 07:54
-
-
Save polyvertex/e5dacc97350910f080fc85c61af20192 to your computer and use it in GitHub Desktop.
sqlite.py - sqlite3 with nested transactions for real
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Jean-Charles Lefebvre | |
# SPDX-License-Identifier: MIT | |
import contextlib | |
import importlib | |
import importlib.resources | |
import os | |
import re | |
import sqlite3 | |
import sys | |
import threading | |
import types | |
__all__ = ("SqliteConnection", "SqliteCursor") | |
DEFAULT_SCHEMA_RESOURCE_REGEX = re.compile( | |
r"^schema\-(\d+(?:_\d+)?)\.(?:sql|py)$", re.A) | |
SCHEMA_UPDATER_CALLABLE_NAME = "sqlitedb_update_schema" | |
class SqliteCursor(sqlite3.Cursor): | |
""" | |
A wrapper around `sqlite3.Cursor` that is the default Cursor class for | |
`SqliteConnection`. | |
Most notably it re-implements `executescript`, to honor the support of | |
nested transactions offered by `SqliteConnection`. This is because CPython's | |
`sqlite3.Connection.executescript` and `sqlite3.Cursor.executescript` do not | |
take into account the ``isolation_level`` value and forcefully issue a | |
``COMMIT`` statement before executing the passed SQL script. | |
This implies that any current transaction is commited at a lower-level | |
without any chance for us to be notified about the new internal state of the | |
sqlite3 connection. Thus breaking the support of nested transactions. | |
So there was no other choice than reimplementing `executescript`. It is done | |
here by relying on sqlite3 API to parse the SQL script and extract its | |
statements one by one, in order to execute them manually. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def __del__(self): | |
with contextlib.suppress(Exception): | |
self.close() | |
def executescript(self, script, *, source="<memory>"): | |
stmt_it = sqlite_iterate_script_statements( | |
script, source=source, with_keyword=True) | |
for keyword, stmt in stmt_it: | |
if keyword == "SELECT": | |
raise sqlite3.ProgrammingError( | |
"SELECT statements not permitted in executescript method") | |
elif keyword in ( | |
"BEGIN", "COMMIT", "END", | |
"SAVEPOINT", "RELEASE", "ROLLBACK"): | |
raise sqlite3.ProgrammingError( | |
f"transaction-related statements not permitted in " | |
f"executescript method (got {keyword} statement)") | |
self.execute(stmt) | |
class SqliteConnection: | |
""" | |
A wrapper around `sqlite3.Connection` with a schema updating feature and | |
that truly supports nested transactions with context management by using | |
sqlite3's ``SAVEPOINT`` feature. | |
.. seealso:: | |
Python's `issue16958 <https://bugs.python.org/issue16958>`_ about using | |
`sqlite3.Connection` as a context manager. | |
""" | |
def __init__(self, database, **kwargs): | |
isolation_level = kwargs.pop("isolation_level", None) | |
if isolation_level is not None: | |
raise ValueError( | |
"isolation_level arg specified and different than None") | |
self.uri = os.fspath(database) | |
self.conn = sqlite3.connect(self.uri, isolation_level=None, **kwargs) | |
assert self.conn.isolation_level is None | |
self.conn.execute("PRAGMA temp_store = MEMORY") | |
self.conn.execute("PRAGMA journal_mode = WAL") | |
self._savepoint_lock = threading.RLock() | |
self._savepoint_id = 0 | |
self._savepoint_stack = [] | |
def __del__(self): | |
with contextlib.suppress(Exception): | |
self.close() | |
def __getattr__(self, name): | |
if self.conn is None: | |
raise sqlite3.OperationalError( | |
f"trying to get or call {name} but database is closed: " | |
f"{self.uri}") | |
return getattr(self.conn, name) | |
def __bool__(self): | |
return self.conn is not None | |
def __str__(self): | |
return self.uri | |
def __repr__(self): | |
return f"<{self.__class__.__name__} {self.uri}>" | |
def __enter__(self): | |
self._push_savepoint() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
commit = exc_type is None | |
self._pop_savepoint(commit=commit, pop=True) | |
@property | |
def isolation_level(self): | |
self.ensure_open() | |
return self.conn.isolation_level | |
@isolation_level.setter | |
def isolation_level(self, value): | |
raise sqlite3.NotSupportedError( | |
"isolation_level change not supported by this wrapper") | |
@property | |
def in_transaction(self): | |
return self.conn is not None and self.conn.in_transaction | |
def close(self): | |
with self._savepoint_lock: | |
if self.conn is not None: | |
self.conn.close() | |
self.conn = None | |
if self._savepoint_stack: | |
self._savepoint_stack = [] | |
def ensure_open(self): | |
"""Raise `RuntimeError` if `close` has been called already""" | |
if self.conn is None: | |
raise sqlite3.OperationalError( | |
f"database connection closed: {self.uri}") | |
assert bool(self._savepoint_stack) == bool(self.conn.in_transaction) | |
def cursor(self, factory=SqliteCursor): | |
""" | |
The Cursor class factory. | |
It is important to use `SqliteCursor` or a derived class as a factory, | |
due to the reimplementation of `executescript`. | |
""" | |
self.ensure_open() | |
return self.conn.cursor(factory=factory) | |
def executescript(self, script, *, source="<memory>"): | |
# see `SqliteCursor` for the rationale behind this reimplementation | |
cursor = self.cursor() | |
cursor.executescript(script, source=source) | |
cursor.close() | |
def commit(self): | |
""" | |
Commit the current transaction if not already commited or released. | |
This method must only be called from a context. | |
""" | |
with self._savepoint_lock: | |
self.ensure_open() | |
if not self._savepoint_stack: | |
raise sqlite3.OperationalError( | |
"commit() called outside of a transaction context") | |
else: | |
self._pop_savepoint(commit=True, pop=False) | |
def rollback(self): | |
""" | |
Rollback the current transaction if not already commited or released. | |
This method must only be called from a context. | |
""" | |
with self._savepoint_lock: | |
self.ensure_open() | |
if not self._savepoint_stack: | |
raise sqlite3.OperationalError( | |
"rollback() called outside of a transaction context") | |
else: | |
self._pop_savepoint(commit=False, pop=False) | |
def fetchone(self, sql, parameters=()): | |
"""Shorthand for an `execute` call folowed by `fetchone`""" | |
cursor = self.execute(sql, parameters) | |
row = cursor.fetchone() | |
cursor.close() | |
del cursor | |
return row | |
def fetchmany(self, sql, parameters=(), size=None): | |
"""Shorthand for an `execute` call folowed by `fetchmany`""" | |
cursor = self.execute(sql, parameters) | |
if not size: | |
size = cursor.arraysize | |
rows = cursor.fetchmany(size) | |
cursor.close() | |
del cursor | |
return rows | |
def fetchall(self, sql, parameters=()): | |
"""Shorthand for an `execute` call folowed by `fetchall`""" | |
cursor = self.execute(sql, parameters) | |
rows = cursor.fetchall() | |
cursor.close() | |
del cursor | |
return rows | |
def create_or_update_schema( | |
self, meta_table, meta_column, resource_package, *, | |
schema_resource_regex=None): | |
""" | |
Get the current schema version of the database using *meta_table* and | |
*meta_column* names, then apply all the schema updates found in | |
*resource_package* if any. | |
*resource_package* must be a module object. | |
Return a `tuple` of two `int`: the detected version number before | |
applying any update (may be zero if database was not created), and the | |
version number of the latest update applied by this method, which may be | |
equal to the first value in the tuple. | |
""" | |
self.ensure_open() | |
initial_version = self.get_installed_schema_version( | |
meta_table, meta_column) | |
# get all the available schema updates | |
manifest = self.get_schema_resources_manifest(resource_package) | |
if not manifest: | |
raise sqlite3.OperationalError( | |
f"empty SQL schema manifest for package: " | |
f"{resource_package.__name__}") | |
# apply updates | |
latest_version = manifest.apply_updates(self, initial_version) | |
return (initial_version, latest_version) | |
def get_installed_schema_version(self, table_name, column_name): | |
""" | |
Used by `create_or_update_schema` to get the current database schema | |
version. | |
This method executes a ``SELECT`` statement using *table_name* and | |
*column_name* and return the value of the *column_name* value of the | |
first row. Expected to be an `int`. | |
Additionally, the requested table is expected to be a one-row table. | |
This method raises `RuntimeError` in case the number of rows is | |
different than one. | |
""" | |
self.ensure_open() | |
try: | |
cursor = self.execute( | |
f"SELECT {column_name} FROM {table_name} LIMIT 2") | |
except sqlite3.OperationalError: # missing table | |
cursor = None | |
if not cursor: | |
return 0 | |
rows = cursor.fetchall() | |
if not rows: | |
raise sqlite3.DatabaseError( | |
f"missing {table_name}.{column_name} value in database: " | |
f"{self.uri}") | |
elif len(rows) > 1: | |
raise sqlite3.DatabaseError( | |
f"unexpected multiple rows in table {table_name} in database: " | |
f"{self.uri}") | |
version = rows[0][0] | |
assert isinstance(version, int) | |
return version | |
def get_schema_resources_manifest( | |
self, resource_package, *, schema_resource_regex=None): | |
""" | |
Used by `create_or_update_schema` to get a `SqliteSchemasManifest` | |
object populated with all the database schema resources found under | |
*resource_package*. | |
*resource_package* must be a module object. | |
""" | |
if not schema_resource_regex: | |
schema_resource_regex = DEFAULT_SCHEMA_RESOURCE_REGEX | |
manifest = SqliteSchemasManifest(resource_package) | |
for res_name in importlib.resources.contents(resource_package): | |
if rem := schema_resource_regex.fullmatch(res_name): | |
dbver = int(rem.group(1)) | |
if not dbver: | |
raise ValueError( | |
f"SQL package {resource_package.__name__} contains a " | |
f"resource with a schema version value of zero: " | |
f"{res_name}") | |
manifest.register_resource(dbver, res_name) | |
return manifest | |
def _push_savepoint(self): | |
with self._savepoint_lock: | |
self.ensure_open() | |
if __debug__: | |
if not self._savepoint_stack: | |
assert not self.conn.in_transaction | |
self._savepoint_id += 1 | |
savepoint = f"SqliteConnTx_{self._savepoint_id}" | |
self.conn.execute(f"SAVEPOINT {savepoint}") | |
self._savepoint_stack.append(savepoint) | |
assert self.conn.in_transaction | |
def _pop_savepoint(self, *, commit, pop): | |
with self._savepoint_lock: | |
if not self._savepoint_stack: | |
return None | |
if self.conn is None: | |
self._savepoint_stack = [] | |
return None | |
else: | |
if pop: | |
savepoint = self._savepoint_stack.pop(-1) | |
else: | |
savepoint = self._savepoint_stack[-1] | |
if savepoint is not None: | |
self._savepoint_stack[-1] = None | |
# reminder: savepoint may be None due to commit() or | |
# rollback() methods | |
if savepoint: | |
assert self.conn.in_transaction | |
verb = "RELEASE" if commit else "ROLLBACK TO" | |
self.conn.execute(f"{verb} SAVEPOINT {savepoint}") | |
return savepoint | |
class SqliteSchemaResource: | |
""" | |
Utility class for the schema resource(s) associated with a single version. | |
Created by `SqliteSchemasManifest`. Not meant to be instanciated nor used | |
directly. | |
""" | |
def __init__(self, resource_package, version): | |
assert isinstance(resource_package, types.ModuleType) | |
assert isinstance(version, int) | |
self.resource_package = resource_package | |
self.version = version | |
self.sql_resource_name = None | |
self.py_resource_name = None | |
@property | |
def has_sql(self): | |
return bool(self.sql_resource_name) | |
@property | |
def has_py(self): | |
return bool(self.py_resource_name) | |
@property | |
def py_module_name(self): | |
if not self.has_py: | |
raise ValueError( | |
f"no Python module for schema version {self.version}") | |
assert self.py_resource_name.lower().endswith(".py") | |
return "{}.{}".format( | |
self.resource_package.__name__, | |
self.py_resource_name[0:-len(".py")]) | |
def apply_update(self, db): | |
""" | |
Apply schema update to the provided Connection or Cursor object *db* | |
""" | |
if self.has_sql: | |
db.executescript(self._extract_sql()) | |
if self.has_py: | |
module = self._import_py() | |
try: | |
try: | |
func = getattr(module, SCHEMA_UPDATER_CALLABLE_NAME) | |
except AttributeError: | |
func = None | |
if not func or not callable(func): | |
raise ValueError( | |
f"{module.__name__}.{SCHEMA_UPDATER_CALLABLE_NAME} " | |
f"missing or is not a callable") | |
func(db) | |
finally: | |
# release module | |
del func | |
modname = module.__name__ | |
del module | |
del sys.modules[modname] | |
def _extract_sql(self): | |
if not self.has_sql: | |
raise ValueError(f"no SQL resource for schema version {self.version}") | |
return importlib.resources.read_text( | |
self.resource_package, self.sql_resource_name, | |
encoding="utf-8", errors="strict") | |
def _import_py(self): | |
if not self.has_sql: | |
raise ImportError( | |
f"no Python module for schema version {self.version}") | |
return importlib.import_module(self.py_module_name) | |
class SqliteSchemasManifest: | |
""" | |
A snapshot of the schema resources (``.sql`` and ``.py``) embedded in the | |
passed Python package. | |
Created by `SqliteConnection.get_schema_resources_manifest`. Not meant to be | |
instanciated directly. | |
""" | |
def __init__(self, resource_package): | |
assert isinstance(resource_package, types.ModuleType) | |
self.resource_package = resource_package | |
self._modified = False | |
self._resources = {} | |
self._oldest_version = None | |
self._latest_version = None | |
def __len__(self): | |
return len(self._resources) | |
def __iter__(self): | |
if self._modified: | |
self._sort() | |
# guaranteed by _sort() to be ordered by ascending version | |
return self._resources.values() | |
def __contains__(self, version): | |
return self.has_schema(version) | |
def __getitem__(self, version): | |
return self.get_schema(version) | |
@property | |
def oldest_version(self): | |
"""The smallest version number registered (`int`)""" | |
if self._modified: | |
self._sort() | |
return self._oldest_version | |
@property | |
def latest_version(self): | |
"""The biggest version number registered (`int`)""" | |
if self._modified: | |
self._sort() | |
return self._latest_version | |
@property | |
def versions(self): | |
""" | |
The `list` of registered versions so far. | |
List is ordered by ascending version number. | |
""" | |
if self._modified: | |
self._sort() | |
# guaranteed by _sort() to be ordered by ascending version | |
return list(self._resources.keys()) | |
def has_schema(self, version): | |
""" | |
Check if a schema of the passed *version* exists and return a `bool` | |
value. | |
A null *version* value stands for "the oldest version". | |
""" | |
try: | |
self.get_schema(version) | |
return True | |
except IndexError: | |
return False | |
def get_schema(self, version): | |
""" | |
Get the `SqliteSchemaResource` object associated to the passed | |
*version*. | |
A null *version* value stands for "the oldest version". | |
Raise `IndexError` if *version* was not found. | |
""" | |
assert isinstance(version, int) | |
self._sort() | |
if not self._resources: | |
raise IndexError("no schema in manifest") | |
assert self._oldest_version is not None | |
assert self._latest_version is not None | |
if not version: | |
version = self._oldest_version | |
return self._resources[version] # may raise IndexError | |
def apply_updates(self, db, from_version): | |
""" | |
Apply every schema updates available from the specified *from_version* | |
number (non-included unless it is zero), up to the latest version | |
available in embedded resources. | |
*db* must be either a Connection or a Cursor compatible object. | |
Return the latest version number applied. This value may be equal to | |
*from_version* if database was up-to-date already. | |
""" | |
assert isinstance(from_version, int) | |
self._sort() | |
if from_version and from_version not in self._resources: | |
raise ValueError(f"unknown schema version {from_version}") | |
# _sort() guarantees self._resources to be ordered by ascending version | |
latest_version = from_version | |
for schema in self._resources.values(): | |
if not from_version or schema.version > from_version: | |
schema.apply_update(db) | |
latest_version = schema.version | |
return latest_version | |
def register_resource(self, schema_version, resource_name): | |
""" | |
Used by `SqliteConnection.get_schema_resources_manifest` to register an | |
embbeded resource in the manifest | |
""" | |
assert isinstance(schema_version, int) | |
if schema_version in self._resources: | |
schema = self._resources[schema_version] | |
else: | |
schema = SqliteSchemaResource(self.resource_package, schema_version) | |
self._resources[schema_version] = schema | |
self._modified = True | |
if resource_name.lower().endswith(".sql"): | |
assert schema.sql_resource_name is None | |
schema.sql_resource_name = resource_name | |
elif resource_name.lower().endswith(".py"): | |
assert schema.py_resource_name is None | |
schema.py_resource_name = resource_name | |
else: | |
raise ValueError(f"unknown resource type: {resource_name}") | |
def _sort(self): | |
"""Ensure ``self._resources`` is ordered by ascending version""" | |
if self._modified: | |
versions = list(self._resources.keys()) | |
versions.sort() | |
old_dict = self._resources | |
self._resources = {} | |
self._oldest_version = versions[0] | |
self._latest_version = versions[-1] | |
# CAUTION: this assumes Python 3.7+ (i.e. ordered dict) | |
assert sys.version_info >= (3, 7) | |
for ver in versions: | |
assert isinstance(ver, int) | |
self._resources[ver] = old_dict[ver] | |
self._modified = False | |
def sqlite_iterate_script_statements( | |
script, *, source="<memory>", with_keyword=False): | |
""" | |
Yield SQL `str` statements from a sqlite3-compatible *script*. | |
If *with_keyword* is true, each yielded value is a `tuple` containing a pair | |
of `str` objects: the keyword of the statement and the statement itself. | |
""" | |
def _prepare_yield(st): | |
st = st.strip() | |
return st if not with_keyword else (st.split(maxsplit=1)[0].upper(), st) | |
if not sqlite3.complete_statement(script): | |
raise sqlite3.ProgrammingError( | |
f"not a complete SQL statement or script: {source}") | |
stmt = "" # current statement | |
for line in script.splitlines(keepends=True): | |
if ";" not in line: | |
stmt += line | |
if sqlite3.complete_statement(stmt): | |
yield _prepare_yield(stmt) | |
stmt = "" | |
else: | |
parts = line.split(";") | |
for idx, part in enumerate(parts): | |
stmt += part | |
if idx < len(parts) - 1: | |
stmt += ";" | |
if sqlite3.complete_statement(stmt): | |
yield _prepare_yield(stmt) | |
stmt = "" | |
# Trailing data, if any, can be safely ignored because the whole script has | |
# been validated at the beginning of this function so that any remaining | |
# data is likely to be space characters or comment. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment