모델의 각 변수(feature)가 결과의 예측에 어떻게 기여했는지를 정량적으로 측정한다.
본 게시글에서 소개하고자 하는 shapviz는 모든 모형에 적용가능한 형태는 아니며,
R의 xgboost, lightgbm, h2o (tree-based regression or binary classification model)에만 적용 가능하다.
0. 패키지
패키지는 R Cran을 통한 설치도 가능하며 (다른 패키지처럼 설치 ok)
가장 최신 버전을 받으려면 devtools 라이브러리를 사용해서 github을 통해 직접 받을 수도 있다.
# From CRAN
install.packages("shapviz")
# Or the newest version from GitHub:
# library(devtools);
devtools::install_github("ModelOriented/shapviz")
1. 주요 함수
shapviz 패키지에서 주요하게 사용할 수 있는 함수는 아래와 같다.
간략하게 설명 후, 상세한 설명은 2와 3에서 다시 진행한다.
- shapviz(): 모형(fit), 예측에 사용할 X 데이터셋(X_pred), 시각화할 데이터셋 (X)를 사용하여 shap value 계산.
- sv_importance(): 적합된 모형에서 각 변수가 얼마나 중요한지를 나타내줌.
- kind='bar' (the default): 각각의 변수의 shap value 절대값의 합을 나타냄. 이는 변수 중요도를 나타냄. (일반적으로 변수 중요도 구할 때 사용되는 permutation feature importance 해당 값과 달리, 각 변수의 예측에 기여한 절대적인 크기로 해석)
- kind='beeswarm': 개인이 갖는 각 변수별 shap value 값을 시각화함. 각 변수의 shapley value와 예측에 어떤 영향을 주었는지를 시각적으로 나타냄.
- sv_dependence()
- 특정 변수의 shaply value와 y값 사이의 관계를 시각적으로 나타낼 수 있음.
- sv_force() & sv_waterfall()
- 특정 개인의 예측에 대하여, 각 변수의 작용을 확인 가능함
2. shap value 계산
아래 예제는 Reference 페이지의 예제를 그대로 사용하여, 설명한다.
먼저, 분석에 필요한 패키지를 불러오고자 한다.
library(shapviz)
library(ggplot2)
library(xgboost)
library(patchwork)
분석에 사용하고자 하는 데이터는 기본 데이터 diamonds이며, 다이아몬드의 여러 속성에 따른 가격을 예측하는 문제의 예시이다. 변수들 중 "price" 예측에 "caret", "cut", "color", "clarity"를 사용하여, xgb.DMatrix를 사용한다.
분석에 사용하고자 하는 변수들은 모두 연속형이거나 순서를 갖는 factor로 구성되어 있다.
- cut: Fair < Good < Very Good < Premium < Ideal
- color: D < E < F < G < H < I < J
- clarity: I1 < SI2 < SI1 < VS2 < VS1 < VVS2 < VVS1 < IF
head(diamonds)
# # A tibble: 6 x 10
# carat cut color clarity depth table price x y z
# <dbl> <ord> <ord> <ord> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
# 1 0.23 Ideal E SI2 61.5 55 326 3.95 3.98 2.43
# 2 0.21 Premium E SI1 59.8 61 326 3.89 3.84 2.31
# 3 0.23 Good E VS1 56.9 65 327 4.05 4.07 2.31
# 4 0.29 Premium I VS2 62.4 58 334 4.2 4.23 2.63
# 5 0.31 Good J SI2 63.3 58 335 4.34 4.35 2.75
# 6 0.24 Very Good J VVS2 62.8 57 336 3.94 3.96 2.48
x <- c("carat", "cut", "color", "clarity")
dtrain <- xgb.DMatrix(data.matrix(diamonds[x]), label = diamonds$price, nthread = 1)
fit <- xgb.train(
params = list(learning_rate = 0.1, nthread = 1), data = dtrain, nrounds = 65
)
shapviz 패키지를 통한 아래의 코드를 입력하면, 간단하게 shap value를 계산할 수 있다.
인풋 값으로 사전에 학습에 사용한 데이터와 모형이 필요하다.
이 때 X_pred에 사용할 데이터셋과 X에 사용할 데이터셋의 변수 이름은 같아야 한다.
# SHAP analysis: X can even contain factors
dia_2000 <- diamonds[sample(nrow(diamonds), 2000), x]
shp <- shapviz(fit, X_pred = data.matrix(dia_2000), X = dia_2000)
shp는 2000 (샘플 사이즈) by 4 (변수의 수)의 shapley value를 갖는 Matrix 형태이다.
shp를 출력했을 때 shp 그 자체로는 가장 위의 줄만 보여주며 개별 값은 'S'에 저장된다.
shap value 전체 출력하고자 할 때 다음과 같이 출력하면 된다.
3. 결과 해석
결과 해석은 연구 목적과 질문에 따라 크게 5가지 그래프를 살펴볼 수 있다.
예측 모형을 설명하려고 할 때 아래 세가지의 질문들이 궁금할 것이다.
- 예측에 가장 크게 기여한 변수는 무엇일까?
- 주요 변수들이 구체적으로 어떻게 y의 예측에 기여했을까?
- 개인에게 변수들이 구체적으로 어떤 기여를 했을까?
연구 목적에 따라 하나의 질문에만 답하는 것이 아닌, 세 가지 질문에 해당하는 그래프를 신중히 해석해야 한다.
1) 예측에 가장 크게 기여한 변수는 무엇일까?
예측에는 caret > clarity > color > cut 변수 순서대로 크게 기여했음을 확인할 수 있다.
해당 그래프는 xai 라는 것이 떠오르기 전에 예측 모형을 설명할 때 일반적으로 많이 보고했던 permutation 변수 중요도를 대체하여 사용할 수 있다.
sv_importance(shp, kind="bar")
2) 주요 변수들이 y의 예측에 어떻게 기여했을까?
x축은 shap value, y축은 각 변수를 의미한다. 각 변수에 대한 그림을 기반으로 아래와 같이 해석 가능하다.
- 적합된 모형은 carat 값이 작을 때, 다이아몬드 가격을 낮게 예측하였다.
- 적합된 모형은 clarity의 등급이 높을 때 (IF일수록), 다이아몬드의 가격을 높게 예측하였다.
sv_importance(shp, kind="beeswarm")
dependence plot은 각 특성을 x축에 y에 shap value를 표시한다.
하나의 변수에 한정하는 것이 아닌 Interaction의 기여도를 확인할 수 있다.
sv_dependence(shp, v = x)
3) 개인에게 변수들이 어떻게 기여했을까?
이때, 사람들에게 설명하기 좋은 "전형적인 경우", 혹은 "예외적인 경우"를 선택하여 개인의 예측에 있었던 변수들의 기여도를 설명할 수 있다. 이때 보고자 하는 개인의 row id를 옵션을 사용하여 알려주어야 한다.
평균적으로 3929의 가격을 갖는데, row id = 1인 경우에 4892의 가격을 갖고 있다.
이때, 3929 -> 4892가 된 각 변수의 기여도를 그림을 통해 나타내는 두 가지 그래프이다.
sv_force(shp, row_id = 1)
sv_waterfall(shp, row_id = 1) # 각 변수별 분해
Reference
- https://cran.r-project.org/web/packages/shapviz/vignettes/basic_use.html
'R Programming > Analysis' 카테고리의 다른 글
[R] Auto Correlation 데이터 생성과 Durbin-Watson 검정 (1) | 2023.12.28 |
---|---|
[R] Cox 분석을 위한 생존 시간 데이터 생성 (시뮬레이션 코드) (2) | 2023.12.22 |
[R] 부트스트랩 신뢰 구간 (Bootstrap Confidence Intervals) 계산 (2) | 2023.11.27 |
[R] 랜덤포레스트 (randomForest)에 대한 모든 것 (1) | 2023.11.16 |
[R] 몬테카를로 실험 기반의 검정 (monte carlo test) (2) | 2023.10.11 |