Last active
October 13, 2021 12:48
-
-
Save PaoloLeonard/3e5aa714397147d516778660573de023 to your computer and use it in GitHub Desktop.
Expectation implementation for the GE table expectation tutorial.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from typing import Dict, Tuple, Any, Optional, Callable, List | |
from great_expectations.core import ExpectationConfiguration | |
from great_expectations.execution_engine import ( | |
ExecutionEngine | |
) | |
from great_expectations.expectations.expectation import TableExpectation | |
from great_expectations.exceptions.exceptions import InvalidExpectationKwargsError | |
class ExpectTableRowCountToBeMoreThanOthers(TableExpectation): | |
"""TableExpectation class to compare the row count of the current dataset to other dataset(s).""" | |
metric_dependencies = ("table.row_count", "table.row_count_other") | |
success_keys = ( | |
"other_table_filenames_list", | |
"comparison_key", | |
"lower_percentage_threshold", | |
) | |
default_kwarg_values = { | |
"other_table_filenames_list": None, | |
"comparison_key": "MEAN", | |
"lower_percentage_threshold": 100, | |
} | |
@staticmethod | |
def _validate_success_key( | |
param: str, | |
required: bool, | |
configuration: Optional[ExpectationConfiguration], | |
validation_rules: Dict[Callable, str], | |
) -> None: | |
"""Simple method to aggregate and apply validation rules to the `param`.""" | |
if param not in configuration.kwargs: | |
if required: | |
raise InvalidExpectationKwargsError( | |
f"Param {param} is required but was not found in configuration." | |
) | |
return | |
param_value = configuration.kwargs[param] | |
for rule, error_message in validation_rules.items(): | |
if not rule(param_value): | |
raise InvalidExpectationKwargsError(error_message) | |
def validate_configuration( | |
self, configuration: Optional[ExpectationConfiguration] | |
) -> bool: | |
super().validate_configuration(configuration=configuration) | |
if configuration is None: | |
configuration = self.configuration | |
self._validate_success_key( | |
param="other_table_filenames_list", | |
required=True, | |
configuration=configuration, | |
validation_rules={ | |
lambda x: isinstance(x, str) | |
or isinstance( | |
x, List | |
): "other_table_filenames_list should either be a list or a string.", | |
lambda x: x: "other_table_filenames_list should not be empty", | |
}, | |
) | |
self._validate_success_key( | |
param="comparison_key", | |
required=False, | |
configuration=configuration, | |
validation_rules={ | |
lambda x: isinstance(x, str): "comparison_key should be a string.", | |
lambda x: x.upper() | |
in SupportedComparisonEnum.__members__: "Given comparison_key is not supported.", | |
}, | |
) | |
self._validate_success_key( | |
param="lower_percentage_threshold", | |
required=False, | |
configuration=configuration, | |
validation_rules={ | |
lambda x: isinstance( | |
x, int | |
): "lower_percentage_threshold should be an integer.", | |
lambda x: x | |
> 0: "lower_percentage_threshold should be strictly greater than 0.", | |
}, | |
) | |
return True | |
def get_validation_dependencies( | |
self, | |
configuration: Optional[ExpectationConfiguration] = None, | |
execution_engine: Optional[ExecutionEngine] = None, | |
runtime_configuration: Optional[dict] = None, | |
) -> dict: | |
dependencies = super().get_validation_dependencies( | |
configuration, execution_engine, runtime_configuration | |
) | |
other_table_filenames_list = configuration.kwargs.get( | |
"other_table_filenames_list" | |
) | |
if isinstance(other_table_filenames_list, str): | |
other_table_filenames_list = [other_table_filenames_list] | |
for other_table_filename in other_table_filenames_list: | |
table_row_count_metric_config_other = deepcopy( | |
dependencies["metrics"]["table.row_count_other"] | |
) | |
table_row_count_metric_config_other.metric_domain_kwargs[ | |
"table_filename" | |
] = other_table_filename | |
dependencies["metrics"][ | |
f"table.row_count_other.{other_table_filename}" | |
] = table_row_count_metric_config_other | |
dependencies["metrics"]["table.row_count.self"] = dependencies["metrics"].pop( | |
"table.row_count" | |
) | |
dependencies["metrics"].pop("table.row_count_other") | |
return dependencies | |
def _validate( | |
self, | |
configuration: ExpectationConfiguration, | |
metrics: Dict, | |
runtime_configuration: dict = None, | |
execution_engine: ExecutionEngine = None, | |
) -> Dict: | |
comparison_key = self.get_success_kwargs(configuration)["comparison_key"] | |
other_table_filename_list = self.get_success_kwargs(configuration)[ | |
"other_table_filenames_list" | |
] | |
lower_percentage_threshold = self.get_success_kwargs(configuration)[ | |
"lower_percentage_threshold" | |
] | |
current_row_count = metrics["table.row_count.self"] | |
previous_row_count_list = [] | |
for other_table_filename in other_table_filename_list: | |
previous_row_count_list.append( | |
metrics[f"table.row_count_other.{other_table_filename}"] | |
) | |
comparison_key_fn = SupportedComparisonEnum[comparison_key.upper()] | |
success_flag = comparison_key_fn( | |
current_row_count, previous_row_count_list, lower_percentage_threshold | |
) | |
return { | |
"success": success_flag, | |
"result": { | |
"self": current_row_count, | |
"other": previous_row_count_list, | |
"comparison_key": comparison_key_fn.name, | |
}, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment