Last active
October 28, 2022 11:06
-
-
Save PaoloLeonard/6b2fe9b6e2241a24f8a82f86a0d4eaf6 to your computer and use it in GitHub Desktop.
Full implementation of a custom table expectation that compares the considered dataset row count to other datasets row count with the possibility of using different comparison keys.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Custom table expectation which checks whether the row count is greater than the row count of other tables. | |
There are different ways to compare the row counts: | |
* With absolute values, if one row count value of the other tables is greater than the current then the validation | |
fails, | |
* With mean values, if the mean of value of the other tables row count is greater than the current row count then | |
the validation fails. | |
""" | |
from copy import deepcopy | |
from enum import Enum, auto | |
from statistics import mean | |
from typing import Dict, Tuple, Any, Optional, Callable, List | |
from great_expectations.core import ExpectationConfiguration | |
from great_expectations.core.batch_spec import PathBatchSpec | |
from great_expectations.execution_engine import ( | |
SparkDFExecutionEngine, | |
PandasExecutionEngine, | |
ExecutionEngine, | |
) | |
from great_expectations.expectations.metrics.metric_provider import metric_value | |
from great_expectations.expectations.metrics.table_metric_provider import ( | |
TableMetricProvider, | |
) | |
from great_expectations.expectations.expectation import TableExpectation | |
from great_expectations.exceptions.exceptions import InvalidExpectationKwargsError | |
class SupportedComparisonEnum(Enum): | |
"""Enum class with the currently supported comparison type.""" | |
ABSOLUTE = auto() | |
MEAN = auto() | |
def __call__(self, *args, **kwargs): | |
if self.name == "ABSOLUTE": | |
return all(args[0] >= i * args[2] / 100 for i in args[1]) | |
elif self.name == "MEAN": | |
return args[0] >= mean(args[1]) * args[2] / 100 | |
else: | |
raise NotImplementedError("Comparison key is not supported.") | |
class OtherTableRowCount(TableMetricProvider): | |
"""MetricProvider class to get row count from different tables than the current.""" | |
metric_name = "table.row_count_other" | |
@metric_value(engine=PandasExecutionEngine) | |
def _pandas( | |
cls, | |
execution_engine: "PandasExecutionEngine", | |
metric_domain_kwargs: Dict, | |
metric_value_kwargs: Dict, | |
metrics: Dict[Tuple, Any], | |
runtime_configuration: Dict, | |
) -> int: | |
other_table_filename = metric_domain_kwargs.get("table_filename") | |
batch_spec = PathBatchSpec( | |
{"path": other_table_filename, "reader_method": "read_csv"} | |
) | |
batch_data = execution_engine.get_batch_data(batch_spec=batch_spec) | |
df = batch_data.dataframe | |
return df.shape[0] | |
@metric_value(engine=SparkDFExecutionEngine) | |
def _spark( | |
cls, | |
execution_engine: "SparkDFExecutionEngine", | |
metric_domain_kwargs: Dict, | |
metric_value_kwargs: Dict, | |
metrics: Dict[Tuple, Any], | |
runtime_configuration: Dict, | |
) -> int: | |
other_table_filename = metric_domain_kwargs.get("table_filename") | |
batch_spec = PathBatchSpec( | |
{"path": other_table_filename, | |
"reader_method": "csv"} | |
) | |
batch_data = execution_engine.get_batch_data(batch_spec=batch_spec) | |
df = batch_data.dataframe | |
return df.count() | |
class ExpectTableRowCountToBeMoreThanOthers(TableExpectation): | |
"""TableExpectation class to compare the row count of the current dataset to other dataset(s).""" | |
metric_dependencies = ("table.row_count", "table.row_count_other") | |
success_keys = ( | |
"other_table_filenames_list", | |
"comparison_key", | |
"lower_percentage_threshold", | |
) | |
default_kwarg_values = { | |
"other_table_filenames_list": None, | |
"comparison_key": "MEAN", | |
"lower_percentage_threshold": 100, | |
} | |
@staticmethod | |
def _validate_success_key( | |
param: str, | |
required: bool, | |
configuration: Optional[ExpectationConfiguration], | |
validation_rules: Dict[Callable, str], | |
) -> None: | |
"""Simple method to aggregate and apply validation rules to the `param`.""" | |
if param not in configuration.kwargs: | |
if required: | |
raise InvalidExpectationKwargsError( | |
f"Param {param} is required but was not found in configuration." | |
) | |
return | |
param_value = configuration.kwargs[param] | |
for rule, error_message in validation_rules.items(): | |
if not rule(param_value): | |
raise InvalidExpectationKwargsError(error_message) | |
def validate_configuration( | |
self, configuration: Optional[ExpectationConfiguration] | |
) -> bool: | |
super().validate_configuration(configuration=configuration) | |
if configuration is None: | |
configuration = self.configuration | |
self._validate_success_key( | |
param="other_table_filenames_list", | |
required=True, | |
configuration=configuration, | |
validation_rules={ | |
lambda x: isinstance(x, str) | |
or isinstance( | |
x, List | |
): "other_table_filenames_list should either be a list or a string.", | |
lambda x: x: "other_table_filenames_list should not be empty", | |
}, | |
) | |
self._validate_success_key( | |
param="comparison_key", | |
required=False, | |
configuration=configuration, | |
validation_rules={ | |
lambda x: isinstance(x, str): "comparison_key should be a string.", | |
lambda x: x.upper() | |
in SupportedComparisonEnum.__members__: "Given comparison_key is not supported.", | |
}, | |
) | |
self._validate_success_key( | |
param="lower_percentage_threshold", | |
required=False, | |
configuration=configuration, | |
validation_rules={ | |
lambda x: isinstance( | |
x, int | |
): "lower_percentage_threshold should be an integer.", | |
lambda x: x | |
> 0: "lower_percentage_threshold should be strictly greater than 0.", | |
}, | |
) | |
return True | |
def get_validation_dependencies( | |
self, | |
configuration: Optional[ExpectationConfiguration] = None, | |
execution_engine: Optional[ExecutionEngine] = None, | |
runtime_configuration: Optional[dict] = None, | |
) -> dict: | |
dependencies = super().get_validation_dependencies( | |
configuration, execution_engine, runtime_configuration | |
) | |
other_table_filenames_list = configuration.kwargs.get( | |
"other_table_filenames_list" | |
) | |
if isinstance(other_table_filenames_list, str): | |
other_table_filenames_list = [other_table_filenames_list] | |
for other_table_filename in other_table_filenames_list: | |
table_row_count_metric_config_other = deepcopy( | |
dependencies["metrics"]["table.row_count_other"] | |
) | |
table_row_count_metric_config_other.metric_domain_kwargs[ | |
"table_filename" | |
] = other_table_filename | |
dependencies["metrics"][ | |
f"table.row_count_other.{other_table_filename}" | |
] = table_row_count_metric_config_other | |
dependencies["metrics"]["table.row_count.self"] = dependencies["metrics"].pop( | |
"table.row_count" | |
) | |
dependencies["metrics"].pop("table.row_count_other") | |
return dependencies | |
def _validate( | |
self, | |
configuration: ExpectationConfiguration, | |
metrics: Dict, | |
runtime_configuration: dict = None, | |
execution_engine: ExecutionEngine = None, | |
) -> Dict: | |
comparison_key = self.get_success_kwargs(configuration)["comparison_key"] | |
other_table_filename_list = self.get_success_kwargs(configuration)[ | |
"other_table_filenames_list" | |
] | |
lower_percentage_threshold = self.get_success_kwargs(configuration)[ | |
"lower_percentage_threshold" | |
] | |
current_row_count = metrics["table.row_count.self"] | |
previous_row_count_list = [] | |
for other_table_filename in other_table_filename_list: | |
previous_row_count_list.append( | |
metrics[f"table.row_count_other.{other_table_filename}"] | |
) | |
comparison_key_fn = SupportedComparisonEnum[comparison_key.upper()] | |
success_flag = comparison_key_fn( | |
current_row_count, previous_row_count_list, lower_percentage_threshold | |
) | |
return { | |
"success": success_flag, | |
"result": { | |
"self": current_row_count, | |
"other": previous_row_count_list, | |
"comparison_key": comparison_key_fn.name, | |
}, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment