Last active
July 27, 2017 02:48
-
-
Save Yuffster/a804a7e479c44b103ba3fb5171e61121 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
import json | |
import sqlite3 | |
from contextlib import contextmanager | |
__data_path = None | |
def get_data_path(file): | |
return os.path.join( | |
os.path.dirname(__file__), | |
'..', | |
'data', | |
file | |
) | |
def init_db(data_path=None, env="dev"): | |
global __data_path | |
if data_path is None: | |
data_path = '{}.sqlite3'.format(env) | |
__data_path = data_path | |
with connection() as c: | |
c.execute(user_schema) | |
user_schema = """ | |
CREATE TABLE IF NOT EXISTS users( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
name TEXT NOT NULL UNIQUE, | |
email TEXT NOT NULL, | |
password TEXT NOT NULL, | |
created_at datetime default current_timestamp, | |
pmt_key TEXT, | |
CONSTRAINT name_unique UNIQUE (name), | |
CONSTRAINT email_unqiue UNIQUE (email), | |
CONSTRAINT pmt_key_unqiue UNIQUE (pmt_key) | |
); | |
""" | |
@contextmanager | |
def connection(db_path=None): | |
if db_path is None: | |
db_path = get_data_path(__data_path or 'dev.sqlite3') | |
c = sqlite3.connect(db_path) | |
c.row_factory = sqlite3.Row | |
yield c | |
c.commit() | |
c.close() | |
class DataModel(): | |
_table = False | |
_write = None # [] | |
_read = None # [] | |
_raw = None # {} | |
_dirty = None # {} | |
_new_record = False | |
def __init__(self, **kwargs): | |
self._raw = {} | |
self._dirty = {} | |
self._read = self._read or [] | |
self._write = self._write or [] | |
row = kwargs.pop("__dbrow", None) | |
if row: | |
self._raw = row | |
else: | |
for k, v in kwargs.items(): | |
setattr(self, k, v) | |
self._new_record = True | |
def __getattr__(self, prop): | |
val = self._dirty.get(prop, self._raw.get(prop, None)) | |
try: | |
transform = object.__getattribute__(self, "get_"+prop) | |
return transform(val) | |
except AttributeError: | |
return val | |
def __setattr__(self, prop, val): | |
if prop[0] is "_": | |
object.__setattr__(self, prop, val) | |
return val | |
if prop not in self._write: | |
raise AttributeError(prop, "is not writable") | |
transform = getattr(self, "set_"+prop, None) | |
if transform is not None: | |
val = transform(val) | |
self._dirty[prop] = val | |
def save(self): | |
if self._new_record: | |
self._insert() | |
else: | |
self._update() | |
def delete(self): | |
print( | |
"DELETE FROM {} WHERE id=? LIMIT 1;"\ | |
.format(self._table), self.id | |
) | |
def _insert(self): | |
fields = [] | |
vals = [] | |
ph = [] | |
for k, v in self._dirty.items(): | |
if k not in self._read and k not in self._write: | |
raise AttributeError("Invalid field:", k) | |
fields.append(k) | |
vals.append(v) | |
ph.append("?") | |
ins = "INSERT INTO {} ({}) VALUES ({})".format( | |
self._table, | |
",".join(fields), | |
",".join(ph) | |
) | |
with connection() as c: | |
i = c.execute(ins, vals) | |
self._raw['id'] = i.lastrowid | |
self._move_dirty() | |
self._new_record = False | |
def _update(self): | |
vals = [] | |
ups = [] | |
up = "UPDATE {} SET ".format(self._table) | |
for k, v in self._dirty.items(): | |
ups.append("{}=?".format(k)) | |
vals.append(v) | |
up += ", ".join(ups) | |
up += " WHERE id={}".format(self._raw['id']) | |
with connection() as c: | |
c.execute(up, vals) | |
self._move_dirty() | |
def _move_dirty(self): | |
for k, v in self._dirty.items(): | |
self._raw[k] = v | |
@classmethod | |
def find(cls, **kwargs): | |
limit = kwargs.pop('limit', None) | |
select = "SELECT * FROM {} WHERE ".format(cls._table) | |
vals = [] | |
ks = [] | |
for k, v in kwargs.items(): | |
ks.append(k+'=?') | |
vals.append(v) | |
if len(ks) > 0: | |
select += ",".join(ks) | |
else: | |
select += "1" | |
if limit: | |
select += " LIMIT {}".format(int(limit)) | |
data = [] | |
with connection() as c: | |
for row in c.execute(select, vals).fetchall(): | |
data.append(cls.fill_model(row)) | |
return data | |
@classmethod | |
def first(cls, **kwargs): | |
kwargs['limit'] = 1 | |
result = cls.find(**kwargs) | |
if len(result) == 0: | |
return None | |
return result[0] | |
@classmethod | |
def fill_model(cls, row): | |
raw = {} | |
for i, k in enumerate(row.keys()): | |
raw[k] = row[i] | |
return cls(__dbrow=raw) | |
def __repr__(self): | |
return self.__class__.__name__ +\ | |
"("+json.dumps(self._raw, sort_keys=True, indent=4)+")" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from src db | |
import bcrypt | |
class User(db.DataModel): | |
_table = "users" | |
_write = ['name', 'email', 'password', 'pmt_key'] | |
def set_password(self, password): | |
return bcrypt.hashpw(password, bcrypt.gensalt()) | |
def check_pass(self, password): | |
return bcrypt.hashpw(password, self.password) == self.password |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment