diff --git a/explainableai/core.py b/explainableai/core.py index 40af264..76532da 100644 --- a/explainableai/core.py +++ b/explainableai/core.py @@ -3,6 +3,8 @@ import colorama from colorama import Fore, Style +from explainableai.exceptions import ExplainableAIError + # Initialize colorama colorama.init(autoreset=True) @@ -27,6 +29,8 @@ from reportlab.platypus import PageBreak import logging from sklearn.model_selection import cross_val_score +from .model_interpretability import interpret_model +from .logging_config import logger logger=logging.getLogger(__name__) @@ -136,12 +140,12 @@ def _preprocess_data(self): except Exception as e: logger.error(f"Some error occur while updating...{str(e)}") - def analyze(self, batch_size=None, parallel=False): + + def analyze(self, batch_size=None, parallel=False, instance_index=0): logger.debug("Analysing...") results = {} logger.info("Evaluating model performance...") - # Evaluate model performance (batch processing if batch_size is provided) if batch_size: results['model_performance'] = self._process_in_batches(self._evaluate_model_in_batches, batch_size, parallel) else: @@ -153,18 +157,18 @@ def analyze(self, batch_size=None, parallel=False): logger.info("Generating visualizations...") self._generate_visualizations(self.feature_importance) - - # Calculate SHAP values (batch processing if batch_size is provided) + logger.info("Calculating SHAP values...") if batch_size: - results['shap_values'] = self._process_in_batches(self._calculate_shap_in_batches, batch_size, parallel) + shap_values = self._process_in_batches(self._calculate_shap_in_batches, batch_size, parallel) + results['shap_values'] = shap_values else: results['shap_values'] = calculate_shap_values(self.model, self.X, self.feature_names) - # Perform cross-validation (batch processing if batch_size is provided) logger.info("Performing cross-validation...") if batch_size: - results['cv_scores'] = self._process_in_batches(self._cross_validate_in_batches, batch_size, parallel) + cv_results = self._process_in_batches(self._cross_validate_in_batches, batch_size, parallel) + results['cv_scores'] = (np.mean(cv_results['mean_score']), np.mean(cv_results['std_score'])) else: mean_score, std_score = cross_validate(self.model, self.X, self.y) results['cv_scores'] = (mean_score, std_score) @@ -172,18 +176,22 @@ def analyze(self, batch_size=None, parallel=False): logger.info("Model comparison results:") results['model_comparison'] = self.model_comparison_results + logger.info("Performing model interpretation (SHAP and LIME)...") + try: + interpretation_results = interpret_model(self.model, self.X, self.feature_names, instance_index) + results.update(interpretation_results) + except ExplainableAIError as e: + logger.warning(f"Model interpretation failed: {str(e)}") + results['interpretation_error'] = str(e) + self._print_results(results) logger.info("Generating LLM explanation...") results['llm_explanation'] = get_llm_explanation(self.gemini_model, results) - # Generate XAI report after analysis - logger.info("Generating XAI report") - self.generate_report() - self.results = results return results - + def _process_in_batches(self, batch_func, batch_size, parallel=False): results = [] num_batches = (len(self.X) + batch_size - 1) // batch_size # Calculate number of batches @@ -209,7 +217,7 @@ def _process_in_batches(self, batch_func, batch_size, parallel=False): # Aggregate results after batch processing return self._aggregate_results(results) - + # private helper functions def _evaluate_model_in_batches(self, X_batch, y_batch): return evaluate_model(self.model, X_batch, y_batch, self.is_classifier) @@ -249,7 +257,6 @@ def _aggregate_results(self, results): return aggregated_result - def generate_report(self, filename='xai_report.pdf'): if self.results is None: raise ValueError("No analysis results available. Please run analyze() first.") @@ -272,9 +279,17 @@ def generate_report(self, filename='xai_report.pdf'): for section, section_func in sections.items(): if input(f"Do you want {section} in xai_report? (y/n) ").lower() in ['y', 'yes']: section_func(report) + self._generate_shap_lime_visualizations(report) report.generate() + def _generate_shap_lime_visualizations(self, report): + report.add_heading("SHAP and LIME Visualizations", level=2) + report.add_image('shap_summary.png') + report.content.append(PageBreak()) + report.add_image('lime_explanation.png') + report.content.append(PageBreak()) + def _generate_model_comparison(self, report): report.add_heading("Model Comparison", level=2) model_comparison_data = [["Model", "CV Score", "Test Score"]] + [ @@ -286,7 +301,12 @@ def _generate_model_comparison(self, report): def _generate_model_performance(self, report): report.add_heading("Model Performance", level=2) for metric, value in self.results['model_performance'].items(): - report.add_paragraph(f"**{metric}:** {value:.4f}" if isinstance(value, (int, float, np.float64)) else f"**{metric}:**\n{value}") + if isinstance(value, np.ndarray): + report.add_paragraph(f"**{metric}:**\n{value}") + elif isinstance(value, (int, float, np.float64)): + report.add_paragraph(f"**{metric}:** {value:.4f}") + else: + report.add_paragraph(f"**{metric}:** {value}") def _generate_feature_importance(self, report): report.add_heading("Feature Importance", level=2) @@ -401,10 +421,14 @@ def _print_results(self, results): logger.info("- ROC Curve: roc_curve.png") logger.info("- Precision-Recall Curve: precision_recall_curve.png") - if results['shap_values'] is not None: - logger.info("\nSHAP values calculated successfully. See 'shap_summary.png' for visualization.") - else: - logger.info("\nSHAP values calculation failed. Please check the console output for more details.") + if 'shap_plot_url' in results: + logger.info("\nSHAP summary plot saved as 'shap_summary.png'") + logger.info("SHAP plot URL (base64 encoded) available in results['shap_plot_url']") + + if 'lime_plot_url' in results: + logger.info("\nLIME explanation plot saved as 'lime_explanation.png'") + logger.info("LIME plot URL (base64 encoded) available in results['lime_plot_url']") + except Exception as e: logger.error(f"Error occur in printing results...{str(e)}") diff --git a/explainableai/exceptions.py b/explainableai/exceptions.py new file mode 100644 index 0000000..ff098a5 --- /dev/null +++ b/explainableai/exceptions.py @@ -0,0 +1,3 @@ +class ExplainableAIError(Exception): + """Base exception class for ExplainableAI package""" + pass \ No newline at end of file diff --git a/explainableai/logging_config.py b/explainableai/logging_config.py new file mode 100644 index 0000000..f9f2aa9 --- /dev/null +++ b/explainableai/logging_config.py @@ -0,0 +1,22 @@ +import logging + +def setup_logging(): + logger = logging.getLogger('explainableai') + logger.setLevel(logging.DEBUG) + + # Create console handler and set level to debug + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to ch + ch.setFormatter(formatter) + + # Add ch to logger + logger.addHandler(ch) + + return logger + +logger = setup_logging() \ No newline at end of file diff --git a/explainableai/model_interpretability.py b/explainableai/model_interpretability.py index 6fdc051..ada8b31 100644 --- a/explainableai/model_interpretability.py +++ b/explainableai/model_interpretability.py @@ -1,82 +1,116 @@ -# model_interpretability.py import shap import lime import lime.lime_tabular import matplotlib.pyplot as plt import numpy as np -import logging +import pandas as pd +import io +import base64 +from .logging_config import logger +from .exceptions import ExplainableAIError -logger=logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -def calculate_shap_values(model, X): - logger.debug("Calculating values...") +def calculate_shap_values(model, X, feature_names): + logger.debug("Calculating SHAP values...") try: - explainer = shap.Explainer(model, X) - shap_values = explainer(X) - logger.info("Values caluated...") + X_df = pd.DataFrame(X, columns=feature_names) + if hasattr(model, 'predict_proba'): + explainer = shap.TreeExplainer(model) + shap_values = explainer.shap_values(X_df) + else: + explainer = shap.Explainer(model, X_df) + shap_values = explainer(X_df) + logger.info("SHAP values calculated successfully.") return shap_values except Exception as e: - logger.error(f"Some error occurred in calculating values...{str(e)}") + logger.error(f"Error in calculate_shap_values: {str(e)}") + raise ExplainableAIError(f"Error in calculate_shap_values: {str(e)}") -def plot_shap_summary(shap_values, X): - logger.debug("Summary...") +def plot_shap_summary(shap_values, X, feature_names): + logger.debug("Plotting SHAP summary...") try: - plt.figure(figsize=(10, 8)) - shap.summary_plot(shap_values, X, plot_type="bar", show=False) + plt.figure(figsize=(12, 8)) + shap.summary_plot(shap_values, X, plot_type="bar", feature_names=feature_names, show=False) plt.tight_layout() + + # Save plot to file plt.savefig('shap_summary.png') + logger.info("SHAP summary plot saved as 'shap_summary.png'") + + # Convert plot to base64 for display + img = io.BytesIO() + plt.savefig(img, format='png') + img.seek(0) + plot_url = base64.b64encode(img.getvalue()).decode() plt.close() - except TypeError as e: - logger.error(f"Error in generating SHAP summary plot: {str(e)}") - logger.error("Attempting alternative SHAP visualization...") - try: - plt.figure(figsize=(10, 8)) - shap.summary_plot(shap_values.values, X.values, feature_names=X.columns.tolist(), plot_type="bar", show=False) - plt.tight_layout() - plt.savefig('shap_summary.png') - plt.close() - except Exception as e2: - logger.error(f"Alternative SHAP visualization also failed: {str(e2)}") - logger.error("Skipping SHAP summary plot.") + + return plot_url + except Exception as e: + logger.error(f"Error in plot_shap_summary: {str(e)}") + raise ExplainableAIError(f"Error in plot_shap_summary: {str(e)}") def get_lime_explanation(model, X, instance, feature_names): - logger.debug("Explaining model...") + logger.debug("Generating LIME explanation...") try: explainer = lime.lime_tabular.LimeTabularExplainer( X, feature_names=feature_names, class_names=['Negative', 'Positive'], - mode='classification' + mode='classification' if hasattr(model, 'predict_proba') else 'regression' + ) + exp = explainer.explain_instance( + instance, + model.predict_proba if hasattr(model, 'predict_proba') else model.predict ) - exp = explainer.explain_instance(instance, model.predict_proba) - logger.info("Model explained...") + logger.info("LIME explanation generated successfully.") return exp except Exception as e: - logger.error(f"Some error occurred in explaining model...{str(e)}") + logger.error(f"Error in get_lime_explanation: {str(e)}") + raise ExplainableAIError(f"Error in get_lime_explanation: {str(e)}") def plot_lime_explanation(exp): - exp.as_pyplot_figure() - plt.tight_layout() - plt.savefig('lime_explanation.png') - plt.close() + logger.debug("Plotting LIME explanation...") + try: + plt.figure(figsize=(12, 8)) + exp.as_pyplot_figure() + plt.tight_layout() + + # Save plot to file + plt.savefig('lime_explanation.png') + logger.info("LIME explanation plot saved as 'lime_explanation.png'") + + # Convert plot to base64 for display + img = io.BytesIO() + plt.savefig(img, format='png') + img.seek(0) + plot_url = base64.b64encode(img.getvalue()).decode() + plt.close() + + return plot_url + except Exception as e: + logger.error(f"Error in plot_lime_explanation: {str(e)}") + raise ExplainableAIError(f"Error in plot_lime_explanation: {str(e)}") -def plot_ice_curve(model, X, feature, num_ice_lines=50): - ice_data = X.copy() - feature_values = np.linspace(X[feature].min(), X[feature].max(), num=100) - - plt.figure(figsize=(10, 6)) - for _ in range(num_ice_lines): - ice_instance = ice_data.sample(n=1, replace=True) - predictions = [] - for value in feature_values: - ice_instance[feature] = value - predictions.append(model.predict_proba(ice_instance)[0][1]) - plt.plot(feature_values, predictions, color='blue', alpha=0.1) - - plt.xlabel(feature) - plt.ylabel('Predicted Probability') - plt.title(f'ICE Plot for {feature}') - plt.tight_layout() - plt.savefig(f'ice_plot_{feature}.png') - plt.close() \ No newline at end of file +def interpret_model(model, X, feature_names, instance_index=0): + logger.info("Starting model interpretation...") + try: + # SHAP analysis + shap_values = calculate_shap_values(model, X, feature_names) + shap_plot_url = plot_shap_summary(shap_values, X, feature_names) + + # LIME analysis + instance = X[instance_index] + lime_exp = get_lime_explanation(model, X, instance, feature_names) + lime_plot_url = plot_lime_explanation(lime_exp) + + interpretation_results = { + "shap_values": shap_values, + "shap_plot_url": shap_plot_url, + "lime_explanation": lime_exp, + "lime_plot_url": lime_plot_url + } + + logger.info("Model interpretation completed successfully.") + return interpretation_results + except Exception as e: + logger.error(f"Error in interpret_model: {str(e)}") + raise ExplainableAIError(f"Error in interpret_model: {str(e)}") \ No newline at end of file