Skip to content

Instantly share code, notes, and snippets.

@berendgort
Created March 23, 2022 12:45
Show Gist options
  • Save berendgort/0878367a8a1c00916d8bda618ad823c1 to your computer and use it in GitHub Desktop.
Save berendgort/0878367a8a1c00916d8bda618ad823c1 to your computer and use it in GitHub Desktop.
cmap_data = plt.cm.Paired
cmap_cv = plt.cm.coolwarm
def plot_cv_indices(cv, X, y, group, ax, n_paths, k, paths, lw=5):
"""Create a sample plot for indices of a cross-validation object."""
# generate the combinations
N = n_paths + 1
test_groups = np.array(list(itt.combinations(np.arange(N), k))).reshape(-1, k)
n_splits = len(test_groups)
# Generate the training/testing visualizations for each CV split
for ii, (tr, tt) in enumerate(cv.split(X, y, pred_times=prediction_times, eval_times=evaluation_times)):
# Fill in indices with the training/test groups
indices = np.array([np.nan] * len(X))
indices[tt] = 1
indices[tr] = 0
indices[np.isnan(indices)] = 2
# Visualize the results
ax.scatter(
[ii + 0.5] * len(indices),
range(len(indices)),
c=[indices],
marker="_",
lw=lw,
cmap=cmap_cv,
vmin=-0.2,
vmax=1.2
)
# Plot the data classes and groups at the end
ax.scatter(
[ii + 1.5] * len(X),
range(len(X)),
c=y,
marker="_",
lw=lw,
cmap=cmap_data
)
ax.scatter(
[ii + 2.5] * len(X),
range(len(X)),
c=group,
marker="_",
lw=lw,
cmap=cmap_data
)
# Formatting
xlabelz = list(range(n_splits, 0, -1))
xlabelz = ['S' + str(x) for x in xlabelz]
xticklabels = xlabelz + ["class", "group"]
ax.set(
xticks=np.arange(n_splits + 2) + 0.45,
xticklabels=xticklabels,
ylabel="Sample index",
xlabel="CV iteration",
xlim=[n_splits + 2.2, -0.2],
ylim=[0, X.shape[0]],
)
ax.set_title("{}".format(type(cv).__name__), fontsize=5)
ax.xaxis.tick_top()
return ax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment