Last active
August 28, 2024 17:25
-
-
Save jpmallette/8cdfda4f1734ad8c7174b3ffe6c0d416 to your computer and use it in GitHub Desktop.
Execute Cross Validation and Performance Loop
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def execute_cross_validation_and_performance_loop(cross_valid_params, metric = 'mse'): | |
""" Execute Cross Validation and Performance Loop | |
To complement Medium Blog Post: Prophet Auto Selection with Cross-Validation | |
Link https://medium.com/@jeanphilippemallette/prophet-auto-selection-with-cross-validation-7ba2c0a3beef | |
Parameters | |
---------- | |
cross_valid_params: List of dict | |
dict value same as cross_validation function argument | |
model, horizon, period, initial | |
metric: string | |
sort the dataframe in ascending order base on the | |
performance metric of your choice either mse, rmse, mae or mape | |
Returns | |
------- | |
A pd.DataFrame with cross_validation result. One row | |
per different configuration sorted ascending base on | |
the metric inputed by the user. | |
Example | |
-------- | |
>>> m = Prophet() | |
>>> df = pd.read_csv('/examples/example_wp_log_peyton_manning.csv') | |
>>> m.fit(df) | |
>>> cross_valid_params = [{'model': m, | |
'initial': '730 days', | |
'period': '180 days', | |
'horizon': '365 days'}, | |
{'model': m, | |
'initial': '500 days', | |
'period': '180 days', | |
'horizon': '365 days'}] | |
index initial horizon period mse rmse mae mape coverage | |
4332 500 days 365 days 180 days 0.663628 0.814634 0.627102 0.075824 0.572352 | |
3987 730 days 365 days 180 days 0.670460 0.818816 0.628407 0.075577 0.589017 | |
""" | |
assert metric in ['mse','rmse','mae','mape'], \ | |
'metric must be either mse, rmse, mae or mape' | |
df_ps = pd.DataFrame() | |
for cross_valid_param in cross_valid_params: | |
df_cv = cross_validation(**cross_valid_param) | |
df_p = performance_metrics(df_cv, rolling_window=1) | |
df_p['initial'] = cross_valid_param['initial'] | |
df_p['period'] = cross_valid_param['period'] | |
df_ps = df_ps.append(df_p) | |
df_ps = df_ps[['initial','horizon','period','mse' | |
,'rmse','mae','mape','coverage']] | |
return df_ps.sort_values(metric) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment