Created
July 24, 2024 10:31
-
-
Save thomasthaddeus/84c83896827f6ca0f02361e33d0b165e to your computer and use it in GitHub Desktop.
K-Means Clustering Analysis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import plotly.graph_objects as go | |
import plotly.subplots as sp | |
# Example data | |
k_values = list(range(1, 11)) | |
wss_values = [1000, 800, 650, 500, 450, 400, 350, 330, 320, 310] # Elbow plot data | |
silhouette_scores = [ | |
0.1, | |
0.3, | |
0.45, | |
0.52, | |
0.51, | |
0.50, | |
0.48, | |
0.47, | |
0.46, | |
0.45, | |
] # Silhouette scores data | |
dimensions = list(range(1, 11)) | |
explained_variance_ratio = [ | |
0.1, | |
0.18, | |
0.25, | |
0.30, | |
0.35, | |
0.38, | |
0.40, | |
0.42, | |
0.43, | |
0.44, | |
] # Explained variance data | |
# Create subplots | |
fig = sp.make_subplots( | |
rows=1, | |
cols=3, | |
subplot_titles=( | |
"Elbow Method for Optimal k", | |
"Silhouette Score vs. Number of Dimensions", | |
"Variance Explained vs. Dimensions", | |
), | |
) | |
# Elbow Method Plot | |
fig.add_trace( | |
go.Scatter(x=k_values, y=wss_values, mode="lines+markers", name="WSS"), row=1, col=1 | |
) | |
fig.add_vline(x=4, line_dash="dash", line_color="red", row=1, col=1) | |
fig.add_annotation( | |
x=4, | |
y=500, | |
text="Elbow Point (k=3)<br>WSS not reduced<br>significantly after this point", | |
showarrow=True, | |
arrowhead=2, | |
ax=100, | |
ay=-100, | |
row=1, | |
col=1, | |
) | |
# Silhouette Score vs. Number of Dimensions | |
fig.add_trace( | |
go.Scatter( | |
x=dimensions, y=silhouette_scores, mode="lines+markers", name="Silhouette Score" | |
), | |
row=1, | |
col=2, | |
) | |
fig.add_vline(x=4, line_dash="dash", line_color="green", row=1, col=2) | |
fig.add_annotation( | |
x=4, | |
y=0.52, | |
text="Highest Silhouette Score<br>Indicates best clustering", | |
showarrow=True, | |
arrowhead=2, | |
ax=-75, | |
ay=-75, | |
row=1, | |
col=2, | |
) | |
# Variance vs. Dimensions | |
fig.add_trace( | |
go.Scatter( | |
x=dimensions, | |
y=explained_variance_ratio, | |
mode="lines+markers", | |
name="Explained Variance Ratio", | |
), | |
row=1, | |
col=3, | |
) | |
fig.add_vline(x=4, line_dash="dash", line_color="blue", row=1, col=3) | |
fig.add_annotation( | |
x=4, | |
y=0.30, | |
text="Variance Levels Off<br>Explained variance<br>increase is marginal", | |
showarrow=True, | |
arrowhead=2, | |
ax=-75, | |
ay=-75, | |
row=1, | |
col=3, | |
) | |
# Update layout | |
fig.update_layout( | |
height=600, | |
width=1800, | |
title_text="Analysis Plots for K-Means Clustering", | |
showlegend=False, | |
) | |
fig.update_xaxes(title_text="Number of Clusters (k)", row=1, col=1) | |
fig.update_yaxes(title_text="Within-Cluster Sum of Squares (WSS)", row=1, col=1) | |
fig.update_xaxes(title_text="Number of Dimensions", row=1, col=2) | |
fig.update_yaxes(title_text="Silhouette Score", row=1, col=2) | |
fig.update_xaxes(title_text="Number of Dimensions", row=1, col=3) | |
fig.update_yaxes(title_text="Explained Variance Ratio", row=1, col=3) | |
# Show the plot | |
fig.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment