Last active
March 16, 2023 16:59
-
-
Save davidrpmorris/1f742a17553c3a21870cd0b75e876486 to your computer and use it in GitHub Desktop.
WTForms FormField with FieldList with JSON column type
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Skip to Line 56 to see how to set things up with JSON | |
# Normally each FormField is represented by its own row in a database. In this example | |
# each FormField is a row in the Specification table. | |
# models.py | |
class Specification(db.Model): | |
... | |
name = db.Column(db.String) | |
minimum = db.Column(db.Float) | |
maximum = db.Column(db.Float) | |
average = db.Column(db.Float) | |
equipment_id = db.Column(db.Integer, ForeignKey('equipments.id')) | |
equipment = db.relationship('Equipment', back_populates='specifications') | |
# A FieldList consists of one or more FormFields. | |
# This makes it very simple to create a many-to-one relationship between | |
# the child (Specification) and parent (Equipment): | |
# models.py | |
class Equipment(db.Model): | |
... | |
specifications = db.relationship('Specification', back_populates='equipment') | |
# forms.py: | |
from flask_wtf import Form | |
from models import Specification | |
from wtforms import FieldList, FormField | |
from wtforms import Form as NoCSRFForm | |
from wtforms.validators import Optional | |
class SpecificationForm(NoCSRFForm): | |
name = TextField('Specification', validators=[Optional()]) | |
minimum = DecimalField('Min Value', validators=[Optional()]) | |
maximum = DecimalField('Max Value', validators=[Optional()]) | |
average = DecimalField('Average', validators=[Optional()]) | |
class EquipmentForm(Form): | |
specifications = FieldList(FormField(SpecificationForm, | |
default=lambda: Specification()), | |
min_entries=1, max_entries=5) | |
# Setting the default to lambda: Specification() is required for | |
# if you want to use populate_obj in views.py: | |
# Refer to https://github.com/sebkouba/dynamic-flask-form and | |
# https://groups.google.com/forum/#!msg/wtforms/5KQvYdLFiKE/TSgHIxmsI8wJ | |
# This works fine, but what if you want specifications to be JSON? | |
# Not only can each column (name, maximum, minimum, average) be condensed | |
# to one column (specifications), but (1) multiple specifications can be | |
# expressed in one row (instead of multiple rows and columns) and (2) | |
# it can be added to the Equipment table, minimizing the number of tables and | |
# relationships: | |
# models.py | |
from sqlalchemy.dialects.postgresql import JSON | |
class Equipment(db.Model): | |
... | |
specifications = db.Column(JSON) | |
# specifications = [{'maximum': 30.0, 'minimum': 0.0, 'name': 'Temperature', 'average': 22.0}, | |
# {'maximum': 1.0, 'minimum': 0.0, 'name': 'Pressure'}] | |
# It's a little tricky getting WTForms to play nicely with JSON. | |
# There is even a package to help with this: https://github.com/kvesteri/wtforms-json | |
# Here I demonstrate how to configure a FieldList with multiple FormFields that | |
# are described by a single JSON column. The best part is that populate_obj can be | |
# called in views.py and it works. | |
# The fields are optional, so we also check to see if all of the fields are empty: | |
# if they are, then that FormField is skipped. If any are filled in, then the data is | |
# processed but the rest of the fields are skipped. | |
# forms.py | |
from flask_wtf import Form | |
from wtforms import (BooleanField, DecimalField, FieldList, | |
FormField, TextField) | |
from wtforms import Form as NoCSRFForm | |
import json | |
def decimal_default(obj): | |
'''Convert Python Decimal to float for JSON''' | |
if isinstance(obj, decimal.Decimal): | |
return float(obj) | |
raise TypeError | |
class CustomFormField(FormField): | |
def is_field_empty(self, data): | |
'''Check if the field is empty. Note that '0' | |
is a valid input.''' | |
for v in data.values(): | |
if type(v) == dict: | |
for _v in v.values(): | |
if _v != '': | |
return False | |
else: | |
if v != '': | |
return False | |
return True | |
def populate_obj(self, obj, name): | |
'''Unchanged except to pass if there is no data | |
in the form''' | |
if self.is_field_empty(self.data): | |
pass | |
else: | |
candidate = getattr(obj, name, None) | |
if candidate is None: | |
if self._obj is None: | |
raise TypeError('populate_obj: cannot find a value to populate from the provided obj or input data/defaults') | |
candidate = self._obj | |
setattr(obj, name, candidate) | |
self.form.populate_obj(candidate) | |
class SpecFieldList(FieldList): | |
def populate_obj(self, obj, name): | |
'''Populate a JSON column in a model''' | |
specs_list = [] | |
for spec in self.data: | |
specs_dict = {k: v for k, v in spec.items() if v is not None} # removes None (careful as 0 evals to False) | |
for e in [k for k,v in spec_dict.iteritems() if v == '']: specs_dict.pop(e) # removes empty string ('') | |
spec_list.append(specs_dict) | |
setattr(obj, name, json.dumps(specs_list, default=decimal_default)) # default=decimal_default is necessary to convert Decimal to float | |
def _extract_indices(self, prefix, formdata): | |
offset = len(prefix) + 1 | |
for k in formdata: | |
if k.startswith(prefix): | |
k = k[offset:].split('-', 1)[0] | |
if k.isdigit(): | |
yield int(k) | |
def process(self, formdata, data=unset_value): | |
self.entries = [] | |
if data is unset_value or not data: | |
try: | |
data = self.default() | |
except TypeError: | |
data = self.default | |
self.object_data = data | |
if formdata: | |
indices = sorted(set(self._extract_indices(self.name, formdata))) | |
if self.max_entries: | |
indices = indices[:self.max_entries] | |
idata = iter(data) | |
for index in indices: | |
try: | |
obj_data = next(idata) | |
except StopIteration: | |
obj_data = unset_value | |
self._add_entry(formdata, obj_data, index=index) | |
elif data: | |
for obj_data in json.loads(data): # convert to JSON | |
self._add_entry(formdata, obj_data) | |
while len(self.entries) < self.min_entries: | |
self._add_entry(formdata) | |
class SpecFormField(CustomFormField): | |
def populate_obj(self, obj, name): | |
pass # We only care about processing the entire FieldList when calling populate_obj | |
class SpecForm(NoCSRFForm): | |
name = TextField('Specification', validators=[Optional()]) | |
minimum = DecimalField('Min Value', validators=[Optional()]) | |
maximum = DecimalField('Max Value', validators=[Optional()]) | |
average = DecimalField('Average', validators=[Optional()]) | |
def populate_obj(self, obj): | |
for name, field in iteritems(self._fields): | |
field.populate_obj(obj, name) | |
class EquipmentForm(Form): | |
specs = SpecFieldList(SpecFormField(SpecForm), min_entries=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment