Skip to content

XGBoost & SHAP Value

Background

对于 SHAP Value 本来呢是想用 MLP 来算,但是居然还是树状模型表现最好,于是还是用回了XGBoost。 (相比较来说的话 MLP 拟合效果会更好一点。至少方便一点。)

Calculate SHAP Value

python
# 创建Tree类型模型解释器
explainer = shap.TreeExplainer(
    model=model,
    data=background_data,
    feature_perturbation="interventional"
)
print(f"[info] 正在计算:{model_name} shap_values")
shap_values = explainer.shap_values(background_data)
计算shap值并保存的全部代码
python
import os
import shap
import joblib
from pathlib import Path
from datetime import datetime

def save_shap_explainer(model, background_data, save_path, model_name):
    """
    生成并保存SHAP解释器
    
    参数:
        model: 训练好的模型
        background_data: 用于解释的数据
        save_path: 保存根目录
        model_name: 模型标识名称
    """
    
    save_dir = Path(save_path) / f"shap_explainers/{datetime.now().strftime('%Y%m%d')}"
    save_dir.mkdir(parents=True, exist_ok=True)
    explainer_path = save_dir / f"{model_name}_explainer.pkl"
    print(f"[info] 正在计算:{explainer_path}")
    
    # 序列化保存解释器
    if explainer_path.exists():
        print(f"[警告] 已存在文件:{explainer_path} 跳过")
        return
    else:
        # 创建Tree类型模型解释器
        explainer = shap.TreeExplainer(
            model=model,
            data=background_data,
            feature_perturbation="interventional"
        )
        print(f"[info] 正在计算:{model_name} shap_values")
        shap_values = explainer.shap_values(background_data)
        explainer_tree = shap.TreeExplainer(
            model=model,
            feature_perturbation="tree_path_dependent"
        )
        print(f"[info] 正在计算:{model_name} shap_interaction_values")
        shap_interaction_values = explainer.shap_interaction_values(background_data)
        # 创建带时间戳的保存目录
        joblib.dump((explainer,shap_values,explainer_tree,shap_interaction_values), explainer_path)
        print(f"[成功] SHAP解释器保存至:{explainer_path}")
        
# 计算主程序
background_data = shap.sample(X_train, nsamples=5000, random_state=42)
for target_name, model in loaded_models.items():
    if model is not None:
        save_shap_explainer(
            model=model,
            background_data=background_data,
            save_path="../saved_models",
            model_name=target_name
        )

Calculate SHAP interaction Value

最主要的感觉是交互值! 交互值大丈夫(也这样可以压缩超多的信息在里面)

shap_interaction_matrix

Attention!

虽然哪里都没强调,也许是因为不是英语母语者(没有非常敏感)但是如果想计算交互作用,必须得用这样的

python
explainer = shap.TreeExplainer(
    model=model,
    data=None, # rather than background_data,
    feature_perturbation="tree_path_dependent" # rather than interventional,
)

Ref: Basic SHAP Interaction Value Example in XGBoost

shap issues #568

python
interaction_matrix = np.abs(shap_interaction_values).sum(0)
for i in range(interaction_matrix.shape[0]):
    interaction_matrix[i, i] = 0
inds = np.argsort(-interaction_matrix.sum(0))[:12]
sorted_ia_matrix = interaction_matrix[inds, :][:, inds]
pl.figure(figsize=(12, 12))
pl.imshow(sorted_ia_matrix)
pl.yticks(
    range(sorted_ia_matrix.shape[0]),
    X.columns[inds],
    rotation=50.4,
    horizontalalignment="right",
)
pl.xticks(
    range(sorted_ia_matrix.shape[0]),
    X.columns[inds],
    rotation=50.4,
    horizontalalignment="left",
)
pl.gca().xaxis.tick_top()
pl.show()

Ref: shap docs: tree_based_models (NHANES I Survival Model)