Last active
March 28, 2016 11:49
-
-
Save smahs/f9d68d1f61869f535301 to your computer and use it in GitHub Desktop.
Micro REST Framework for Django: A Prototype
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.views.generic import View | |
from django.http import ( | |
QueryDict, HttpResponse, HttpResponseBadRequest, | |
HttpResponseNotFound, HttpResponseNotAllowed, | |
) | |
from django.core.serializers.json import DjangoJSONEncoder | |
from django.core.exceptions import ( | |
ObjectDoesNotExist, ValidationError, | |
) | |
from django.db import transaction | |
from django.db.models import ( | |
Model, ForeignKey, ManyToManyField, Q, FieldDoesNotExist, | |
) | |
from django.db.models.fields.related import RelatedField | |
from json import dumps, loads | |
from collections import defaultdict | |
from functools import partial | |
class Unauthorised(Exception): | |
pass | |
class InputValidation(object): | |
""" | |
A descriptor class for input validation decorator. | |
Called after View.dispatch but before the handler funcs. | |
""" | |
def __init__(self, f): | |
self._func = f | |
def get_data(self): | |
if self.request.method == 'GET': | |
self.data = self.request.GET | |
return self.parse_form(self.view.model_class, self.data.keys()) | |
else: | |
try: | |
data = loads(self.request.body) | |
return self.parse_data(self.view.model_class, data) | |
except (TypeError, ValueError): | |
raise ValidationError({'input': ['parsing failed']}) | |
def parse_data(self, klass, hmap): | |
""" | |
All HTTP methods except GET are supposed to send the object data | |
in their request.body. This data is expected as JSON, following the | |
same structure as their model class. | |
""" | |
local, rels = self.view.__class__.split_params(hmap) | |
params = dict() | |
for key in local: | |
try: | |
field = klass._meta.get_field(key) | |
if isinstance(field, RelatedField): | |
raise FieldDoesNotExist() | |
params[key] = field.to_python(local[key]) | |
except (FieldDoesNotExist, ValidationError, ValueError): | |
raise ValidationError({key: ['invalid input data']}) | |
for key in rels: | |
try: | |
field = klass._meta.get_field(key) | |
if not isinstance(field, RelatedField): | |
raise FieldDoesNotExist() | |
vals = ([rels.get(key)] if not isinstance( | |
rels.get(key), list) else rels.get(key)) | |
store = [self.parse_data(field.related.parent_model, i) | |
for i in vals] | |
params[key] = store[0] if len(store) == 1 else store | |
except FieldDoesNotExist: | |
raise ValidationError({key: ['invalid input data']}) | |
return params | |
def parse_form(self, klass, keys, sup=None): | |
""" | |
GET data come as URLEncoded, however the query fields should | |
follow Django's '__' notation for model relationships. | |
For example: comment=0&user__id=1 will give all comments for user 1. | |
""" | |
local, rels = self.view.__class__.split_fields(keys) | |
rels = self.view.__class__.tokenize(rels) | |
params = dict() | |
for key in local: | |
supkey = sup + '__' + key if sup else key | |
vals = self.data.getlist(supkey) | |
try: | |
field = klass._meta.get_field(key) | |
if isinstance(field, RelatedField): | |
raise FieldDoesNotExist() | |
params[key] = [field.to_python(i) for i in vals] | |
except (FieldDoesNotExist, ValidationError, ValueError): | |
raise ValidationError({key: ['Invalid input data']}) | |
for key in rels: | |
try: | |
field = klass._meta.get_field(key) | |
if not isinstance(field, RelatedField): | |
raise FieldDoesNotExist() | |
sup = sup + '__' + key if sup else key | |
params[key] = self.parse_form(field.related.parent_model, | |
rels.get(key), sup=sup) | |
except FieldDoesNotExist: | |
raise ValidationError({key: ['Invalid input data']}) | |
return params | |
def check_pk(self, field, data): | |
try: | |
assert field.name in data | |
except AssertionError: | |
raise ValidationError({'pk': ['validation failed']}) | |
def validate_pk(self, klass, data): | |
""" | |
Validates the presence of primary keys in the input data. | |
TODO: Allow a way to skip this validation. | |
""" | |
local_pk = klass._meta.pk | |
local_rels = [i for i in klass._meta.local_fields | |
if isinstance(i, RelatedField)] | |
m2m_rels = klass._meta.local_many_to_many | |
if self.request.method is not 'POST': | |
self.check_pk(local_pk, data) | |
for field in local_rels + m2m_rels: | |
kls = field.related.parent_model | |
val = data.get(field.name, None) | |
if isinstance(val, list): | |
for datum in val: | |
self.check_pk(kls._meta.pk, datum) | |
elif isinstance(val, dict): | |
self.check_pk(kls._meta.pk, val) | |
else: | |
continue | |
def __call__(self, *args, **kwargs): | |
self.view = args[0] | |
self.request = args[1] | |
try: | |
params = self.get_data() | |
self.validate_pk(self.view.model_class, params) | |
setattr(self.view, 'params', params) | |
return self._func(*args, **kwargs) | |
except ValidationError as ve: | |
message = self.view.cleaning_errors(ve) | |
return HttpResponseBadRequest(message) | |
def __get__(self, obj, objtype): | |
return partial(self.__call__, obj) | |
class RestBaseView(View): | |
""" | |
A boilerplate class for providing some nice | |
extentions to Django's View class. | |
""" | |
def __init__(self, *args, **kwargs): | |
if not (self.model_class and Model in self.model_class.mro()): | |
raise TypeError('Model class not defined or not supported') | |
super(RestBaseView, self).__init__(*args, **kwargs) | |
def dispatch(self, request, *args, **kwargs): | |
if request.method not in self.methods: | |
return HttpResponseNotAllowed('Method not allowed') | |
try: | |
if hasattr(self, 'auth_class'): | |
self.auth_class().process_request(request) | |
except Unauthorised: | |
return HttpResponse(status=401) | |
return super(RestBaseView, self).dispatch(request, *args, **kwargs) | |
""" | |
Class methods for generic algorithms | |
""" | |
@classmethod | |
def tokenize(cls, arr, sep='__'): | |
d = defaultdict(list) | |
for i in arr: | |
j, k = i.split(sep, 1) | |
d[j].append(k) | |
return d | |
@classmethod | |
def split_fields(cls, fields_list): | |
local = [i for i in fields_list if '__' not in i] | |
rels = [i for i in fields_list if '__' in i] | |
return (local, rels) | |
@classmethod | |
def split_params(cls, params): | |
local = {k: v for k, v in params.iteritems() | |
if not hasattr(v, '__iter__')} | |
rels = {k: v for k, v in params.iteritems() | |
if hasattr(v, '__iter__')} | |
return (local, rels) | |
@classmethod | |
def flatten(cls, hmap): | |
def process(): | |
for key, value in hmap.iteritems(): | |
if isinstance(value, dict): | |
for subkey, subvalue in cls.flatten(value).iteritems(): | |
yield key + "__" + subkey, subvalue | |
elif isinstance(value, list): | |
collector = [] | |
for subvalue in value: | |
if hasattr(subvalue, '__iter__'): | |
collector.append(cls.flatten(subvalue)) | |
else: | |
collector.append(subvalue) | |
yield key, collector | |
else: | |
yield key, value | |
return dict(process()) | |
""" | |
Utility methods for View classes | |
""" | |
def send_json(self, data): | |
return HttpResponse(dumps(data, cls=DjangoJSONEncoder), | |
content_type='application/json') | |
def send_error(self, code, message): | |
return HttpResponse(dumps({'error': message}), status=code, | |
content_type='application/json') | |
def cleaning_errors(self, exc): | |
if isinstance(exc, ValidationError): | |
return '\n'.join([k + ': ' + ' '.join(v) for k, v | |
in exc.message_dict.iteritems()]) | |
""" | |
Serialization methods | |
""" | |
def get_serializable(self, val): | |
try: | |
dumps(val, cls=DjangoJSONEncoder) | |
return val | |
except (ValueError, TypeError): | |
return str(val) | |
def serialize_local(self, obj, fields): | |
return {i: self.get_serializable(getattr(obj, i)) | |
for i in fields} | |
def serialize_related(self, obj, fields): | |
attributes = RestBaseView.tokenize(fields) | |
out = dict() | |
for name, attr in attributes.iteritems(): | |
field = obj.__class__._meta.get_field(name) | |
if isinstance(field, ForeignKey): | |
out.update(self.serialize_fk(obj, field, attr)) | |
elif isinstance(field, ManyToManyField): | |
out.update(self.serialize_m2m(obj, field, attr)) | |
return out | |
def serialize_fk(self, obj, field, attr): | |
val = field.value_from_object(obj) | |
if not val: | |
out = None | |
else: | |
if len(attr) == 1 and 'id' in attr: | |
out = {'id': val} | |
else: | |
sub = getattr(obj, field.name) | |
local, rels = RestBaseView.split_fields(attr) | |
out = self.serialize_local(sub, local) | |
out.update(self.serialize_related(sub, rels)) | |
return {field.name: out} | |
def serialize_m2m(self, obj, field, attr): | |
local, rels = RestBaseView.split_fields(attr) | |
templ = getattr(obj, field.name).values_list(*local) | |
out = [dict(zip(local, i)) for i in templ] | |
if rels: | |
subs = getattr(obj, field.name).all() | |
tempr = [self.serialize_related(i, rels) for i in subs] | |
out = [dict(v, **tempr[i])for i, v in enumerate(out)] | |
return {field.name: out} | |
def serialize(self, obj): | |
local, rels = RestBaseView.split_fields(self.return_fields) | |
serialized = self.serialize_local(obj, local) | |
related = self.serialize_related(obj, rels) | |
if related: | |
serialized.update(related) | |
return serialized | |
""" | |
Database writes for relations | |
""" | |
def update_m2m(self, obj, field, vals): | |
attr = getattr(obj, field.name) | |
pkname = field.related.parent_model._meta.pk.name | |
ids = attr.values_list(pkname, flat=True) | |
intersect = set(vals).intersection(set(ids)) | |
add = set(vals) - intersect | |
rem = set(ids) - intersect | |
for i in rem: | |
attr.remove(i) | |
for i in add: | |
attr.add(i) | |
def update_related(self, obj, rels): | |
for name, val in rels.iteritems(): | |
field = obj.__class__._meta.get_field(name) | |
if isinstance(field, ForeignKey): | |
pkname = field.related_field.name | |
setattr(obj, name + '_id', val.get(pkname)) | |
elif isinstance(field, ManyToManyField): | |
pkname = field.related.parent_model._meta.pk.name | |
if isinstance(val, list): | |
vals = [i.get(pkname) for i in val if pkname in i] | |
if isinstance(val, dict): | |
vals = [val.get(pkname)] | |
self.update_m2m(obj, field, vals) | |
""" | |
DB fetch, override to custom gets | |
""" | |
def get_records(self, params): | |
params = RestBaseView.flatten(params) | |
params = {k: v for k, v in params.items() if 0 not in v} | |
params = {k + '__in' if len(v) > 1 else k: v | |
for k, v in params.items()} | |
params = {k: v[0] if len(v) == 1 else v | |
for k, v in params.items()} | |
return self.model_class.objects.filter(Q(**params)) | |
""" | |
HTTP methods | |
""" | |
@InputValidation | |
def get(self, request, *args, **kwargs): | |
try: | |
objs = self.get_records(self.params) | |
out = [self.serialize(i) for i in objs] | |
return self.send_json({self.model_class.__name__.lower(): out}) | |
except ObjectDoesNotExist: | |
return HttpResponseNotFound('Object not found') | |
@InputValidation | |
@transaction.atomic() | |
def post(self, request, *args, **kwargs): | |
try: | |
local, rels = RestBaseView.split_params(self.params) | |
obj = self.model_class(**local) | |
if self.model_class._meta.local_many_to_many: | |
obj.save() | |
self.update_related(obj, rels) | |
obj.full_clean() | |
obj.save() | |
return self.send_json(self.serialize(obj)) | |
except ValidationError as e: | |
return HttpResponseBadRequest(self.cleaning_errors(e)) | |
@InputValidation | |
@transaction.atomic() | |
def put(self, request, *args, **kwargs): | |
try: | |
local, rels = RestBaseView.split_params(self.params) | |
ids = local.pop('id') | |
qparams = {'id': [ids]} | |
obj = self.get_records(qparams) | |
if not obj: | |
raise ObjectDoesNotExist() | |
obj = obj[0] | |
for key, val in local.iteritems(): | |
setattr(obj, key, val) | |
self.update_related(obj, rels) | |
obj.full_clean() | |
obj.save() | |
return self.send_json(self.serialize(obj)) | |
except ObjectDoesNotExist: | |
return HttpResponseNotFound('Object not found') | |
except ValidationError as e: | |
return HttpResponseBadRequest(self.cleaning_errors(e)) | |
@InputValidation | |
def delete(self, request, *args, **kwargs): | |
try: | |
pkname = self.model_class._meta.pk.name | |
self.model_class.objects.filter( | |
pk=self.params.get(pkname)).delete() | |
return HttpResponse(self.send_json('Deletion successful')) | |
except ObjectDoesNotExist: | |
return HttpResponseNotFound('Object not found') | |
# ============== auth.py =========================== | |
class CustomAuth(object): | |
""" | |
Middleware style class, called before dispath | |
""" | |
def process_request(self, request): | |
auth_token = request.META.get('HTTP_AUTH') | |
if not auth_token: | |
raise Unauthorised() | |
# set request.user from token | |
# ============== models.py ========================= | |
class Comment(models.Model): | |
""" | |
Example model class, for the classical blog example | |
""" | |
title = models.CharField(max_length=256) | |
body = models.TextField(null=True, blank=True) | |
owner = models.ForeignKey(User, blank=True) | |
def validate_user(self): | |
try: | |
return User.objects.get(pk=int(self.owner_id)) | |
except ObjectDoesNotExist, TypeError, ValueError: | |
raise ValidationError({'owner_id': ['invalid data']}) | |
def full_clean(self, *args, **kwargs): | |
self.owner_id = validate_user() | |
super(Comment, self).full_clean(*args, **kwargs) | |
# ============== views.py ============================ | |
class CommentView(RestBaseView): | |
auth_class = CustomAuth | |
model_class = Comment | |
methods = ['GET', 'POST', 'PUT', 'DELETE'] | |
return_fields = ['id', 'title', 'body', 'user_id'] | |
# =============== tests.py =========================== | |
from django.utils import unittest | |
from django.test.client import Client | |
from django.core.urlresolvers import reverse | |
from django.contrib.auth.models import User | |
from json import dumps, loads | |
class CommentViewTests(unittest.TestCase): | |
def setUp(self): | |
self.user = User.objects.create_user(username='user', | |
email='[email protected]', password='secret') | |
self.client = Client() | |
def headers(self): | |
return { | |
'content_type': 'application/json', | |
'auth': self.auth_token, | |
} | |
def test_denies_anonymous(self): | |
response = self.client.get(reverse('comment_view')) | |
self.assertEqual(response.status_code, 401) | |
def test_login(self): | |
payload = { | |
'username': 'user', | |
'password': 'pass', | |
} | |
response = self.client.post(reverse('login'), payload) | |
self.assertEqual(response.status_code, 200) | |
self.assertTrue(response.has_header('AUTH')) | |
self.auth_token(response.get('AUTH')) | |
def test_post_comment(self): | |
payload = dumps({ | |
'title': 'A title', | |
'body': 'None', | |
'owner': { | |
'id': self.user.id, | |
}, | |
}) | |
response = self.client.post(reverse('comment_view'), | |
payload, **self.headers()) | |
self.assertEqual(response.status_code, 200) | |
self.comment = loads(response.body) | |
def test_post_comment_bad(self): | |
payload = dumps({ | |
'title': 'A title', | |
'body': 'None', | |
}) | |
response = self.client.post(reverse('comment_view'), | |
payload, **self.headers()) | |
self.assertEqual(response.status_code, 400) | |
def test_put_comment(self): | |
self.comment['body'] = 'Not None' | |
payload = dumps(self.comment) | |
response = self.client.post(reverse('comment_view'), | |
payload, **self.headers()) | |
self.assertEqual(response.status_code, 200) | |
comment = loads(response.body).get('comment')[0] | |
self.assertEqual(comment.get('body'), self.comment.get('body')) | |
def test_get_comments_for_user(self): | |
payload = { | |
'id': 0, | |
'user__id': self.user.id, | |
} | |
response = self.client.get(reverse('comment_view'), | |
payload, **self.headers()) | |
self.assertEqual(response.status_code, 200) | |
comments = loads(response.body) | |
assertTrue(self.comment.id in [comment.get('id') | |
for comment in comments.get('comment')]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment