Last active
July 24, 2021 01:21
-
-
Save martijnluinstra/f10ff7f2125d8b618a7df858f834cd66 to your computer and use it in GitHub Desktop.
Python 3 classes for filtering SQLAlchemy queries
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This set of Python 3 classes is created to make automatic filtering of | |
SQLAlchemy queries easier. It is designed to have an api similar to | |
Django-filter (https://github.com/carltongibson/django-filter). | |
You may use an modify this code however you like for non-commercial purposes. | |
I will appreciate it if you mention my name when you do so. | |
Copyright (c) 2021 Martijn Luinstra | |
""" | |
import datetime | |
class Filter: | |
""" | |
Base filter class | |
""" | |
def __init__(self, operation='action', field=None, action=None, nullable=False): | |
""" | |
operation: the name of the class-method to use | |
field: the field to operate on (will use name of the property if not assigned) | |
action: custom action function that accepts a query and a compare value and returns a filtered query | |
nullable: allows comparevalue = None | |
""" | |
self.operation = operation | |
self.field = field | |
self._action = action | |
self.nullable = nullable | |
def prepare_compare_value(self, compare): | |
""" Format the compare value """ | |
return compare | |
def apply(self, query, field, compare): | |
""" Apply the filter """ | |
compare = self.prepare_compare_value(compare) | |
if not self.nullable and compare is None: | |
return query | |
if self._action: | |
return self._action(query, compare) | |
if self.field: | |
field = self.field | |
return getattr(self, self.operation)(query, field, compare) | |
def action(self, query, field, compare): | |
""" The filter action """ | |
return query.filter(field==compare) | |
class CompareFilter(Filter): | |
""" | |
Base filter class for (basic) comparisons | |
""" | |
def equal_to(self, query, field, compare): | |
return query.filter(field == compare) | |
def not_equal_to(self, query, field, compare): | |
return query.filter(field != compare) | |
def greater_equal_to(self, query, field, compare): | |
return query.filter(field >= compare) | |
def greater_than(self, query, field, compare): | |
return query.filter(field > compare) | |
def less_equal_to(self, query, field, compare): | |
return query.filter(field <= compare) | |
def less_than(self, query, field, compare): | |
return query.filter(field < compare) | |
class ListFilter(Filter): | |
""" | |
Base filter that accepts comma separated lists as compare value | |
""" | |
def prepare_compare_value(self, compare): | |
if not compare: | |
return None | |
compare = compare.split(',') | |
if (len(compare) == 1 and not compare[0]): | |
return None | |
return [c.strip() for c in compare] | |
class RangeFilter(CompareFilter, ListFilter): | |
""" | |
Base filter that filters for ranges | |
""" | |
def prepare_compare_value(self, compare): | |
compare = super().prepare_compare_value(compare) | |
if compare and len(compare) > 2: | |
return None | |
return compare | |
def action(self, query, field, compare): | |
if len(compare) == 1: | |
return self.equal_to(query, field, compare[0]) | |
if compare[0]: # Left boundary | |
query = self.greater_than(query, field, compare[0]) | |
if compare[1]: # Right boundary | |
query = self.less_than(query, field, compare[1]) | |
return query | |
class InFilter(ListFilter): | |
""" | |
Base filter that filters whether the value is in a list of comparevalues | |
""" | |
def action(self, query, field, compare): | |
return query.filter(field.in_(compare)) | |
class BooleanFilter(Filter): | |
""" | |
Filter that compares boolean values | |
""" | |
def prepare_compare_value(self, compare): | |
if isinstance(compare, bool) or compare is None: | |
return compare | |
elif compare.lower() in ('true', 'y', 'yes', 't', 'on'): | |
return True | |
elif compare.lower() in ('false', 'n', 'no', 'f', 'off'): | |
return False | |
return None | |
class StringFilter(Filter): | |
""" | |
Filter that compares string values | |
""" | |
def __init__(self, template='{}', **kwargs): | |
self.template = template | |
if not 'operation' in kwargs: | |
kwargs['operation'] = 'equals' | |
super().__init__(**kwargs) | |
def prepare_compare_value(self, compare): | |
if not compare: | |
return None | |
return str(compare) | |
def equal_to(self, query, field, compare): | |
return query.filter(field==compare) | |
def like(self, query, field, compare): | |
return query.filter(field.like(self.template.format(compare))) | |
class IntegerFilter(CompareFilter): | |
""" | |
Filter that compares integer values | |
""" | |
def prepare_compare_value(self, compare): | |
try: | |
return int(compare) | |
except Exception: | |
# Skip filter | |
return None | |
class ForeignKeyFilter(IntegerFilter): | |
""" | |
Filter that compares foreign keys (only accepts non-zero compare values) | |
""" | |
def prepare_compare_value(self, compare): | |
try: | |
compare = int(compare) | |
return compare if compare else None | |
except Exception: | |
# Skip filter | |
return None | |
class DateTimeFilter(CompareFilter): | |
""" | |
Filter that compares datetime values | |
""" | |
def __init__(self, format='%Y-%m-%d', **kwargs): | |
self.format = format | |
super().__init__(**kwargs) | |
def prepare_compare_value(self, compare): | |
try: | |
return datetime.datetime.strptime(compare, self.format) | |
except Exception: | |
# Skip filter | |
return None | |
class DateTimeRangeFilter(RangeFilter): | |
""" | |
Filter that compares datetime ranges | |
""" | |
def __init__(self, format='%Y-%m-%d', **kwargs): | |
self.format = format | |
super().__init__(**kwargs) | |
def prepare_compare_value(self, compare): | |
compare = super().prepare_compare_value(compare) | |
if not compare: | |
return None | |
for idx, c in enumerate(compare): | |
try: | |
compare[idx] = datetime.datetime.strptime(c, self.format) | |
except Exception: | |
# Skip failed item in range | |
compare[idx] = None | |
return compare | |
class FilteredQuery: | |
""" | |
Base class for filtered queries | |
""" | |
def __init__(self, query, data, default=None): | |
self.data = default if default else {} | |
self.data.update(data) | |
self.query = self.apply(query) | |
if isinstance(self.data.get('order_by'), list): | |
self._order = self.data['order_by'] | |
elif isinstance(self.data.get('order_by'), str): | |
self._order = self.data['order_by'].split(',') | |
self._order = [o.strip() for o in self._order] | |
else: | |
self._order = None | |
def apply(self, query): | |
""" | |
Applies filters to query | |
""" | |
for key in dir(self): | |
if key not in self.data: | |
continue | |
value = getattr(self, key) | |
compare = self.data[key] | |
if isinstance(value, Filter): | |
query = value.apply(query, self._get_field(key), compare) | |
return query | |
def order(self, query): | |
""" | |
Orders query | |
""" | |
if not self._order: | |
return query | |
args = [] | |
for name in self._order: | |
if not name: | |
continue | |
elif name[0] == '-': | |
field = self._get_field(name[1:]) | |
if field: | |
args.append( field.desc() ) | |
else: | |
args.append( self._get_field(name) ) | |
if args: | |
return query.order_by(*args) | |
return query | |
@property | |
def ordered_query(self): | |
return self.order(self.query) | |
def _get_field(self, fieldname): | |
""" | |
Gets field object based on __model__ property and the filter property name | |
""" | |
if not hasattr(self, '__model__'): | |
return None | |
model = getattr(self,'__model__') | |
if hasattr(model, fieldname): | |
return getattr(model, fieldname) | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment