Last active
September 17, 2015 20:54
-
-
Save ecederstrand/6748a3496acdc95e40ef to your computer and use it in GitHub Desktop.
Prefetch objects using the original filters rather than the IN clause used by prefetch_related()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import logging | |
from django.db.models import Q | |
from django.db.models.fields.related import ForeignRelatedObjectsDescriptor, ReverseSingleRelatedObjectDescriptor, \ | |
ReverseManyRelatedObjectsDescriptor | |
log = logging.getLogger(__name__) | |
def to_tree(prefetch_fields): | |
# Creates a tree structure of recursive prefetch fields given in 'foo__bar' form | |
tree = {} | |
for s in prefetch_fields: | |
t = tree | |
for f in s.split('__'): | |
if f not in t: | |
t[f] = {} | |
t = t[f] | |
return tree | |
def flatten(tree): | |
# Creates 'foo__bar' form from tree structure of recursive prefetch fields | |
names = [] | |
for name, subtree in tree.items(): | |
names.append(name) | |
names.extend('%s__%s' % (name, n) for n in flatten(subtree)) | |
return names | |
def prefix_args(related_field, args=(), kwargs=None): | |
# Add prefixes to Q objects and kwargs so we can use them on a field related to the original model | |
new_args = [] | |
for arg in args: | |
assert isinstance(arg, Q) | |
clone = arg.clone() | |
clone.children = [('%s__%s' % (related_field, k), v) for k, v in arg.children] | |
new_kwargs = {'%s__%s' % (related_field, k): v for k, v in kwargs.items()} | |
log.debug('Prefetch args/kwargs for %s: %s %s', related_field, new_args, new_kwargs) | |
return new_args, new_kwargs | |
def prefetch(model, prefetch_fields, filter_args=(), filter_kwargs=None, exclude_args=(), exclude_kwargs=None): | |
# Prefetches related items using the same filters as the original model, instead of the IN clause that | |
# QuerySet.prefetch_related() uses. | |
# TODO: For now, only supports recursion when the lower levels are FK relations, not M2M relations | |
if not filter_kwargs: | |
filter_kwargs = {} | |
if not exclude_kwargs: | |
exclude_kwargs = {} | |
log.debug('Prefetching %s on %s', prefetch_fields, model.__name__) | |
res = model.objects.filter(*filter_args, **filter_kwargs).exclude(*exclude_args, **exclude_kwargs) | |
for r in res: | |
r._prefetched_objects_cache = {} | |
related = {} | |
for f, select_fields in to_tree(prefetch_fields).items(): | |
if f in related: | |
continue | |
log.debug('Getting related %s on %s', f, model.__name__) | |
field = getattr(model, f) | |
select_fields = flatten(select_fields) | |
if isinstance(field, ForeignRelatedObjectsDescriptor): | |
# Through models on 'model', aka. foorel_set, and M2M relations via reverse FK | |
reverse_field = field.related.field.name | |
fargs, fkwargs = prefix_args(reverse_field, filter_args, filter_kwargs) | |
xargs, xkwargs = prefix_args(reverse_field, exclude_args, exclude_kwargs) | |
id_field = field.related.field.attname | |
related[f] = defaultdict(set) | |
# Get the FK fields on the through model | |
related_field_names = list(f.name for f in field.related.related_model._meta.local_fields if f.rel) | |
# select_related() the fields we're supposed to prefetch anyway, and that we haven't fetched already | |
extra_select_fields = [f for f in related_field_names if f + 's' in prefetch_fields and f + 's' not in related] | |
extra_prefetch_names = [f + 's' for f in extra_select_fields] | |
for extra_prefetch_name in extra_prefetch_names: | |
related[extra_prefetch_name] = defaultdict(set) | |
select_fields += extra_select_fields | |
log.debug('Getting select_related %s for %s', select_fields, field.related.related_model) | |
for o in field.related.related_model.objects\ | |
.select_related(*select_fields)\ | |
.filter(*fargs, **fkwargs)\ | |
.exclude(*xargs, **xkwargs): | |
related_id = getattr(o, id_field) | |
related[f][related_id].add(o) | |
for extra_select_field, extra_prefetch_name in zip(extra_select_fields, extra_prefetch_names): | |
related[extra_prefetch_name][related_id].add(getattr(o, extra_select_field)) | |
elif isinstance(field, ReverseManyRelatedObjectsDescriptor): | |
# M2M relations via through model | |
reverse_field = field.field.m2m_field_name() | |
fargs, fkwargs = prefix_args(reverse_field, filter_args, filter_kwargs) | |
xargs, xkwargs = prefix_args(reverse_field, exclude_args, exclude_kwargs) | |
reverse_id_field = model.__name__.lower() + '_id' | |
select_field = field.field.m2m_reverse_field_name() | |
select_fields += [select_field] | |
through_model = field.through | |
through_field = field.through.__name__.lower() + '_set' | |
related[f] = defaultdict(set) | |
related[through_field] = defaultdict(set) | |
log.debug('Getting select_related %s for %s', select_fields, through_model) | |
for o in through_model.objects\ | |
.select_related(*select_fields)\ | |
.filter(*fargs, **fkwargs)\ | |
.exclude(*xargs, **xkwargs): | |
related_id = getattr(o, reverse_id_field) | |
related[through_field][related_id].add(o) | |
related[f][related_id].add(getattr(o, select_field)) | |
elif isinstance(field, ReverseSingleRelatedObjectDescriptor): | |
# Prefetch FK relations on 'model' | |
reverse_field = field.field.rel.related_name | |
fargs, fkwargs = prefix_args(reverse_field, filter_args, filter_kwargs) | |
xargs, xkwargs = prefix_args(reverse_field, exclude_args, exclude_kwargs) | |
log.debug('Getting select_related %s for %s', select_fields, field.field.rel.related_model) | |
related[f] = { | |
o.pk: o for o in field.field.rel.related_model.objects | |
.select_related(*select_fields) | |
.filter(*fargs, **fkwargs) | |
.exclude(*xargs, **xkwargs) | |
} | |
else: | |
assert False, 'Unsupported prefetch field %s' % field | |
for r in res: | |
for f in related.keys(): | |
# Prefetched through models end with '_set' but go into _prefetched_objects_cache without the '_set' | |
if f.endswith('_set'): | |
r._prefetched_objects_cache[f[:-4]] = related[f][r.pk] | |
else: | |
r._prefetched_objects_cache[f] = related[f][r.pk] | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment