Skip to content

Instantly share code, notes, and snippets.

@alejandrobernardis
Last active May 23, 2022 06:13
Show Gist options
  • Save alejandrobernardis/8572960 to your computer and use it in GitHub Desktop.
Save alejandrobernardis/8572960 to your computer and use it in GitHub Desktop.
Tornado, Class Role, Permission, Identity (+mixin) and AuthContext
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2013 Asumi Kamikaze Inc.
# Copyright (c) 2013 The Octopus Apps Inc.
# Licensed under the Apache License, Version 2.0 (the "License")
# Author: Alejandro M. Bernardis
# Email: alejandro.bernardis at gmail.com
# Created: 04/Oct/2013 09:37
import urllib
import urlparse
from functools import wraps
from tornado.web import RequestHandler, HTTPError
STRICT_MODE = 'strict'
REQUIRED_MODE = 'required'
HTTP_ERROR = 403
def is_permissions(method):
@wraps(method)
def wrapper(self, other, *args, **kwargs):
if not isinstance(other, Permission):
raise TypeError('Value invalid, must be a Permission')
return method(self, other, *args, **kwargs)
return wrapper
def is_role(method):
@wraps(method)
def wrapper(self, other, *args, **kwargs):
if not isinstance(other, Role):
raise TypeError('Value invalid, must be a Role')
return method(self, other, *args, **kwargs)
return wrapper
class Role(object):
def __init__(self, key, value=None):
if not isinstance(key, basestring):
raise TypeError('Key invalid, must be a string')
self._key = key
self._value = value
@property
def key(self):
return self._key
@property
def value(self):
return self._value
def _hash(self, other=None):
if not isinstance(other, Role):
other = self
value = '%s$%s$%s' % (other.__class__.__name__, other.key, other.value)
return hash(value)
__hash__ = _hash
@is_role
def __eq__(self, other):
return self.__hash__() == other.__hash__()
def __repr__(self):
return '<%s key="%s" value="%s">' % \
(self.__class__.__name__, self.key, self.value)
def __str__(self):
return '<%s="%s">' % (self.key, self.value)
class Permission(object):
def __init__(self, *roles):
if not roles:
raise ValueError('List of roles not defined')
self._roles = set(roles)
def require(self, http_error=None, **kwargs):
return AuthContext(self, REQUIRED_MODE, http_error, **kwargs)
def strict(self, http_error=None, **kwargs):
return AuthContext(self, STRICT_MODE, http_error, **kwargs)
@property
def roles(self):
return self._roles
@is_permissions
def union(self, other):
permissions = self.roles.union(other.roles)
return Permission(*permissions)
@is_permissions
def intersection(self, other):
permissions = self.roles.intersection(other.roles)
return Permission(*permissions)
@is_permissions
def difference(self, other):
permissions = self.roles.difference(other.roles)
return Permission(*permissions)
@is_permissions
def symmetric_difference(self, other):
permissions = self.roles.symmetric_difference(other.roles)
return Permission(*permissions)
def __repr__(self):
return '<%s roles="%s">' % \
(self.__class__.__name__, self.roles)
def __str__(self):
return '%s=Roles(%d)' % \
(self.__class__.__name__, len(self.roles))
def to_object(self):
result = {}
for item in self.roles:
result[item.key] = item.value
return result
class Identity(object):
def __init__(self, user, *args):
if not isinstance(user, dict):
raise ValueError('User invalid, must be a dictionary')
data = []
user_data = ('_id', 'username', 'email',)
for key, value in user.items():
if key not in user_data:
raise ValueError('User key %s, not defined' % key)
elif not value:
raise ValueError('User key-value %s, not defined' % key)
data.append(Role(key, value))
for key, value in user.get('permissions', {}):
data.append(Role(key, value))
self._user = user
self._roles = set(data)
self.add_roles(*args)
@property
def user(self):
return self._user
@property
def role(self):
return self._roles
@property
def is_active(self):
return self._user.get('enabled', False) \
and self._user.get('available', False)
@is_role
def add_role(self, other):
try:
self._roles.add(other)
except Exception:
raise ValueError('Role not supported: %s' % other)
def add_roles(self, *args):
for item in args:
self.add_role(item)
@is_role
def remove_role(self, other):
try:
self._roles.remove(other)
except Exception:
raise ValueError('Role not found: %s' % other)
def remove_roles(self, *args):
for item in args:
self.remove_role(item)
@is_role
def __contains__(self, item):
try:
return item in self._roles
except Exception:
return False
def __repr__(self):
return '<%s roles="%s">' % (self.__class__.__name__, self._roles)
def __str__(self):
return '%s=%d' % (self.__class__.__name__, len(self._roles))
class IdentityMixin(object):
@property
def identity(self):
if not isinstance(self, RequestHandler):
raise TypeError('Class invalid, must be a tornado RequestHandler')
elif not hasattr(self, '_identity'):
try:
current_user = getattr(self, 'current_user')
setattr(self, '_identity', Identity(current_user))
except Exception:
raise ValueError('Current user not defined')
return getattr(self, '_identity')
class AuthContext(object):
def __init__(self, permissions, mode, http_error=HTTP_ERROR, **kwargs):
if not isinstance(permissions, Permission):
raise TypeError('Permissions invalids, must be a Permission')
elif not isinstance(mode, basestring):
raise TypeError('Mode invalid, must be a string')
mode = mode.lower()
if mode not in (REQUIRED_MODE, STRICT_MODE,):
raise ValueError('Mode not supported: %s' % mode)
self._permissions = permissions
self._mode = mode
self._http_error = http_error
self._options = kwargs or {}
def validate(self, **kwargs):
options = {}
options.update(kwargs or {})
options.update(self._options)
handler = options.get('handler')
if not isinstance(handler, RequestHandler):
raise HTTPError(
self._http_error,
'Handler invalid, must be a tornado RequestHandler'
)
method = handler.request.method
if method not in ('GET', 'POST',):
raise HTTPError(
self._http_error, 'Method not supported: %s' % method)
elif not handler.current_user:
url = handler.get_login_url()
if '?' not in url:
if urlparse.urlsplit(url).scheme:
url = handler.request.full_url()
else:
next_url = {'next': handler.request.uri}
url = '%s?%s' % (url, urllib.urlencode(next_url))
return handler.redirect(url)
validator = options.get('validator')
if hasattr(validator, 'context_validator'):
validator = getattr(validator, 'context_validator')
if validator and not options.get('ignore_context_validator'):
return validator(context=self)
elif not hasattr(handler, 'identity'):
raise HTTPError(self._http_error, 'Identity not supported')
roles = self._permissions.roles
auth = getattr(handler, 'identity').roles.intersection(roles)
if (self._mode == REQUIRED_MODE and not auth) \
or (self._mode == STRICT_MODE and auth != roles):
raise HTTPError(self._http_error)
return True
def __call__(self, method):
context = self
@wraps(method)
def wrapper(ref, *args, **kwargs):
context.validate(handler=ref)
return method(ref, *args, **kwargs)
return wrapper
def __enter__(self):
self.validate()
def __exit__(self, exc_type, exc_val, exc_tb):
return False
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import unittest
from auth import *
from tornado.testing import LogTrapTestCase, AsyncHTTPTestCase
from tornado.web import Application, RequestHandler, HTTPError
user_data = dict(
uid=1 << 8,
username='sysadmin',
rid=2 << 8,
role_name='admin',
email='sysadmin@localhost',
enabled=True,
available=True,
)
class TestAuthRole(unittest.TestCase):
def test_01_role(self):
username = 'admin'
email = 'admin@local'
r_username = Role('username', username)
r_email = Role('email', email)
self.assertIsInstance(r_username, Role)
self.assertEqual(r_username.name, 'username')
self.assertEqual(r_username.value, username)
self.assertIsInstance(r_email, Role)
self.assertEqual(r_email.name, 'email')
self.assertEqual(r_email.value, email)
self.assertRaises(TypeError, Role.__init__, 1, 1)
class TestAuthPermission(unittest.TestCase):
def test_01_permission(self):
r_username = Role('username', 'admin')
r_name = Role('role_name', 'admin')
r_email = Role('email', 'admin@local')
r_uid = Role('uid', 1 << 8)
r_rid = Role('rid', 2 << 8)
perm_a = Permission(r_username, r_name)
perm_b = Permission(r_name, r_rid)
perm_c = Permission(r_email, r_username, r_uid)
self.assertNotEqual(perm_a, perm_b)
self.assertNotEqual(perm_b, perm_c)
self.assertNotEqual(perm_c, perm_a)
perm_a_b = perm_a.union(perm_b)
perm_a_b_role = [r_username, r_name, r_rid]
perm_a_b_comp = set(perm_a_b_role)
self.assertEqual(perm_a_b.roles, perm_a_b_comp)
perm_a_b_role = [r_username, r_name]
perm_a_b_comp = set(perm_a_b_role)
self.assertNotEqual(perm_a_b.roles, perm_a_b_comp)
perm_b_c = perm_b.union(perm_c)
perm_b_c_role = [r_username, r_name, r_rid, r_email, r_uid]
perm_b_c_comp = set(perm_b_c_role)
self.assertEqual(perm_b_c.roles, perm_b_c_comp)
perm_b_c_role = [r_username, r_name, r_rid, r_email]
perm_b_c_comp = set(perm_b_c_role)
self.assertNotEqual(perm_b_c.roles, perm_b_c_comp)
perm_c_a = perm_c.union(perm_a)
perm_c_a_role = [r_username, r_name, r_email, r_uid]
perm_c_a_comp = set(perm_c_a_role)
self.assertEqual(perm_c_a.roles, perm_c_a_comp)
perm_c_a_role = [r_username, r_name, r_email]
perm_c_a_comp = set(perm_c_a_role)
self.assertNotEqual(perm_c_a.roles, perm_c_a_comp)
perm_a_b_inter_c_a = perm_a_b.intersection(perm_c_a)
perm_a_b_inter_c_a_role = [r_username, r_name]
perm_a_b_inter_c_a_comp = set(perm_a_b_inter_c_a_role)
self.assertEqual(perm_a_b_inter_c_a.roles, perm_a_b_inter_c_a_comp)
perm_a_b_diff_c_a = perm_a_b.difference(perm_c_a)
perm_a_b_diff_c_a_role = [r_rid]
perm_a_b_diff_c_a_comp = set(perm_a_b_diff_c_a_role)
self.assertEqual(perm_a_b_diff_c_a.roles, perm_a_b_diff_c_a_comp)
perm_a_b_sym_c_a = perm_a_b.symmetric_difference(perm_c_a)
perm_a_b_sym_c_a_role = [r_rid, r_uid, r_email]
perm_a_b_sym_c_a_comp = set(perm_a_b_sym_c_a_role)
self.assertEqual(perm_a_b_sym_c_a.roles, perm_a_b_sym_c_a_comp)
perm_a_copy = Permission(*perm_a.roles)
self.assertEqual(perm_a_copy.roles, perm_a.roles)
perm_a_union = perm_a_copy.union(dict())
self.assertEqual(perm_a_copy.roles, perm_a_union.roles)
perm_a_union = perm_a_copy.union('other_perm')
self.assertEqual(perm_a_copy.roles, perm_a_union.roles)
perm_a_union = perm_a_copy.union(perm_b)
self.assertNotEqual(perm_a_copy.roles, perm_a_union.roles)
class TestAuthIdentity(unittest.TestCase):
def test_01_Identity(self):
user = user_data.copy()
self.assertRaises(TypeError, Identity.__init__, None)
user_fail = user.copy()
del user_fail['uid']
self.assertRaises(TypeError, Identity.__init__, user_fail)
i = Identity(user)
self.assertEqual(i.user, user)
i = Identity(user, Role('a', 1), Role('b', 2))
self.assertTrue(i.has_role(Role('a', 1)))
self.assertTrue(i.has_role(Role('b', 2)))
self.assertTrue(i.has_roles(Role('a', 1), Role('b', 2)))
self.assertTrue(i.is_active)
self.assertTrue(i.is_superuser)
self.assertTrue(i.is_superuser_active)
self.assertTrue(i.is_admin)
self.assertTrue(i.is_admin_active)
self.assertFalse(i.is_sysadmin)
self.assertFalse(i.is_sysadmin_active)
i.add_role(Role('c', 3))
self.assertTrue(i.has_role(Role('c', 3)))
self.assertFalse(i.has_role(Role('d', 4)))
i.add_roles(Role('d', 4), Role('e', 5), Role('f', 6), Role('d', 4))
self.assertTrue(i.has_roles(Role('d', 4), Role('e', 5), Role('f', 6)))
i.remove_role(Role('d', 4))
self.assertFalse(i.has_role(Role('d', 4)))
i.remove_roles(Role('e', 5), Role('f', 6))
self.assertFalse(i.has_roles(Role('e', 5), Role('f', 6)))
self.assertRaises(TypeError, i.add_role, 1)
self.assertRaises(TypeError, i.add_role, 'role')
self.assertRaises(TypeError, i.add_role, dict())
# Request's
GENERIC_RESPONSE = 'success'.encode()
LOGIN_RESPONSE = 'login'.encode()
perm_admin_ids = Permission(
Role('uid', 1 << 8), Role('rid', 2 << 8))
perm_admin_names = Permission(
Role('username', 'sysadmin'), Role('role_name', 'admin'))
perm_admin_all = perm_admin_ids.union(perm_admin_names)
class BaseHandler(RequestHandler, IdentityMixin):
def initialize(self):
self._current_user = user_data.copy()
def get(self):
self.finish(GENERIC_RESPONSE)
class BaseRequireErrorHandler(BaseHandler):
def initialize(self):
super().initialize()
self._current_user['uid'] = 2 << 8
self._current_user['username'] = 'moderator'
self._current_user['rid'] = 3 << 8
self._current_user['role_name'] = 'moderator'
self._current_user['email'] = 'moderator@local'
class BaseStrictErrorHandler(BaseHandler):
def initialize(self):
super().initialize()
self._current_user['uid'] = -1
class DecoratorRequireHandler(BaseHandler):
@perm_admin_all.require()
def get(self):
super().get()
class DecoratorRequireErrorHandler(
DecoratorRequireHandler, BaseRequireErrorHandler):
pass
class DecoratorStrictHandler(BaseHandler):
@perm_admin_all.strict()
def get(self):
super().get()
class DecoratorStrictErrorHandler(
DecoratorStrictHandler, BaseStrictErrorHandler):
pass
class WithRequireHandler(BaseHandler):
def get(self):
with perm_admin_all.require(handler=self):
super().get()
class WithRequireErrorHandler(
WithRequireHandler, BaseRequireErrorHandler):
pass
class WithStrictHandler(BaseHandler):
def get(self):
with perm_admin_all.strict(handler=self):
super().get()
class WithStrictErrorHandler(
WithStrictHandler, BaseStrictErrorHandler):
pass
class RequireHandler(BaseHandler):
def get(self):
perm = perm_admin_all.require(handler=self)
if perm.validate():
super().get()
class RequireErrorHandler(RequireHandler, BaseRequireErrorHandler):
pass
class StrictHandler(BaseHandler):
def get(self):
perm = perm_admin_all.strict(handler=self)
if perm.validate():
super().get()
class StrictErrorHandler(StrictHandler, BaseStrictErrorHandler):
pass
class ValidatorMethodHandler(DecoratorRequireHandler):
def context_validator(self, **kwargs):
return True
class ValidatorMethodErrorHandler(DecoratorRequireHandler):
def context_validator(self, **kwargs):
raise HTTPError(403)
class ValidatorMethodIgnoreHandler(BaseHandler):
@perm_admin_all.require(ignore_context_validator=True)
def get(self):
super().get()
def context_validator(self, **kwargs):
raise HTTPError(403)
class ValidatorMethodNotIgnoreHandler(BaseHandler):
@perm_admin_all.require(ignore_context_validator=False)
def get(self):
super().get()
def context_validator(self, **kwargs):
raise HTTPError(403)
class NotHandlerErrorHandler(BaseHandler):
def get(self):
with perm_admin_all.require():
super().get()
class NotHandlerInConditionalErrorHandler(BaseHandler):
def get(self):
perm = perm_admin_all.require()
if perm.validate():
super().get()
class NotIdentityErrorHandler(RequestHandler):
def initialize(self):
self._current_user = user_data.copy()
def get(self):
with perm_admin_all.require():
self.finish(GENERIC_RESPONSE)
class NotCurrentUserErrorHandler(DecoratorRequireHandler):
@property
def identity(self):
return Identity(None)
class DecoratorNotCurrentUserHandler(RequestHandler):
def initialize(self):
self.settings['login_url'] = '/auth/login'
@perm_admin_all.require()
def get(self):
self.finish(GENERIC_RESPONSE)
class WithNotCurrentUserHandler(DecoratorNotCurrentUserHandler):
def get(self):
with perm_admin_all.require(handler=self):
self.finish(GENERIC_RESPONSE)
class ConditionalNotCurrentUserHandler(DecoratorNotCurrentUserHandler):
def get(self):
perm = perm_admin_all.require(handler=self)
if perm.validate():
self.finish(GENERIC_RESPONSE)
class AuthLoginHandler(RequestHandler):
def get(self):
self.finish(LOGIN_RESPONSE)
class TestAuthContext(AsyncHTTPTestCase, LogTrapTestCase):
def get_app(self):
handlers = [
('/decorator-require', DecoratorRequireHandler),
('/decorator-require-error', DecoratorRequireErrorHandler),
('/decorator-strict', DecoratorStrictHandler),
('/decorator-strict-error', DecoratorStrictErrorHandler),
('/with-require', WithRequireHandler),
('/with-require-error', WithRequireErrorHandler),
('/with-strict', WithStrictHandler),
('/with-strict-error', WithStrictErrorHandler),
('/require', RequireHandler),
('/require-error', RequireErrorHandler),
('/strict', StrictHandler),
('/strict-error', StrictErrorHandler),
('/validator-method', ValidatorMethodHandler),
('/validator-method-error', ValidatorMethodErrorHandler),
('/validator-method-ignore', ValidatorMethodIgnoreHandler),
('/validator-method-not-ignore', ValidatorMethodNotIgnoreHandler),
('/not-handler-error', NotHandlerErrorHandler),
('/not-handler-in-conditional-error',
NotHandlerInConditionalErrorHandler),
('/not-identity-error', NotIdentityErrorHandler),
('/not-current-user-error', NotCurrentUserErrorHandler),
('/decorator-not-current-user', DecoratorNotCurrentUserHandler),
('/with-not-current-user', WithNotCurrentUserHandler),
('/conditional-not-current-user', ConditionalNotCurrentUserHandler),
('/auth/login', AuthLoginHandler),
]
return Application(handlers)
def get_url_response(self, url='/', follow_redirects=False):
self.http_client.fetch(
self.get_url(url), self.stop, follow_redirects=follow_redirects)
return self.wait()
def test_01_decorator_all(self):
response = self.get_url_response('/decorator-require')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/decorator-require-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/decorator-strict')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/decorator-strict-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
def test_02_with_all(self):
response = self.get_url_response('/with-require')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/with-require-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/with-strict')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/with-strict-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
def test_03_simple_all(self):
response = self.get_url_response('/require')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/require-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/strict')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/strict-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
def test_04_validator_method(self):
response = self.get_url_response('/validator-method')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/validator-method-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/validator-method-ignore')
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/validator-method-not-ignore')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
def test_05_common_errors(self):
response = self.get_url_response('/not-handler-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/not-handler-in-conditional-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/not-identity-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response('/not-current-user-error')
self.assertEqual(response.code, 403)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
def test_06_identity_mixin(self):
class IdentityMixinHandler(IdentityMixin):
def __init__(self):
self.current_user = user_data.copy()
i = IdentityMixinHandler()
self.assertTrue(hasattr(i, 'identity'))
self.assertTrue(i.identity.is_superuser)
self.assertTrue(i.identity.is_active)
def test_07_not_current_user(self):
response = self.get_url_response(
'/decorator-not-current-user', follow_redirects=True)
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), LOGIN_RESPONSE)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response(
'/with-not-current-user', follow_redirects=True)
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), LOGIN_RESPONSE)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
response = self.get_url_response(
'/conditional-not-current-user', follow_redirects=True)
self.assertEqual(response.code, 200)
self.assertEqual(response.buffer.read(), LOGIN_RESPONSE)
self.assertNotEqual(response.buffer.read(), GENERIC_RESPONSE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment