kdiagram.plot.comparison.plot_model_comparison¶
- kdiagram.plot.comparison.plot_model_comparison(y_true, *y_preds, train_times=None, metrics=None, names=None, title=None, figsize=None, colors=None, alpha=0.7, legend=True, show_grid=True, grid_props=None, scale='norm', lower_bound=0.0, savefig=None, loc='upper right', verbose=0, acov='default', ax=None)[source]¶
Plot multi-metric model performance comparison on a radar chart.
Generates a radar chart (spider chart) visualizing multiple performance metrics for one or more models simultaneously. Each axis corresponds to a metric (e.g., R2, MAE, accuracy, precision), and each polygon represents a model, allowing for a holistic comparison of their strengths and weaknesses across different evaluation criteria [1].
This function is highly valuable for model selection, providing a compact overview that goes beyond single-score comparisons. Use it when you need to balance trade-offs between various metrics (like accuracy vs. training time) or understand how different models perform relative to each other across a spectrum of relevant performance indicators. Internally relies on helpers to handle potential NaN values and determine data types [2].
- Parameters:
- y_truearray_like
ofshape(n_samples,) The ground truth (correct) target values.
- *y_predsarray_like
ofshape(n_samples,) Variable number of prediction arrays, one for each model to be compared. Each array must have the same length as y_true.
- train_times
floatorlistoffloat,optional Training time in seconds for each model corresponding to *y_preds. If provided:
A single float assumes the same time for all models.
A list must match the number of models.
It will be added as an additional axis/metric on the chart. Default is
None.- metrics
str,callable(),listofthese,optional The performance metrics to calculate and plot. Default is
None, which triggers automatic metric selection based on the target type inferred from y_true:Regression: Defaults to
["r2", "mae", "mape", "rmse"].Classification: Defaults to
["accuracy", "precision", "recall"].
Can be provided as:
A list of strings: Names of metrics known by scikit-learn or gofast’s get_scorer (e.g.,
['r2', 'rmse']).A list of callables: Functions with the signature metric(y_true, y_pred).
A mix of strings and callables.
- names
listofstr,optional Names for each model corresponding to *y_preds. Used for the legend. If
Noneor too short, defaults like “Model_1”, “Model_2” are generated. Default isNone.- title
str,optional Title displayed above the radar chart. If
None, a generic title may be used internally or omitted. Default isNone.- figsize
tupleof(float,float),optional Figure size
(width, height)in inches. IfNone, uses Matplotlib’s default (often similar to(8, 8)for this type of plot).- colors
listofstrorNone,optional List of Matplotlib color specifications for each model’s polygon. If
None, colors are automatically assigned from the default palette (‘tab10’). If provided, the list length should ideally match n_models.- alpha
float,optional Transparency level (between 0 and 1) for the plotted lines and filled areas. Default is
0.7. (Note: Fill alpha is often hardcoded lower, e.g., 0.1, in implementation).- legendbool,
optional If
True, display a legend mapping colors/lines to model names. Default isTrue.- show_gridbool,
optional If
True, display the radial grid lines on the chart. Default isTrue.- scale{‘norm’, ‘min-max’, ‘std’, ‘standard’},
optional Method for scaling metric values before plotting. Scaling is applied independently to each metric (axis) across models. Default is
'norm'.'norm'or'min-max': Min-max scaling. Transforms values to the range [0, 1] using \((X - min) / (max - min)\). Useful for comparing relative performance when metrics have different scales.'std'or'standard': Standard scaling (Z-score). Transforms values to have zero mean and unit variance using \((X - mean) / std\). Preserves relative spacing better than min-max but results can be negative.None: Plot raw metric values without scaling. Use only if metrics naturally share a comparable, non-negative range.
- lower_bound
float,optional Sets the minimum value for the radial axis (innermost circle). Useful when using standard scaling (‘std’) which can produce negative values, or to adjust the plot’s center. Default is
0.- savefig
str,optional If provided, the file path (e.g., ‘radar_comparison.svg’) where the figure will be saved. If
None, the plot is displayed interactively. Default isNone.- loc
str,optional Location argument passed to matplotlib.pyplot.legend() to position the legend (e.g., ‘upper right’, ‘lower left’, ‘center right’). Default is
'upper right'.- verbose
int,optional Controls the verbosity level.
0is silent. Higher values may print debugging information during metric calculation or scaling. Default is0.- acov{‘default’, ‘half_circle’, ‘quarter_circle’,
‘eighth_circle’}, default=’default’ Angular coverage of the polar sector.
'default': full circle, \(2\pi\) (360°)'half_circle': \(\pi\) (180°)'quarter_circle': \(\pi/2\) (90°)'eighth_circle': \(\pi/4\) (45°)
- y_truearray_like
- Returns:
- ax
matplotlib.axes.Axes The Matplotlib Axes object containing the radar chart. Allows for further customization after the function call.
- ax
- Raises:
ValueErrorIf lengths of y_preds, names (if provided), and train_times (if provided) do not match. If an invalid string is provided for scale. If a metric string name is not recognized by the internal scorer.
TypeErrorIf y_true or y_preds contain non-numeric data.
- Parameters:
See also
kdiagram.utils.metric_utils.get_scorerFunction likely used internally to fetch metric callables (verify path).
sklearn.metricsScikit-learn metrics module.
matplotlib.pyplot.polarFunction for creating polar plots.
Notes
This function provides a multi-dimensional view of model performance.
Metric Calculation: For each model \(k\) with predictions \(\hat{y}_k\) and each metric \(m\) (from the metrics list), the score \(S_{m,k}\) is calculated:
(1)¶\[S_{m,k} = \text{Metric}_m(y_{true}, \hat{y}_k)\]If train_times are provided, they are treated as an additional metric axis.
Scaling: If scale is specified, scaling is applied column-wise (per metric) across all models before plotting:
Min-Max (‘norm’):
(2)¶\[S'_{m,k} = \frac{S_{m,k} - \min_j(S_{m,j})}{\max_j(S_{m,j}) - \min_j(S_{m,j})}\]Standard (‘std’):
(3)¶\[S'_{m,k} = \frac{S_{m,k} - \text{mean}_j(S_{m,j})}{\text{std}_j(S_{m,j})}\]
Plotting: The (scaled) scores \(S'_{m,k}\) for each model \(k\) determine the radial distance along the axis corresponding to metric \(m\). Points are connected to form a polygon for each model.
References
[1]Wikipedia contributors. (2024). Radar chart. In Wikipedia, The Free Encyclopedia. Retrieved April 14, 2025, from https://en.wikipedia.org/wiki/Radar_chart (General reference for radar charts)
[2]Kenny-Denecke, J. F., Hernandez-Amaro, A., Martin-Gorriz, M. L., & Castejon-Limos, P. (2024). Lead-Time Prediction in Wind Tower Manufacturing: A Machine Learning-Based Approach. Mathematics, 12(15), 2347. https://doi.org/10.3390/math12152347 (Example application using radar charts for ML comparison)
Examples
>>> from kdiagram.plot.comparison import plot_model_comparison >>> import numpy as np >>> >>> # Example 1: Regression task >>> y_true_reg = np.array([3, -0.5, 2, 7, 5]) >>> y_pred_r1 = np.array([2.5, 0.0, 2.1, 7.8, 5.2]) >>> y_pred_r2 = np.array([3.2, 0.2, 1.8, 6.5, 4.8]) >>> times = [0.1, 0.5] # Training times in seconds >>> names = ['ModelLin', 'ModelTree'] >>> ax1 = plot_model_comparison(y_true_reg, y_pred_r1, y_pred_r2, ... train_times=times, names=names, ... metrics=['r2', 'mae', 'rmse'], # Specify metrics ... title="Regression Model Comparison", ... scale='norm') # Normalize for comparison >>> >>> # Example 2: Classification task (requires appropriate y_true/y_pred) >>> y_true_clf = np.array([0, 1, 0, 1, 1, 0]) >>> y_pred_c1 = np.array([0, 1, 0, 1, 0, 0]) # Model 1 preds >>> y_pred_c2 = np.array([0, 1, 1, 1, 1, 0]) # Model 2 preds >>> ax2 = plot_model_comparison(y_true_clf, y_pred_c1, y_pred_c2, ... names=["LogReg", "SVM"], ... # Uses default classification metrics ... title="Classification Model Comparison", ... scale='norm')