Skip to content

Instantly share code, notes, and snippets.

@Yuffster
Last active July 27, 2017 02:48
Show Gist options
  • Save Yuffster/a804a7e479c44b103ba3fb5171e61121 to your computer and use it in GitHub Desktop.
Save Yuffster/a804a7e479c44b103ba3fb5171e61121 to your computer and use it in GitHub Desktop.
import os.path
import json
import sqlite3
from contextlib import contextmanager
__data_path = None
def get_data_path(file):
return os.path.join(
os.path.dirname(__file__),
'..',
'data',
file
)
def init_db(data_path=None, env="dev"):
global __data_path
if data_path is None:
data_path = '{}.sqlite3'.format(env)
__data_path = data_path
with connection() as c:
c.execute(user_schema)
user_schema = """
CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
email TEXT NOT NULL,
password TEXT NOT NULL,
created_at datetime default current_timestamp,
pmt_key TEXT,
CONSTRAINT name_unique UNIQUE (name),
CONSTRAINT email_unqiue UNIQUE (email),
CONSTRAINT pmt_key_unqiue UNIQUE (pmt_key)
);
"""
@contextmanager
def connection(db_path=None):
if db_path is None:
db_path = get_data_path(__data_path or 'dev.sqlite3')
c = sqlite3.connect(db_path)
c.row_factory = sqlite3.Row
yield c
c.commit()
c.close()
class DataModel():
_table = False
_write = None # []
_read = None # []
_raw = None # {}
_dirty = None # {}
_new_record = False
def __init__(self, **kwargs):
self._raw = {}
self._dirty = {}
self._read = self._read or []
self._write = self._write or []
row = kwargs.pop("__dbrow", None)
if row:
self._raw = row
else:
for k, v in kwargs.items():
setattr(self, k, v)
self._new_record = True
def __getattr__(self, prop):
val = self._dirty.get(prop, self._raw.get(prop, None))
try:
transform = object.__getattribute__(self, "get_"+prop)
return transform(val)
except AttributeError:
return val
def __setattr__(self, prop, val):
if prop[0] is "_":
object.__setattr__(self, prop, val)
return val
if prop not in self._write:
raise AttributeError(prop, "is not writable")
transform = getattr(self, "set_"+prop, None)
if transform is not None:
val = transform(val)
self._dirty[prop] = val
def save(self):
if self._new_record:
self._insert()
else:
self._update()
def delete(self):
print(
"DELETE FROM {} WHERE id=? LIMIT 1;"\
.format(self._table), self.id
)
def _insert(self):
fields = []
vals = []
ph = []
for k, v in self._dirty.items():
if k not in self._read and k not in self._write:
raise AttributeError("Invalid field:", k)
fields.append(k)
vals.append(v)
ph.append("?")
ins = "INSERT INTO {} ({}) VALUES ({})".format(
self._table,
",".join(fields),
",".join(ph)
)
with connection() as c:
i = c.execute(ins, vals)
self._raw['id'] = i.lastrowid
self._move_dirty()
self._new_record = False
def _update(self):
vals = []
ups = []
up = "UPDATE {} SET ".format(self._table)
for k, v in self._dirty.items():
ups.append("{}=?".format(k))
vals.append(v)
up += ", ".join(ups)
up += " WHERE id={}".format(self._raw['id'])
with connection() as c:
c.execute(up, vals)
self._move_dirty()
def _move_dirty(self):
for k, v in self._dirty.items():
self._raw[k] = v
@classmethod
def find(cls, **kwargs):
limit = kwargs.pop('limit', None)
select = "SELECT * FROM {} WHERE ".format(cls._table)
vals = []
ks = []
for k, v in kwargs.items():
ks.append(k+'=?')
vals.append(v)
if len(ks) > 0:
select += ",".join(ks)
else:
select += "1"
if limit:
select += " LIMIT {}".format(int(limit))
data = []
with connection() as c:
for row in c.execute(select, vals).fetchall():
data.append(cls.fill_model(row))
return data
@classmethod
def first(cls, **kwargs):
kwargs['limit'] = 1
result = cls.find(**kwargs)
if len(result) == 0:
return None
return result[0]
@classmethod
def fill_model(cls, row):
raw = {}
for i, k in enumerate(row.keys()):
raw[k] = row[i]
return cls(__dbrow=raw)
def __repr__(self):
return self.__class__.__name__ +\
"("+json.dumps(self._raw, sort_keys=True, indent=4)+")"
from src db
import bcrypt
class User(db.DataModel):
_table = "users"
_write = ['name', 'email', 'password', 'pmt_key']
def set_password(self, password):
return bcrypt.hashpw(password, bcrypt.gensalt())
def check_pass(self, password):
return bcrypt.hashpw(password, self.password) == self.password
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment