Back to Portfolio

Understanding SHAP: Making ML Models Explainable

Machine Learning Explainable AI December 2024 Md. Zehadul Islam

In the rapidly evolving field of machine learning, model interpretability has become crucial, especially in sensitive domains like healthcare. SHAP (SHapley Additive exPlanations) provides a unified approach to explaining the output of any machine learning model.

SHAP Explainability Visualization

What is SHAP?

SHAP is based on Shapley values from cooperative game theory. It assigns each feature an importance value for a particular prediction. The beauty of SHAP lies in its ability to provide both local explanations (for individual predictions) and global insights (across the entire dataset).

Key Principle: SHAP values represent the average marginal contribution of a feature value across all possible coalitions of features.

Why SHAP Matters in Healthcare

Implementing SHAP in Python

Here's a comprehensive example of using SHAP with a Random Forest model for healthcare prediction:

import shap
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

# Load your healthcare dataset
# X = features (patient data), y = target (adverse drug reaction)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train your model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Create SHAP explainer for tree-based models
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# Visualizations
# 1. Summary plot - shows feature importance across all predictions
shap.summary_plot(shap_values[1], X_test, plot_type="bar")

# 2. Detailed summary plot with feature values
shap.summary_plot(shap_values[1], X_test)

# 3. Force plot for individual prediction
shap.force_plot(explainer.expected_value[1], shap_values[1][0], X_test.iloc[0])

# 4. Dependence plot - shows interaction effects
shap.dependence_plot("age", shap_values[1], X_test)

Key SHAP Visualizations Explained

1. Summary Plot (Bar)

Shows the mean absolute SHAP value for each feature, indicating global feature importance. Features are ranked from most to least important.

2. Summary Plot (Beeswarm)

A more detailed view showing:

3. Force Plot

Visualizes how features contribute to pushing a prediction from the base value (average model output) to the actual prediction. Red arrows push the prediction higher, blue arrows push it lower.

4. Dependence Plot

Shows the relationship between a feature's value and its SHAP value, revealing:

My Research Application: Neonatal ADR Prediction

In my ongoing research on predicting adverse drug reactions in neonates, SHAP has been instrumental in several ways:

Identifying Critical Risk Factors

SHAP revealed that patient characteristics like gestational age, birth weight, and concurrent medications most strongly influence adverse event predictions. This aligns with clinical knowledge, validating our model.

Detecting Data Biases

We discovered through SHAP that the model was over-relying on certain reporting patterns in the FAERS dataset rather than true clinical signals. This led us to implement better preprocessing strategies.

Clinical Validation

By presenting SHAP explanations to neonatologists, we validated that the model focuses on clinically relevant features. For example:

Implementation Example from My Research

# Analyze SHAP values for neonatal ADR prediction
import shap
import matplotlib.pyplot as plt

# After training your model on FAERS neonatal data
explainer = shap.TreeExplainer(adr_model)
shap_values = explainer.shap_values(X_neonatal_test)

# Get top 10 most important features
shap_importance = pd.DataFrame({
    'feature': X_neonatal_test.columns,
    'importance': np.abs(shap_values[1]).mean(axis=0)
}).sort_values('importance', ascending=False).head(10)

print("Top 10 Risk Factors for Neonatal ADR:")
print(shap_importance)

# Analyze specific high-risk case
high_risk_idx = np.argmax(model.predict_proba(X_neonatal_test)[:, 1])
print(f"\nExplanation for highest risk case:")
shap.force_plot(
    explainer.expected_value[1], 
    shap_values[1][high_risk_idx], 
    X_neonatal_test.iloc[high_risk_idx],
    matplotlib=True
)
plt.tight_layout()
plt.savefig('high_risk_explanation.png', dpi=300, bbox_inches='tight')

Advanced SHAP Techniques

Interaction Values

SHAP can also compute interaction effects between features:

# Compute SHAP interaction values
shap_interaction_values = explainer.shap_interaction_values(X_test)

# Visualize interaction between two features
shap.dependence_plot(
    ("drug_dose", "patient_weight"),
    shap_interaction_values[1],
    X_test
)

Waterfall Plots

For detailed individual explanations:

# Waterfall plot for single prediction
shap.plots.waterfall(shap.Explanation(
    values=shap_values[1][0],
    base_values=explainer.expected_value[1],
    data=X_test.iloc[0],
    feature_names=X_test.columns.tolist()
))

SHAP for Different Model Types

Tree-Based Models (Fast)

explainer = shap.TreeExplainer(model)  # RF, XGBoost, LightGBM

Linear Models

explainer = shap.LinearExplainer(model, X_train)

Deep Learning Models

explainer = shap.DeepExplainer(model, X_train[:100])  # Use sample of training data

Any Model (Slower but Universal)

explainer = shap.KernelExplainer(model.predict, X_train[:50])

Best Practices

Limitations and Considerations

Conclusion

SHAP transforms black-box machine learning models into transparent, trustworthy systems. In healthcare applications like my neonatal ADR prediction research, SHAP provides:

By combining SHAP with domain expertise, we can build AI systems that are not only accurate but also explainable, fair, and safe for deployment in critical healthcare settings.

Next Steps: Try implementing SHAP in your own ML projects. Start with TreeExplainer on a Random Forest model, then explore different visualizations to understand your model's behavior. The insights you gain will be invaluable for model improvement and stakeholder communication.

Back to Portfolio