Skip to content

Instantly share code, notes, and snippets.

@SHi-ON
Created March 10, 2019 17:08
Show Gist options
  • Save SHi-ON/63839f3a3647051a180cb03af0f7d0d9 to your computer and use it in GitHub Desktop.
Save SHi-ON/63839f3a3647051a180cb03af0f7d0d9 to your computer and use it in GitHub Desktop.
An expirement to show how stratify option works
# Experiment to confirm the effect of stratify option in Scikit Learn, tran_test_split() method.
# by Shayan Amani
from sklearn.model_selection import train_test_split
import pandas as pd
raw_data = pd.read_csv("codebase/adrel/dataset/train.csv")
cnt = raw_data.groupby('label').count()
''' experiment begins '''
''' Part One: stratify is ON '''
train, validate = train_test_split(raw_data, test_size=0.1, random_state=seed, stratify=raw_data['label'])
tr = train.groupby('label').count()
for i in range(9):
ratio = tr.iloc[i][0] / cnt.iloc[i][0]
print(ratio)
# assert that all train label classes has 90% of raw data
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
''' Output:
0.9000484027105518
0.9000853970964987
0.8999281781182668
0.9000229832222477
0.900049115913556
0.8998682476943346
0.8999274836838289
0.9000227221086117
0.9000738370662072
'''
''' Part Two: stratify is OFF'''
train, validate = train_test_split(raw_data, test_size=0.1, random_state=seed)
tr = train.groupby('label').count()
for i in range(9):
ratio = tr.iloc[i][0] / cnt.iloc[i][0]
print(ratio)
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
''' Output:
0.9010164569215876
0.8936806148590948
0.8889154895858271
Traceback (most recent call last):
File "/home/shi-on/anaconda3/envs/PyON36/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3267, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-115-5835d3554147>", line 4, in <module>
assert 0.89 < ratio < 0.91, 'Ratio is not following the rules {}'.format(i)
AssertionError: Ratio is not following the rules 2
'''
@ShojibDE
Copy link

Hi, I applied your approach on my rating data: "train_data, test_data = train_test_split(rating_data, test_size=test_size, stratify= rating_data['reviewerID'])" , but it gives the following error: "ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2." Is there any way that I can apply the same function to split my rating data into train and test such that each users 80% of reviews goes to the training set and 20% to test set? Thank you in advance!

me also having same error

@chayanroyc
Copy link

This error is because some groups might only have one sample. Changing the grouping helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment