머신 러닝 (11) Mean Shift Clustering

 

Wine 데이터셋 Mean Shift 클러스터링

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import MeanShift, estimate_bandwidth

# 와인 데이터셋 로드
wine = load_wine()
X = pd.DataFrame(wine.data, columns=wine.feature_names)

# 데이터 정규화
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

print(f"데이터 크기: {X.shape}")
print(f"\n데이터 샘플:\n{X.head()}")

차원 축소 (PCA)


# PCA 적용
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

print(f"설명된 분산 비율: {pca.explained_variance_ratio_}")
print(f"총 설명된 분산: {sum(pca.explained_variance_ratio_):.2%}")

대역폭 추정 및 모델 학습


# 대역폭 자동 추정
bandwidth = estimate_bandwidth(X_pca, quantile=0.2, n_samples=len(X_pca))
print(f"추정된 Bandwidth: {bandwidth:.4f}")

Mean Shift 학습

# Mean Shift 모델 학습
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X_pca)

# 결과 추출
labels = ms.labels_
cluster_centers = ms.cluster_centers_
n_clusters_ = len(np.unique(labels))

print(f"발견된 군집 개수: {n_clusters_}")
print(f"\n군집 중심점:\n{cluster_centers}")

군집별 샘플 개수


# 각 군집별 샘플 개수
for k in range(n_clusters_):
    count = np.sum(labels == k)
    print(f"클러스터 {k}: {count}개")

결과 시각화

  • 원형 마커: 군집별 데이터 포인트
  • 붉은색 P: 밀도 중심점(피크)

plt.figure(figsize=(6, 4))

# 군집별 색상 생성
colors = plt.cm.viridis(np.linspace(0, 1, n_clusters_))

for k, col in zip(range(n_clusters_), colors):
    my_members = (labels == k)
    cluster_center = cluster_centers[k]

    # 군집 데이터 포인트
    plt.plot(X_pca[my_members, 0], X_pca[my_members, 1], 'o',
             markerfacecolor=col, markeredgecolor='k', markersize=8, alpha=0.6)

    # 군집 중심점
    plt.plot(cluster_center[0], cluster_center[1], 'P',
             markerfacecolor='red', markeredgecolor='k', markersize=14)

plt.title(f'Mean Shift Clustering Results (Wine Dataset + PCA)\nDetected Clusters: {n_clusters_}')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.grid(True)
plt.show()

원본 레이블 비교 (선택사항)


fig, axes = plt.subplots(1, 2, figsize=(8, 3))

# Mean Shift 결과
colors_ms = plt.cm.viridis(np.linspace(0, 1, n_clusters_))
for k, col in zip(range(n_clusters_), colors_ms):
    my_members = (labels == k)
    axes[0].scatter(X_pca[my_members, 0], X_pca[my_members, 1],
                   c=[col], label=f'Cluster {k}', s=60, edgecolors='k', alpha=0.6)
axes[0].set_title(f'Mean Shift ({n_clusters_} clusters)')
axes[0].set_xlabel('PC1')
axes[0].set_ylabel('PC2')
axes[0].legend()
axes[0].grid(True)

# 원본 레이블
colors_true = plt.cm.Set1(np.linspace(0, 1, len(np.unique(wine.target))))
for k, col in zip(np.unique(wine.target), colors_true):
    my_members = (wine.target == k)
    axes[1].scatter(X_pca[my_members, 0], X_pca[my_members, 1],
                   c=[col], label=f'Class {k}', s=60, edgecolors='k', alpha=0.6)
axes[1].set_title(f'True Labels ({len(np.unique(wine.target))} classes)')
axes[1].set_xlabel('PC1')
axes[1].set_ylabel('PC2')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

댓글

이 블로그의 인기 게시물

베이스 캠프에서 (1)

베이스 캠프에서 (2)

Database 분석 (4)