Preview:
from sklearn.datasets import load_iris
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
 
# Load the Iris dataset
data = load_iris()
X = data.data  # Features
 
# Initialize the KMeans model with 3 clusters (since there are 3 species in the Iris dataset)
kmeans = KMeans(n_clusters=3, random_state=42)
 
# Fit the model to the data
kmeans.fit(X)
 
# Predict the cluster labels
y_pred = kmeans.predict(X)
 
# Create a DataFrame for visualization
df = pd.DataFrame(X, columns=data.feature_names)
df['Cluster'] = y_pred
 
# Plot the clusters (using only the first two features for simplicity)
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df, x=data.feature_names[0], y=data.feature_names[1], hue='Cluster')#, palette='viridis', s=100)
plt.title('K-Means Clustering on Iris Dataset (2D)')
plt.xlabel(data.feature_names[0])
plt.ylabel(data.feature_names[1])
plt.show()
 
# Print the cluster centers
print("Cluster Centers (Centroids):")
print(kmeans.cluster_centers_)
downloadDownload PNG downloadDownload JPEG downloadDownload SVG

Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!

Click to optimize width for Twitter