XGBoost & SHAP Value
XGBoost & SHAP Value
Background
对于 SHAP Value 本来呢是想用 MLP 来算,但是居然还是树状模型表现最好,于是还是用回了XGBoost。 (相比较来说的话 MLP 拟合效果会更好一点。至少方便一点。)
Calculate SHAP Value
python
# 创建Tree类型模型解释器
explainer = shap.TreeExplainer(
model=model,
data=background_data,
feature_perturbation="interventional"
)
print(f"[info] 正在计算:{model_name} shap_values")
shap_values = explainer.shap_values(background_data)
计算shap值并保存的全部代码
python
import os
import shap
import joblib
from pathlib import Path
from datetime import datetime
def save_shap_explainer(model, background_data, save_path, model_name):
"""
生成并保存SHAP解释器
参数:
model: 训练好的模型
background_data: 用于解释的数据
save_path: 保存根目录
model_name: 模型标识名称
"""
save_dir = Path(save_path) / f"shap_explainers/{datetime.now().strftime('%Y%m%d')}"
save_dir.mkdir(parents=True, exist_ok=True)
explainer_path = save_dir / f"{model_name}_explainer.pkl"
print(f"[info] 正在计算:{explainer_path}")
# 序列化保存解释器
if explainer_path.exists():
print(f"[警告] 已存在文件:{explainer_path} 跳过")
return
else:
# 创建Tree类型模型解释器
explainer = shap.TreeExplainer(
model=model,
data=background_data,
feature_perturbation="interventional"
)
print(f"[info] 正在计算:{model_name} shap_values")
shap_values = explainer.shap_values(background_data)
explainer_tree = shap.TreeExplainer(
model=model,
feature_perturbation="tree_path_dependent"
)
print(f"[info] 正在计算:{model_name} shap_interaction_values")
shap_interaction_values = explainer.shap_interaction_values(background_data)
# 创建带时间戳的保存目录
joblib.dump((explainer,shap_values,explainer_tree,shap_interaction_values), explainer_path)
print(f"[成功] SHAP解释器保存至:{explainer_path}")
# 计算主程序
background_data = shap.sample(X_train, nsamples=5000, random_state=42)
for target_name, model in loaded_models.items():
if model is not None:
save_shap_explainer(
model=model,
background_data=background_data,
save_path="../saved_models",
model_name=target_name
)
Calculate SHAP interaction Value
最主要的感觉是交互值! 交互值大丈夫(也这样可以压缩超多的信息在里面)
Attention!
虽然哪里都没强调,也许是因为不是英语母语者(没有非常敏感)但是如果想计算交互作用,必须得用这样的
python
explainer = shap.TreeExplainer(
model=model,
data=None, # rather than background_data,
feature_perturbation="tree_path_dependent" # rather than interventional,
)
python
interaction_matrix = np.abs(shap_interaction_values).sum(0)
for i in range(interaction_matrix.shape[0]):
interaction_matrix[i, i] = 0
inds = np.argsort(-interaction_matrix.sum(0))[:12]
sorted_ia_matrix = interaction_matrix[inds, :][:, inds]
pl.figure(figsize=(12, 12))
pl.imshow(sorted_ia_matrix)
pl.yticks(
range(sorted_ia_matrix.shape[0]),
X.columns[inds],
rotation=50.4,
horizontalalignment="right",
)
pl.xticks(
range(sorted_ia_matrix.shape[0]),
X.columns[inds],
rotation=50.4,
horizontalalignment="left",
)
pl.gca().xaxis.tick_top()
pl.show()