Created
February 23, 2020 15:24
-
-
Save bryanyang0528/0ebc86e3180ae029ade704752e6c0e0d to your computer and use it in GitHub Desktop.
assert.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def assertSQLEqual(first, second): | |
if isinstance(first, str) and isinstance(second, str): | |
first = "".join(first.split()) | |
second = "".join(second.split()) | |
if first != second: | |
raise AssertionError('{} != {}'.format(first, second)) | |
else: | |
raise TypeError('Both {} and {} should be str.'.format(first, second)) | |
def assert_sql(self, task_id=None, func=None): | |
if func is None: | |
return partial(self.assert_sql, task_id) | |
@wraps(func) | |
def wrapper(cls): | |
expected_sql = func(cls) | |
ti = TaskInstance(task=cls.dag.get_task(task_id), | |
execution_date=self.default_date) | |
ti.dry_run() | |
try: | |
rendered_sql = ti.task.sql | |
except AttributeError as e: | |
log.warning(e) | |
try: | |
rendered_sql = ti.task.templates_dict.get('sql') | |
except AttributeError as e: | |
log.warning(e) | |
rendered_sql = ti.task.query | |
except Exception as e: | |
log.error(e) | |
raise | |
except Exception as e: | |
log.error(e) | |
raise | |
self.assertSQLEqual(rendered_sql, expected_sql) | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment