# License: Apache 2.0
# Author: LKouadio <etanoyau@gmail.com>
"""Model comparison plots."""
from __future__ import annotations
import warnings
from numbers import Real
from typing import (
Any,
Callable,
Literal,
)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import cm
from matplotlib.colors import Normalize
from ..compat.matplotlib import get_cmap
from ..compat.sklearn import StrOptions, type_of_target, validate_params
from ..decorators import check_non_emptiness, isdf
from ..utils.generic_utils import drop_nan_in
from ..utils.handlers import columns_manager
from ..utils.metric_utils import get_scorer
from ..utils.plot import set_axis_grid
from ..utils.validator import _assert_all_types, is_iterable, validate_yy
__all__ = [
"plot_reliability_diagram",
"plot_model_comparison",
"plot_horizon_metrics",
]
[docs]
@validate_params(
{
"y_true": ["array-like"],
"strategy": [StrOptions({"uniform", "quantile"})],
"error_bars": [StrOptions({"wilson", "normal", "none"})],
"counts_panel": [StrOptions({"none", "bottom"})],
"counts_norm": [StrOptions({"fraction", "count"})],
}
)
def plot_reliability_diagram(
y_true,
*y_preds,
names: list[str] | None = None,
sample_weight: list[float] | np.ndarray | None = None,
n_bins: int = 10,
strategy: str = "uniform",
positive_label: int | float | str = 1,
class_index: int | None = None,
clip_probs: tuple[float, float] = (0.0, 1.0),
normalize_probs: bool = True,
error_bars: str = "wilson",
conf_level: float = 0.95,
show_diagonal: bool = True,
diagonal_kwargs: dict[str, Any] | None = None,
show_ece: bool = True,
show_brier: bool = True,
counts_panel: str = "bottom",
counts_norm: Literal["fraction", "count"] = "fraction",
counts_alpha: float = 0.35,
figsize: tuple[float, float] | None = (9, 7),
title: str | None = None,
xlabel: str | None = "Predicted probability",
ylabel: str | None = "Observed frequency",
cmap: str = "tab10",
color_palette: list[Any] | None = None,
marker: str = "o",
s: int = 40,
linewidth: float = 2.0,
alpha: float = 0.9,
connect: bool = True,
legend: bool = True,
legend_loc: str = "best",
show_grid: bool = True,
grid_props: dict | None = None,
xlim: tuple[float, float] = (0.0, 1.0),
ylim: tuple[float, float] = (0.0, 1.0),
savefig: str | None = None,
return_data: bool = False,
**kw, # for future extension
):
# -------------- input handling -------------- #
if len(y_preds) == 0:
raise ValueError(
"Provide at least one prediction array via *y_preds."
)
names = columns_manager(names, to_string=True) or []
if len(names) < len(y_preds):
names.extend(
[f"Model_{i+1}" for i in range(len(names), len(y_preds))]
)
if len(names) > len(y_preds):
warnings.warn(
(
f"Received {len(names)} names for {len(y_preds)} models. "
"Extra names ignored."
),
UserWarning,
stacklevel=2,
)
names = names[: len(y_preds)]
y_true = np.asarray(y_true)
if type_of_target(y_true) not in ("binary", "multiclass"):
raise ValueError(
"y_true must be a classification target. "
"Binary reliability is expected."
)
y_bin = (y_true == positive_label).astype(int)
prob_list: list[np.ndarray] = []
for arr in y_preds:
arr = np.asarray(arr)
prob_list.append(_to_prob_vector(arr, class_index))
if sample_weight is None:
y_bin, *prob_list = drop_nan_in(y_bin, *prob_list, error="raise")
w = np.ones_like(y_bin, dtype=float)
else:
w = np.asarray(sample_weight, dtype=float)
y_bin, *prob_list, w = drop_nan_in(
y_bin, *prob_list, w, error="raise"
)
clip_lo, clip_hi = clip_probs
clipped_flag = False
new_probs = []
for p in prob_list:
p0 = p.copy()
p1 = _prep_probs(p0, clip_lo, clip_hi, normalize_probs)
if not np.allclose(p0, p1):
clipped_flag = True
new_probs.append(p1)
prob_list = new_probs
if clipped_flag:
warnings.warn(
(
"Some predicted probabilities were normalized/clipped "
f"to [{clip_lo}, {clip_hi}]."
),
UserWarning,
stacklevel=2,
)
edges, centers = _build_bins(
prob_list, n_bins, strategy, clip_lo, clip_hi
)
z = _z_from_conf(conf_level)
# -------------- colors & layout -------------- #
colors = _colors(cmap, color_palette, len(prob_list))
if counts_panel == "bottom":
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 1, height_ratios=(3.0, 1.0), hspace=0.12)
ax = fig.add_subplot(gs[0, 0])
axb = fig.add_subplot(gs[1, 0], sharex=ax)
else:
fig, ax = plt.subplots(figsize=figsize)
axb = None
# -------------- compute & plot -------------- #
per_model: dict[str, pd.DataFrame] = {}
for i, (name, p, col) in enumerate(zip(names, prob_list, colors)):
stats = _bin_stats(p, y_bin, w, edges, error_bars, z)
ece = float(np.nansum(stats["ece"]))
br = _brier(p, y_bin, w)
df = pd.DataFrame(
{
"bin_left": edges[:-1],
"bin_right": edges[1:],
"bin_center": centers,
"n": stats["n"],
"w_sum": stats["wsum"],
"p_mean": stats["pmean"],
"y_rate": stats["yrate"],
"y_low": stats["ylo"],
"y_high": stats["yhi"],
"ece_contrib": stats["ece"],
}
)
per_model[name] = df
valid = df["w_sum"].to_numpy() > 0
x = df.loc[valid, "p_mean"].to_numpy()
y = df.loc[valid, "y_rate"].to_numpy()
ylo = df.loc[valid, "y_low"].to_numpy()
yhi = df.loc[valid, "y_high"].to_numpy()
if error_bars.lower() != "none":
yerr = np.vstack([y - ylo, yhi - y])
ax.errorbar(
x,
y,
yerr=yerr,
fmt="none",
ecolor=col,
elinewidth=1.0,
capsize=2,
alpha=alpha * 0.85,
)
ax.scatter(x, y, c=[col], s=s, marker=marker, alpha=alpha)
if connect and len(x) > 1:
ax.plot(x, y, color=col, linewidth=linewidth, alpha=alpha)
label = name
pieces = []
if show_ece:
pieces.append(f"ECE={ece:.3f}")
if show_brier:
pieces.append(f"Brier={br:.3f}")
if pieces:
label = f"{label} ({', '.join(pieces)})"
ax.plot(
[],
[],
color=col,
marker=marker,
linestyle="-" if connect else "None",
linewidth=linewidth,
label=label,
)
if axb is not None:
bw = edges[1:] - edges[:-1]
slot = bw * 0.8 / max(1, len(prob_list))
left = edges[:-1] + i * slot
vals = df["w_sum"].to_numpy()
if counts_norm == "fraction":
denom = vals.sum() if vals.sum() > 0 else 1.0
vals = vals / denom
axb.bar(
left,
vals,
width=slot,
align="edge",
color=col,
alpha=counts_alpha,
label=name,
)
# -------------- format axes -------------- #
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if title:
ax.set_title(title)
if show_diagonal:
diag_kw = {
"color": "gray",
"linestyle": "--",
"linewidth": 1.2,
"alpha": 0.9,
}
if diagonal_kwargs:
_assert_all_types(
diagonal_kwargs, dict, objname="'diagonal_kwargs'"
)
diag_kw.update(diagonal_kwargs)
ax.plot((0.0, 1.0), (0.0, 1.0), **diag_kw)
set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props)
if legend:
ax.legend(loc=legend_loc)
if axb is not None:
axb.set_xlim(*xlim)
axb.axhline(0, color="gray", lw=0.8)
axb.set_xlabel(xlabel or "Predicted probability")
axb.set_ylabel("Frac." if counts_norm == "fraction" else "Count")
set_axis_grid(axb, show_grid=True, grid_props={"alpha": 0.25})
handles, labels = axb.get_legend_handles_labels()
if handles and labels:
axb.legend(loc="upper right", fontsize=8)
plt.tight_layout()
if savefig:
try:
fig.savefig(savefig, bbox_inches="tight", dpi=300)
print(f"Plot saved to {savefig}")
except Exception as e:
print(f"Error saving plot to {savefig}: {e}")
else:
plt.show()
if return_data:
return ax, per_model
return ax
# ------------------ helpers ------------------ #
def _z_from_conf(cf: float) -> float:
table = {
0.80: 1.2815515655,
0.90: 1.6448536269,
0.95: 1.9599639845,
0.975: 2.241402728,
0.99: 2.5758293035,
}
return table.get(round(cf, 3), 1.9599639845)
def _to_prob_vector(arr: np.ndarray, ci: int | None) -> np.ndarray:
if arr.ndim == 1:
return arr.astype(float, copy=False)
if arr.ndim == 2:
idx = arr.shape[1] - 1 if ci is None else ci
if idx < 0 or idx >= arr.shape[1]:
raise ValueError(
"class_index out of bounds for 2D predictions: "
f"{idx} not in [0, {arr.shape[1]-1}]"
)
return arr[:, idx].astype(float, copy=False)
raise ValueError(
"Predictions must be 1D probabilities or "
"(n_samples, n_classes) arrays."
)
def _prep_probs(
p: np.ndarray, clip_lo: float, clip_hi: float, do_norm: bool
) -> np.ndarray:
p = np.asarray(p, dtype=float)
if do_norm:
pmin, pmax = np.nanmin(p), np.nanmax(p)
if (pmin < -1e-9) or (pmax > 1.0 + 1e-9):
rng = pmax - pmin
if rng > 1e-12:
p = (p - pmin) / rng
p = np.clip(p, clip_lo, clip_hi)
return p
def _build_bins(
probs_list: list[np.ndarray], nb: int, strat: str, low: float, high: float
) -> tuple[np.ndarray, np.ndarray]:
if strat == "uniform":
edges = np.linspace(low, high, nb + 1)
else:
allp = np.concatenate(probs_list)
q = np.linspace(0.0, 1.0, nb + 1)
edges = np.quantile(allp, q)
edges = np.unique(edges)
if len(edges) - 1 < nb:
warnings.warn(
(
"Not enough unique quantile edges; "
"falling back to uniform bins."
),
UserWarning,
stacklevel=2,
)
edges = np.linspace(low, high, nb + 1)
centers = 0.5 * (edges[:-1] + edges[1:])
return edges, centers
def _bin_stats(
p: np.ndarray,
y: np.ndarray,
w: np.ndarray,
edges: np.ndarray,
ebars: str,
zval: float,
) -> dict[str, np.ndarray]:
nb = len(edges) - 1
eps = 1e-12
idx = np.digitize(p, edges, right=False) - 1
idx[idx < 0] = 0
idx[idx >= nb] = nb - 1
n = np.zeros(nb, dtype=float)
wsum = np.zeros(nb, dtype=float)
pmean = np.zeros(nb, dtype=float)
yr = np.zeros(nb, dtype=float)
for b in range(nb):
m = idx == b
if not np.any(m):
continue
ww = w[m]
pp = p[m]
yy = y[m]
wsum[b] = ww.sum()
n[b] = float(m.sum())
denom = max(wsum[b], eps)
pmean[b] = float(np.dot(ww, pp) / denom)
yr[b] = float(np.dot(ww, yy) / denom)
if ebars.lower() == "none":
ylo = np.full_like(yr, np.nan)
yhi = np.full_like(yr, np.nan)
elif ebars.lower() == "normal":
neff = np.maximum(wsum, eps)
se = np.sqrt(np.clip(yr * (1.0 - yr) / neff, 0.0, 1.0))
ylo = np.clip(yr - zval * se, 0.0, 1.0)
yhi = np.clip(yr + zval * se, 0.0, 1.0)
else:
neff = np.maximum(wsum, eps)
ylo = np.empty_like(yr)
yhi = np.empty_like(yr)
for i in range(nb):
ph = yr[i]
n_ = neff[i]
if n_ <= eps:
ylo[i] = np.nan
yhi[i] = np.nan
continue
denom = 1.0 + (zval**2) / n_
center = (ph + (zval**2) / (2.0 * n_)) / denom
rad = (
zval
* np.sqrt((ph * (1.0 - ph) + (zval**2) / (4.0 * n_)) / n_)
) / denom
ylo[i] = np.clip(center - rad, 0.0, 1.0)
yhi[i] = np.clip(center + rad, 0.0, 1.0)
totw = max(w.sum(), eps)
wbin = wsum / totw
ece_contrib = wbin * np.abs(yr - pmean)
return {
"n": n,
"wsum": wsum,
"pmean": pmean,
"yrate": yr,
"ylo": ylo,
"yhi": yhi,
"wbin": wbin,
"ece": ece_contrib,
}
def _brier(p: np.ndarray, y: np.ndarray, w: np.ndarray) -> float:
return float(np.average((p - y) ** 2, weights=w))
def _colors(cmap_name: str, palette: list[Any] | None, k: int) -> list[Any]:
if palette is not None:
return [palette[i % len(palette)] for i in range(k)]
try:
cmo = get_cmap(cmap_name, default="tab10", failsafe="discrete")
except ValueError:
warnings.warn(
f"Invalid cmap '{cmap_name}'. Using 'tab10' instead.",
UserWarning,
stacklevel=2,
)
cmo = get_cmap("tab10", default="tab10", failsafe="discrete")
if hasattr(cmo, "colors") and len(cmo.colors) >= k:
return list(cmo.colors[:k])
if k == 1:
return [cmo(0.5)]
return [cmo(i / (k - 1)) for i in range(k)]
plot_reliability_diagram.__doc__ = r"""
Plot a reliability diagram (calibration plot) for one or more
classification models.
This compares **predicted probabilities** to **observed
frequencies** across bins of predicted probability. Perfect
calibration lies on the diagonal :math:`y=x`.
Parameters
----------
y_true : array-like of shape (n_samples,)
Ground truth labels. For binary calibration, values are
compared to ``positive_label`` after validation and
flattening.
*y_preds : array-like(s)
One or more model predictions. Each item may be:
- 1D array of positive-class probabilities in ``[0, 1]``.
- 2D array of shape ``(n_samples, n_classes)``; use
``class_index`` to select a column. If omitted, the
last column is used.
names : list of str, optional
Labels for each model curve. If fewer names are provided
than models, placeholders like ``'Model_1'`` are appended.
sample_weight : array-like of shape (n_samples,), optional
Per-sample weights used for observed frequencies, ECE,
and Brier score. If ``None``, equal weights are used.
n_bins : int, default=10
Number of probability bins.
strategy : {'uniform', 'quantile'}, default='uniform'
Binning strategy.
- ``'uniform'``: equally spaced edges in ``[0, 1]``.
- ``'quantile'``: edges are empirical quantiles of the
pooled predictions. If edges are not unique, the method
falls back to uniform binning with a warning.
positive_label : int or float or str, default=1
Label in ``y_true`` treated as the positive class when
constructing the binary target.
class_index : int, optional
Column index to pick from 2D probability arrays. If
omitted, the last column is used.
clip_probs : tuple of (float, float), default=(0.0, 1.0)
Inclusive clipping range applied to predictions. A warning
is issued if clipping occurs.
normalize_probs : bool, default=True
If ``True``, attempts to linearly rescale predictions into
``[0, 1]`` when minor out-of-range values are detected,
then applies clipping.
error_bars : {'wilson', 'normal', 'none'}, default='wilson'
Per-bin uncertainty for observed frequencies.
- ``'wilson'``: Wilson interval using ``conf_level``.
- ``'normal'``: normal approximation.
- ``'none'``: no error bars.
conf_level : float, default=0.95
Confidence level used for error bars when applicable.
show_diagonal : bool, default=True
Draw the reference diagonal :math:`y=x`.
diagonal_kwargs : dict, optional
Matplotlib keyword arguments for the diagonal reference
line (e.g., ``linestyle``, ``color``).
show_ece : bool, default=True
Compute Expected Calibration Error (ECE) and append a
summary to each model label.
show_brier : bool, default=True
Compute (weighted) Brier score and append a summary to
each model label.
counts_panel : {'none', 'bottom'}, default='bottom'
If not ``'none'``, draw a compact histogram below the main
panel that shows per-bin totals for each model.
counts_norm : {'fraction', 'count'}, default='fraction'
Normalization for the counts panel. ``'fraction'`` divides
by the total weight; ``'count'`` shows raw weighted sums.
counts_alpha : float, default=0.35
Alpha for bars in the counts panel.
figsize : tuple of (float, float), default=(9, 7)
Figure size for the layout. When ``counts_panel='bottom'``,
a two-row gridspec is used.
title : str, optional
Title for the plot. If ``None``, no title is set.
xlabel : str, optional
Label for the x-axis. Defaults to
``'Predicted probability'``.
ylabel : str, optional
Label for the y-axis. Defaults to
``'Observed frequency'``.
cmap : str, default='tab10'
Matplotlib colormap name used to generate model colors.
color_palette : list, optional
Explicit list of colors. When provided, colors are cycled
from this list instead of the colormap.
marker : str, default='o'
Marker used for the bin points.
s : int, default=40
Marker size for the bin points.
linewidth : float, default=2.0
Line width used when connecting bin points.
alpha : float, default=0.9
Alpha for points and lines in the main panel.
connect : bool, default=True
Connect bin points with a line for each model.
legend : bool, default=True
Display a legend. Summary metrics (ECE, Brier) are shown
next to model names when enabled.
legend_loc : str, default='best'
Legend location passed to Matplotlib.
show_grid : bool, default=True
Toggle gridlines via the package helper ``set_axis_grid``.
grid_props : dict, optional
Keyword arguments passed to ``set_axis_grid`` for grid
customization (e.g., ``linestyle``, ``alpha``).
xlim : tuple of (float, float), default=(0.0, 1.0)
X-axis limits.
ylim : tuple of (float, float), default=(0.0, 1.0)
Y-axis limits.
savefig : str, optional
If provided, save the figure to this path; otherwise the
plot is shown interactively.
return_data : bool, default=False
If ``True``, return ``(ax, data_dict)`` where values are
per-model ``pandas.DataFrame`` objects with per-bin stats:
``['bin_left', 'bin_right', 'bin_center', 'n', 'w_sum',
'p_mean', 'y_rate', 'y_low', 'y_high', 'ece_contrib']``.
Otherwise, return only the Matplotlib axes.
Returns
-------
ax : matplotlib.axes.Axes
Axes of the main calibration plot. When
``counts_panel='bottom'``, the second axes (counts panel)
is not returned.
Notes
-----
Calibration compares *confidence* to *accuracy* within bins.
For bin :math:`b`, let :math:`\hat{p}_i` be predictions and
:math:`y_i\in\{0,1\}` be binary targets with weights
:math:`w_i\ge 0`. Define the weighted bin mean probability
and accuracy as
.. math::
\bar{p}_b \;=\;
\frac{\sum_{i\in b} w_i \hat{p}_i}
{\sum_{i\in b} w_i},
\qquad
\bar{y}_b \;=\;
\frac{\sum_{i\in b} w_i y_i}
{\sum_{i\in b} w_i}.
The Expected Calibration Error (ECE) is
.. math::
\mathrm{ECE}
\;=\;
\sum_b
\left(
\frac{\sum_{i\in b} w_i}{\sum_i w_i}
\right)
\left|
\bar{y}_b - \bar{p}_b
\right|.
The (weighted) Brier score is
.. math::
\mathrm{Brier}
\;=\;
\frac{\sum_i
w_i \left(\hat{p}_i - y_i\right)^2}
{\sum_i w_i}.
Wilson confidence intervals for :math:`\bar{y}_b` use
:math:`z = \Phi^{-1}\!\left(\tfrac{1+\alpha}{2}\right)` and
effective count :math:`n_b=\sum_{i\in b} w_i`:
.. math::
\mathrm{center}
\;=\;
\frac{\bar{y}_b + \frac{z^2}{2 n_b}}
{1 + \frac{z^2}{n_b}},
\qquad
\mathrm{radius}
\;=\;
\frac{z}{1 + \frac{z^2}{n_b}}
\sqrt{\frac{\bar{y}_b(1-\bar{y}_b)}{n_b}
+ \frac{z^2}{4 n_b^2}}.
The interval is
:math:`[\mathrm{center}-\mathrm{radius},
\mathrm{center}+\mathrm{radius}]`,
clipped to ``[0, 1]``. The normal interval replaces the term
with the usual standard error
:math:`\sqrt{\bar{y}_b(1-\bar{y}_b)/n_b}`.
When ``strategy='quantile'``, bin edges are the empirical
quantiles of the pooled predictions. If many identical values
exist, edges can collapse; in that case, the function falls
back to uniform edges with a warning.
Examples
--------
Binary example with quantile bins and Wilson intervals.
>>> import numpy as np
>>> from kdiagram.plot.comparison import \
... plot_reliability_diagram
>>> rng = np.random.default_rng(0)
>>> y = (rng.random(1000) < 0.4).astype(int)
>>> p1 = 0.4 * np.ones_like(y) + 0.15 * rng.random(len(y))
>>> p2 = 0.4 * np.ones_like(y) + 0.05 * rng.random(len(y))
>>> ax = plot_reliability_diagram(
... y, p1, p2,
... names=['Wide', 'Tight'],
... n_bins=12,
... strategy='quantile',
... error_bars='wilson',
... counts_panel='bottom',
... show_ece=True,
... show_brier=True,
... title=('Reliability Diagram '
... '(Quantile bins + Wilson CIs)'),
... )
"""
[docs]
@validate_params(
{
"train_times": ["array-like", None],
"metrics": [str, "array-like", callable, None],
"scale": [
StrOptions(
{
"norm",
"min-max",
"std",
"standard",
}
),
None,
],
"lower_bound": [Real],
}
)
def plot_model_comparison(
y_true,
*y_preds,
train_times: float | list[float] | None = None,
metrics: str | Callable | list[str | Callable] | None = None,
names: list[str] | None = None,
title: str | None = None,
figsize: tuple[float, float] | None = None,
colors: list[Any] | None = None,
alpha: float = 0.7,
legend: bool = True,
show_grid: bool = True,
grid_props: dict = None,
scale: str | None = "norm",
lower_bound: float = 0,
savefig: str | None = None,
loc: str = "upper right",
verbose: int = 0,
):
r"""Plot multi-metric model performance comparison on a radar chart.
Generates a radar chart (spider chart) visualizing multiple
performance metrics for one or more models simultaneously. Each
axis corresponds to a metric (e.g., R2, MAE, accuracy,
precision), and each polygon represents a model, allowing for a
holistic comparison of their strengths and weaknesses across
different evaluation criteria [1]_.
This function is highly valuable for model selection, providing a
compact overview that goes beyond single-score comparisons. Use
it when you need to balance trade-offs between various metrics
(like accuracy vs. training time) or understand how different
models perform relative to each other across a spectrum of
relevant performance indicators. Internally relies on helpers
to handle potential NaN values and determine data types [2]_.
Parameters
----------
y_true : array-like of shape (n_samples,)
The ground truth (correct) target values.
*y_preds : array-like of shape (n_samples,)
Variable number of prediction arrays, one for each model to
be compared. Each array must have the same length as
`y_true`.
train_times : float or list of float, optional
Training time in seconds for each model corresponding to
`*y_preds`. If provided:
- A single float assumes the same time for all models.
- A list must match the number of models.
It will be added as an additional axis/metric on the chart.
Default is ``None``.
metrics : str, callable, list of these, optional
The performance metrics to calculate and plot. Default is
``None``, which triggers automatic metric selection based on
the target type inferred from `y_true`:
- **Regression:** Defaults to ``["r2", "mae", "mape", "rmse"]``.
- **Classification:** Defaults to ``["accuracy", "precision",
"recall"]``.
Can be provided as:
- A list of strings: Names of metrics known by scikit-learn
or gofast's `get_scorer` (e.g., ``['r2', 'rmse']``).
- A list of callables: Functions with the signature
`metric(y_true, y_pred)`.
- A mix of strings and callables.
names : list of str, optional
Names for each model corresponding to `*y_preds`. Used for
the legend. If ``None`` or too short, defaults like
"Model_1", "Model_2" are generated. Default is ``None``.
title : str, optional
Title displayed above the radar chart. If ``None``, a generic
title may be used internally or omitted. Default is ``None``.
figsize : tuple of (float, float), optional
Figure size ``(width, height)`` in inches. If ``None``, uses
Matplotlib's default (often similar to ``(8, 8)`` for this
type of plot).
colors : list of str or None, optional
List of Matplotlib color specifications for each model's
polygon. If ``None``, colors are automatically assigned from
the default palette ('tab10'). If provided, the list length
should ideally match `n_models`.
alpha : float, optional
Transparency level (between 0 and 1) for the plotted lines
and filled areas. Default is ``0.7``. (Note: Fill alpha is
often hardcoded lower, e.g., 0.1, in implementation).
legend : bool, optional
If ``True``, display a legend mapping colors/lines to model
names. Default is ``True``.
show_grid : bool, optional
If ``True``, display the radial grid lines on the chart.
Default is ``True``.
scale : {'norm', 'min-max', 'std', 'standard'}, optional
Method for scaling metric values before plotting. Scaling is
applied independently to each metric (axis) across models.
Default is ``'norm'``.
- ``'norm'`` or ``'min-max'``: Min-max scaling. Transforms
values to the range [0, 1] using
:math:`(X - min) / (max - min)`. Useful for comparing
relative performance when metrics have different scales.
- ``'std'`` or ``'standard'``: Standard scaling (Z-score).
Transforms values to have zero mean and unit variance using
:math:`(X - mean) / std`. Preserves relative spacing better
than min-max but results can be negative.
- ``None``: Plot raw metric values without scaling. Use only
if metrics naturally share a comparable, non-negative range.
lower_bound : float, optional
Sets the minimum value for the radial axis (innermost circle).
Useful when using standard scaling ('std') which can produce
negative values, or to adjust the plot's center.
Default is ``0``.
savefig : str, optional
If provided, the file path (e.g., 'radar_comparison.svg')
where the figure will be saved. If ``None``, the plot is
displayed interactively. Default is ``None``.
loc : str, optional
Location argument passed to `matplotlib.pyplot.legend()` to
position the legend (e.g., 'upper right', 'lower left',
'center right'). Default is ``'upper right'``.
verbose : int, optional
Controls the verbosity level. ``0`` is silent. Higher values
may print debugging information during metric calculation or
scaling. Default is ``0``.
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object containing the radar chart. Allows
for further customization after the function call.
Raises
------
ValueError
If lengths of `y_preds`, `names` (if provided), and
`train_times` (if provided) do not match. If an invalid
string is provided for `scale`. If a metric string name is
not recognized by the internal scorer.
TypeError
If `y_true` or `y_preds` contain non-numeric data.
See Also
--------
kdiagram.utils.metric_utils.get_scorer : Function likely used
internally to fetch metric callables (verify path).
sklearn.metrics : Scikit-learn metrics module.
matplotlib.pyplot.polar : Function for creating polar plots.
Notes
-----
This function provides a multi-dimensional view of model performance.
**Metric Calculation:**
For each model :math:`k` with predictions :math:`\hat{y}_k` and
each metric :math:`m` (from the `metrics` list), the score
:math:`S_{m,k}` is calculated:
.. math::
S_{m,k} = \text{Metric}_m(y_{true}, \hat{y}_k)
If `train_times` are provided, they are treated as an additional
metric axis.
**Scaling:**
If `scale` is specified, scaling is applied column-wise (per metric)
across all models before plotting:
- Min-Max ('norm'):
.. math::
S'_{m,k} = \frac{S_{m,k} - \min_j(S_{m,j})}{\max_j(S_{m,j}) - \min_j(S_{m,j})}
- Standard ('std'):
.. math::
S'_{m,k} = \frac{S_{m,k} - \text{mean}_j(S_{m,j})}{\text{std}_j(S_{m,j})}
**Plotting:**
The (scaled) scores :math:`S'_{m,k}` for each model :math:`k`
determine the radial distance along the axis corresponding to
metric :math:`m`. Points are connected to form a polygon for
each model.
References
----------
.. [1] Wikipedia contributors. (2024). Radar chart. In Wikipedia,
The Free Encyclopedia. Retrieved April 14, 2025, from
https://en.wikipedia.org/wiki/Radar_chart
*(General reference for radar charts)*
.. [2] Kenny-Denecke, J. F., Hernandez-Amaro, A.,
Martin-Gorriz, M. L., & Castejon-Limos, P. (2024).
Lead-Time Prediction in Wind Tower Manufacturing: A Machine
Learning-Based Approach. *Mathematics*, 12(15), 2347.
https://doi.org/10.3390/math12152347
*(Example application using radar charts for ML comparison)*
Examples
--------
>>> from kdiagram.plot.comparison import plot_model_comparison
>>> import numpy as np
>>>
>>> # Example 1: Regression task
>>> y_true_reg = np.array([3, -0.5, 2, 7, 5])
>>> y_pred_r1 = np.array([2.5, 0.0, 2.1, 7.8, 5.2])
>>> y_pred_r2 = np.array([3.2, 0.2, 1.8, 6.5, 4.8])
>>> times = [0.1, 0.5] # Training times in seconds
>>> names = ['ModelLin', 'ModelTree']
>>> ax1 = plot_model_comparison(y_true_reg, y_pred_r1, y_pred_r2,
... train_times=times, names=names,
... metrics=['r2', 'mae', 'rmse'], # Specify metrics
... title="Regression Model Comparison",
... scale='norm') # Normalize for comparison
>>>
>>> # Example 2: Classification task (requires appropriate y_true/y_pred)
>>> y_true_clf = np.array([0, 1, 0, 1, 1, 0])
>>> y_pred_c1 = np.array([0, 1, 0, 1, 0, 0]) # Model 1 preds
>>> y_pred_c2 = np.array([0, 1, 1, 1, 1, 0]) # Model 2 preds
>>> ax2 = plot_model_comparison(y_true_clf, y_pred_c1, y_pred_c2,
... names=["LogReg", "SVM"],
... # Uses default classification metrics
... title="Classification Model Comparison",
... scale='norm')
"""
try:
# Remove NaN values and ensure consistency
y_true, *y_preds = drop_nan_in(y_true, *y_preds, error="raise")
# Validate y_true and each y_pred
temp_preds = []
for _, pred in enumerate(y_preds):
# Validate returns tuple, we need the second element
validated_pred = validate_yy(
y_true, pred, expected_type=None, flatten=True
)[1]
temp_preds.append(validated_pred)
y_preds = temp_preds
except Exception as e:
# Catch potential errors during validation/NaN drop
raise TypeError(f"Input validation failed: {e}") from e
n_models = len(y_preds)
if n_models == 0:
warnings.warn(
"No prediction arrays (*y_preds) provided.", stacklevel=2
)
return None # Cannot plot without predictions
# --- Handle Names ---
if names is None:
names = [f"Model_{i+1}" for i in range(n_models)]
else:
names = columns_manager(
list(names), empty_as_none=False
) # Ensure list
if len(names) < n_models:
names += [f"Model_{i+1}" for i in range(len(names), n_models)]
elif len(names) > n_models:
warnings.warn(
f"Received {len(names)} names for {n_models}"
f" models. Extra names ignored.",
UserWarning,
stacklevel=2,
)
names = names[:n_models]
# --- Handle Metrics ---
if metrics is None:
target_type = type_of_target(y_true)
if target_type in ["continuous", "continuous-multioutput"]:
# Default regression metrics
metrics = ["r2", "mae", "mape", "rmse"]
else:
# Default classification metrics
metrics = ["accuracy", "precision", "recall", "f1"]
if verbose >= 1:
print(
f"[INFO] Auto-selected metrics for target type "
f"'{target_type}': {metrics}"
)
metrics = is_iterable(metrics, exclude_string=True, transform=True)
metric_funcs = []
metric_names = []
error_metrics = [] # Track metrics needing sign inversion
for metric in metrics:
try:
if isinstance(metric, str):
# get_scorer returns a callable scorer object
scorer_func = get_scorer(metric)
metric_funcs.append(scorer_func)
metric_names.append(metric)
# Identify error metrics (lower is better) for potential scaling flip
if metric in [
"mae",
"mape",
"rmse",
"mse",
]: # Add others if needed
error_metrics.append(metric)
elif callable(metric):
metric_funcs.append(metric)
m_name = getattr(
metric, "__name__", f"func_{len(metric_names)}"
)
metric_names.append(m_name)
# Cannot easily determine if callable is error/score metric
else:
warnings.warn(
f"Ignoring invalid metric type: {type(metric)}",
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"Could not retrieve scorer for metric '{metric}': {e}",
stacklevel=2,
)
if not metric_funcs:
raise ValueError("No valid metrics found or specified.")
# --- Handle Train Times ---
train_time_vals = None
if train_times is not None:
if isinstance(
train_times, (int, float, np.number)
): # Handle single value
train_time_vals = np.array([float(train_times)] * n_models)
else:
train_times = np.asarray(train_times, dtype=float)
if train_times.ndim != 1 or len(train_times) != n_models:
raise ValueError(
f"train_times must be a single float or a list/array "
f"of length n_models ({n_models}). "
f"Got shape {train_times.shape}."
)
train_time_vals = train_times
metric_names.append("Train Time (s)") # Use clearer name
# Add a placeholder for calculation loop, will substitute later
metric_funcs.append("train_time_placeholder")
# --- Calculate Metric Results ---
results = np.zeros((n_models, len(metric_names)), dtype=float)
for i, y_pred in enumerate(y_preds):
for j, metric_func in enumerate(metric_funcs):
if metric_func == "train_time_placeholder":
results[i, j] = train_time_vals[i]
elif metric_func is not None:
try:
score = metric_func(y_true, y_pred)
results[i, j] = score
except Exception as e:
warnings.warn(
f"Could not compute metric "
f"'{metric_names[j]}' for model "
f"'{names[i]}': {e}. Setting to NaN.",
stacklevel=2,
)
results[i, j] = np.nan
else:
results[i, j] = (
np.nan
) # Should not happen if logic is correct
# --- Scale Results ---
# Make copy for scaling to preserve original results if needed later
results_scaled = results.copy()
# Handle potential NaNs before scaling
if np.isnan(results_scaled).any():
warnings.warn(
"NaN values found in metric results. Scaling might "
"be affected or rows/cols dropped depending on method.",
stacklevel=2,
)
# Option 1: Impute (e.g., with column mean) - complex
# Option 2: Use nan-aware numpy functions
# Let's use nan-aware functions
# Note: Some metrics are better when *lower* (MAE, RMSE, MAPE, train_time).
# For visualization where larger radius is better, we might invert these
# before scaling, or adjust the interpretation. Let's scale first.
if scale in ["norm", "min-max"]:
if verbose >= 1:
print("[INFO] Scaling metrics using Min-Max.")
min_vals = np.nanmin(results_scaled, axis=0)
max_vals = np.nanmax(results_scaled, axis=0)
range_vals = max_vals - min_vals
# Avoid division by zero for metrics with no variance
range_vals[range_vals < 1e-9] = 1.0
results_scaled = (results_scaled - min_vals) / range_vals
# Now, for error metrics, higher value (closer to 1) is WORSE.
# Invert them so higher value (closer to 1) is BETTER.
for j, name in enumerate(metric_names):
if name in error_metrics or name == "Train Time (s)":
results_scaled[:, j] = 1.0 - results_scaled[:, j]
# Scaled results are now in [0, 1], higher is better.
elif scale in ["std", "standard"]:
if verbose >= 1:
print("[INFO] Scaling metrics using Standard Scaler.")
mean_vals = np.nanmean(results_scaled, axis=0)
std_vals = np.nanstd(results_scaled, axis=0)
# Avoid division by zero
std_vals[std_vals < 1e-9] = 1.0
results_scaled = (results_scaled - mean_vals) / std_vals
# Std scaling preserves relative order but changes range.
# Lower errors become more negative. Higher scores become more positive.
# Maybe invert sign for error metrics?
for j, name in enumerate(metric_names):
if name in error_metrics or name == "Train Time (s)":
results_scaled[:, j] = -results_scaled[:, j]
# Now higher value means better performance (higher score or lower error)
# but range is not [0, 1]. We need to handle lower_bound.
# Replace any potential NaNs resulting from scaling (e.g., if all NaNs)
results_scaled = np.nan_to_num(results_scaled, nan=lower_bound)
# --- Plotting ---
fig = plt.figure(figsize=figsize or (8, 8)) # Default figsize here
ax = fig.add_subplot(111, polar=True)
# Angles for each metric axis
num_metrics = len(metric_names)
angles = np.linspace(0, 2 * np.pi, num_metrics, endpoint=False).tolist()
angles_closed = angles + angles[:1] # Repeat first angle to close plot
# Colors
if colors is None:
# Use a robust colormap like tab10 if available
try:
cmap_obj = get_cmap("tab10", default="tab10", failsafe="discrete")
plot_colors = [cmap_obj(i % 10) for i in range(n_models)]
except ValueError: # Fallback if tab10 not found (unlikely)
cmap_obj = get_cmap("viridis")
plot_colors = [cmap_obj(i / n_models) for i in range(n_models)]
else:
plot_colors = colors # Use user-provided list
# Plot each model
for i, row in enumerate(results_scaled):
values = np.concatenate((row, [row[0]])) # Close the polygon
color = plot_colors[i % len(plot_colors)] # Cycle colors if needed
ax.plot(
angles_closed,
values,
label=names[i],
color=color,
linewidth=1.5,
alpha=alpha,
)
ax.fill(angles_closed, values, color=color, alpha=0.1) # Lighter fill
# --- Configure Axes ---
ax.set_xticks(angles)
ax.set_xticklabels(metric_names)
# Adjust radial limits and labels
# If scaled to [0, 1], set limit slightly above 1
# If std scaled, auto-limit might be better, but respect lower_bound
if scale in ["norm", "min-max"]:
ax.set_ylim(bottom=lower_bound, top=1.05)
# Optional: Add radial ticks for [0, 1] scale
ax.set_yticks(np.linspace(lower_bound, 1, 5))
else: # Raw or std scaled
ax.set_ylim(bottom=lower_bound)
# Let matplotlib auto-determine upper limit and ticks
ax.tick_params(axis="y", labelsize=8) # Smaller radial labels
ax.tick_params(axis="x", pad=10) # Pad angular labels outwards
# Grid
set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props)
# Legend
if legend:
ax.legend(loc=loc, bbox_to_anchor=(1.25, 1.05)) # Adjust position
# Title
ax.set_title(title or "Model Performance Comparison", y=1.15, fontsize=14)
# --- Output ---
plt.tight_layout(pad=2.0) # Adjust layout
if savefig:
try:
plt.savefig(savefig, bbox_inches="tight", dpi=300)
print(f"Plot saved to {savefig}")
except Exception as e:
print(f"Error saving plot to {savefig}: {e}")
else:
try:
plt.show()
except Exception as e:
warnings.warn(
f"Could not display plot interactively ({e})."
f" Use savefig parameter.",
UserWarning,
stacklevel=2,
)
return ax
[docs]
@check_non_emptiness
@isdf
def plot_horizon_metrics(
df: pd.DataFrame,
qlow_cols: list[str],
qup_cols: list[str],
*,
q50_cols: list[str] | None = None,
xtick_labels: list[str] | None = None,
normalize_radius: bool = False,
show_value_labels: bool = True,
cbar_label: str | None = None,
r_label: str | None = None,
cmap: str = "coolwarm",
acov: str = "default",
title: str | None = None,
figsize: tuple[float, float] = (8, 8),
alpha: float = 0.85,
show_grid: bool = True,
grid_props: dict | None = None,
mask_angle: bool = False,
savefig: str | None = None,
dpi: int = 300,
cbar: bool = True,
):
# --- Input Validation ---
if len(qlow_cols) != len(qup_cols):
raise ValueError(
"Mismatch in length between `qlow_cols` "
f"({len(qlow_cols)}) and `qup_cols` ({len(qup_cols)})."
)
if q50_cols and len(qlow_cols) != len(q50_cols):
raise ValueError(
"Mismatch in length: `q50_cols` must match other "
"quantile column lists."
)
# --- Data Calculation ---
qlow_data = df[qlow_cols].values
qup_data = df[qup_cols].values
interval_widths = qup_data - qlow_data
# Radial values are the mean width for each category/row
radial_values = np.mean(interval_widths, axis=1)
if q50_cols:
color_vals = np.mean(df[q50_cols].values, axis=1)
else:
# Default color to the radial value if no q50 provided
color_vals = radial_values
if normalize_radius:
min_r, max_r = radial_values.min(), radial_values.max()
if (max_r - min_r) > 1e-9:
radial_values = (radial_values - min_r) / (max_r - min_r)
# --- Plot Setup ---
angular_map = {
"default": 2 * np.pi,
"half_circle": np.pi,
"quarter_circle": np.pi / 2,
"eighth_circle": np.pi / 4,
}
span = angular_map.get(acov.lower(), 2 * np.pi)
num_bars = len(df)
theta = np.linspace(0, span, num_bars, endpoint=False)
fig, ax = plt.subplots(
figsize=figsize, subplot_kw={"projection": "polar"}
)
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
ax.set_thetamin(0)
ax.set_thetamax(np.degrees(span))
# --- Plot Bars ---
norm = Normalize(vmin=color_vals.min(), vmax=color_vals.max())
cmap_obj = get_cmap(cmap, default="coolwarm")
colors = cmap_obj(norm(color_vals))
bar_width = (span / num_bars) * 0.9
ax.bar(
theta,
radial_values,
width=bar_width,
color=colors,
edgecolor="k",
alpha=alpha,
linewidth=0.5,
)
# --- Annotations and Labels (Now Controllable) ---
if show_value_labels:
for angle, radius in zip(theta, radial_values):
ax.text(
angle,
radius + 0.03 * radial_values.max(),
f"{radius:.2f}",
ha="center",
va="bottom",
fontsize=8,
)
if xtick_labels:
ax.set_xticks(theta)
ax.set_xticklabels(xtick_labels)
elif mask_angle:
ax.set_xticklabels([])
ax.set_yticklabels([]) # Hide radial value ticks
ax.set_title(title or "Polar Bar Comparison", fontsize=14)
if r_label:
ax.set_ylabel(r_label, fontsize=12, labelpad=20)
set_axis_grid(ax, show_grid, grid_props=grid_props)
if cbar:
sm = cm.ScalarMappable(cmap=cmap_obj, norm=norm)
cbar_obj = plt.colorbar(sm, ax=ax, pad=0.1, shrink=0.7)
cbar_obj.set_label(cbar_label or "Color Metric", fontsize=10)
# --- Output ---
plt.tight_layout()
if savefig:
plt.savefig(savefig, dpi=dpi, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return ax
plot_horizon_metrics.__doc__ = r"""
Plot a polar bar chart comparing metrics across different horizons.
This function visualizes a primary metric (typically **mean
interval width**) as the height of bars arranged in a circle.
Each bar represents a distinct category or forecast horizon. A
secondary metric (typically the **mean Q50 value**) can be encoded
as the color of the bars, providing a multi-faceted comparison.
Parameters
----------
df : pd.DataFrame
Input DataFrame where each **row** represents a distinct
horizon or category to be compared.
qlow_cols : list of str
List of column names containing lower quantile samples
(e.g., Q10) for each horizon.
qup_cols : list of str
List of column names containing upper quantile samples
(e.g., Q90). Must have the same length as ``qlow_cols``.
q50_cols : list of str, optional
List of column names for the median quantile (Q50). If
provided, the mean of these values determines the bar color.
If ``None``, bar color is determined by the bar height
(the mean interval width).
xtick_labels : list of str, optional
Custom labels for each bar on the angular axis. The length
must match the number of rows in ``df``. If ``None``, no
angular labels are shown.
normalize_radius : bool, default=False
If ``True``, the radial values (bar heights) are min-max
scaled to the range ``[0, 1]``.
show_value_labels : bool, default=True
If ``True``, display the numeric value of the radial metric
on top of each bar.
cbar_label : str, optional
Custom label for the color bar. If ``None``, a default
label is generated.
r_label : str, optional
Custom label for the radial axis.
cmap : str, default='coolwarm'
Matplotlib colormap name for coloring the bars.
acov : {'default', 'half_circle', 'quarter_circle', \
'eighth_circle'}, default='default'
Specifies the angular coverage of the plot: ``'default'``
(360°), ``'half_circle'`` (180°), etc.
title : str, optional
Title for the plot. If ``None``, a default title is used.
figsize : tuple of (float, float), default=(8, 8)
Figure size in inches.
alpha : float, default=0.85
Transparency level for the bars.
show_grid : bool, default=True
Toggle gridlines via the package helper ``set_axis_grid``.
grid_props : dict, optional
Keyword arguments passed to ``set_axis_grid`` for grid
customization.
mask_angle : bool, default=False
If ``True`` and ``xtick_labels`` is not provided, this will
hide any default angular tick labels.
savefig : str, optional
If provided, save the figure to this path; otherwise the
plot is shown interactively.
dpi : int, default=300
Resolution for the saved figure.
cbar : bool, default=True
If ``True``, display a color bar.
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object containing the polar bar plot.
Notes
-----
The plot summarizes metrics for :math:`N` horizons (rows)
using data from :math:`M` samples (columns). Let
:math:`\mathbf{L}`, :math:`\mathbf{U}`, and :math:`\mathbf{Q50}`
be data matrices of shape :math:`(N, M)` extracted from the
corresponding columns.
1. **Interval Width Calculation**: For each horizon :math:`j`
and sample :math:`i`, the interval width is:
.. math::
W_{j,i} = U_{j,i} - L_{j,i}
2. **Radial Value (Bar Height)**: The radial value :math:`r_j`
for horizon :math:`j` is the mean interval width across
all :math:`M` samples.
.. math::
r_j = \frac{1}{M} \sum_{i=0}^{M-1} W_{j,i}
3. **Color Value**: The color value :math:`c_j` for horizon
:math:`j` is determined by the mean of the ``q50_cols`` values.
.. math::
c_j = \frac{1}{M} \sum_{i=0}^{M-1} Q50_{j,i}
If ``q50_cols`` is not provided, the color defaults to the
radial value, :math:`c_j = r_j`.
4. **Angular Position**: Horizons are spaced evenly around the
circle. For horizon :math:`j`, the angle is:
.. math::
\theta_j = \frac{j}{N} \times S
where :math:`S` is the angular span from ``acov``. The plot
starts at the top (12 o'clock) and proceeds clockwise.
Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from kdiagram.plot import plot_horizon_metrics
>>>
>>> # Create synthetic data for 6 horizons with 2 samples each
>>> horizons = ["H+1", "H+2", "H+3", "H+4", "H+5", "H+6"]
>>> df = pd.DataFrame({
... 'q10_s1': [1, 2, 3, 4, 5, 6],
... 'q10_s2': [1.2, 2.3, 3.4, 4.5, 5.6, 6.7],
... 'q90_s1': [3, 4, 5.5, 7, 8, 9.5],
... 'q90_s2': [3.1, 4.2, 5.7, 7.3, 8.4, 9.9],
... 'q50_s1': [2, 3, 4.2, 5.7, 6.5, 8.2],
... 'q50_s2': [2.1, 3.2, 4.4, 5.9, 6.9, 8.8],
... })
>>>
>>> q10_cols = ['q10_s1', 'q10_s2']
>>> q90_cols = ['q90_s1', 'q90_s2']
>>> q50_cols = ['q50_s1', 'q50_s2']
>>>
>>> ax = plot_horizon_metrics(
... df=df,
... qlow_cols=q10_cols,
... qup_cols=q90_cols,
... q50_cols=q50_cols,
... title="Mean Interval Width Across Horizons",
... xtick_labels=horizons,
... show_value_labels=True,
... r_label="Mean Interval Width (Q90-Q10)",
... cbar_label="Mean Q50 Value",
... acov="default"
... )
"""