Last active
July 14, 2016 07:22
-
-
Save anti1869/38607fb95736dea98e0850cbe563052c to your computer and use it in GitHub Desktop.
Asynchronous Nano ORM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Some kind of Nano-ORM. Just playing along | |
""" | |
import asyncio | |
from collections import OrderedDict | |
import logging | |
import re | |
from motor import motor_asyncio | |
logger = logging.getLogger(__name__) | |
client = motor_asyncio.AsyncIOMotorClient() | |
# TODO: Set database and collection somewhere else | |
db = client.test_database | |
class DoesNotExist(Exception): | |
pass | |
class FieldDescriptor(object): | |
# Fields are exposed as descriptors in order to control access to the | |
# underlying "raw" data. | |
def __init__(self, field): | |
self.field = field | |
self.att_name = self.field.name | |
def __get__(self, instance, instance_type=None): | |
if instance is not None: | |
return instance._data.get(self.att_name) | |
return self.field | |
def __set__(self, instance, value): | |
instance._data[self.att_name] = value | |
instance._dirty.add(self.att_name) | |
class Field(object): | |
_field_counter = 0 | |
_order = 0 | |
_data = None | |
def __init__(self, default=None, verbose_name=None, primary_key=None): | |
self.default = default | |
self.verbose_name = verbose_name | |
# Used internally for recovering the order in which Fields were defined | |
# on the Model class. | |
Field._field_counter += 1 | |
self._order = Field._field_counter | |
self._is_primary_key = primary_key or self._order == 1 | |
def add_to_class(self, model_class, name): | |
""" | |
Hook that replaces the `Field` attribute on a class with a named | |
`FieldDescriptor`. Called by the metaclass during construction of the | |
`Model`. | |
""" | |
self.name = name | |
self.model_class = model_class | |
if not self.verbose_name: | |
self.verbose_name = re.sub('_+', ' ', name).title() | |
# model_class._meta.add_field(self) | |
setattr(model_class, name, FieldDescriptor(self)) | |
self._is_bound = True | |
class TextField(Field): | |
pass | |
class IntegerField(Field): | |
pass | |
class FloatField(Field): | |
pass | |
class BooleanField(Field): | |
pass | |
class ModelOptions(object): | |
def __init__(self, cls, database=None, collection=None, key_field=None, **kwargs): | |
self.model_class = cls | |
self.name = cls.__name__.lower() | |
self.fields = kwargs.get("fields", {}) | |
self.database = database | |
self.collection = collection | |
self.key_field = key_field | |
class ObjectManager(object): | |
def __init__(self, model): | |
self._model = model | |
async def list(self, **query): | |
cursor = self._model._meta.collection.find() | |
ret = (self._model(**document) for document in (await cursor.to_list(length=100))) | |
return ret | |
async def get(self, pk=None, **query): | |
logger.debug("Get key=%s, query=%s" % (pk, query)) | |
query = { | |
self._model._meta.key_field.name: pk | |
} | |
document = await self._model._meta.collection.find_one(query) | |
if document is None: | |
raise DoesNotExist(str(query)) | |
obj = self._model(**document) | |
return obj | |
class BaseModel(type): | |
def __new__(cls, name, bases, attrs): | |
# initialize the new class and set the magic attributes | |
cls = super(BaseModel, cls).__new__(cls, name, bases, attrs) | |
# replace fields with field descriptors, calling the add_to_class hook | |
fields = [] | |
for name, attr in cls.__dict__.items(): | |
if isinstance(attr, Field): | |
fields.append((attr, name)) | |
primary_key_field = None | |
for field, name in fields: | |
field.add_to_class(cls, name) | |
if field._is_primary_key: | |
primary_key_field = field | |
# Attach meta | |
meta_options = { | |
"key_field": primary_key_field, | |
"database": db, | |
"collection": db.test_collection, # TODO: Properly configure collection somewhere else | |
"fields": OrderedDict([(n, f) for f, n in fields]), | |
} | |
meta = attrs.pop('Meta', None) | |
if meta: | |
for k, v in meta.__dict__.items(): | |
if not k.startswith('_'): | |
meta_options[k] = v | |
# model_pk = getattr(meta, 'key_field', None) | |
cls._meta = ModelOptions(cls, **meta_options) | |
# Attach object manager | |
cls.objects = ObjectManager(cls) | |
# create a repr and error class before finalizing | |
if hasattr(cls, '__unicode__'): | |
setattr(cls, '__repr__', lambda self: '<%s: %r>' % ( | |
cls.__name__, self.__unicode__())) | |
exc_name = '%sDoesNotExist' % cls.__name__ | |
exc_attrs = {'__module__': cls.__module__} | |
exception_class = type(exc_name, (DoesNotExist,), exc_attrs) | |
cls.DoesNotExist = exception_class | |
return cls | |
class Persistent(object, metaclass=BaseModel): | |
objects = None | |
def __init__(self, *args, **kwargs): | |
self._data = {} | |
self._dirty = set() | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
self._data[k] = v | |
@property | |
def key_field_name(self): | |
return getattr(self, "_meta").key_field.name | |
@property | |
def pk(self): | |
return self._data.get(self.key_field_name, None) | |
@property | |
def collection(self): | |
return getattr(self, "_meta").collection | |
async def save(self): | |
logger.debug("Saving %s\nkeyfield = %s" % (self._data, self.key_field_name)) | |
print(self.pk) | |
await self.collection.save(self._data) | |
async def reload(self): | |
print("Reloading") | |
async def delete(self): | |
print("DELETE") | |
class Datacenter(Persistent): | |
name = TextField() | |
title = TextField() | |
if __name__ == "__main__": | |
async def test(): | |
from rollerdrome.dc import Datacenter | |
# dc1 = Datacenter() | |
# dc1.name = "testname" | |
# dc1.title = "testtitle" | |
# await dc1.save() | |
# print(await Datacenter.objects.list()) | |
# print(await Datacenter.objects.list(title="testtitle", stuff="notexisted")) | |
dc2 = await Datacenter.objects.get("testname") | |
await dc2.save() | |
await Datacenter.objects.list() | |
async def test_insert(): | |
doc = {"test": "tvalue"} | |
result = await db.test_collection.insert(doc) | |
print(repr(result)) | |
async def test_list(): | |
cursor = db.test_collection.find({"test": "tvalue"}) | |
for doc in await cursor.to_list(length=100): | |
print(doc) | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(test()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment