Skip to content

Instantly share code, notes, and snippets.

@SHi-ON
Created March 10, 2019 17:08
Show Gist options
  • Save SHi-ON/63839f3a3647051a180cb03af0f7d0d9 to your computer and use it in GitHub Desktop.
Save SHi-ON/63839f3a3647051a180cb03af0f7d0d9 to your computer and use it in GitHub Desktop.
An expirement to show how stratify option works
# Experiment to confirm the effect of stratify option in Scikit Learn, tran_test_split() method.
# by Shayan Amani
from sklearn.model_selection import train_test_split
import pandas as pd
raw_data = pd.read_csv("codebase/adrel/dataset/train.csv")
cnt = raw_data.groupby('label').count()
''' experiment begins '''
''' Part One: stratify is ON '''
train, validate = train_test_split(raw_data, test_size=0.1, random_state=seed, stratify=raw_data['label'])
tr = train.groupby('label').count()
for i in range(9):
ratio = tr.iloc[i][0] / cnt.iloc[i][0]
print(ratio)
# assert that all train label classes has 90% of raw data
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
''' Output:
0.9000484027105518
0.9000853970964987
0.8999281781182668
0.9000229832222477
0.900049115913556
0.8998682476943346
0.8999274836838289
0.9000227221086117
0.9000738370662072
'''
''' Part Two: stratify is OFF'''
train, validate = train_test_split(raw_data, test_size=0.1, random_state=seed)
tr = train.groupby('label').count()
for i in range(9):
ratio = tr.iloc[i][0] / cnt.iloc[i][0]
print(ratio)
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
''' Output:
0.9010164569215876
0.8936806148590948
0.8889154895858271
Traceback (most recent call last):
File "/home/shi-on/anaconda3/envs/PyON36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3267, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-115-5835d3554147>", line 4, in <module>
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
AssertionError: Ratio is not following the rules 2
'''
@chayanroyc
Copy link

This error is because some groups might only have one sample. Changing the grouping helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment