Last active
March 8, 2019 14:57
-
-
Save nokados/c09eb54a6fad8007aa65cd0f8baafb6f to your computer and use it in GitHub Desktop.
Analogue of sklearn's train_test_split for multilabel classification with stratification and shuffling. And also under/over sampling to make distributions of class lengths more flat.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
def parallel_shuffle(*arrays): | |
length = arrays[0].shape[0] | |
for arr in arrays: | |
assert arr.shape[0] == length | |
p = np.random.permutation(length) | |
return [arr[p] for arr in arrays] | |
def multi_strat_split(x_train, y_train, test_size=0.2, random_state=None): | |
# Sizes | |
test_freq = test_size | |
size = x_train.shape[0] | |
train_size = int((1-test_freq) * size) | |
test_size = size - train_size | |
# Shuffle before | |
y = np.array(y_train) | |
np.random.seed(random_state) | |
x, y = parallel_shuffle(x_train, y) | |
# Create resulting arrays | |
X_train_new = np.zeros((train_size, *x.shape[1:])) | |
Y_train_new = np.zeros((train_size, *y.shape[1:])) | |
X_test_new = np.zeros((test_size, *x.shape[1:])) | |
Y_test_new = np.zeros((test_size, *y.shape[1:])) | |
# Ordering classes by length | |
class_sizes = y.sum(axis=0) | |
class_indices = np.argsort(class_sizes) | |
# Choosing samples | |
test_index = 0 | |
train_index = 0 | |
used_indices = set() | |
for cls_id in class_indices: | |
cls_size = class_sizes[cls_id] | |
cls_train_size = int((1-test_freq) * cls_size) | |
cls_test_size = cls_size - cls_train_size | |
current_test_size = Y_test_new[:, cls_id].sum() | |
diff = cls_test_size - current_test_size | |
cls_samples_indices = np.argwhere(y[:, cls_id] == 1) | |
# Iterate to add test_samples | |
for ind in cls_samples_indices: | |
ind=ind[0] | |
if diff <= 0: | |
break | |
if test_index >= test_size: | |
break | |
if ind in used_indices: | |
continue | |
X_test_new[test_index] = x[ind] | |
Y_test_new[test_index] = y[ind] | |
test_index += 1 | |
used_indices.add(ind) | |
diff -= 1 | |
# iterate to add train_samples | |
for ind in cls_samples_indices: | |
ind=ind[0] | |
if train_index >= train_size: | |
break | |
if ind in used_indices: | |
continue | |
X_train_new[train_index] = x[ind] | |
Y_train_new[train_index] = y[ind] | |
train_index += 1 | |
used_indices.add(ind) | |
assert train_index == train_size | |
if test_index < test_size: | |
unused_indices = set(range(x.shape[0])) - used_indices | |
for ind in unused_indices: | |
X_test_new[test_index] = x[ind] | |
Y_test_new[test_index] = y[ind] | |
test_index += 1 | |
used_indices.add(ind) | |
assert test_index == test_size | |
test_parts = Y_test_new.sum(axis=0) / class_sizes | |
print('Min test_part: ', test_parts.min(), ' at index ', test_parts.argmin()) | |
print('Max test_part: ', test_parts.max(), ' at index ', test_parts.argmax()) | |
X_train_new, Y_train_new = parallel_shuffle(X_train_new, Y_train_new) | |
X_test_new, Y_test_new = parallel_shuffle(X_test_new, Y_test_new) | |
return X_train_new, X_test_new, Y_train_new, Y_test_new | |
def flat_sampling(x_train, y_train, max_quantile = 0.85, max_div_min=5, seed=None): | |
np.random.seed(seed) | |
x_train = np.array(x_train); y_train=np.array(y_train) | |
class_sizes = y_train.sum(axis=0) | |
print(f'BEFORE: Max size {class_sizes.max()}. Min size: {class_sizes.min()}. Total samples: {x_train.shape[0]}') | |
class_indices = np.argsort(class_sizes) | |
plt.bar(range(len(class_sizes)), class_sizes[class_indices]) | |
X = np.zeros((0, *x_train.shape[1:])) | |
y = np.zeros((0, *y_train.shape[1:])) | |
def updateXy(indices): | |
nonlocal X | |
nonlocal y | |
if len(indices) == 0: | |
return | |
X = np.concatenate((X, x_train[indices]), axis=0) | |
y = np.concatenate((y, y_train[indices]), axis=0) | |
max_size = int(class_sizes[class_indices[int(len(class_sizes) * max_quantile)]]) | |
min_size = max(1, max_size // max_div_min) | |
print(f'Expected AFTER: Max size {max_size}. Min size: {min_size}') | |
used_indices = set() | |
for cls_id in class_indices: | |
cls_samples_indices = np.argwhere(y_train[:, cls_id] == 1)[:,0] | |
actual_size = int(y[:, cls_id].sum()) | |
unused_indices = np.array([ind for ind in cls_samples_indices if ind not in used_indices]) | |
if class_sizes[cls_id] < min_size: | |
updateXy(cls_samples_indices) | |
add_num = max(0, int(min_size - class_sizes[cls_id] - actual_size)) | |
additional_indices = np.random.choice(cls_samples_indices, add_num) | |
updateXy(additional_indices) | |
elif class_sizes[cls_id] > max_size or len(unused_indices) + actual_size > max_size: | |
if max_size <= actual_size: | |
continue | |
indices = np.random.choice(unused_indices, max_size - actual_size, replace=False) | |
updateXy(indices) | |
else: | |
updateXy(unused_indices) | |
used_indices |= set(cls_samples_indices) | |
assert X.shape[0] == y.shape[0] | |
class_sizes = y.sum(axis=0) | |
print(f'Actual AFTER: Max size {class_sizes.max()}. Min size: {class_sizes.min()}. Total samples: {X.shape[0]}') | |
plt.bar(range(len(class_sizes)), class_sizes[class_indices], alpha=0.5) | |
plt.show() | |
return parallel_shuffle(X, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment