Source code for kdiagram.plot.feature_based

# -*- coding: utf-8 -*-
# License: Apache 2.0
# Author: LKouadio <etanoyau@gmail.com>

import warnings 
import matplotlib.pyplot as plt 
import matplotlib.cm as cm
from typing import Optional, List, Tuple, Union, Any 

import numpy as np 
import pandas as pd 

from ..decorators import check_non_emptiness 
from..utils.handlers import columns_manager 
from ..utils.validator import ensure_2d 

__all__=['plot_feature_fingerprint']

[docs] @check_non_emptiness (params =["importances"]) def plot_feature_fingerprint( importances, features: Optional[List[str]] = None, labels: Optional[List[str]] = None, normalize: bool = True, fill: bool = True, cmap: Union[str, List[Any]] = 'tab10', title: str = "Feature Impact Fingerprint", figsize: Optional[Tuple[float, float]] = None, show_grid: bool = True, savefig: Optional[str] = None ): r"""Create a radar chart visualizing feature importance profiles. This function generates a polar (radar) chart to visually compare the importance or contribution profiles of a set of features across different groups, conditions, or time periods (e.g., geographical zones, yearly data, different models). Each group is represented by a distinct polygon (layer) on the chart, making it easy to identify patterns, dominant features, and shifts in feature influence across the groups, often referred to as a 'fingerprint'. It is particularly useful for model interpretation, allowing a quick comparison of how feature rankings change under different circumstances. Parameters ------------- importances : array-like of shape (n_layers, n_features) The core data containing feature importance values. Each row represents a different layer (e.g., a zone, a year, a model) and each column corresponds to a feature. Can be a list of lists or a NumPy array. features : list of str, optional Names of the features corresponding to the columns in ``importances``. The order must match the columns. If ``None``, generic names like 'feature 1', 'feature 2', etc., will be generated. Default is ``None``. labels : list of str, optional Names for each layer (row) in ``importances``. These labels will appear in the legend. If ``None``, generic names like 'Layer 1', 'Layer 2', etc., will be generated. Default is ``None``. normalize : bool, default=True If ``True``, normalize the importance values within each layer (row) to a range of [0, 1] by dividing by the maximum importance value in that layer. This is useful for comparing the *shape* of the importance profiles independent of their absolute magnitudes. If ``False``, the raw importance values are plotted. fill : bool, default=True If ``True``, the area enclosed by each layer's polygon on the radar chart will be filled with a semi-transparent color, enhancing visual distinction between layers. If ``False``, only the outlines are plotted. cmap : str or list, default='tab10' Matplotlib colormap name (e.g., 'viridis', 'plasma', 'tab10') or a list of valid color specifications (e.g., ['red', '#00FF00', 'blue']) to color the different layers. If a colormap name is provided, colors will be sampled from it. If a list is provided, it should ideally have at least as many colors as there are layers. title : str, default="Feature Impact Fingerprint" The title displayed above the radar chart. figsize : tuple of (float, float), default=(8, 8) The width and height of the figure in inches. show_grid : bool, default=True If ``True``, display the polar grid lines (both radial and angular) on the plot, which can aid in reading values. If ``False``, the grid is hidden. savefig : str, optional The file path (including extension, e.g., 'fingerprint.png') where the plot should be saved. If ``None``, the plot is displayed interactively using ``plt.show()`` instead of being saved. Default is ``None``. Returns -------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the radar chart. This can be used for further customization if needed. See Also --------- matplotlib.pyplot.polar : Underlying function for polar plots. numpy.linspace : Used for calculating angles. Notes ------ - The function uses helper utilities like `ensure_2d` and `columns_manager` (assumed available) for input validation and preprocessing. - To create closed polygons, the function appends the first data point and the first angle to the end of their respective lists before plotting each layer. - Normalization (`normalize=True`) scales each layer independently: :math:`r'_{ij} = r_{ij} / \max_{j}(r_{ij})`, where :math:`r_{ij}` is the importance of feature :math:`j` for layer :math:`i`. This can highlight relative importance patterns but obscures absolute magnitude differences between layers. - The angular positions of features are evenly spaced around the circle: :math:`\theta_j = 2 \pi j / N` for :math:`j=0, ..., N-1`, where :math:`N` is the number of features. Let :math:`\mathbf{R}` be the input `importances` matrix of shape :math:`(M, N)`, where :math:`M` is the number of layers (labels) and :math:`N` is the number of features. 1. **Angle Calculation**: Angles for each feature axis are calculated as: .. math:: \theta_j = \frac{2 \pi j}{N}, \quad j = 0, 1, \dots, N-1 2. **Normalization** (if `normalize=True`): Each row :math:`\mathbf{r}_i = (r_{i0}, r_{i1}, \dots, r_{i,N-1})` is normalized: .. math:: r'_{ij} = \frac{r_{ij}}{\max_{k}(r_{ik})} If :math:`\max_{k}(r_{ik}) = 0`, :math:`r'_{ij}` is set to 0. Let :math:`\mathbf{R}'` be the matrix of normalized values. 3. **Plotting**: For each layer :math:`i`, the function plots points in polar coordinates :math:`(r'_{ij}, \theta_j)` (or :math:`(r_{ij}, \theta_j)` if not normalized). To close the shape, the first point :math:`(r'_{i0}, \theta_0)` is repeated at angle :math:`2\pi`. The points are connected by lines, and optionally, the enclosed area is filled. Examples --------- >>> import numpy as np >>> from kdiagram.plot.feature_based import plot_feature_fingerprint **1. Random Example:** >>> np.random.seed(42) # for reproducibility >>> random_importances = np.random.rand(3, 6) # 3 layers, 6 features >>> feature_names = [f'Feature {i+1}' for i in range(6)] >>> layer_labels = ['Model A', 'Model B', 'Model C'] >>> ax = plot_feature_fingerprint( ... importances=random_importances, ... features=feature_names, ... labels=layer_labels, ... title="Random Feature Importance Comparison", ... cmap='Set3', ... normalize=True, ... fill=True ... ) >>> # plt.show() is called internally if savefig is None **2. Concrete Example (Yearly Weights):** >>> features = ['rainfall', 'GWL', 'seismic', 'density', 'geo'] >>> weights_per_year = [ ... [0.2, 0.4, 0.1, 0.6, 0.3], # 2023 ... [0.3, 0.5, 0.2, 0.4, 0.4], # 2024 ... [0.1, 0.6, 0.2, 0.5, 0.3], # 2025 ... ] >>> years = ['2023', '2024', '2025'] >>> ax_yearly = plot_feature_fingerprint( ... importances=weights_per_year, ... features=features, ... labels=years, ... title="Feature Influence Over Years", ... cmap='tab10', ... normalize=False # Show raw weights ... ) >>> # plt.show() is called internally """ # --- Input Validation and Preparation --- # Ensure importances is a 2D NumPy array importance_matrix = ensure_2d(importances) n_layers, n_features_data = importance_matrix.shape # Manage feature names if features is None: # Generate default feature names if none provided features_list = [f'feature {i+1}' for i in range(n_features_data)] else: # Ensure features is a list and handle potential discrepancies features_list = columns_manager(features, empty_as_none=False) # If user provided fewer feature names than data columns, append # generic names if len(features_list) < n_features_data: features_list.extend( [f'feature {ix + 1}' for ix in range(len(features_list), n_features_data)] ) # Truncate if user provided more names than needed (optional, # could also raise error) elif len(features_list) > n_features_data: warnings.warn( f"More feature names ({len(features_list)}) provided " f"than data columns ({n_features_data}). " "Extra names ignored." ) features_list = features_list[:n_features_data] n_features = len(features_list) # Final number of features used # Manage labels if labels is None: # Generate default layer labels if none provided labels_list = [f"Layer {idx+1}" for idx in range(n_layers)] else: labels_list = list(labels) # Ensure it's a list # Check label count consistency if len(labels_list) < n_layers: warnings.warn( f"Fewer labels ({len(labels_list)}) provided than " f"layers ({n_layers}). Using generic names for the rest." ) labels_list.extend( [f'Layer {ix + 1}' for ix in range(len(labels_list), n_layers)] ) elif len(labels_list) > n_layers: warnings.warn( f"More labels ({len(labels_list)}) provided than " f"layers ({n_layers}). Extra labels ignored." ) labels_list = labels_list[:n_layers] # --- Normalization (if requested) --- if normalize: # Calculate max per row (layer), keep dimensions for broadcasting # max_per_row shape: (n_layers, 1), e.g., (3, 1) importance_matrix = importance_matrix.values if isinstance ( importance_matrix, pd.DataFrame) else importance_matrix max_per_row = importance_matrix.max(axis=1, keepdims=True) # Create a mask for rows with max_val > 0 (where normalization is safe) # valid_max_mask shape: (n_layers, 1), e.g., (3, 1) valid_max_mask = max_per_row > 1e-9 # Initialize normalized matrix normalized_matrix = np.zeros_like(importance_matrix, dtype=float) # --- FIX START --- # Get boolean index for valid rows, shape (n_layers,) e.g., (3,) valid_rows_indices = valid_max_mask[:, 0] # Proceed only if there are any rows to normalize if np.any(valid_rows_indices): # Select the rows from the original matrix that need normalization # Shape: (n_valid_rows, n_features), e.g., (3, 6) rows_to_normalize = importance_matrix[valid_rows_indices] # Select the corresponding max values for these rows # Since max_per_row is (n_layers, 1) and valid_rows_indices is (n_layers,), # this indexing correctly results in shape (n_valid_rows, 1), e.g., (3, 1) max_values_for_valid_rows = max_per_row[valid_rows_indices] # Perform the division using broadcasting: (MxN / Mx1 works) normalized_rows = rows_to_normalize / max_values_for_valid_rows # Place the normalized rows back into the result matrix normalized_matrix[valid_rows_indices] = normalized_rows # --- FIX END --- # Rows where max_val <= 0 remain zero (already initialized) # Update importance_matrix with normalized values importance_matrix = normalized_matrix # --- Angle Calculation for Radar Axes --- # Calculate evenly spaced angles for each feature axis angles = np.linspace(0, 2 * np.pi, n_features, endpoint=False).tolist() # Add the first angle to the end to close the loop for plotting angles_closed = angles + angles[:1] # --- Plotting Setup --- fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # Get colors from specified colormap or list try: cmap_obj = cm.get_cmap(cmap) # Sample colors if it's a standard Matplotlib cmap colors = [cmap_obj(i / n_layers) for i in range(n_layers)] except ValueError: # Handle case where cmap might be a list of colors if isinstance(cmap, list): colors = cmap if len(colors) < n_layers: warnings.warn( f"Provided color list has fewer colors " f"({len(colors)}) than layers ({n_layers}). " f"Colors will repeat." ) else: # Fallback if cmap is invalid string or list warnings.warn( f"Invalid cmap '{cmap}'. Falling back to 'tab10'.") cmap_obj = cm.get_cmap('tab10') colors = [cmap_obj(i / n_layers) for i in range(n_layers)] # --- Plot Each Layer --- for idx, row in enumerate(importance_matrix): # Get the importance values for the current layer values = row.tolist() # Add the first value to the end to close the loop values_closed = values + values[:1] # Determine the label for the legend label = labels_list[idx] # Determine the color, cycling if necessary color = colors[idx % len(colors)] # Plot the outline ax.plot(angles_closed, values_closed, label=label, color=color, linewidth=2) # Fill the area if requested if fill: ax.fill(angles_closed, values_closed, color=color, alpha=0.25) # --- Customize Plot Appearance --- ax.set_title(title, size=16, y=1.1) # Adjust title position # Set feature labels on the angular axes ax.set_xticks(angles) ax.set_xticklabels(features_list, fontsize=11) # Hide radial tick labels (often preferred for normalized data) ax.set_yticklabels([]) # Set radial limits (optional, e.g., enforce 0 start) ax.set_ylim(bottom=0) if normalize: # Optionally add a single radial label for the max value (1.0) ax.set_yticks([0.25, 0.5, 0.75, 1.0]) ax.set_yticklabels(["0.25", "0.50", "0.75", "1.00"], fontsize=9, color='gray') # Show grid lines if requested if show_grid: ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.6) else: ax.grid(False) # Add legend, positioned outside the plot area ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize=10) # Adjust layout to prevent labels/title overlapping plt.tight_layout(pad=2.0) # --- Save or Show --- if savefig: try: plt.savefig(savefig, bbox_inches='tight', dpi=300) print(f"Plot saved to {savefig}") except Exception as e: print(f"Error saving plot to {savefig}: {e}") else: plt.show() return ax