Created
November 14, 2023 01:51
-
-
Save aynik/b9455f8dd512aff8f54b7466cc2c2d3a to your computer and use it in GitHub Desktop.
Save an entire graph of dicts as django model instances
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import copy | |
from dataclasses import asdict | |
from typing import Any, Dict, cast | |
from django.conf import settings | |
from django.db import models | |
from django.db.models import Q | |
from inflection import pluralize | |
def categorize_fields(model): | |
related_fields = {} | |
direct_fields = {} | |
model_fields = [field.name for field in model._meta.fields] | |
for field, attr in model.__dict__.items(): | |
remote_relation = getattr(getattr(attr, 'field', {}), 'remote_field', None) | |
relation = getattr(attr, 'rel', None) or getattr(attr, 'related', None) | |
if isinstance(relation, models.OneToOneRel): | |
related_fields[field] = relation | |
related_fields[f'{field}_id'] = relation | |
elif isinstance(relation, models.ManyToOneRel): | |
related_fields[field] = relation | |
elif isinstance(relation, models.ManyToManyRel) and not relation.symmetrical: | |
related_fields[field] = relation | |
elif isinstance(remote_relation, models.ManyToOneRel): | |
direct_fields[field] = remote_relation | |
elif field in model_fields: | |
direct_fields[field] = None | |
return related_fields, direct_fields | |
def exclude_field(q: Q, field: str): | |
new_q = copy(q) | |
new_q.children = [f for f in new_q.children if f[0] != field] | |
return new_q | |
def extract_field(q: Q, field: str): | |
return next(iter([f for f in q.children if f[0] == field]), None) | |
def release_other_child(obj, relation, **input): | |
if isinstance(relation, models.OneToOneRel): | |
parent_obj = input.get(relation.field.name) | |
if hasattr(parent_obj, relation.related_name): | |
child_obj = getattr(parent_obj, relation.related_name) | |
if child_obj is not None and child_obj.pk != obj.pk: | |
setattr(child_obj, relation.field.name, None) | |
child_obj.save() | |
def save_graph(manager, **input) -> models.Model: | |
related_fields, direct_fields = categorize_fields(manager.model) | |
related_input: Dict[Any, Any] = {} | |
direct_input: Dict[Any, Any] = {} | |
for field, value in input.items(): | |
if value is None: | |
continue | |
elif value == settings.GRAPH_SET_NONE: | |
value = None | |
if isinstance(value, list): | |
value = [item for item in value if item is not None] | |
if field.endswith('_ids'): | |
field = pluralize(field[:-4]) | |
if field in related_fields: | |
if isinstance(value, list): | |
related_input[field] = related_input.get(field, []) | |
related_input[field] += value | |
else: | |
related_input[field] = value | |
elif field in direct_fields: | |
direct_input[field] = value | |
_clone = input.get('_clone', False) | |
pk = input.get('pk', Q(id=direct_input.get('id'))) | |
pk_field = extract_field(pk, manager.model._meta.pk.name) | |
if pk_field and pk_field[1] is not None: | |
relation = cast(models.ForeignObjectRel, input.get('_relation')) | |
obj = manager.filter(pk).first() | |
if obj: | |
if not _clone: | |
release_other_child(obj, relation, **direct_input) | |
for field, value in direct_input.items(): | |
setattr(obj, field, value) | |
if _clone: | |
obj.pk = None | |
obj._state.adding = True | |
obj.save() | |
else: | |
obj = manager.create(**direct_input) | |
for field, values in related_input.items(): | |
pk = Q() | |
relation = cast(models.ForeignObjectRel, related_fields[field]) | |
if isinstance(relation, models.ManyToManyRel): | |
getattr(obj, relation.field.name).clear() | |
processed = [] | |
if not isinstance(values, list): | |
values = [values] | |
for value in values: | |
if value is None: | |
continue | |
if isinstance(relation, models.OneToOneRel) or isinstance(relation, models.ManyToOneRel): | |
if isinstance(value, int): | |
if relation.related_model._meta.pk: | |
value = dict(pk=Q(**{relation.related_model._meta.pk.name: value})) | |
else: | |
continue | |
if not isinstance(value, dict): | |
value = asdict(value) | |
pk = value.get('pk', Q()) | |
processed += [ | |
save_graph( | |
relation.related_model.objects, | |
**{ | |
'_clone': _clone, | |
'_relation': relation, | |
**value, | |
relation.field.name: obj, | |
}, | |
) | |
] | |
else: | |
if not isinstance(value, int): | |
continue | |
getattr(obj, relation.field.name).add(value) | |
if isinstance(relation, models.ManyToOneRel): | |
if relation.related_model._meta.pk: | |
pk_filter = exclude_field(pk, relation.related_model._meta.pk.name) | |
else: | |
pk_filter = pk | |
relation.related_model.objects.filter(Q(**{relation.field.name: obj}) & pk_filter).exclude( | |
pk__in=[obj.pk for obj in processed if obj] | |
).delete() | |
return obj |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment