# License: Apache 2.0
# Author: LKouadio <etanoyau@gmail.com>
"""Model comparison plots."""
from __future__ import annotations
import warnings
from numbers import Real
from typing import Any, Callable, Literal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import cm
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from ..compat.matplotlib import get_cmap
from ..compat.sklearn import StrOptions, type_of_target, validate_params
from ..decorators import check_non_emptiness, isdf
from ..utils.generic_utils import drop_nan_in
from ..utils.handlers import columns_manager
from ..utils.metric_utils import get_scorer
from ..utils.plot import set_axis_grid
from ..utils.validator import _assert_all_types, is_iterable, validate_yy
__all__ = [
"plot_reliability_diagram",
"plot_model_comparison",
"plot_horizon_metrics",
"plot_polar_reliability",
]
[docs]
@validate_params(
{
"y_true": ["array-like"],
"strategy": [StrOptions({"uniform", "quantile"})],
"error_bars": [StrOptions({"wilson", "normal", "none"})],
"counts_panel": [StrOptions({"none", "bottom"})],
"counts_norm": [StrOptions({"fraction", "count"})],
}
)
def plot_reliability_diagram(
y_true,
*y_preds,
names: list[str] | None = None,
sample_weight: list[float] | np.ndarray | None = None,
n_bins: int = 10,
strategy: str = "uniform",
positive_label: int | float | str = 1,
class_index: int | None = None,
clip_probs: tuple[float, float] = (0.0, 1.0),
normalize_probs: bool = True,
error_bars: str = "wilson",
conf_level: float = 0.95,
show_diagonal: bool = True,
diagonal_kwargs: dict[str, Any] | None = None,
show_ece: bool = True,
show_brier: bool = True,
counts_panel: str = "bottom",
counts_norm: Literal["fraction", "count"] = "fraction",
counts_alpha: float = 0.35,
figsize: tuple[float, float] | None = (9, 7),
title: str | None = None,
xlabel: str | None = "Predicted probability",
ylabel: str | None = "Observed frequency",
cmap: str = "tab10",
color_palette: list[Any] | None = None,
marker: str = "o",
s: int = 40,
linewidth: float = 2.0,
alpha: float = 0.9,
connect: bool = True,
legend: bool = True,
legend_loc: str = "best",
show_grid: bool = True,
grid_props: dict | None = None,
xlim: tuple[float, float] = (0.0, 1.0),
ylim: tuple[float, float] = (0.0, 1.0),
savefig: str | None = None,
return_data: bool = False,
**kw, # for future extension
):
# -------------- input handling -------------- #
if len(y_preds) == 0:
raise ValueError(
"Provide at least one prediction array via *y_preds."
)
names = columns_manager(names, to_string=True) or []
if len(names) < len(y_preds):
names.extend(
[f"Model_{i+1}" for i in range(len(names), len(y_preds))]
)
if len(names) > len(y_preds):
warnings.warn(
(
f"Received {len(names)} names for {len(y_preds)} models. "
"Extra names ignored."
),
UserWarning,
stacklevel=2,
)
names = names[: len(y_preds)]
y_true = np.asarray(y_true)
if type_of_target(y_true) not in ("binary", "multiclass"):
raise ValueError(
"y_true must be a classification target. "
"Binary reliability is expected."
)
y_bin = (y_true == positive_label).astype(int)
prob_list: list[np.ndarray] = []
for arr in y_preds:
arr = np.asarray(arr)
prob_list.append(_to_prob_vector(arr, class_index))
if sample_weight is None:
y_bin, *prob_list = drop_nan_in(y_bin, *prob_list, error="raise")
w = np.ones_like(y_bin, dtype=float)
else:
w = np.asarray(sample_weight, dtype=float)
y_bin, *prob_list, w = drop_nan_in(
y_bin, *prob_list, w, error="raise"
)
clip_lo, clip_hi = clip_probs
clipped_flag = False
new_probs = []
for p in prob_list:
p0 = p.copy()
p1 = _prep_probs(p0, clip_lo, clip_hi, normalize_probs)
if not np.allclose(p0, p1):
clipped_flag = True
new_probs.append(p1)
prob_list = new_probs
if clipped_flag:
warnings.warn(
(
"Some predicted probabilities were normalized/clipped "
f"to [{clip_lo}, {clip_hi}]."
),
UserWarning,
stacklevel=2,
)
edges, centers = _build_bins(
prob_list, n_bins, strategy, clip_lo, clip_hi
)
z = _z_from_conf(conf_level)
# -------------- colors & layout -------------- #
colors = _colors(cmap, color_palette, len(prob_list))
if counts_panel == "bottom":
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 1, height_ratios=(3.0, 1.0), hspace=0.12)
ax = fig.add_subplot(gs[0, 0])
axb = fig.add_subplot(gs[1, 0], sharex=ax)
else:
fig, ax = plt.subplots(figsize=figsize)
axb = None
# -------------- compute & plot -------------- #
per_model: dict[str, pd.DataFrame] = {}
for i, (name, p, col) in enumerate(zip(names, prob_list, colors)):
stats = _bin_stats(p, y_bin, w, edges, error_bars, z)
ece = float(np.nansum(stats["ece"]))
br = _brier(p, y_bin, w)
df = pd.DataFrame(
{
"bin_left": edges[:-1],
"bin_right": edges[1:],
"bin_center": centers,
"n": stats["n"],
"w_sum": stats["wsum"],
"p_mean": stats["pmean"],
"y_rate": stats["yrate"],
"y_low": stats["ylo"],
"y_high": stats["yhi"],
"ece_contrib": stats["ece"],
}
)
per_model[name] = df
valid = df["w_sum"].to_numpy() > 0
x = df.loc[valid, "p_mean"].to_numpy()
y = df.loc[valid, "y_rate"].to_numpy()
ylo = df.loc[valid, "y_low"].to_numpy()
yhi = df.loc[valid, "y_high"].to_numpy()
if error_bars.lower() != "none":
yerr = np.vstack([y - ylo, yhi - y])
ax.errorbar(
x,
y,
yerr=yerr,
fmt="none",
ecolor=col,
elinewidth=1.0,
capsize=2,
alpha=alpha * 0.85,
)
ax.scatter(x, y, c=[col], s=s, marker=marker, alpha=alpha)
if connect and len(x) > 1:
ax.plot(x, y, color=col, linewidth=linewidth, alpha=alpha)
label = name
pieces = []
if show_ece:
pieces.append(f"ECE={ece:.3f}")
if show_brier:
pieces.append(f"Brier={br:.3f}")
if pieces:
label = f"{label} ({', '.join(pieces)})"
ax.plot(
[],
[],
color=col,
marker=marker,
linestyle="-" if connect else "None",
linewidth=linewidth,
label=label,
)
if axb is not None:
bw = edges[1:] - edges[:-1]
slot = bw * 0.8 / max(1, len(prob_list))
left = edges[:-1] + i * slot
vals = df["w_sum"].to_numpy()
if counts_norm == "fraction":
denom = vals.sum() if vals.sum() > 0 else 1.0
vals = vals / denom
axb.bar(
left,
vals,
width=slot,
align="edge",
color=col,
alpha=counts_alpha,
label=name,
)
# -------------- format axes -------------- #
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
if title:
ax.set_title(title)
if show_diagonal:
diag_kw = {
"color": "gray",
"linestyle": "--",
"linewidth": 1.2,
"alpha": 0.9,
}
if diagonal_kwargs:
_assert_all_types(
diagonal_kwargs, dict, objname="'diagonal_kwargs'"
)
diag_kw.update(diagonal_kwargs)
ax.plot((0.0, 1.0), (0.0, 1.0), **diag_kw)
set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props)
if legend:
ax.legend(loc=legend_loc)
if axb is not None:
axb.set_xlim(*xlim)
axb.axhline(0, color="gray", lw=0.8)
axb.set_xlabel(xlabel or "Predicted probability")
axb.set_ylabel("Frac." if counts_norm == "fraction" else "Count")
set_axis_grid(axb, show_grid=True, grid_props={"alpha": 0.25})
handles, labels = axb.get_legend_handles_labels()
if handles and labels:
axb.legend(loc="upper right", fontsize=8)
plt.tight_layout()
if savefig:
try:
fig.savefig(savefig, bbox_inches="tight", dpi=300)
print(f"Plot saved to {savefig}")
except Exception as e:
print(f"Error saving plot to {savefig}: {e}")
else:
plt.show()
if return_data:
return ax, per_model
return ax
# ------------------ helpers ------------------ #
def _z_from_conf(cf: float) -> float:
table = {
0.80: 1.2815515655,
0.90: 1.6448536269,
0.95: 1.9599639845,
0.975: 2.241402728,
0.99: 2.5758293035,
}
return table.get(round(cf, 3), 1.9599639845)
def _to_prob_vector(arr: np.ndarray, ci: int | None) -> np.ndarray:
if arr.ndim == 1:
return arr.astype(float, copy=False)
if arr.ndim == 2:
idx = arr.shape[1] - 1 if ci is None else ci
if idx < 0 or idx >= arr.shape[1]:
raise ValueError(
"class_index out of bounds for 2D predictions: "
f"{idx} not in [0, {arr.shape[1]-1}]"
)
return arr[:, idx].astype(float, copy=False)
raise ValueError(
"Predictions must be 1D probabilities or "
"(n_samples, n_classes) arrays."
)
def _prep_probs(
p: np.ndarray, clip_lo: float, clip_hi: float, do_norm: bool
) -> np.ndarray:
p = np.asarray(p, dtype=float)
if do_norm:
pmin, pmax = np.nanmin(p), np.nanmax(p)
if (pmin < -1e-9) or (pmax > 1.0 + 1e-9):
rng = pmax - pmin
if rng > 1e-12:
p = (p - pmin) / rng
p = np.clip(p, clip_lo, clip_hi)
return p
def _build_bins(
probs_list: list[np.ndarray], nb: int, strat: str, low: float, high: float
) -> tuple[np.ndarray, np.ndarray]:
if strat == "uniform":
edges = np.linspace(low, high, nb + 1)
else:
allp = np.concatenate(probs_list)
q = np.linspace(0.0, 1.0, nb + 1)
edges = np.quantile(allp, q)
edges = np.unique(edges)
if len(edges) - 1 < nb:
warnings.warn(
(
"Not enough unique quantile edges; "
"falling back to uniform bins."
),
UserWarning,
stacklevel=2,
)
edges = np.linspace(low, high, nb + 1)
centers = 0.5 * (edges[:-1] + edges[1:])
return edges, centers
def _bin_stats(
p: np.ndarray,
y: np.ndarray,
w: np.ndarray,
edges: np.ndarray,
ebars: str,
zval: float,
) -> dict[str, np.ndarray]:
nb = len(edges) - 1
eps = 1e-12
idx = np.digitize(p, edges, right=False) - 1
idx[idx < 0] = 0
idx[idx >= nb] = nb - 1
n = np.zeros(nb, dtype=float)
wsum = np.zeros(nb, dtype=float)
pmean = np.zeros(nb, dtype=float)
yr = np.zeros(nb, dtype=float)
for b in range(nb):
m = idx == b
if not np.any(m):
continue
ww = w[m]
pp = p[m]
yy = y[m]
wsum[b] = ww.sum()
n[b] = float(m.sum())
denom = max(wsum[b], eps)
pmean[b] = float(np.dot(ww, pp) / denom)
yr[b] = float(np.dot(ww, yy) / denom)
if ebars.lower() == "none":
ylo = np.full_like(yr, np.nan)
yhi = np.full_like(yr, np.nan)
elif ebars.lower() == "normal":
neff = np.maximum(wsum, eps)
se = np.sqrt(np.clip(yr * (1.0 - yr) / neff, 0.0, 1.0))
ylo = np.clip(yr - zval * se, 0.0, 1.0)
yhi = np.clip(yr + zval * se, 0.0, 1.0)
else:
neff = np.maximum(wsum, eps)
ylo = np.empty_like(yr)
yhi = np.empty_like(yr)
for i in range(nb):
ph = yr[i]
n_ = neff[i]
if n_ <= eps:
ylo[i] = np.nan
yhi[i] = np.nan
continue
denom = 1.0 + (zval**2) / n_
center = (ph + (zval**2) / (2.0 * n_)) / denom
rad = (
zval
* np.sqrt((ph * (1.0 - ph) + (zval**2) / (4.0 * n_)) / n_)
) / denom
ylo[i] = np.clip(center - rad, 0.0, 1.0)
yhi[i] = np.clip(center + rad, 0.0, 1.0)
totw = max(w.sum(), eps)
wbin = wsum / totw
ece_contrib = wbin * np.abs(yr - pmean)
return {
"n": n,
"wsum": wsum,
"pmean": pmean,
"yrate": yr,
"ylo": ylo,
"yhi": yhi,
"wbin": wbin,
"ece": ece_contrib,
}
def _brier(p: np.ndarray, y: np.ndarray, w: np.ndarray) -> float:
return float(np.average((p - y) ** 2, weights=w))
def _colors(cmap_name: str, palette: list[Any] | None, k: int) -> list[Any]:
if palette is not None:
return [palette[i % len(palette)] for i in range(k)]
try:
cmo = get_cmap(cmap_name, default="tab10", failsafe="discrete")
except ValueError:
warnings.warn(
f"Invalid cmap '{cmap_name}'. Using 'tab10' instead.",
UserWarning,
stacklevel=2,
)
cmo = get_cmap("tab10", default="tab10", failsafe="discrete")
if hasattr(cmo, "colors") and len(cmo.colors) >= k:
return list(cmo.colors[:k])
if k == 1:
return [cmo(0.5)]
return [cmo(i / (k - 1)) for i in range(k)]
plot_reliability_diagram.__doc__ = r"""
Plot a reliability diagram (calibration plot) for one or more
classification models.
This compares **predicted probabilities** to **observed
frequencies** across bins of predicted probability. Perfect
calibration lies on the diagonal :math:`y=x`.
Parameters
----------
y_true : array-like of shape (n_samples,)
Ground truth labels. For binary calibration, values are
compared to ``positive_label`` after validation and
flattening.
*y_preds : array-like(s)
One or more model predictions. Each item may be:
- 1D array of positive-class probabilities in ``[0, 1]``.
- 2D array of shape ``(n_samples, n_classes)``; use
``class_index`` to select a column. If omitted, the
last column is used.
names : list of str, optional
Labels for each model curve. If fewer names are provided
than models, placeholders like ``'Model_1'`` are appended.
sample_weight : array-like of shape (n_samples,), optional
Per-sample weights used for observed frequencies, ECE,
and Brier score. If ``None``, equal weights are used.
n_bins : int, default=10
Number of probability bins.
strategy : {'uniform', 'quantile'}, default='uniform'
Binning strategy.
- ``'uniform'``: equally spaced edges in ``[0, 1]``.
- ``'quantile'``: edges are empirical quantiles of the
pooled predictions. If edges are not unique, the method
falls back to uniform binning with a warning.
positive_label : int or float or str, default=1
Label in ``y_true`` treated as the positive class when
constructing the binary target.
class_index : int, optional
Column index to pick from 2D probability arrays. If
omitted, the last column is used.
clip_probs : tuple of (float, float), default=(0.0, 1.0)
Inclusive clipping range applied to predictions. A warning
is issued if clipping occurs.
normalize_probs : bool, default=True
If ``True``, attempts to linearly rescale predictions into
``[0, 1]`` when minor out-of-range values are detected,
then applies clipping.
error_bars : {'wilson', 'normal', 'none'}, default='wilson'
Per-bin uncertainty for observed frequencies.
- ``'wilson'``: Wilson interval using ``conf_level``.
- ``'normal'``: normal approximation.
- ``'none'``: no error bars.
conf_level : float, default=0.95
Confidence level used for error bars when applicable.
show_diagonal : bool, default=True
Draw the reference diagonal :math:`y=x`.
diagonal_kwargs : dict, optional
Matplotlib keyword arguments for the diagonal reference
line (e.g., ``linestyle``, ``color``).
show_ece : bool, default=True
Compute Expected Calibration Error (ECE) and append a
summary to each model label.
show_brier : bool, default=True
Compute (weighted) Brier score and append a summary to
each model label.
counts_panel : {'none', 'bottom'}, default='bottom'
If not ``'none'``, draw a compact histogram below the main
panel that shows per-bin totals for each model.
counts_norm : {'fraction', 'count'}, default='fraction'
Normalization for the counts panel. ``'fraction'`` divides
by the total weight; ``'count'`` shows raw weighted sums.
counts_alpha : float, default=0.35
Alpha for bars in the counts panel.
figsize : tuple of (float, float), default=(9, 7)
Figure size for the layout. When ``counts_panel='bottom'``,
a two-row gridspec is used.
title : str, optional
Title for the plot. If ``None``, no title is set.
xlabel : str, optional
Label for the x-axis. Defaults to
``'Predicted probability'``.
ylabel : str, optional
Label for the y-axis. Defaults to
``'Observed frequency'``.
cmap : str, default='tab10'
Matplotlib colormap name used to generate model colors.
color_palette : list, optional
Explicit list of colors. When provided, colors are cycled
from this list instead of the colormap.
marker : str, default='o'
Marker used for the bin points.
s : int, default=40
Marker size for the bin points.
linewidth : float, default=2.0
Line width used when connecting bin points.
alpha : float, default=0.9
Alpha for points and lines in the main panel.
connect : bool, default=True
Connect bin points with a line for each model.
legend : bool, default=True
Display a legend. Summary metrics (ECE, Brier) are shown
next to model names when enabled.
legend_loc : str, default='best'
Legend location passed to Matplotlib.
show_grid : bool, default=True
Toggle gridlines via the package helper ``set_axis_grid``.
grid_props : dict, optional
Keyword arguments passed to ``set_axis_grid`` for grid
customization (e.g., ``linestyle``, ``alpha``).
xlim : tuple of (float, float), default=(0.0, 1.0)
X-axis limits.
ylim : tuple of (float, float), default=(0.0, 1.0)
Y-axis limits.
savefig : str, optional
If provided, save the figure to this path; otherwise the
plot is shown interactively.
return_data : bool, default=False
If ``True``, return ``(ax, data_dict)`` where values are
per-model ``pandas.DataFrame`` objects with per-bin stats:
``['bin_left', 'bin_right', 'bin_center', 'n', 'w_sum',
'p_mean', 'y_rate', 'y_low', 'y_high', 'ece_contrib']``.
Otherwise, return only the Matplotlib axes.
Returns
-------
ax : matplotlib.axes.Axes
Axes of the main calibration plot. When
``counts_panel='bottom'``, the second axes (counts panel)
is not returned.
Notes
-----
Calibration compares *confidence* to *accuracy* within bins.
For bin :math:`b`, let :math:`\hat{p}_i` be predictions and
:math:`y_i\in\{0,1\}` be binary targets with weights
:math:`w_i\ge 0`. Define the weighted bin mean probability
and accuracy as
.. math::
\bar{p}_b \;=\;
\frac{\sum_{i\in b} w_i \hat{p}_i}
{\sum_{i\in b} w_i},
\qquad
\bar{y}_b \;=\;
\frac{\sum_{i\in b} w_i y_i}
{\sum_{i\in b} w_i}.
The Expected Calibration Error (ECE) is
.. math::
\mathrm{ECE}
\;=\;
\sum_b
\left(
\frac{\sum_{i\in b} w_i}{\sum_i w_i}
\right)
\left|
\bar{y}_b - \bar{p}_b
\right|.
The (weighted) Brier score is
.. math::
\mathrm{Brier}
\;=\;
\frac{\sum_i
w_i \left(\hat{p}_i - y_i\right)^2}
{\sum_i w_i}.
Wilson confidence intervals for :math:`\bar{y}_b` use
:math:`z = \Phi^{-1}\!\left(\tfrac{1+\alpha}{2}\right)` and
effective count :math:`n_b=\sum_{i\in b} w_i`:
.. math::
\mathrm{center}
\;=\;
\frac{\bar{y}_b + \frac{z^2}{2 n_b}}
{1 + \frac{z^2}{n_b}},
\qquad
\mathrm{radius}
\;=\;
\frac{z}{1 + \frac{z^2}{n_b}}
\sqrt{\frac{\bar{y}_b(1-\bar{y}_b)}{n_b}
+ \frac{z^2}{4 n_b^2}}.
The interval is
:math:`[\mathrm{center}-\mathrm{radius},
\mathrm{center}+\mathrm{radius}]`,
clipped to ``[0, 1]``. The normal interval replaces the term
with the usual standard error
:math:`\sqrt{\bar{y}_b(1-\bar{y}_b)/n_b}`.
When ``strategy='quantile'``, bin edges are the empirical
quantiles of the pooled predictions. If many identical values
exist, edges can collapse; in that case, the function falls
back to uniform edges with a warning.
Examples
--------
Binary example with quantile bins and Wilson intervals.
>>> import numpy as np
>>> from kdiagram.plot.comparison import \
... plot_reliability_diagram
>>> rng = np.random.default_rng(0)
>>> y = (rng.random(1000) < 0.4).astype(int)
>>> p1 = 0.4 * np.ones_like(y) + 0.15 * rng.random(len(y))
>>> p2 = 0.4 * np.ones_like(y) + 0.05 * rng.random(len(y))
>>> ax = plot_reliability_diagram(
... y, p1, p2,
... names=['Wide', 'Tight'],
... n_bins=12,
... strategy='quantile',
... error_bars='wilson',
... counts_panel='bottom',
... show_ece=True,
... show_brier=True,
... title=('Reliability Diagram '
... '(Quantile bins + Wilson CIs)'),
... )
"""
[docs]
@check_non_emptiness(params=["y_true", "y_preds"])
def plot_polar_reliability(
y_true: np.ndarray,
*y_preds: np.ndarray,
names: list[str] | None = None,
n_bins: int = 10,
strategy: str = "uniform",
title: str = "Polar Reliability Diagram",
figsize: tuple[float, float] = (8, 8),
cmap: str = "coolwarm",
show_grid: bool = True,
grid_props: dict[str, Any] | None = None,
show_cbar: bool = True,
mask_radius: bool = False,
savefig: str | None = None,
dpi: int = 300,
):
if not y_preds:
raise ValueError("At least one prediction array must be provided.")
if not names:
names = [f"Model {i+1}" for i in range(len(y_preds))]
y_true = np.asarray(y_true)
prob_list = [_to_prob_vector(p, ci=None) for p in y_preds]
weights = np.ones_like(y_true, dtype=float)
edges, _ = _build_bins(prob_list, n_bins, strategy, 0.0, 1.0)
# Use the robust _colors helper for consistency
colors = _colors(cmap, palette=None, k=len(y_preds))
fig, ax = plt.subplots(
figsize=figsize, subplot_kw={"projection": "polar"}
)
# --- Plot Perfect Calibration Spiral ---
perfect_theta = np.linspace(0, np.pi / 2, 100)
perfect_radius = np.linspace(0, 1, 100)
ax.plot(
perfect_theta,
perfect_radius,
color="black",
linestyle="--",
lw=1.5,
label="Perfect Calibration",
)
# Use a single LineCollection for the colorbar reference
line_collection_for_cbar = None
# --- Plot Model Spirals with Diagnostic Coloring ---
for i, (name, p) in enumerate(zip(names, prob_list)):
stats = _bin_stats(p, y_true, weights, edges, "none", 0)
df = pd.DataFrame(
{
"p_mean": stats["pmean"],
"y_rate": stats["yrate"],
}
).dropna()
model_theta = df["p_mean"] * (np.pi / 2)
model_radius = df["y_rate"]
# Calculate the error (deviation from perfect calibration)
calibration_error = model_radius - df["p_mean"]
# Create line segments for coloring
points = np.array([model_theta, model_radius]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
# Normalize the error for the colormap (- is over-confident, + is under-confident)
norm = Normalize(vmin=-0.5, vmax=0.5)
lc = LineCollection(segments, cmap=cmap, norm=norm)
lc.set_array(calibration_error)
lc.set_linewidth(3)
line = ax.add_collection(lc)
if i == 0: # Use the first line for the colorbar
line_collection_for_cbar = line
# Add a dummy line for the legend
ax.plot([], [], color=get_cmap(cmap)(0.5), lw=3, label=name)
ax.fill_between(
perfect_theta,
np.interp(
perfect_theta, model_theta, model_radius, left=0, right=1
),
perfect_radius,
color=colors[i],
alpha=0.15,
)
# --- Formatting ---
ax.set_title(title, fontsize=16, y=1.1)
ax.set_thetamin(0)
ax.set_thetamax(90)
ax.set_ylim(0, 1.05)
ax.set_xticks(np.linspace(0, np.pi / 2, 6))
ax.set_xticklabels([f"{val:.1f}" for val in np.linspace(0, 1, 6)])
# Add padding to the x-axis label to prevent overlap
ax.set_xlabel("Predicted Probability", labelpad=15)
ax.set_ylabel("Observed Frequency", labelpad=25)
# Adjust legend position to avoid overlap
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.15))
set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props)
# Add a colorbar to explain the diagnostic colors
# Conditionally add colorbar and move it to the bottom
if show_cbar and line_collection_for_cbar:
cbar = fig.colorbar(
line_collection_for_cbar,
ax=ax,
orientation="horizontal", # Horizontal orientation
shrink=0.75,
pad=0.08, # Adjust padding
)
cbar.set_label(
"Calibration Error (Observed - Predicted)", fontsize=10
)
# cbar = fig.colorbar(line, ax=ax, pad=0.1, shrink=0.75)
# cbar.set_label("Calibration Error (Observed - Predicted)", fontsize=10)
if mask_radius:
ax.set_yticklabels([])
plt.tight_layout()
if savefig:
plt.savefig(savefig, dpi=dpi, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return ax
plot_polar_reliability.__doc__ = r"""
Plot a Polar Reliability Diagram (Calibration Spiral).
This function provides a novel visualization of model calibration by
mapping the traditional reliability diagram onto a polar coordinate
system :footcite:p:`kouadiob2025`. It compares **predicted
probabilities** (mapped to the angle) to **observed frequencies**
(mapped to the radius).
Perfect calibration is represented by a perfect Archimedean spiral.
The plot uses a diverging colormap to diagnostically color the
model's spiral, immediately revealing regions of over- or
under-confidence.
Parameters
----------
y_true : np.ndarray
1D array of true binary labels (0 or 1).
*y_preds : np.ndarray
One or more 1D arrays of predicted probabilities for each model.
names : list of str, optional
Display names for each of the models. If not provided, generic
names like ``'Model 1'`` will be generated.
n_bins : int, default=10
Number of bins to group predicted probabilities into for analysis.
strategy : {'uniform', 'quantile'}, default='uniform'
The strategy for creating bins:
- ``'uniform'``: Bins are of equal width across the [0, 1] range.
- ``'quantile'``: Bins are created based on the quantiles of the
predicted probabilities, ensuring each bin has a similar
number of samples.
title : str, default="Polar Reliability Diagram"
The title for the plot.
figsize : tuple of (float, float), default=(8, 8)
The figure size in inches.
cmap : str, default='coolwarm'
A diverging colormap used to color the model's spiral. The center
of the colormap represents perfect calibration, with one color for
over-confidence and another for under-confidence.
show_cbar : bool, default=True
If ``True``, display a color bar that explains the diagnostic
coloring of the calibration error.
show_grid : bool, default=True
Toggle the visibility of the polar grid lines.
grid_props : dict, optional
Custom keyword arguments passed to the grid for styling (e.g.,
``linestyle``, ``alpha``).
mask_radius : bool, default=False
If ``True``, hide the radial tick labels.
savefig : str, optional
The file path to save the plot. If ``None``, the plot is
displayed interactively.
dpi : int, default=300
The resolution (dots per inch) for the saved figure.
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object containing the polar reliability plot.
Notes
-----
This plot is a polar adaptation of the standard reliability diagram,
a key tool in forecast verification :footcite:p:`Jolliffe2012`.
1. **Binning**: Predicted probabilities :math:`p_i` are first
partitioned into :math:`K` bins. For each bin :math:`k`, the mean
predicted probability (:math:`\bar{p}_k`) and the mean observed
frequency (:math:`\bar{y}_k`) are calculated.
2. **Polar Mapping**: These values are then mapped to polar
coordinates:
.. math::
\theta_k &= \bar{p}_k \cdot \frac{\pi}{2} \\
r_k &= \bar{y}_k
The plot is constrained to a 90-degree quadrant where the angle
:math:`\theta` represents the predicted probability from 0 to 1,
and the radius :math:`r` represents the observed frequency from
0 to 1.
3. **Perfect Calibration**: A perfectly calibrated model, where
:math:`\bar{p}_k = \bar{y}_k` for all bins, will form a perfect
Archimedean spiral defined by :math:`r = \frac{2\theta}{\pi}`.
This is drawn as a dashed black reference line.
4. **Diagnostic Coloring**: The calibration error for each bin is
calculated as :math:`e_k = \bar{y}_k - \bar{p}_k`. The line
segments of the model's spiral are colored based on this error:
- :math:`e_k < 0`: The model is **over-confident** (observed
frequency is lower than predicted probability).
- :math:`e_k > 0`: The model is **under-confident** (observed
frequency is higher than predicted probability).
Examples
--------
>>> import numpy as np
>>> from kdiagram.plot.comparison import plot_polar_reliability
>>>
>>> # Generate synthetic data for two models
>>> np.random.seed(0)
>>> n_samples = 2000
>>> y_true = (np.random.rand(n_samples) < 0.4).astype(int)
>>> # A well-calibrated model
>>> calibrated_preds = np.clip(0.4 + np.random.normal(0, 0.15, n_samples), 0, 1)
>>> # An over-confident model
>>> overconfident_preds = np.clip(0.4 + np.random.normal(0, 0.3, n_samples), 0, 1)
>>>
>>> # Generate the plot
>>> ax = plot_polar_reliability(
... y_true,
... calibrated_preds,
... overconfident_preds,
... names=["Well-Calibrated", "Over-Confident"],
... n_bins=15,
... cmap='coolwarm'
... )
References
----------
.. footbibliography::
"""
[docs]
@validate_params(
{
"train_times": ["array-like", None],
"metrics": [str, "array-like", callable, None],
"scale": [
StrOptions(
{
"norm",
"min-max",
"std",
"standard",
}
),
None,
],
"lower_bound": [Real],
}
)
def plot_model_comparison(
y_true,
*y_preds,
train_times: float | list[float] | None = None,
metrics: str | Callable | list[str | Callable] | None = None,
names: list[str] | None = None,
title: str | None = None,
figsize: tuple[float, float] | None = None,
colors: list[Any] | None = None,
alpha: float = 0.7,
legend: bool = True,
show_grid: bool = True,
grid_props: dict = None,
scale: str | None = "norm",
lower_bound: float = 0,
savefig: str | None = None,
loc: str = "upper right",
verbose: int = 0,
):
r"""Plot multi-metric model performance comparison on a radar chart.
Generates a radar chart (spider chart) visualizing multiple
performance metrics for one or more models simultaneously. Each
axis corresponds to a metric (e.g., R2, MAE, accuracy,
precision), and each polygon represents a model, allowing for a
holistic comparison of their strengths and weaknesses across
different evaluation criteria [1]_.
This function is highly valuable for model selection, providing a
compact overview that goes beyond single-score comparisons. Use
it when you need to balance trade-offs between various metrics
(like accuracy vs. training time) or understand how different
models perform relative to each other across a spectrum of
relevant performance indicators. Internally relies on helpers
to handle potential NaN values and determine data types [2]_.
Parameters
----------
y_true : array-like of shape (n_samples,)
The ground truth (correct) target values.
*y_preds : array-like of shape (n_samples,)
Variable number of prediction arrays, one for each model to
be compared. Each array must have the same length as
`y_true`.
train_times : float or list of float, optional
Training time in seconds for each model corresponding to
`*y_preds`. If provided:
- A single float assumes the same time for all models.
- A list must match the number of models.
It will be added as an additional axis/metric on the chart.
Default is ``None``.
metrics : str, callable, list of these, optional
The performance metrics to calculate and plot. Default is
``None``, which triggers automatic metric selection based on
the target type inferred from `y_true`:
- **Regression:** Defaults to ``["r2", "mae", "mape", "rmse"]``.
- **Classification:** Defaults to ``["accuracy", "precision",
"recall"]``.
Can be provided as:
- A list of strings: Names of metrics known by scikit-learn
or gofast's `get_scorer` (e.g., ``['r2', 'rmse']``).
- A list of callables: Functions with the signature
`metric(y_true, y_pred)`.
- A mix of strings and callables.
names : list of str, optional
Names for each model corresponding to `*y_preds`. Used for
the legend. If ``None`` or too short, defaults like
"Model_1", "Model_2" are generated. Default is ``None``.
title : str, optional
Title displayed above the radar chart. If ``None``, a generic
title may be used internally or omitted. Default is ``None``.
figsize : tuple of (float, float), optional
Figure size ``(width, height)`` in inches. If ``None``, uses
Matplotlib's default (often similar to ``(8, 8)`` for this
type of plot).
colors : list of str or None, optional
List of Matplotlib color specifications for each model's
polygon. If ``None``, colors are automatically assigned from
the default palette ('tab10'). If provided, the list length
should ideally match `n_models`.
alpha : float, optional
Transparency level (between 0 and 1) for the plotted lines
and filled areas. Default is ``0.7``. (Note: Fill alpha is
often hardcoded lower, e.g., 0.1, in implementation).
legend : bool, optional
If ``True``, display a legend mapping colors/lines to model
names. Default is ``True``.
show_grid : bool, optional
If ``True``, display the radial grid lines on the chart.
Default is ``True``.
scale : {'norm', 'min-max', 'std', 'standard'}, optional
Method for scaling metric values before plotting. Scaling is
applied independently to each metric (axis) across models.
Default is ``'norm'``.
- ``'norm'`` or ``'min-max'``: Min-max scaling. Transforms
values to the range [0, 1] using
:math:`(X - min) / (max - min)`. Useful for comparing
relative performance when metrics have different scales.
- ``'std'`` or ``'standard'``: Standard scaling (Z-score).
Transforms values to have zero mean and unit variance using
:math:`(X - mean) / std`. Preserves relative spacing better
than min-max but results can be negative.
- ``None``: Plot raw metric values without scaling. Use only
if metrics naturally share a comparable, non-negative range.
lower_bound : float, optional
Sets the minimum value for the radial axis (innermost circle).
Useful when using standard scaling ('std') which can produce
negative values, or to adjust the plot's center.
Default is ``0``.
savefig : str, optional
If provided, the file path (e.g., 'radar_comparison.svg')
where the figure will be saved. If ``None``, the plot is
displayed interactively. Default is ``None``.
loc : str, optional
Location argument passed to `matplotlib.pyplot.legend()` to
position the legend (e.g., 'upper right', 'lower left',
'center right'). Default is ``'upper right'``.
verbose : int, optional
Controls the verbosity level. ``0`` is silent. Higher values
may print debugging information during metric calculation or
scaling. Default is ``0``.
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object containing the radar chart. Allows
for further customization after the function call.
Raises
------
ValueError
If lengths of `y_preds`, `names` (if provided), and
`train_times` (if provided) do not match. If an invalid
string is provided for `scale`. If a metric string name is
not recognized by the internal scorer.
TypeError
If `y_true` or `y_preds` contain non-numeric data.
See Also
--------
kdiagram.utils.metric_utils.get_scorer : Function likely used
internally to fetch metric callables (verify path).
sklearn.metrics : Scikit-learn metrics module.
matplotlib.pyplot.polar : Function for creating polar plots.
Notes
-----
This function provides a multi-dimensional view of model performance.
**Metric Calculation:**
For each model :math:`k` with predictions :math:`\hat{y}_k` and
each metric :math:`m` (from the `metrics` list), the score
:math:`S_{m,k}` is calculated:
.. math::
S_{m,k} = \text{Metric}_m(y_{true}, \hat{y}_k)
If `train_times` are provided, they are treated as an additional
metric axis.
**Scaling:**
If `scale` is specified, scaling is applied column-wise (per metric)
across all models before plotting:
- Min-Max ('norm'):
.. math::
S'_{m,k} = \frac{S_{m,k} - \min_j(S_{m,j})}{\max_j(S_{m,j}) - \min_j(S_{m,j})}
- Standard ('std'):
.. math::
S'_{m,k} = \frac{S_{m,k} - \text{mean}_j(S_{m,j})}{\text{std}_j(S_{m,j})}
**Plotting:**
The (scaled) scores :math:`S'_{m,k}` for each model :math:`k`
determine the radial distance along the axis corresponding to
metric :math:`m`. Points are connected to form a polygon for
each model.
References
----------
.. [1] Wikipedia contributors. (2024). Radar chart. In Wikipedia,
The Free Encyclopedia. Retrieved April 14, 2025, from
https://en.wikipedia.org/wiki/Radar_chart
*(General reference for radar charts)*
.. [2] Kenny-Denecke, J. F., Hernandez-Amaro, A.,
Martin-Gorriz, M. L., & Castejon-Limos, P. (2024).
Lead-Time Prediction in Wind Tower Manufacturing: A Machine
Learning-Based Approach. *Mathematics*, 12(15), 2347.
https://doi.org/10.3390/math12152347
*(Example application using radar charts for ML comparison)*
Examples
--------
>>> from kdiagram.plot.comparison import plot_model_comparison
>>> import numpy as np
>>>
>>> # Example 1: Regression task
>>> y_true_reg = np.array([3, -0.5, 2, 7, 5])
>>> y_pred_r1 = np.array([2.5, 0.0, 2.1, 7.8, 5.2])
>>> y_pred_r2 = np.array([3.2, 0.2, 1.8, 6.5, 4.8])
>>> times = [0.1, 0.5] # Training times in seconds
>>> names = ['ModelLin', 'ModelTree']
>>> ax1 = plot_model_comparison(y_true_reg, y_pred_r1, y_pred_r2,
... train_times=times, names=names,
... metrics=['r2', 'mae', 'rmse'], # Specify metrics
... title="Regression Model Comparison",
... scale='norm') # Normalize for comparison
>>>
>>> # Example 2: Classification task (requires appropriate y_true/y_pred)
>>> y_true_clf = np.array([0, 1, 0, 1, 1, 0])
>>> y_pred_c1 = np.array([0, 1, 0, 1, 0, 0]) # Model 1 preds
>>> y_pred_c2 = np.array([0, 1, 1, 1, 1, 0]) # Model 2 preds
>>> ax2 = plot_model_comparison(y_true_clf, y_pred_c1, y_pred_c2,
... names=["LogReg", "SVM"],
... # Uses default classification metrics
... title="Classification Model Comparison",
... scale='norm')
"""
try:
# Remove NaN values and ensure consistency
y_true, *y_preds = drop_nan_in(y_true, *y_preds, error="raise")
# Validate y_true and each y_pred
temp_preds = []
for _, pred in enumerate(y_preds):
# Validate returns tuple, we need the second element
validated_pred = validate_yy(
y_true, pred, expected_type=None, flatten=True
)[1]
temp_preds.append(validated_pred)
y_preds = temp_preds
except Exception as e:
# Catch potential errors during validation/NaN drop
raise TypeError(f"Input validation failed: {e}") from e
n_models = len(y_preds)
if n_models == 0:
warnings.warn(
"No prediction arrays (*y_preds) provided.", stacklevel=2
)
return None # Cannot plot without predictions
# --- Handle Names ---
if names is None:
names = [f"Model_{i+1}" for i in range(n_models)]
else:
names = columns_manager(
list(names), empty_as_none=False
) # Ensure list
if len(names) < n_models:
names += [f"Model_{i+1}" for i in range(len(names), n_models)]
elif len(names) > n_models:
warnings.warn(
f"Received {len(names)} names for {n_models}"
f" models. Extra names ignored.",
UserWarning,
stacklevel=2,
)
names = names[:n_models]
# --- Handle Metrics ---
if metrics is None:
target_type = type_of_target(y_true)
if target_type in ["continuous", "continuous-multioutput"]:
# Default regression metrics
metrics = ["r2", "mae", "mape", "rmse"]
else:
# Default classification metrics
metrics = ["accuracy", "precision", "recall", "f1"]
if verbose >= 1:
print(
f"[INFO] Auto-selected metrics for target type "
f"'{target_type}': {metrics}"
)
metrics = is_iterable(metrics, exclude_string=True, transform=True)
metric_funcs = []
metric_names = []
error_metrics = [] # Track metrics needing sign inversion
for metric in metrics:
try:
if isinstance(metric, str):
# get_scorer returns a callable scorer object
scorer_func = get_scorer(metric)
metric_funcs.append(scorer_func)
metric_names.append(metric)
# Identify error metrics (lower is better) for potential scaling flip
if metric in [
"mae",
"mape",
"rmse",
"mse",
]: # Add others if needed
error_metrics.append(metric)
elif callable(metric):
metric_funcs.append(metric)
m_name = getattr(
metric, "__name__", f"func_{len(metric_names)}"
)
metric_names.append(m_name)
# Cannot easily determine if callable is error/score metric
else:
warnings.warn(
f"Ignoring invalid metric type: {type(metric)}",
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"Could not retrieve scorer for metric '{metric}': {e}",
stacklevel=2,
)
if not metric_funcs:
raise ValueError("No valid metrics found or specified.")
# --- Handle Train Times ---
train_time_vals = None
if train_times is not None:
if isinstance(
train_times, (int, float, np.number)
): # Handle single value
train_time_vals = np.array([float(train_times)] * n_models)
else:
train_times = np.asarray(train_times, dtype=float)
if train_times.ndim != 1 or len(train_times) != n_models:
raise ValueError(
f"train_times must be a single float or a list/array "
f"of length n_models ({n_models}). "
f"Got shape {train_times.shape}."
)
train_time_vals = train_times
metric_names.append("Train Time (s)") # Use clearer name
# Add a placeholder for calculation loop, will substitute later
metric_funcs.append("train_time_placeholder")
# --- Calculate Metric Results ---
results = np.zeros((n_models, len(metric_names)), dtype=float)
for i, y_pred in enumerate(y_preds):
for j, metric_func in enumerate(metric_funcs):
if metric_func == "train_time_placeholder":
results[i, j] = train_time_vals[i]
elif metric_func is not None:
try:
score = metric_func(y_true, y_pred)
results[i, j] = score
except Exception as e:
warnings.warn(
f"Could not compute metric "
f"'{metric_names[j]}' for model "
f"'{names[i]}': {e}. Setting to NaN.",
stacklevel=2,
)
results[i, j] = np.nan
else:
results[i, j] = (
np.nan
) # Should not happen if logic is correct
# --- Scale Results ---
# Make copy for scaling to preserve original results if needed later
results_scaled = results.copy()
# Handle potential NaNs before scaling
if np.isnan(results_scaled).any():
warnings.warn(
"NaN values found in metric results. Scaling might "
"be affected or rows/cols dropped depending on method.",
stacklevel=2,
)
# Option 1: Impute (e.g., with column mean) - complex
# Option 2: Use nan-aware numpy functions
# Let's use nan-aware functions
# Note: Some metrics are better when *lower* (MAE, RMSE, MAPE, train_time).
# For visualization where larger radius is better, we might invert these
# before scaling, or adjust the interpretation. Let's scale first.
if scale in ["norm", "min-max"]:
if verbose >= 1:
print("[INFO] Scaling metrics using Min-Max.")
min_vals = np.nanmin(results_scaled, axis=0)
max_vals = np.nanmax(results_scaled, axis=0)
range_vals = max_vals - min_vals
# Avoid division by zero for metrics with no variance
range_vals[range_vals < 1e-9] = 1.0
results_scaled = (results_scaled - min_vals) / range_vals
# Now, for error metrics, higher value (closer to 1) is WORSE.
# Invert them so higher value (closer to 1) is BETTER.
for j, name in enumerate(metric_names):
if name in error_metrics or name == "Train Time (s)":
results_scaled[:, j] = 1.0 - results_scaled[:, j]
# Scaled results are now in [0, 1], higher is better.
elif scale in ["std", "standard"]:
if verbose >= 1:
print("[INFO] Scaling metrics using Standard Scaler.")
mean_vals = np.nanmean(results_scaled, axis=0)
std_vals = np.nanstd(results_scaled, axis=0)
# Avoid division by zero
std_vals[std_vals < 1e-9] = 1.0
results_scaled = (results_scaled - mean_vals) / std_vals
# Std scaling preserves relative order but changes range.
# Lower errors become more negative. Higher scores become more positive.
# Maybe invert sign for error metrics?
for j, name in enumerate(metric_names):
if name in error_metrics or name == "Train Time (s)":
results_scaled[:, j] = -results_scaled[:, j]
# Now higher value means better performance (higher score or lower error)
# but range is not [0, 1]. We need to handle lower_bound.
# Replace any potential NaNs resulting from scaling (e.g., if all NaNs)
results_scaled = np.nan_to_num(results_scaled, nan=lower_bound)
# --- Plotting ---
fig = plt.figure(figsize=figsize or (8, 8)) # Default figsize here
ax = fig.add_subplot(111, polar=True)
# Angles for each metric axis
num_metrics = len(metric_names)
angles = np.linspace(0, 2 * np.pi, num_metrics, endpoint=False).tolist()
angles_closed = angles + angles[:1] # Repeat first angle to close plot
# Colors
if colors is None:
# Use a robust colormap like tab10 if available
try:
cmap_obj = get_cmap("tab10", default="tab10", failsafe="discrete")
plot_colors = [cmap_obj(i % 10) for i in range(n_models)]
except ValueError: # Fallback if tab10 not found (unlikely)
cmap_obj = get_cmap("viridis")
plot_colors = [cmap_obj(i / n_models) for i in range(n_models)]
else:
plot_colors = colors # Use user-provided list
# Plot each model
for i, row in enumerate(results_scaled):
values = np.concatenate((row, [row[0]])) # Close the polygon
color = plot_colors[i % len(plot_colors)] # Cycle colors if needed
ax.plot(
angles_closed,
values,
label=names[i],
color=color,
linewidth=1.5,
alpha=alpha,
)
ax.fill(angles_closed, values, color=color, alpha=0.1) # Lighter fill
# --- Configure Axes ---
ax.set_xticks(angles)
ax.set_xticklabels(metric_names)
# Adjust radial limits and labels
# If scaled to [0, 1], set limit slightly above 1
# If std scaled, auto-limit might be better, but respect lower_bound
if scale in ["norm", "min-max"]:
ax.set_ylim(bottom=lower_bound, top=1.05)
# Optional: Add radial ticks for [0, 1] scale
ax.set_yticks(np.linspace(lower_bound, 1, 5))
else: # Raw or std scaled
ax.set_ylim(bottom=lower_bound)
# Let matplotlib auto-determine upper limit and ticks
ax.tick_params(axis="y", labelsize=8) # Smaller radial labels
ax.tick_params(axis="x", pad=10) # Pad angular labels outwards
# Grid
set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props)
# Legend
if legend:
ax.legend(loc=loc, bbox_to_anchor=(1.25, 1.05)) # Adjust position
# Title
ax.set_title(title or "Model Performance Comparison", y=1.15, fontsize=14)
# --- Output ---
plt.tight_layout(pad=2.0) # Adjust layout
if savefig:
try:
plt.savefig(savefig, bbox_inches="tight", dpi=300)
print(f"Plot saved to {savefig}")
except Exception as e:
print(f"Error saving plot to {savefig}: {e}")
else:
try:
plt.show()
except Exception as e:
warnings.warn(
f"Could not display plot interactively ({e})."
f" Use savefig parameter.",
UserWarning,
stacklevel=2,
)
return ax
[docs]
@check_non_emptiness
@isdf
def plot_horizon_metrics(
df: pd.DataFrame,
qlow_cols: list[str],
qup_cols: list[str],
*,
q50_cols: list[str] | None = None,
xtick_labels: list[str] | None = None,
normalize_radius: bool = False,
show_value_labels: bool = True,
cbar_label: str | None = None,
r_label: str | None = None,
cmap: str = "coolwarm",
acov: str = "default",
title: str | None = None,
figsize: tuple[float, float] = (8, 8),
alpha: float = 0.85,
show_grid: bool = True,
grid_props: dict | None = None,
mask_angle: bool = False,
savefig: str | None = None,
dpi: int = 300,
cbar: bool = True,
):
# --- Input Validation ---
if len(qlow_cols) != len(qup_cols):
raise ValueError(
"Mismatch in length between `qlow_cols` "
f"({len(qlow_cols)}) and `qup_cols` ({len(qup_cols)})."
)
if q50_cols and len(qlow_cols) != len(q50_cols):
raise ValueError(
"Mismatch in length: `q50_cols` must match other "
"quantile column lists."
)
# --- Data Calculation ---
qlow_data = df[qlow_cols].values
qup_data = df[qup_cols].values
interval_widths = qup_data - qlow_data
# Radial values are the mean width for each category/row
radial_values = np.mean(interval_widths, axis=1)
if q50_cols:
color_vals = np.mean(df[q50_cols].values, axis=1)
else:
# Default color to the radial value if no q50 provided
color_vals = radial_values
if normalize_radius:
min_r, max_r = radial_values.min(), radial_values.max()
if (max_r - min_r) > 1e-9:
radial_values = (radial_values - min_r) / (max_r - min_r)
# --- Plot Setup ---
angular_map = {
"default": 2 * np.pi,
"half_circle": np.pi,
"quarter_circle": np.pi / 2,
"eighth_circle": np.pi / 4,
}
span = angular_map.get(acov.lower(), 2 * np.pi)
num_bars = len(df)
theta = np.linspace(0, span, num_bars, endpoint=False)
fig, ax = plt.subplots(
figsize=figsize, subplot_kw={"projection": "polar"}
)
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
ax.set_thetamin(0)
ax.set_thetamax(np.degrees(span))
# --- Plot Bars ---
norm = Normalize(vmin=color_vals.min(), vmax=color_vals.max())
cmap_obj = get_cmap(cmap, default="coolwarm")
colors = cmap_obj(norm(color_vals))
bar_width = (span / num_bars) * 0.9
ax.bar(
theta,
radial_values,
width=bar_width,
color=colors,
edgecolor="k",
alpha=alpha,
linewidth=0.5,
)
# --- Annotations and Labels (Now Controllable) ---
if show_value_labels:
for angle, radius in zip(theta, radial_values):
ax.text(
angle,
radius + 0.03 * radial_values.max(),
f"{radius:.2f}",
ha="center",
va="bottom",
fontsize=8,
)
if xtick_labels:
ax.set_xticks(theta)
ax.set_xticklabels(xtick_labels)
elif mask_angle:
ax.set_xticklabels([])
ax.set_yticklabels([]) # Hide radial value ticks
ax.set_title(title or "Polar Bar Comparison", fontsize=14)
if r_label:
ax.set_ylabel(r_label, fontsize=12, labelpad=20)
set_axis_grid(ax, show_grid, grid_props=grid_props)
if cbar:
sm = cm.ScalarMappable(cmap=cmap_obj, norm=norm)
cbar_obj = plt.colorbar(sm, ax=ax, pad=0.1, shrink=0.7)
cbar_obj.set_label(cbar_label or "Color Metric", fontsize=10)
# --- Output ---
plt.tight_layout()
if savefig:
plt.savefig(savefig, dpi=dpi, bbox_inches="tight")
plt.close(fig)
else:
plt.show()
return ax
plot_horizon_metrics.__doc__ = r"""
Plot a polar bar chart comparing metrics across different horizons.
This function visualizes a primary metric (typically **mean
interval width**) as the height of bars arranged in a circle.
Each bar represents a distinct category or forecast horizon. A
secondary metric (typically the **mean Q50 value**) can be encoded
as the color of the bars, providing a multi-faceted comparison.
Parameters
----------
df : pd.DataFrame
Input DataFrame where each **row** represents a distinct
horizon or category to be compared.
qlow_cols : list of str
List of column names containing lower quantile samples
(e.g., Q10) for each horizon.
qup_cols : list of str
List of column names containing upper quantile samples
(e.g., Q90). Must have the same length as ``qlow_cols``.
q50_cols : list of str, optional
List of column names for the median quantile (Q50). If
provided, the mean of these values determines the bar color.
If ``None``, bar color is determined by the bar height
(the mean interval width).
xtick_labels : list of str, optional
Custom labels for each bar on the angular axis. The length
must match the number of rows in ``df``. If ``None``, no
angular labels are shown.
normalize_radius : bool, default=False
If ``True``, the radial values (bar heights) are min-max
scaled to the range ``[0, 1]``.
show_value_labels : bool, default=True
If ``True``, display the numeric value of the radial metric
on top of each bar.
cbar_label : str, optional
Custom label for the color bar. If ``None``, a default
label is generated.
r_label : str, optional
Custom label for the radial axis.
cmap : str, default='coolwarm'
Matplotlib colormap name for coloring the bars.
acov : {'default', 'half_circle', 'quarter_circle', \
'eighth_circle'}, default='default'
Specifies the angular coverage of the plot: ``'default'``
(360°), ``'half_circle'`` (180°), etc.
title : str, optional
Title for the plot. If ``None``, a default title is used.
figsize : tuple of (float, float), default=(8, 8)
Figure size in inches.
alpha : float, default=0.85
Transparency level for the bars.
show_grid : bool, default=True
Toggle gridlines via the package helper ``set_axis_grid``.
grid_props : dict, optional
Keyword arguments passed to ``set_axis_grid`` for grid
customization.
mask_angle : bool, default=False
If ``True`` and ``xtick_labels`` is not provided, this will
hide any default angular tick labels.
savefig : str, optional
If provided, save the figure to this path; otherwise the
plot is shown interactively.
dpi : int, default=300
Resolution for the saved figure.
cbar : bool, default=True
If ``True``, display a color bar.
Returns
-------
ax : matplotlib.axes.Axes
The Matplotlib Axes object containing the polar bar plot.
Notes
-----
The plot summarizes metrics for :math:`N` horizons (rows)
using data from :math:`M` samples (columns). Let
:math:`\mathbf{L}`, :math:`\mathbf{U}`, and :math:`\mathbf{Q50}`
be data matrices of shape :math:`(N, M)` extracted from the
corresponding columns.
1. **Interval Width Calculation**: For each horizon :math:`j`
and sample :math:`i`, the interval width is:
.. math::
W_{j,i} = U_{j,i} - L_{j,i}
2. **Radial Value (Bar Height)**: The radial value :math:`r_j`
for horizon :math:`j` is the mean interval width across
all :math:`M` samples.
.. math::
r_j = \frac{1}{M} \sum_{i=0}^{M-1} W_{j,i}
3. **Color Value**: The color value :math:`c_j` for horizon
:math:`j` is determined by the mean of the ``q50_cols`` values.
.. math::
c_j = \frac{1}{M} \sum_{i=0}^{M-1} Q50_{j,i}
If ``q50_cols`` is not provided, the color defaults to the
radial value, :math:`c_j = r_j`.
4. **Angular Position**: Horizons are spaced evenly around the
circle. For horizon :math:`j`, the angle is:
.. math::
\theta_j = \frac{j}{N} \times S
where :math:`S` is the angular span from ``acov``. The plot
starts at the top (12 o'clock) and proceeds clockwise.
Examples
--------
>>> import numpy as np
>>> import pandas as pd
>>> from kdiagram.plot import plot_horizon_metrics
>>>
>>> # Create synthetic data for 6 horizons with 2 samples each
>>> horizons = ["H+1", "H+2", "H+3", "H+4", "H+5", "H+6"]
>>> df = pd.DataFrame({
... 'q10_s1': [1, 2, 3, 4, 5, 6],
... 'q10_s2': [1.2, 2.3, 3.4, 4.5, 5.6, 6.7],
... 'q90_s1': [3, 4, 5.5, 7, 8, 9.5],
... 'q90_s2': [3.1, 4.2, 5.7, 7.3, 8.4, 9.9],
... 'q50_s1': [2, 3, 4.2, 5.7, 6.5, 8.2],
... 'q50_s2': [2.1, 3.2, 4.4, 5.9, 6.9, 8.8],
... })
>>>
>>> q10_cols = ['q10_s1', 'q10_s2']
>>> q90_cols = ['q90_s1', 'q90_s2']
>>> q50_cols = ['q50_s1', 'q50_s2']
>>>
>>> ax = plot_horizon_metrics(
... df=df,
... qlow_cols=q10_cols,
... qup_cols=q90_cols,
... q50_cols=q50_cols,
... title="Mean Interval Width Across Horizons",
... xtick_labels=horizons,
... show_value_labels=True,
... r_label="Mean Interval Width (Q90-Q10)",
... cbar_label="Mean Q50 Value",
... acov="default"
... )
"""