Created
June 27, 2018 08:10
-
-
Save anti1869/cf557cd8a04870811a5d56cc746b516b to your computer and use it in GitHub Desktop.
Flushes item into PostgreSQL table with fast upserts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from typing import Any, Dict, Optional, Sequence, Tuple, List | |
import simplejson as json | |
from psycopg2 import IntegrityError | |
from django.core.serializers.json import DjangoJSONEncoder | |
from django.db import connection | |
from django.db.utils import ProgrammingError | |
logger = logging.getLogger(__name__) | |
class PGFlusher(object): | |
""" | |
Flushes item into PostgreSQL table with fast upserts. | |
Usage:: | |
data = [ | |
{ | |
"long_id": 123, | |
"created": "2018-01-15", | |
"street": "Street 1", | |
"burg": "Burg 1", | |
"_meta": "stuff1", | |
}, | |
{ | |
"long_id": 124, | |
"created": "2018-01-15", | |
"street": "Street 2", | |
"burg": "Burg 2", | |
"_meta": "stuff2", | |
}, | |
] | |
flusher = PGFlusher( | |
db_table="test", | |
unique_fields=("long_id", ), | |
create_fields=("created", ), | |
ignore_fields=("_meta", ), | |
) | |
pk_collection = flusher.flush(data, return_pk=True) | |
for idx, pk in enumerate(pk_collection): | |
data[idx]["_pk"] = pk | |
""" | |
def __init__(self, db_table: str, unique_fields: Tuple[str, ...], | |
create_fields: Optional[Tuple[str, ...]] = None, | |
ignore_fields: Optional[Tuple[str, ...]] = None): | |
""" | |
You need to provide table name and list of fields on which to build unique constraint. | |
E.g. `db_table="store_items", unique_fields=("shop_id", "original_id")` | |
:param db_table: DB table name | |
:param unique_fields: List of fields to deduct constraint from. | |
:param create_fields: Those only used in INSERT and not in UPDATE operation. | |
:param ignore_fields: These fields will be ignored in all operations. | |
""" | |
self.db_table = db_table | |
self.unique_fields = unique_fields | |
self.unique_fields_string = ", ".join(self.unique_fields) | |
self.create_fields = create_fields or tuple() | |
self.ignore_fields = ignore_fields or tuple() | |
# Those will be populated on first passed data item | |
self.all_fields = None | |
self.all_fields_string = None | |
self.placeholders_string = None | |
self.set_statement = None | |
self._encoder = DjangoJSONEncoder() | |
def _make_fields(self, names: Sequence[str]): | |
""" | |
Prepare fields collection to use in SQL statement. | |
""" | |
self.all_fields = tuple(name for name in names if name not in self.ignore_fields) | |
self.data_fields = tuple(name for name in self.all_fields if name not in self.unique_fields) | |
self.all_fields_string = ", ".join(self.all_fields) | |
self.placeholders_string = ", ".join("%s" for _ in self.all_fields) | |
self.set_statement = ", ".join( | |
"{} = EXCLUDED.{}".format(f, f) | |
for f in self.data_fields if f not in self.create_fields | |
) | |
def flush(self, data: List[Dict], return_pk: Optional[bool] = True): | |
""" | |
This will save data to database using PostgreSQL-specific SQL statement | |
allowing to insert and update records in one take. | |
Example statement produced here:: | |
INSERT INTO household_address | |
(street_name, city, zip_code, house_number, house_number_extension) | |
VALUES | |
('Street 1', 'burg', '1211BB', '22', ''), | |
('Street 2', 'burg', '1211BC', '11', '') | |
ON CONFLICT (zip_code, house_number, house_number_extension) | |
DO UPDATE SET | |
street_name = EXCLUDED.street_name, city = EXCLUDED.city | |
RETURNING id; | |
In return you will get list of PK that was either inserted or found during update. | |
This list is mapped back to the passed data. | |
More info: | |
https://www.postgresql.org/docs/9.6/static/sql-insert.html#SQL-ON-CONFLICT | |
""" | |
if not data: | |
return | |
# First item passed will trigger field initialization, so we can be a little bit dynamic | |
if self.all_fields is None: | |
self._make_fields(tuple(data[0].keys())) | |
sql = """ | |
INSERT INTO {table_name} | |
({all_fields}) | |
VALUES | |
{values_lists} | |
ON CONFLICT ({unique_fields}) | |
DO UPDATE SET {set_statement} | |
{return_pk}; | |
""".format( | |
table_name=self.db_table, | |
all_fields=self.all_fields_string, | |
values_lists=self.make_placeholder_lists(data), | |
unique_fields=self.unique_fields_string, | |
set_statement=self.set_statement, | |
return_pk='RETURNING id' if return_pk else '', | |
) | |
# Execute query | |
try: | |
with connection.cursor() as cursor: | |
cursor.execute(sql, self.make_values_list(data)) | |
if not return_pk: # Exit early if PKs not needed | |
return | |
pk_list = cursor.fetchall() | |
except (IntegrityError, ProgrammingError): | |
logger.error("SQL error %s", sql, exc_info=True) | |
raise | |
for idx, value in enumerate(data): | |
value['_pk'] = pk_list[idx][0] | |
def make_placeholder_lists(self, data: List[Dict]): | |
""" | |
Prepare string placeholder to be filled with values by database driver. | |
E.g.: (%s, %s, %s), (%s, %s, %s). Number of blocks here must correspond to the number | |
of items in data. | |
""" | |
result = '({})'.format( | |
'), ('.join(self.placeholders_string for _ in range(len(data))) | |
) | |
return result | |
def make_values_list(self, data: List[Dict]): | |
""" | |
Prepare tuples with values that are extracted from each item (dict). | |
This is needed to preserve same order of fields for every data item. | |
Result is one big tuple with all values flattened in correct order. | |
""" | |
result = tuple(self._ex_value(v, name) for v in data for name in self.all_fields) | |
return result | |
def _ex_value(self, v: Dict, name: str) -> Any: | |
"""Extract from dict and cast type if necessary.""" | |
value = v.get(name) | |
if isinstance(value, dict): # Simplest JSON detection | |
value = json.dumps(value, default=self._encoder.default) | |
return value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment