TABNET (Attentive Interpretable Tabular Learning)
TABNET (Attentive Interpretable Tabular Learning)
- 정형데이터에서 주로 XGBoost, CatBoost, LightGBM의 높은 성능을 보여주고 있음
- 딥러닝 모델은 위에서 언급한 부스팅 계열의 모델보다 성능이 낮은 경우가 존재
앙상블 모델이 딥러닝 모델보다 선호되는 이유?!
- 정형데이터는 Hyperplane경계를 가지는 Manifold라고 하는데 부스팅 모델은 이러한 Manifold에서 결정할때 더 효율적으로 작동
- Tree기반의 모델은 학습이 빠르고 쉽게 개발 가능
- Tree기반 모델은 높은 해석력을 가지고 있는 장점이 있고, 특성중요도도 구할수 있음
딥러닝 모델을 써야하는 이유
- 성능을 더 높일 수 있음
- 정형데이터와 비정형데이터를 함께 학습에 사용할 수 있음(multi-modal)
- streaming 데이터에 대한 학습이 용이
Tabnet 장점
- TabNet은 Feature의 전처리없이 raw한 데이터를 입력으로 사용할 수 있고, Gradient-descent 기반 최적화를 사용하여 End-to-End learning을 가능하게 하였음.
- 성능과 해석력을 향상시키기 위하여, TabNet은 Sequential attention mechanism을 사용하여 각 의사결정에서 어떤 feature를 사용할지를 선택함. 이러한 Feature selection은 instance-wise하게 입력 각각마다 다르게 수행됨.
- 여러 데이터셋에서 기존의 정형 분류,회귀 모델들보다 성능의 우수성을 가짐. 그리고 해석력의 관점에서 입력 Feature의 중요도와 Feature들이 어떻게 결합되었는지를 시각화한 local한 해석력과, 학습된 모델에서 각 입력 Feature들이 얼마나 자주 결합되었는지의 Global한 해석력을 제시함.
- feature selection : 순차적으로 feature를 선택해가면서 결과를 종합함
- decision making : feature engineering → tabnet encoder를 통해 진행
- unsupervised → 빈 null값에 대한 데이터 생성 → masked learning
- supervised → feature extraction 효과
코드
!pip install torch torchvision pytorch-tabnet
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.pretraining import TabNetPretrainer
# 데이터 불러오기
data = pd.read_csv("your_data.csv")
# 특징과 레이블 분리
X = data.drop(columns=["target_column"])
y = data["target_column"]
# 훈련 세트와 테스트 세트로 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 탭넷 사전 훈련을 위한 파라미터 설정
pretrain_params = {"verbose": 1, "max_epochs": 100}
# 사전 훈련 모델 초기화
pretrainer = TabNetPretrainer(**pretrain_params)
# 사전 훈련 실행
pretrainer.fit(X_train.values)
# 탭넷 분류 모델 초기화
clf = TabNetClassifier()
# 탭넷 분류 모델 학습
clf.fit(X_train.values, y_train.values, eval_set=[(X_test.values, y_test.values)], patience=10, max_epochs=100)
# 모델 평가
evals = clf.evaluate(X_test.values, y_test.values)
print("Test accuracy:", evals["accuracy"])
참고자료
https://arxiv.org/abs/1908.07442
https://ffighting.net/deep-learning-paper-review/tabular-model/tabnet/