Source code for kdiagram.plot.comparison

# -*- coding: utf-8 -*-
# License: Apache 2.0
# Author: LKouadio <etanoyau@gmail.com>

from numbers import Real 
import warnings
import matplotlib.pyplot as plt 
import numpy as np 
from typing import Optional, Tuple, List, Union, Any, Callable 

from ..compat.sklearn import validate_params, StrOptions, type_of_target
from ..utils.generic_utils import drop_nan_in 
from ..utils.handlers import columns_manager
from ..utils.metric_utils import get_scorer 
from ..utils.plot import set_axis_grid  
from ..utils.validator import validate_yy, is_iterable 

__all__=[ 'plot_model_comparison']

[docs] @validate_params({ 'train_times': ['array-like', None], 'metrics': [str, 'array-like', callable, None], 'scale': [StrOptions({"norm", "min-max", 'std', 'standard',}), None], "lower_bound": [Real], }) def plot_model_comparison( y_true, *y_preds, train_times: Optional[Union[float, List[float]]] = None, metrics: Optional[Union[str, Callable, List[Union[str, Callable]]]] = None, names: Optional[List[str]] = None, title: Optional[str] = None, figsize: Optional[Tuple[float, float]] = None, colors: Optional[List[Any]] = None, alpha: float = 0.7, legend: bool = True, show_grid: bool = True, grid_props: dict =None, scale: Optional[str] = 'norm', lower_bound: float = 0, savefig: Optional[str] = None, loc: str = 'upper right', verbose: int = 0, ): r"""Plot multi-metric model performance comparison on a radar chart. Generates a radar chart (spider chart) visualizing multiple performance metrics for one or more models simultaneously. Each axis corresponds to a metric (e.g., R2, MAE, accuracy, precision), and each polygon represents a model, allowing for a holistic comparison of their strengths and weaknesses across different evaluation criteria [1]_. This function is highly valuable for model selection, providing a compact overview that goes beyond single-score comparisons. Use it when you need to balance trade-offs between various metrics (like accuracy vs. training time) or understand how different models perform relative to each other across a spectrum of relevant performance indicators. Internally relies on helpers to handle potential NaN values and determine data types [2]_. Parameters ---------- y_true : array-like of shape (n_samples,) The ground truth (correct) target values. *y_preds : array-like of shape (n_samples,) Variable number of prediction arrays, one for each model to be compared. Each array must have the same length as `y_true`. train_times : float or list of float, optional Training time in seconds for each model corresponding to `*y_preds`. If provided: - A single float assumes the same time for all models. - A list must match the number of models. It will be added as an additional axis/metric on the chart. Default is ``None``. metrics : str, callable, list of these, optional The performance metrics to calculate and plot. Default is ``None``, which triggers automatic metric selection based on the target type inferred from `y_true`: - **Regression:** Defaults to ``["r2", "mae", "mape", "rmse"]``. - **Classification:** Defaults to ``["accuracy", "precision", "recall"]``. Can be provided as: - A list of strings: Names of metrics known by scikit-learn or gofast's `get_scorer` (e.g., ``['r2', 'rmse']``). - A list of callables: Functions with the signature `metric(y_true, y_pred)`. - A mix of strings and callables. names : list of str, optional Names for each model corresponding to `*y_preds`. Used for the legend. If ``None`` or too short, defaults like "Model_1", "Model_2" are generated. Default is ``None``. title : str, optional Title displayed above the radar chart. If ``None``, a generic title may be used internally or omitted. Default is ``None``. figsize : tuple of (float, float), optional Figure size ``(width, height)`` in inches. If ``None``, uses Matplotlib's default (often similar to ``(8, 8)`` for this type of plot). colors : list of str or None, optional List of Matplotlib color specifications for each model's polygon. If ``None``, colors are automatically assigned from the default palette ('tab10'). If provided, the list length should ideally match `n_models`. alpha : float, optional Transparency level (between 0 and 1) for the plotted lines and filled areas. Default is ``0.7``. (Note: Fill alpha is often hardcoded lower, e.g., 0.1, in implementation). legend : bool, optional If ``True``, display a legend mapping colors/lines to model names. Default is ``True``. show_grid : bool, optional If ``True``, display the radial grid lines on the chart. Default is ``True``. scale : {'norm', 'min-max', 'std', 'standard'}, optional Method for scaling metric values before plotting. Scaling is applied independently to each metric (axis) across models. Default is ``'norm'``. - ``'norm'`` or ``'min-max'``: Min-max scaling. Transforms values to the range [0, 1] using :math:`(X - min) / (max - min)`. Useful for comparing relative performance when metrics have different scales. - ``'std'`` or ``'standard'``: Standard scaling (Z-score). Transforms values to have zero mean and unit variance using :math:`(X - mean) / std`. Preserves relative spacing better than min-max but results can be negative. - ``None``: Plot raw metric values without scaling. Use only if metrics naturally share a comparable, non-negative range. lower_bound : float, optional Sets the minimum value for the radial axis (innermost circle). Useful when using standard scaling ('std') which can produce negative values, or to adjust the plot's center. Default is ``0``. savefig : str, optional If provided, the file path (e.g., 'radar_comparison.svg') where the figure will be saved. If ``None``, the plot is displayed interactively. Default is ``None``. loc : str, optional Location argument passed to `matplotlib.pyplot.legend()` to position the legend (e.g., 'upper right', 'lower left', 'center right'). Default is ``'upper right'``. verbose : int, optional Controls the verbosity level. ``0`` is silent. Higher values may print debugging information during metric calculation or scaling. Default is ``0``. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the radar chart. Allows for further customization after the function call. Raises ------ ValueError If lengths of `y_preds`, `names` (if provided), and `train_times` (if provided) do not match. If an invalid string is provided for `scale`. If a metric string name is not recognized by the internal scorer. TypeError If `y_true` or `y_preds` contain non-numeric data. See Also -------- kdiagram.utils.metric_utils.get_scorer : Function likely used internally to fetch metric callables (verify path). sklearn.metrics : Scikit-learn metrics module. matplotlib.pyplot.polar : Function for creating polar plots. Notes ----- This function provides a multi-dimensional view of model performance. **Metric Calculation:** For each model :math:`k` with predictions :math:`\hat{y}_k` and each metric :math:`m` (from the `metrics` list), the score :math:`S_{m,k}` is calculated: .. math:: S_{m,k} = \text{Metric}_m(y_{true}, \hat{y}_k) If `train_times` are provided, they are treated as an additional metric axis. **Scaling:** If `scale` is specified, scaling is applied column-wise (per metric) across all models before plotting: - Min-Max ('norm'): .. math:: S'_{m,k} = \frac{S_{m,k} - \min_j(S_{m,j})}{\max_j(S_{m,j}) - \min_j(S_{m,j})} - Standard ('std'): .. math:: S'_{m,k} = \frac{S_{m,k} - \text{mean}_j(S_{m,j})}{\text{std}_j(S_{m,j})} **Plotting:** The (scaled) scores :math:`S'_{m,k}` for each model :math:`k` determine the radial distance along the axis corresponding to metric :math:`m`. Points are connected to form a polygon for each model. References ---------- .. [1] Wikipedia contributors. (2024). Radar chart. In Wikipedia, The Free Encyclopedia. Retrieved April 14, 2025, from https://en.wikipedia.org/wiki/Radar_chart *(General reference for radar charts)* .. [2] Kenny-Denecke, J. F., Hernandez-Amaro, A., Martin-Gorriz, M. L., & Castejon-Limos, P. (2024). Lead-Time Prediction in Wind Tower Manufacturing: A Machine Learning-Based Approach. *Mathematics*, 12(15), 2347. https://doi.org/10.3390/math12152347 *(Example application using radar charts for ML comparison)* Examples -------- >>> from kdiagram.plot.relationship import plot_model_comparison >>> import numpy as np >>> # Example 1: Regression task >>> y_true_reg = np.array([3, -0.5, 2, 7, 5]) >>> y_pred_r1 = np.array([2.5, 0.0, 2.1, 7.8, 5.2]) >>> y_pred_r2 = np.array([3.2, 0.2, 1.8, 6.5, 4.8]) >>> times = [0.1, 0.5] # Training times in seconds >>> names = ['ModelLin', 'ModelTree'] >>> ax1 = plot_factory_ops(y_true_reg, y_pred_r1, y_pred_r2, ... train_times=times, names=names, ... metrics=['r2', 'mae', 'rmse'], # Specify metrics ... title="Regression Model Comparison", ... scale='norm') # Normalize for comparison >>> # Example 2: Classification task (requires appropriate y_true/y_pred) >>> y_true_clf = np.array([0, 1, 0, 1, 1, 0]) >>> y_pred_c1 = np.array([0, 1, 0, 1, 0, 0]) # Model 1 preds >>> y_pred_c2 = np.array([0, 1, 1, 1, 1, 0]) # Model 2 preds >>> ax2 = plot_factory_ops(y_true_clf, y_pred_c1, y_pred_c2, ... names=["LogReg", "SVM"], ... # Uses default classification metrics ... title="Classification Model Comparison", ... scale='norm') """ # Docstring omitted as requested # --- Input Validation and Preparation --- try: # Remove NaN values and ensure consistency y_true, *y_preds = drop_nan_in(y_true, *y_preds, error='raise') # Validate y_true and each y_pred temp_preds = [] for i, pred in enumerate(y_preds): # Validate returns tuple, we need the second element validated_pred = validate_yy( y_true, pred, expected_type=None, flatten=True )[1] temp_preds.append(validated_pred) y_preds = temp_preds except Exception as e: # Catch potential errors during validation/NaN drop raise TypeError(f"Input validation failed: {e}") from e n_models = len(y_preds) if n_models == 0: warnings.warn("No prediction arrays (*y_preds) provided.") return None # Cannot plot without predictions # --- Handle Names --- if names is None: names = [f"Model_{i+1}" for i in range(n_models)] else: names = columns_manager(list(names), empty_as_none=False) # Ensure list if len(names) < n_models: names += [f"Model_{i+1}" for i in range(len(names), n_models)] elif len(names) > n_models: warnings.warn(f"Received {len(names)} names for {n_models}" f" models. Extra names ignored.", UserWarning) names = names[:n_models] # --- Handle Metrics --- if metrics is None: target_type = type_of_target(y_true) if target_type in ['continuous', 'continuous-multioutput']: # Default regression metrics metrics = ["r2", "mae", "mape", "rmse"] else: # Default classification metrics metrics = ["accuracy", "precision", "recall", "f1"] if verbose >= 1: print(f"[INFO] Auto-selected metrics for target type " f"'{target_type}': {metrics}") metrics = is_iterable(metrics, exclude_string=True, transform=True) metric_funcs = [] metric_names = [] error_metrics = [] # Track metrics needing sign inversion for metric in metrics: try: if isinstance(metric, str): # get_scorer returns a callable scorer object scorer_func = get_scorer(metric) metric_funcs.append(scorer_func) metric_names.append(metric) # Identify error metrics (lower is better) for potential scaling flip if metric in ['mae', 'mape', 'rmse', 'mse']: # Add others if needed error_metrics.append(metric) elif callable(metric): metric_funcs.append(metric) m_name = getattr(metric, '__name__', f'func_{len(metric_names)}') metric_names.append(m_name) # Cannot easily determine if callable is error/score metric else: warnings.warn( f"Ignoring invalid metric type: {type(metric)}") except Exception as e: warnings.warn( f"Could not retrieve scorer for metric '{metric}': {e}") if not metric_funcs: raise ValueError("No valid metrics found or specified.") # --- Handle Train Times --- train_time_vals = None if train_times is not None: if isinstance(train_times, (int, float, np.number)): # Handle single value train_time_vals = np.array([float(train_times)] * n_models) else: train_times = np.asarray(train_times, dtype=float) if train_times.ndim != 1 or len(train_times) != n_models: raise ValueError( f"train_times must be a single float or a list/array " f"of length n_models ({n_models}). " f"Got shape {train_times.shape}." ) train_time_vals = train_times metric_names.append("Train Time (s)") # Use clearer name # Add a placeholder for calculation loop, will substitute later metric_funcs.append("train_time_placeholder") # --- Calculate Metric Results --- results = np.zeros((n_models, len(metric_names)), dtype=float) for i, y_pred in enumerate(y_preds): for j, metric_func in enumerate(metric_funcs): if metric_func == "train_time_placeholder": results[i, j] = train_time_vals[i] elif metric_func is not None: try: score = metric_func(y_true, y_pred) results[i, j] = score except Exception as e: warnings.warn(f"Could not compute metric " f"'{metric_names[j]}' for model " f"'{names[i]}': {e}. Setting to NaN.") results[i, j] = np.nan else: results[i, j] = np.nan # Should not happen if logic is correct # --- Scale Results --- # Make copy for scaling to preserve original results if needed later results_scaled = results.copy() # Handle potential NaNs before scaling if np.isnan(results_scaled).any(): warnings.warn("NaN values found in metric results. Scaling might " "be affected or rows/cols dropped depending on method.") # Option 1: Impute (e.g., with column mean) - complex # Option 2: Use nan-aware numpy functions # Let's use nan-aware functions # Note: Some metrics are better when *lower* (MAE, RMSE, MAPE, train_time). # For visualization where larger radius is better, we might invert these # before scaling, or adjust the interpretation. Let's scale first. if scale in ['norm', 'min-max']: if verbose >= 1: print("[INFO] Scaling metrics using Min-Max.") min_vals = np.nanmin(results_scaled, axis=0) max_vals = np.nanmax(results_scaled, axis=0) range_vals = max_vals - min_vals # Avoid division by zero for metrics with no variance range_vals[range_vals < 1e-9] = 1.0 results_scaled = (results_scaled - min_vals) / range_vals # Now, for error metrics, higher value (closer to 1) is WORSE. # Invert them so higher value (closer to 1) is BETTER. for j, name in enumerate(metric_names): if name in error_metrics or name == "Train Time (s)": results_scaled[:, j] = 1.0 - results_scaled[:, j] # Scaled results are now in [0, 1], higher is better. elif scale in ['std', 'standard']: if verbose >= 1: print("[INFO] Scaling metrics using Standard Scaler.") mean_vals = np.nanmean(results_scaled, axis=0) std_vals = np.nanstd(results_scaled, axis=0) # Avoid division by zero std_vals[std_vals < 1e-9] = 1.0 results_scaled = (results_scaled - mean_vals) / std_vals # Std scaling preserves relative order but changes range. # Lower errors become more negative. Higher scores become more positive. # Maybe invert sign for error metrics? for j, name in enumerate(metric_names): if name in error_metrics or name == "Train Time (s)": results_scaled[:, j] = -results_scaled[:, j] # Now higher value means better performance (higher score or lower error) # but range is not [0, 1]. We need to handle lower_bound. # Replace any potential NaNs resulting from scaling (e.g., if all NaNs) results_scaled = np.nan_to_num(results_scaled, nan=lower_bound) # --- Plotting --- fig = plt.figure(figsize=figsize or (8, 8)) # Default figsize here ax = fig.add_subplot(111, polar=True) # Angles for each metric axis num_metrics = len(metric_names) angles = np.linspace(0, 2 * np.pi, num_metrics, endpoint=False).tolist() angles_closed = angles + angles[:1] # Repeat first angle to close plot # Colors if colors is None: # Use a robust colormap like tab10 if available try: cmap_obj = plt.get_cmap("tab10") plot_colors = [cmap_obj(i % 10) for i in range(n_models)] except ValueError: # Fallback if tab10 not found (unlikely) cmap_obj = plt.get_cmap("viridis") plot_colors = [cmap_obj(i / n_models) for i in range(n_models)] else: plot_colors = colors # Use user-provided list # Plot each model for i, row in enumerate(results_scaled): values = np.concatenate((row, [row[0]])) # Close the polygon color = plot_colors[i % len(plot_colors)] # Cycle colors if needed ax.plot(angles_closed, values, label=names[i], color=color, linewidth=1.5, alpha=alpha) ax.fill(angles_closed, values, color=color, alpha=0.1) # Lighter fill # --- Configure Axes --- ax.set_xticks(angles) ax.set_xticklabels(metric_names) # Adjust radial limits and labels # If scaled to [0, 1], set limit slightly above 1 # If std scaled, auto-limit might be better, but respect lower_bound if scale in ['norm', 'min-max']: ax.set_ylim(bottom=lower_bound, top=1.05) # Optional: Add radial ticks for [0, 1] scale ax.set_yticks(np.linspace(lower_bound, 1, 5)) else: # Raw or std scaled ax.set_ylim(bottom=lower_bound) # Let matplotlib auto-determine upper limit and ticks ax.tick_params(axis='y', labelsize=8) # Smaller radial labels ax.tick_params(axis='x', pad=10) # Pad angular labels outwards # Grid set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props ) # Legend if legend: ax.legend(loc=loc, bbox_to_anchor=(1.25, 1.05)) # Adjust position # Title ax.set_title(title or "Model Performance Comparison", y=1.15, fontsize=14) # --- Output --- plt.tight_layout(pad=2.0) # Adjust layout if savefig: try: plt.savefig(savefig, bbox_inches='tight', dpi=300) print(f"Plot saved to {savefig}") except Exception as e: print(f"Error saving plot to {savefig}: {e}") else: try: plt.show() except Exception as e: warnings.warn(f"Could not display plot interactively ({e})." f" Use savefig parameter.", UserWarning) return ax