Source code for kdiagram.plot.probabilistic

# License: Apache 2.0
# Author: LKouadio <etanoyau@gmail.com>

"""
Probabilistic Forecast Evaluation Plots
"""
from __future__ import annotations

import warnings
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import kstest, uniform

from ..compat.matplotlib import get_cmap
from ..compat.sklearn import validate_params
from ..decorators import check_non_emptiness, isdf
from ..utils.plot import set_axis_grid
from ..utils.validator import exist_features, validate_yy

__all__ = [
    "plot_crps_comparison",
    "plot_pit_histogram",
    "plot_polar_sharpness",
    "plot_credibility_bands",
    "plot_calibration_sharpness",
]


[docs] @validate_params( { "y_true": ["array-like"], "y_preds_quantiles": ["array-like"], "quantiles": ["array-like"], } ) @check_non_emptiness(params=["y_true", "y_preds_quantiles"]) def plot_pit_histogram( y_true: np.ndarray, y_preds_quantiles: np.ndarray, quantiles: np.ndarray, *, n_bins: int = 10, title: str = "PIT Histogram", figsize: tuple[float, float] = (8, 8), color: str = "#3498DB", edgecolor: str = "black", alpha: float = 0.7, show_uniform_line: bool = True, show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation --- y_true, y_preds_quantiles = validate_yy( y_true, y_preds_quantiles, expected_type=None, allow_2d_pred=True, ) quantiles = np.asarray(quantiles) if y_preds_quantiles.shape[1] != len(quantiles): raise ValueError( "Shape mismatch: Number of columns in y_preds_quantiles " f"({y_preds_quantiles.shape[1]}) must match the number of " f"provided quantiles ({len(quantiles)})." ) # --- PIT Calculation --- # Sort quantiles and predictions together sort_idx = np.argsort(quantiles) sorted_preds = y_preds_quantiles[:, sort_idx] # For each observation, find the fraction of forecast quantiles <= true value pit_values = np.mean(sorted_preds <= y_true[:, np.newaxis], axis=1) # --- Histogram Calculation --- hist, bin_edges = np.histogram(pit_values, bins=n_bins, range=(0, 1)) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 # --- Plotting --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) # Angles are the PIT bins, radius is the frequency angles = bin_centers * 2 * np.pi radii = hist width = (2 * np.pi) / n_bins ax.bar( angles, radii, width=width, color=color, edgecolor=edgecolor, alpha=alpha, label="PIT Frequency", ) # Add reference line for perfect calibration if show_uniform_line: expected_count = len(y_true) / n_bins ax.plot( np.linspace(0, 2 * np.pi, 100), [expected_count] * 100, color="red", linestyle="--", lw=2, label=f"Uniform ({expected_count:.1f})", ) # --- Formatting --- ax.set_title(title, fontsize=14, y=1.1) ax.set_xticks(np.linspace(0, 2 * np.pi, n_bins, endpoint=False)) ax.set_xticklabels([f"{edge:.1f}" for edge in bin_edges[:-1]]) ax.set_xlabel("PIT Value Bins") ax.set_ylabel("Frequency") ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1)) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_pit_histogram.__doc__ = r""" Plots a Polar Probability Integral Transform (PIT) Histogram. This function creates a polar bar chart of PIT values to diagnose the calibration of a probabilistic forecast. For a perfectly calibrated forecast, the PIT histogram is uniform, which results in a perfect circle on the polar plot. Deviations from this shape indicate specific model biases. Parameters ---------- y_true : np.ndarray 1D array of observed (true) values. y_preds_quantiles : np.ndarray 2D array of quantile forecasts. Each row corresponds to an observation in ``y_true``, and each column is a specific quantile forecast. quantiles : np.ndarray 1D array of the quantile levels corresponding to the columns of ``y_preds_quantiles`` (e.g., ``[0.05, 0.1, ..., 0.95]``). n_bins : int, default=10 Number of bins for the histogram, which will correspond to the angular sectors in the polar plot. title : str, default="PIT Histogram" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. color : str, default="#3498DB" The fill color for the histogram bars. edgecolor : str, default="black" The edge color for the histogram bars. alpha : float, default=0.7 The transparency of the histogram bars. show_uniform_line : bool, default=True If ``True``, draws a reference circle indicating the expected frequency for a perfectly uniform (calibrated) distribution. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. Notes ----- The Probability Integral Transform (PIT) is a fundamental tool for evaluating the calibration of probabilistic forecasts :footcite:p:`Gneiting2007b`. For a continuous predictive distribution with CDF :math:`F`, the PIT value for an observation :math:`y` is :math:`F(y)`. If the forecast is perfectly calibrated, the PIT values are uniformly distributed on :math:`[0, 1]`. When the predictive CDF is represented by a finite set of :math:`M` quantiles, the PIT value for each observation :math:`y_i` is approximated as the fraction of forecast quantiles that are less than or equal to the observation: .. math:: \text{PIT}_i = \frac{1}{M} \sum_{j=1}^{M} \mathbf{1}\{q_{i,j} \le y_i\} where :math:`q_{i,j}` is the :math:`j`-th quantile forecast for observation :math:`i`, and :math:`\mathbf{1}` is the indicator function. Deviations from a uniform (flat) histogram indicate miscalibration: - **U-shaped**: The forecast is overconfident (too narrow). - **Hump-shaped**: The forecast is underconfident (too wide). - **Sloped**: The forecast is biased. Examples -------- >>> import numpy as np >>> from scipy.stats import norm >>> from kdiagram.plot.probabilistic import plot_pit_histogram >>> >>> # Generate synthetic data >>> np.random.seed(42) >>> n_samples = 1000 >>> y_true = np.random.normal(loc=10, scale=5, size=n_samples) >>> quantiles = np.linspace(0.05, 0.95, 19) >>> >>> # A well-calibrated forecast >>> calibrated_preds = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=5 ... ) >>> >>> # Generate the plot >>> ax = plot_pit_histogram( ... y_true, ... calibrated_preds, ... quantiles, ... title="PIT Histogram (Well-Calibrated Model)" ... ) References ---------- .. footbibliography:: """
[docs] @validate_params( { "quantiles": ["array-like"], } ) @check_non_emptiness(params=["y_preds_quantiles"]) def plot_polar_sharpness( *y_preds_quantiles: np.ndarray, quantiles: np.ndarray, names: list[str] | None = None, title: str = "Forecast Sharpness Comparison", figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", marker: str = "o", s: int = 100, show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation --- if not y_preds_quantiles: raise ValueError("At least one prediction array must be provided.") quantiles = np.asarray(quantiles) if quantiles.ndim != 1: raise ValueError("`quantiles` must be a 1D array.") if names and len(names) != len(y_preds_quantiles): warnings.warn( "Number of names does not match number of models. Using defaults.", stacklevel=2, ) names = None if not names: names = [f"Model {i+1}" for i in range(len(y_preds_quantiles))] # --- Sharpness Calculation --- sharpness_scores = [] for preds in y_preds_quantiles: preds = np.asarray(preds) if preds.shape[1] != len(quantiles): raise ValueError( "Prediction array shape mismatch with quantiles." ) # Use the widest interval for sharpness (e.g., 95% - 5%) lower_bound = preds[:, np.argmin(quantiles)] upper_bound = preds[:, np.argmax(quantiles)] avg_width = np.mean(upper_bound - lower_bound) sharpness_scores.append(avg_width) # --- Plotting --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) num_models = len(y_preds_quantiles) angles = np.linspace(0, 2 * np.pi, num_models, endpoint=False) radii = sharpness_scores cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, num_models)) ax.scatter(angles, radii, c=colors, s=s, marker=marker, zorder=3) # Add labels next to points for i, name in enumerate(names): ax.text( angles[i], radii[i], f" {name}\n ({radii[i]:.2f})", ha="left", va="center", fontsize=9, ) # --- Formatting --- ax.set_title(title, fontsize=14, y=1.1) ax.set_xticks([]) # No angular ticks needed ax.set_ylabel("Average Interval Width (Sharpness)") ax.set_ylim(bottom=0) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
def _calculate_crps(y_true, y_preds_quantiles, quantiles): """ Approximates the CRPS using the pinball loss averaged over quantiles. """ y_true = y_true[:, np.newaxis] # Reshape for broadcasting pinball_loss = np.where( y_true >= y_preds_quantiles, (y_true - y_preds_quantiles) * quantiles, (y_preds_quantiles - y_true) * (1 - quantiles), ) # Average over quantiles for each observation, then over all observations return np.mean(np.mean(pinball_loss, axis=1)) plot_polar_sharpness.__doc__ = r""" Plots a Polar Sharpness Diagram to compare forecast precision. This function creates a polar plot to visually compare the sharpness of one or more probabilistic forecasts. Sharpness is a measure of the concentration of the predictive distribution, typically quantified by the average width of the prediction intervals. Sharper (more precise) forecasts are represented by points closer to the center of the plot. Parameters ---------- *y_preds_quantiles : np.ndarray One or more 2D arrays of quantile forecasts. Each array corresponds to a different model, with shape ``(n_samples, n_quantiles)``. quantiles : np.ndarray 1D array of the quantile levels corresponding to the columns of the prediction arrays. names : list of str, optional Display names for each of the models. If not provided, generic names like ``'Model 1'`` will be generated. title : str, default="Forecast Sharpness Comparison" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap used to assign a unique color to each model's marker. marker : str, default='o' The marker style for the points representing each model. s : int, default=100 The size of the markers. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. Notes ----- A good probabilistic forecast should be both calibrated (reliable) and as sharp as possible :footcite:p:`Gneiting2007b`. This diagram focuses on sharpness, which is independent of the observed outcomes. 1. **Interval Width**: For each model and each observation :math:`i`, the width of the central prediction interval is calculated using the lowest and highest provided quantiles (:math:`q_{min}` and :math:`q_{max}`). .. math:: w_i = y_{i, q_{max}} - y_{i, q_{min}} 2. **Sharpness Score**: The sharpness score :math:`S` for each model is the average of these interval widths over all :math:`N` observations. This score is used as the radial coordinate in the plot. A lower score is better. .. math:: S = \frac{1}{N} \sum_{i=1}^{N} w_i Examples -------- >>> import numpy as np >>> from scipy.stats import norm >>> from kdiagram.plot.probabilistic import plot_polar_sharpness >>> >>> # Generate synthetic data for two models >>> np.random.seed(0) >>> n_samples = 500 >>> y_true = np.random.normal(loc=20, scale=5, size=n_samples) >>> quantiles = np.linspace(0.1, 0.9, 9) # 80% interval >>> >>> # A sharp (precise) forecast >>> sharp_preds = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=2 ... ) >>> # A wide (less precise) forecast >>> wide_preds = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=5 ... ) >>> >>> # Generate the plot >>> ax = plot_polar_sharpness( ... sharp_preds, ... wide_preds, ... quantiles=quantiles, ... names=["Sharp Model", "Wide Model"] ... ) References ---------- .. footbibliography:: """
[docs] @validate_params( { "y_true": ["array-like"], "quantiles": ["array-like"], } ) @check_non_emptiness(params=["y_true", "y_preds_quantiles"]) def plot_crps_comparison( y_true: np.ndarray, *y_preds_quantiles: np.ndarray, quantiles: np.ndarray, names: list[str] | None = None, title: str = "Probabilistic Forecast Performance (CRPS)", figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", marker: str = "o", s: int = 100, show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation --- if not y_preds_quantiles: raise ValueError("At least one prediction array must be provided.") quantiles = np.asarray(quantiles) if names and len(names) != len(y_preds_quantiles): warnings.warn( "Number of names does not match" " number of models. Using defaults.", stacklevel=2, ) names = None if not names: names = [f"Model {i+1}" for i in range(len(y_preds_quantiles))] # --- CRPS Calculation --- crps_scores = [] for preds in y_preds_quantiles: y_true_val, preds_val = validate_yy( y_true, preds, expected_type=None, allow_2d_pred=True, ) crps = _calculate_crps(y_true_val, preds_val, quantiles) crps_scores.append(crps) # --- Plotting --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) num_models = len(y_preds_quantiles) angles = np.linspace(0, 2 * np.pi, num_models, endpoint=False) radii = crps_scores cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, num_models)) ax.scatter(angles, radii, c=colors, s=s, marker=marker, zorder=3) # Add labels next to points for i, name in enumerate(names): ax.text( angles[i], radii[i], f" {name}\n ({radii[i]:.3f})", ha="left", va="center", fontsize=9, ) # --- Formatting --- ax.set_title(title, fontsize=14, y=1.1) ax.set_xticks([]) ax.set_ylabel("Average CRPS (Lower is Better)") ax.set_ylim(bottom=0) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_crps_comparison.__doc__ = r""" Plots a Polar CRPS Comparison Diagram. This function visualizes the overall performance of one or more probabilistic forecasts using the Continuous Ranked Probability Score (CRPS). The CRPS is a proper scoring rule that assesses both calibration and sharpness simultaneously. A lower CRPS value indicates a better forecast. In this plot, models closer to the center are superior. Parameters ---------- y_true : np.ndarray 1D array of observed (true) values. *y_preds_quantiles : np.ndarray One or more 2D arrays of quantile forecasts. Each array corresponds to a different model, with shape ``(n_samples, n_quantiles)``. quantiles : np.ndarray 1D array of the quantile levels corresponding to the columns of the prediction arrays. names : list of str, optional Display names for each of the models. If not provided, generic names like ``'Model 1'`` will be generated. title : str, default="Probabilistic Forecast Performance (CRPS)" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap used to assign a unique color to each model's marker. marker : str, default='o' The marker style for the points representing each model. s : int, default=100 The size of the markers. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. Notes ----- The Continuous Ranked Probability Score (CRPS) is a widely used metric for evaluating probabilistic forecasts :footcite:p:`Gneiting2007b`. For a single observation :math:`y` and a predictive CDF :math:`F`, it is defined as: .. math:: \text{CRPS}(F, y) = \int_{-\infty}^{\infty}\\ (F(x) - \mathbf{1}\{x \ge y\})^2 dx where :math:`\mathbf{1}` is the Heaviside step function. When the forecast is given as a set of :math:`M` quantiles :math:`\{q_1, ..., q_M\}`, the CRPS can be approximated by averaging the pinball loss :math:`\mathcal{L}_{\tau}` over the quantile levels :math:`\tau \in \{ \tau_1, ..., \tau_M \}`: .. math:: \text{CRPS}(F, y) \approx \frac{1}{M} \sum_{j=1}^{M} 2\\ \mathcal{L}_{\tau_j}(q_j, y) The pinball loss for a quantile :math:`\tau` is: .. math:: \mathcal{L}_{\tau}(q, y) = \begin{cases} (y - q) \tau & \text{if } y \ge q \\ (q - y) (1 - \tau) & \text{if } y < q \end{cases} This function calculates the average CRPS over all observations for each model and plots it as the radial coordinate. Examples -------- >>> import numpy as np >>> from scipy.stats import norm >>> from kdiagram.plot.probabilistic import plot_crps_comparison >>> >>> # Generate synthetic data >>> np.random.seed(42) >>> n_samples = 1000 >>> y_true = np.random.normal(loc=10, scale=5, size=n_samples) >>> quantiles = np.linspace(0.05, 0.95, 19) >>> >>> # Create forecasts for three models >>> good_preds = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=5 ... ) >>> sharp_biased_preds = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis] - 2, scale=3 ... ) >>> wide_preds = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=8 ... ) >>> >>> # Generate the plot >>> ax = plot_crps_comparison( ... y_true, ... good_preds, ... sharp_biased_preds, ... wide_preds, ... quantiles=quantiles, ... names=["Good", "Sharp/Biased", "Wide"] ... ) References ---------- .. footbibliography:: """
[docs] @check_non_emptiness(params=["df"]) @isdf def plot_credibility_bands( df: pd.DataFrame, q_cols: tuple[str, str, str], theta_col: str, *, theta_period: float | None = None, theta_bins: int = 24, title: str = "Forecast Credibility Bands", figsize: tuple[float, float] = (8, 8), color: str = "#3498DB", show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, **fill_kws, ): # --- Input Validation --- if len(q_cols) != 3: raise ValueError( "`q_cols` must be a tuple of three column names: " "(lower_q, median_q, upper_q)." ) q_low_col, q_med_col, q_up_col = q_cols required_cols = [q_low_col, q_med_col, q_up_col, theta_col] exist_features(df, features=required_cols) data = df[required_cols].dropna().copy() if data.empty: warnings.warn( "DataFrame is empty after dropping NaNs.", UserWarning, stacklevel=2, ) return None if theta_period: data["theta_rad"] = ( ((data[theta_col] % theta_period) / theta_period) * 2 * np.pi ) else: min_theta, max_theta = data[theta_col].min(), data[theta_col].max() if (max_theta - min_theta) > 1e-9: data["theta_rad"] = ( ((data[theta_col] - min_theta) / (max_theta - min_theta)) * 2 * np.pi ) else: data["theta_rad"] = 0 # --- Binning and Statistics --- theta_edges = np.linspace(0, 2 * np.pi, theta_bins + 1) theta_labels = (theta_edges[:-1] + theta_edges[1:]) / 2 data["theta_bin"] = pd.cut( data["theta_rad"], bins=theta_edges, labels=theta_labels, include_lowest=True, ) stats = ( data.groupby("theta_bin", observed=False) .agg({q_low_col: "mean", q_med_col: "mean", q_up_col: "mean"}) .reset_index() ) # --- Plotting --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) # Plot the mean median line ax.plot( stats["theta_bin"], stats[q_med_col], color="black", lw=2, label="Mean Median Forecast", ) # Plot the shaded credibility band ax.fill_between( stats["theta_bin"], stats[q_low_col], stats[q_up_col], color=color, alpha=fill_kws.pop("alpha", 0.3), label="Credibility Band", **fill_kws, ) # --- Formatting --- ax.set_title(title, fontsize=16, y=1.1) ax.set_xlabel(f"Binned by {theta_col}") ax.set_ylabel("Forecast Value", labelpad=25) ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.1)) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
def _calculate_ks_statistic(pit_values): """ Calculates the Kolmogorov-Smirnov statistic to measure deviation of PIT values from a perfect uniform distribution. """ if len(pit_values) < 2: return 1.0 # Max penalty for insufficient data # Compare the empirical distribution of PIT values to a uniform distribution ks_statistic, _ = kstest(pit_values, uniform.cdf) return ks_statistic plot_credibility_bands.__doc__ = r""" Plots Polar Credibility Bands to visualize forecast uncertainty. This function creates a polar plot that shows how the median forecast and the prediction interval bounds change as a function of another binned variable (e.g., month, hour). It is a descriptive tool for understanding the structure of a model's predictions and its uncertainty estimates. Parameters ---------- df : pd.DataFrame The input DataFrame containing the forecast data. q_cols : tuple of (str, str, str) A tuple of three column names for the lower quantile, the median (Q50), and the upper quantile, in that order. theta_col : str The name of the column to bin against for the angular axis. theta_period : float, optional The period of the cyclical data in ``theta_col`` (e.g., 24 for hours, 12 for months). This ensures the data wraps correctly around the polar plot. theta_bins : int, default=24 The number of angular bins to group the data into. title : str, default="Forecast Credibility Bands" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. color : str, default="#3498DB" The color for the shaded credibility band. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. **fill_kws Additional keyword arguments passed to the ``ax.fill_between`` call for the shaded band (e.g., ``alpha``). Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. Notes ----- This plot visualizes the conditional expectation of the forecast quantiles. It is a novel visualization developed as part of the analytics framework in :footcite:t:`kouadiob2025`. 1. **Binning**: The data is first partitioned into :math:`K` bins, :math:`B_k`, based on the values in ``theta_col``. 2. **Conditional Means**: For each bin :math:`B_k`, the mean of the lower quantile (:math:`\bar{q}_{low,k}`), median quantile (:math:`\bar{q}_{med,k}`), and upper quantile (:math:`\bar{q}_{up,k}`) are calculated. .. math:: \bar{q}_{j,k} = \frac{1}{|B_k|} \sum_{i \in B_k} q_{j,i} where :math:`j \in \{\text{low, med, up}\}`. 3. **Visualization**: The plot displays: - A central line representing the mean median forecast (:math:`\bar{q}_{med,k}`). - A shaded band between the mean lower and upper bounds (:math:`\bar{q}_{low,k}` and :math:`\bar{q}_{up,k}`). The width of this band represents the average forecast sharpness for that bin. Examples -------- >>> import numpy as np >>> import pandas as pd >>> from kdiagram.plot.probabilistic import plot_credibility_bands >>> >>> # Simulate a forecast with seasonal uncertainty >>> np.random.seed(0) >>> n_points = 500 >>> month = np.random.randint(1, 13, n_points) >>> median = 50 + 20 * np.sin((month - 3) * np.pi / 6) >>> width = 10 + 8 * np.cos(month * np.pi / 6)**2 >>> >>> df = pd.DataFrame({ ... 'month': month, ... 'q50': median + np.random.randn(n_points), ... 'q10': median - width / 2, ... 'q90': median + width / 2, ... }) >>> >>> # Generate the plot >>> ax = plot_credibility_bands( ... df=df, ... q_cols=('q10', 'q50', 'q90'), ... theta_col='month', ... theta_period=12, ... theta_bins=12, ... title="Seasonal Forecast Credibility" ... ) References ---------- .. footbibliography:: """
[docs] @check_non_emptiness(params=["y_true", "y_preds_quantiles"]) def plot_calibration_sharpness( y_true: np.ndarray, *y_preds_quantiles: np.ndarray, quantiles: np.ndarray, names: list[str] | None = None, title: str = "Calibration vs. Sharpness Trade-off", figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", marker: str = "o", s: int = 150, show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation --- if not y_preds_quantiles: raise ValueError("At least one prediction array must be provided.") if not names: names = [f"Model {i+1}" for i in range(len(y_preds_quantiles))] # --- Score Calculation --- sharpness_scores = [] calibration_scores = [] for preds in y_preds_quantiles: y_true_val, preds_val = validate_yy(y_true, preds, allow_2d_pred=True) # 1. Calculate Sharpness (Radius) lower = preds_val[:, np.argmin(quantiles)] upper = preds_val[:, np.argmax(quantiles)] sharpness = np.mean(upper - lower) sharpness_scores.append(sharpness) # 2. Calculate Calibration Error (Angle) sort_idx = np.argsort(quantiles) sorted_preds = preds_val[:, sort_idx] pit_values = np.mean( sorted_preds <= y_true_val[:, np.newaxis], axis=1 ) ks_stat = _calculate_ks_statistic(pit_values) calibration_scores.append(ks_stat) # --- Plotting --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) num_models = len(y_preds_quantiles) # Angle: 0 for perfect calibration (KS=0), 90 for worst (KS=1) angles = np.array(calibration_scores) * (np.pi / 2) radii = np.array(sharpness_scores) cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, num_models)) ax.scatter( angles, radii, c=colors, s=s, marker=marker, zorder=3, alpha=0.8 ) # Add labels for i, name in enumerate(names): ax.text( angles[i], radii[i], f" {name}", ha="left", va="bottom", fontsize=9, ) # --- Formatting --- ax.set_title(title, fontsize=16, y=1.1) ax.set_thetamin(0) ax.set_thetamax(90) # Use a quarter circle for clarity ax.set_ylim(bottom=0) # Format angular ticks to represent calibration error ax.set_xticks(np.linspace(0, np.pi / 2, 5)) ax.set_xticklabels([f"{val:.2f}" for val in np.linspace(0, 1, 5)]) ax.set_xlabel("Calibration Error (Lower is Better)") ax.set_ylabel("Sharpness (Lower is Better)", labelpad=25) # Add a legend for colors legend_elements = [ plt.Line2D( [0], [0], marker=marker, color=colors[i], label=names[i], linestyle="None", markersize=10, ) for i in range(num_models) ] ax.legend( handles=legend_elements, loc="upper right", bbox_to_anchor=(1.35, 1.1) ) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_calibration_sharpness.__doc__ = r""" Plots a Polar Calibration-Sharpness Diagram. This function creates a polar plot to visualize the fundamental trade-off between forecast **calibration** (reliability) and **sharpness** (precision) for one or more models. Each model is represented by a single point, allowing for a direct and intuitive comparison of their overall probabilistic performance. The ideal forecast is located at the center of the plot, representing perfect calibration and perfect sharpness. Parameters ---------- y_true : np.ndarray 1D array of observed (true) values. *y_preds_quantiles : np.ndarray One or more 2D arrays of quantile forecasts. Each array corresponds to a different model, with shape ``(n_samples, n_quantiles)``. quantiles : np.ndarray 1D array of the quantile levels corresponding to the columns of the prediction arrays. names : list of str, optional Display names for each of the models. If not provided, generic names like ``'Model 1'`` will be generated. title : str, default="Calibration vs. Sharpness Trade-off" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap used to assign a unique color to each model's marker. marker : str, default='o' The marker style for the points representing each model. s : int, default=150 The size of the markers. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. Notes ----- This plot synthesizes two key aspects of a probabilistic forecast into a single point for each model. It is a novel visualization developed as part of the analytics framework in :footcite:t:`kouadiob2025`. 1. **Sharpness (Radius)**: The radial coordinate represents the forecast's sharpness, calculated as the average width of the prediction interval between the lowest and highest provided quantiles. A smaller radius is better (sharper). .. math:: S = \frac{1}{N} \sum_{i=1}^{N} (y_{i, q_{max}} - y_{i, q_{min}}) 2. **Calibration Error (Angle)**: The angular coordinate represents the forecast's calibration error. This is quantified by first calculating the Probability Integral Transform (PIT) values for each observation. The Kolmogorov-Smirnov (KS) statistic is then used to measure the maximum distance between the empirical CDF of these PIT values and the CDF of a perfect uniform distribution. .. math:: E_{calib} = \sup_{x} | F_{PIT}(x) - U(x) | An error of 0 indicates perfect calibration. The angle is mapped such that :math:`\theta = E_{calib} \cdot \frac{\pi}{2}`, so 0° is perfect and 90° is the worst possible calibration. Examples -------- >>> import numpy as np >>> from scipy.stats import norm >>> from kdiagram.plot.probabilistic import plot_calibration_sharpness >>> >>> # Generate synthetic data >>> np.random.seed(42) >>> n_samples = 1000 >>> y_true = np.random.normal(loc=10, scale=5, size=n_samples) >>> quantiles = np.linspace(0.05, 0.95, 19) >>> >>> # Create forecasts for three models with different trade-offs >>> model_A = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=5 ... ) # Balanced >>> model_B = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis] - 2, scale=3 ... ) # Sharp but biased >>> model_C = norm.ppf( ... quantiles, loc=y_true[:, np.newaxis], scale=8 ... ) # Calibrated but wide >>> >>> # Generate the plot >>> ax = plot_calibration_sharpness( ... y_true, ... model_A, model_B, model_C, ... quantiles=quantiles, ... names=["Balanced", "Sharp/Biased", "Calibrated/Wide"] ... ) References ---------- .. footbibliography:: """