Created
December 25, 2020 14:11
-
-
Save moaminsharifi/d8a706a123fc5691fdb34e74aefad667 to your computer and use it in GitHub Desktop.
Dataset Train and Test split
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def separate(X, y, train_percent = 70): | |
"""Separate Function: separate data set to train and test part | |
Which Each dataset have fair part of each class(or lable) | |
Parameters | |
---------- | |
X : numpy array or list | |
features of dataset | |
y : numpy array or list | |
label of dataset | |
train_percent : int | |
The second parameter and is number beween 1 - 99 | |
Returns | |
------- | |
tupple | |
X_train , y_train ,X_test , y_test | |
where X_train and y_train about {train_percent}% of dataset | |
where y_train and y_test is {100 - train_percent}% of dataset | |
""" | |
assert train_percent >= 1 and train_percent <= 99 , "at least train_percent must be one and maximum 99" | |
count_of_data = len(X) | |
unique_class = np.unique(y) | |
count_of_unique_class = len(unique_class) | |
""" | |
create key_value list which have each class as key | |
and indexes of class as value | |
""" | |
indexs_of_diffrent_class = {class_name:np.where(y == class_name)[0] for class_name in unique_class} | |
assert count_of_unique_class >= 1, "must be atleast two diffrent class" | |
train_set_index = [] | |
test_set_index = [] | |
for y_class in unique_class: | |
lenght_class = len(indexs_of_diffrent_class[y_class]) | |
print(f"count of {y_class} label is : {lenght_class}") | |
count_of_train = int((lenght_class / 100) * train_percent) | |
count_of_test = lenght_class - count_of_train | |
assert count_of_train >= 1 and count_of_test >= 1, "one of the sets is zero member" | |
train_indexes = list(indexs_of_diffrent_class[y_class][:count_of_train]) | |
test_indexes = list(indexs_of_diffrent_class[y_class][count_of_train:]) | |
train_set_index.extend(train_indexes) | |
test_set_index.extend(test_indexes) | |
X_train , y_train = X[train_set_index] , y[train_set_index] | |
X_test , y_test = X[test_set_index] , y[test_set_index] | |
print(f"train set is {len(train_set_index)} \ntesting set is {len(test_set_index)}") | |
return (X_train , y_train ,X_test , y_test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment