머신 러닝 (16) Decision Tree
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import koreanize_matplotlib
import seaborn as sns
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')
pima_columns = ['pregnancies', 'glucose', 'blood_pressure', 'skin_thickness', 'insulin', 'bmi', 'diabetes_pedigree_function', 'age', 'outcome']
pima_data_url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv'
df = pd.read_csv(pima_data_url, names=pima_columns)
# Feature(X)와 Target(y) 분리
X = df.drop('outcome', axis=1)
y = df['outcome']
# 학습용/테스트용 데이터 분할 (8:2)
# stratify=y: 타겟 클래스 비율 유지 (강의자료 강조 사항)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f"학습 데이터: {X_train.shape}")
print(f"테스트 데이터: {X_test.shape}")
# max_depth=4: 트리의 최대 깊이 제한 (과적합 방지)
# min_samples_leaf=10: 리프 노드가 되기 위한 최소 샘플 수
dt_clf = DecisionTreeClassifier(
criterion='gini', # 분할 기준: 지니 지수 (Page 4)
max_depth=4, # 사전 가지치기: 최대 깊이
min_samples_leaf=10, # 사전 가지치기: 리프 노드 최소 샘플
random_state=42
)
# 모델 학습
dt_clf.fit(X_train, y_train)
plt.figure(figsize=(20, 10))
plot_tree(
dt_clf,
feature_names=X.columns,
class_names=['No Diabetes (0)', 'Diabetes (1)'],
filled=True, # 노드 색칠 (불순도에 따라 색상 진하기 변경)
rounded=True, # 노드 모서리 둥글게
fontsize=10
)
plt.show()
# 테스트 데이터 예측
y_pred = dt_clf.predict(X_test)
# 정확도(Accuracy) 출력
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy (정확도): {accuracy:.4f}")
# 상세 리포트 출력 (Precision, Recall, F1-score)
print("\n[Classification Report]")
print(classification_report(y_test, y_pred, target_names=['No Diabetes (0)', 'Diabetes (1)']))
# Feature Importance 확인
for name, value in zip(X.columns, dt_clf.feature_importances_):
print(f"{name}: {value:.4f}")
# 특성 중요도 시각화
plt.figure(figsize=(6, 4))
sns.barplot(x=dt_clf.feature_importances_, y=X.columns, palette='pastel')
plt.title('Feature Importances (Decision Tree)')
plt.show()
댓글
댓글 쓰기