Created
September 18, 2020 20:03
-
-
Save NicolasHug/2db607b01482988fa549eb2c8770f79f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from sklearn.utils.validation import check_random_state\n", | |
"from sklearn.model_selection import KFold\n", | |
"from sklearn.utils import shuffle" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Changes to Estimators and `clone()`:\n", | |
"\n", | |
"- A random seed is drawn in `__init__()`. `set_params()` is updated accordingly\n", | |
"- `clone()` can now explicitly support strict clones and statistical clones.\n", | |
"- `fit()` and `get_params()` are unchanged" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _sample_seed(random_state):\n", | |
" # sample a random seed to be stored as the random_state attribute\n", | |
" # ints are passed-through\n", | |
" if isinstance(random_state, int):\n", | |
" return random_state\n", | |
" else:\n", | |
" return check_random_state(random_state).randint(0, 2**32)\n", | |
"\n", | |
" \n", | |
"class Estimator():\n", | |
" def __init__(self, random_state=None):\n", | |
" self.random_state = _sample_seed(random_state)\n", | |
" \n", | |
" def fit(self, X=None, y=None):\n", | |
" # unchanged\n", | |
" rng = check_random_state(self.random_state)\n", | |
" print(rng.randint(0, 100))\n", | |
" return self\n", | |
" \n", | |
" def get_params(self):\n", | |
" # unchanged\n", | |
" return {'random_state': self.random_state}\n", | |
" \n", | |
" def set_params(self, random_state=None):\n", | |
" self.random_state = _sample_seed(random_state)\n", | |
" \n", | |
" def score(self, X, y):\n", | |
" return 0 # irrelevant\n", | |
"\n", | |
" \n", | |
"def _check_statistical_clone_possible(est):\n", | |
" if 'random_state' not in est.get_params():\n", | |
" raise ValueError(\"This estimator isn't random and can only have exact clones\")\n", | |
" \n", | |
"\n", | |
"def clone(est, statistical=False):\n", | |
" # Return a strict clone or a statistical clone.\n", | |
" \n", | |
" # statistical parameter can be:\n", | |
" # - False: a strict clone is returned\n", | |
" # - True: a statistical clone is returned. Its RNG is seeded from `est`\n", | |
" # - None, int, or RandomState instance: a statistical clone is returned.\n", | |
" # Its RNG is seeded from `statistical`. This is useful to\n", | |
" # create multiple statistical clones that don't have the same RNG\n", | |
" \n", | |
" params = est.get_params()\n", | |
" \n", | |
" if statistical is not False:\n", | |
" # A statistical clone is a clone with a different random_state attribute\n", | |
" _check_statistical_clone_possible(est)\n", | |
" rng = params['random_state'] if statistical is True else statistical\n", | |
" params['random_state'] = _sample_seed(check_random_state(rng))\n", | |
" \n", | |
" return est.__class__(**params)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Illustration of estimators behavior" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"25\n", | |
"25\n", | |
"95\n", | |
"95\n" | |
] | |
} | |
], | |
"source": [ | |
"# Multiple calls to fit on the same instance produce the same rng\n", | |
"# Also, fit is truely idempotent\n", | |
"\n", | |
"a = Estimator(random_state=None).fit().fit()\n", | |
"b = Estimator(random_state=None).fit().fit()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"exact clones:\n", | |
"25\n", | |
"95\n", | |
"statistical clones (different RNGs):\n", | |
"24\n", | |
"30\n", | |
"statistical clones with random_state=int: Different RNG can still be obtained\n", | |
"44\n", | |
"63\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<__main__.Estimator at 0x7f3d17b9fd90>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Users can explicitly create exact and statistical clones\n", | |
"# Exact clones can be obtained even if None/instances are passed (this is impossible in master)\n", | |
"\n", | |
"print(\"exact clones:\")\n", | |
"clone(a).fit()\n", | |
"clone(b).fit()\n", | |
"\n", | |
"print(\"statistical clones (different RNGs):\")\n", | |
"clone(a, statistical=True).fit()\n", | |
"clone(b, statistical=True).fit()\n", | |
"\n", | |
"# Also, statistical clones can be obtained even if ints are passed.\n", | |
"# In master, None/instances can only give statistical clones, and ints can only give exact clones\n", | |
"\n", | |
"print(\"statistical clones with random_state=int: Different RNG can still be obtained\")\n", | |
"with_int = Estimator(random_state=0)\n", | |
"with_int.fit()\n", | |
"clone(with_int, statistical=True).fit()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"a's RNG is unchanged\n", | |
"25\n", | |
"set a's RNG to that of b\n", | |
"95\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<__main__.Estimator at 0x7f3d4821ed60>" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Using set_params and get_params allows to get the exact same RNG as another esitmator\n", | |
"\n", | |
"print(\"a's RNG is unchanged\")\n", | |
"a.set_params(random_state=a.get_params()['random_state'])\n", | |
"a.fit()\n", | |
"\n", | |
"print(\"set a's RNG to that of b\")\n", | |
"a.set_params(random_state=b.get_params()['random_state'])\n", | |
"a.fit()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# CV routines: Users now have explicit control on the CV strategy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Example of what CV routines would look like.\n", | |
"# The behaviour of the CV procedure is now explicit, and doesn't depend on the estimator's random_state\n", | |
"# Use-cases C and D are supported by any estimator.\n", | |
"\n", | |
"def cross_val_score(est, X, y, cv, use_exact_clones=True):\n", | |
" # use_exact_clones:\n", | |
" # - if True, the same estimator RNG is used on each fold (use-case C) \n", | |
" # - if False, the different estimator RNG are used on each fold (use-case D) \n", | |
" # TODO: maybe the default should be 'auto': False if estimato has a random_state, True otherwise\n", | |
" \n", | |
" if use_exact_clones:\n", | |
" statistical = False\n", | |
" else:\n", | |
" # need a local RNG so that clones have different random_state attributes\n", | |
" _check_statistical_clone_possible(est)\n", | |
" statistical = np.random.RandomState(est.random_state)\n", | |
" \n", | |
" return [ # this whole part is unchanged except for the call to clone()\n", | |
" clone(est, statistical=statistical)\n", | |
" .fit(X[train], y[train])\n", | |
" .score(X[test], y[test])\n", | |
" for train, test in cv.split(X, y)\n", | |
" ]\n", | |
"\n", | |
"X = y = np.arange(10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Contant estimator RNG across folds, different estimator RNG across executions\n", | |
"19\n", | |
"19\n", | |
"19\n", | |
"19\n", | |
"19\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Contant estimator RNG across folds, different estimator RNG across executions\")\n", | |
"_ = cross_val_score(Estimator(random_state=None), X, y, cv=KFold(), use_exact_clones=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Different estimator RNG across folds, different estimator RNG across executions\n", | |
"76\n", | |
"5\n", | |
"1\n", | |
"49\n", | |
"72\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Different estimator RNG across folds, different estimator RNG across executions\")\n", | |
"_ = cross_val_score(Estimator(random_state=None), X, y, cv=KFold(), use_exact_clones=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Constant estimator RNG across folds, constant estimator RNG across executions\n", | |
"44\n", | |
"44\n", | |
"44\n", | |
"44\n", | |
"44\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Constant estimator RNG across folds, constant estimator RNG across executions\")\n", | |
"_ = cross_val_score(Estimator(random_state=0), X, y, cv=KFold(), use_exact_clones=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Different estimator RNG across folds, constant estimator RNG across executions\n", | |
"63\n", | |
"82\n", | |
"89\n", | |
"93\n", | |
"34\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Different estimator RNG across folds, constant estimator RNG across executions\")\n", | |
"_ = cross_val_score(Estimator(random_state=0), X, y, cv=KFold(), use_exact_clones=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Changes to CV Splitters\n", | |
"\n", | |
"Similar changes as for estimators: a seed is drawn in `__init__`\n", | |
"\n", | |
"`split` is unchanged" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TwoKFold:\n", | |
" \"\"\"Toy CV class that does shuffled 2-fold CV\"\"\"\n", | |
" def __init__(self, random_state=None):\n", | |
" self.random_state = _sample_seed(random_state)\n", | |
" \n", | |
" def split(self, X, y=None):\n", | |
" # Unchanged.\n", | |
" rng = check_random_state(self.random_state)\n", | |
" \n", | |
" indices = shuffle(np.arange(X.shape[0]), random_state=self.random_state)\n", | |
" mid = X.shape[0] // 2\n", | |
" \n", | |
" yield indices[:mid], indices[mid:]\n", | |
" yield indices[mid:], indices[:mid]\n", | |
" \n", | |
"X = np.arange(10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Illustration of CV Splitters behavior" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Multiple calls to split yield the same splits:\n", | |
"[(array([2, 1, 0, 8, 9]), array([4, 5, 3, 7, 6])), (array([4, 5, 3, 7, 6]), array([2, 1, 0, 8, 9]))]\n", | |
"[(array([2, 1, 0, 8, 9]), array([4, 5, 3, 7, 6])), (array([4, 5, 3, 7, 6]), array([2, 1, 0, 8, 9]))]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Multiple calls to split yield the same splits:\")\n", | |
"cv = TwoKFold(random_state=None)\n", | |
"print(list(cv.split(X)))\n", | |
"print(list(cv.split(X)))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment