import neurenix as nx
from neurenix.explainable import (
KernelShap,
LimeTabular,
PermutationImportance,
PartialDependence
)
class ExplainablePipeline:
def __init__(self, model, X_train, feature_names):
self.model = model
self.X_train = X_train
self.feature_names = feature_names
# Initialize explainers
self.shap_explainer = KernelShap(model, X_train)
self.lime_explainer = LimeTabular(
model,
feature_names=feature_names
)
self.perm_importance = PermutationImportance(model)
self.pd_calculator = PartialDependence(model)
def explain_prediction(self, sample):
"""Comprehensive explanation for a single prediction"""
results = {}
# Model prediction
prediction = self.model(sample)
results['prediction'] = prediction.item()
# SHAP values
shap_values = self.shap_explainer.explain(sample.unsqueeze(0))
results['shap'] = shap_values
# LIME explanation
lime_exp = self.lime_explainer.explain(sample)
results['lime'] = lime_exp
return results
def global_explanation(self, X_test, y_test):
"""Global model explanation"""
results = {}
# Feature importance
importance = self.perm_importance.compute(X_test, y_test)
results['importance'] = importance
# Partial dependence for top features
top_features = importance['mean'].argsort()[-5:][::-1]
pd_plots = {}
for feat_idx in top_features:
pd = self.pd_calculator.compute(
self.X_train,
feature_idx=feat_idx
)
pd_plots[self.feature_names[feat_idx]] = pd
results['partial_dependence'] = pd_plots
return results
def compare_samples(self, sample1, sample2):
"""Compare explanations for two samples"""
exp1 = self.explain_prediction(sample1)
exp2 = self.explain_prediction(sample2)
print("Sample 1:")
print(f" Prediction: {exp1['prediction']:.4f}")
print(" Top 3 SHAP features:")
shap1 = exp1['shap']['values'][0]
top3_1 = shap1.abs().argsort()[-3:][::-1]
for idx in top3_1:
print(f" {self.feature_names[idx]}: {shap1[idx]:.4f}")
print("\nSample 2:")
print(f" Prediction: {exp2['prediction']:.4f}")
print(" Top 3 SHAP features:")
shap2 = exp2['shap']['values'][0]
top3_2 = shap2.abs().argsort()[-3:][::-1]
for idx in top3_2:
print(f" {self.feature_names[idx]}: {shap2[idx]:.4f}")
# Usage
model = train_model(X_train, y_train)
pipeline = ExplainablePipeline(model, X_train, feature_names)
# Explain single prediction
exp = pipeline.explain_prediction(X_test[0])
# Global explanations
global_exp = pipeline.global_explanation(X_test, y_test)
# Compare samples
pipeline.compare_samples(X_test[0], X_test[1])