Source code for kdiagram.plot.relationship

# License: Apache 2.0
# Author: LKouadio <etanoyau@gmail.com>
from __future__ import annotations

import warnings
from typing import Any

import matplotlib.pyplot as plt
import numpy as np

from ..compat.matplotlib import get_cmap
from ..compat.sklearn import StrOptions, validate_params
from ..utils.generic_utils import drop_nan_in
from ..utils.plot import set_axis_grid
from ..utils.validator import validate_yy

__all__ = [
    "plot_relationship",
    "plot_conditional_quantiles",
    "plot_residual_relationship",
    "plot_error_relationship",
]


[docs] @validate_params( { "y_true": ["array-like"], } ) def plot_residual_relationship( y_true: np.ndarray, *y_preds: np.ndarray, names: list[str] | None = None, title: str = "Residual vs. Predicted Relationship", figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", s: int = 50, alpha: float = 0.7, show_zero_line: bool = True, show_grid: bool = True, grid_props: dict[str, Any] | None = None, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation and Preparation --- if not y_preds: raise ValueError("At least one prediction array must be provided.") y_true, *y_preds = drop_nan_in(y_true, *y_preds, error="raise") y_true_val, _ = validate_yy(y_true, y_preds[0]) if not names: names = [f"Model {i+1}" for i in range(len(y_preds))] # --- Error and Coordinate Calculation --- errors_list = [y_true_val - np.asarray(yp) for yp in y_preds] all_errors = np.concatenate(errors_list) # Shift the origin to handle negative error values on the radial axis r_offset = np.abs(np.min(all_errors)) if np.min(all_errors) < 0 else 0 # --- Plotting Setup --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, len(y_preds))) # --- Plot Zero-Error Line --- if show_zero_line: ax.plot( np.linspace(0, 2 * np.pi, 100), [r_offset] * 100, color="black", linestyle="--", lw=1.5, label="Zero Error", ) # --- Plot Error Points for Each Model --- for i, (yp, errors) in enumerate(zip(y_preds, errors_list)): y_pred_val = np.asarray(yp) # Sort by the predicted value for a smooth spiral sort_idx = np.argsort(y_pred_val) y_pred_sorted = y_pred_val[sort_idx] errors_sorted = errors[sort_idx] # Map sorted predicted value to angle theta = ( (y_pred_sorted - y_pred_sorted.min()) / (y_pred_sorted.max() - y_pred_sorted.min()) * 2 * np.pi ) radii = errors_sorted + r_offset ax.scatter( theta, radii, color=colors[i], s=s, alpha=alpha, label=names[i] ) # --- Formatting --- ax.set_title(title, fontsize=16, y=1.1) ax.set_xlabel("Based on Predicted Value") ax.set_ylabel("Forecast Error (Actual - Predicted)", labelpad=25) ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.1)) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_residual_relationship.__doc__ = r""" Plots the relationship between forecast error and predicted value. This function creates a polar scatter plot, a polar version of a classic residual plot, to diagnose model performance. The angle is proportional to the **predicted value**, and the radius represents the **forecast error**. It is a powerful tool for identifying conditional biases and heteroscedasticity related to the model's own output magnitude. Parameters ---------- y_true : np.ndarray 1D array of true observed values. *y_preds : np.ndarray One or more 1D arrays of predicted values from different models. names : list of str, optional Display names for each of the models. If not provided, generic names like ``'Model 1'`` will be generated. title : str, default="Residual vs. Predicted Relationship" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap used to assign a unique color to each model's markers. s : int, default=50 The size of the scatter plot markers. alpha : float, default=0.7 The transparency of the markers. show_zero_line : bool, default=True If ``True``, draws a reference circle representing zero error. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_error_relationship : Plot error vs. the true value. plot_conditional_quantiles : Visualize full conditional quantile bands. Notes ----- This plot is a novel visualization developed as part of the analytics framework in :footcite:p:`kouadiob2025`. It helps diagnose if the model's error is correlated with its own predictions. 1. **Error (Residual) Calculation**: For each observation :math:`i`, the error is the difference between the true and predicted value. .. math:: :label: eq:error_calc e_i = y_{true,i} - y_{pred,i} 2. **Angular Mapping**: The angle :math:`\theta_i` is made proportional to the predicted value :math:`y_{pred,i}`, after sorting, to create a continuous spiral. .. math:: \theta_i \propto y_{pred,i} 3. **Radial Mapping**: The radius :math:`r_i` represents the error :math:`e_i`. To handle negative error values on a polar plot, an offset is added to all radii so that the zero-error line becomes a reference circle. Examples -------- >>> import numpy as np >>> from kdiagram.plot.relationship import plot_residual_relationship >>> >>> # Generate synthetic data with known flaws >>> np.random.seed(0) >>> n_samples = 200 >>> y_true = np.linspace(0, 20, n_samples)**1.5 >>> # Model has errors that increase with the prediction magnitude >>> noise = np.random.normal(0, 1, n_samples) * (y_true / 20) >>> y_pred = y_true + noise >>> >>> # Generate the plot >>> ax = plot_residual_relationship( ... y_true, ... y_pred, ... names=["My Model"], ... title="Residual vs. Predicted Value (Heteroscedasticity)" ... ) References ---------- .. footbibliography:: """
[docs] @validate_params( { "y_true": ["array-like"], } ) def plot_error_relationship( y_true: np.ndarray, *y_preds: np.ndarray, names: list[str] | None = None, title: str = "Error vs. True Value Relationship", figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", s: int = 50, alpha: float = 0.7, show_zero_line: bool = True, show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation and Preparation --- if not y_preds: raise ValueError("At least one prediction array must be provided.") y_true, *y_preds = drop_nan_in(y_true, *y_preds, error="raise") y_true, _ = validate_yy( y_true, y_preds[0] ) # Validate first pred against true if not names: names = [f"Model {i+1}" for i in range(len(y_preds))] # --- Error and Coordinate Calculation --- errors_list = [y_true - np.asarray(yp) for yp in y_preds] all_errors = np.concatenate(errors_list) # To handle negative errors on a polar plot, we shift the origin. # The zero-error line will be a circle. r_offset = np.abs(np.min(all_errors)) if np.min(all_errors) < 0 else 0 # Sort by true value to create a smooth spiral effect sort_idx = np.argsort(y_true) y_true_sorted = y_true[sort_idx] # Map sorted true value to angle theta = ( (y_true_sorted - y_true_sorted.min()) / (y_true_sorted.max() - y_true_sorted.min()) * 2 * np.pi ) # --- Plotting Setup --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, len(y_preds))) # --- Plot Zero-Error Line --- if show_zero_line: ax.plot( np.linspace(0, 2 * np.pi, 100), [r_offset] * 100, color="black", linestyle="--", lw=1.5, label="Zero Error", ) # --- Plot Error Points for Each Model --- for i, errors in enumerate(errors_list): errors_sorted = errors[sort_idx] radii = errors_sorted + r_offset ax.scatter( theta, radii, color=colors[i], s=s, alpha=alpha, label=names[i] ) # --- Formatting --- ax.set_title(title, fontsize=16, y=1.1) ax.set_xlabel(f"Based on {getattr(y_true, 'name', 'True Value')}") ax.set_ylabel("Forecast Error", labelpad=25) ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.1)) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_error_relationship.__doc__ = r""" Plots the relationship between forecast error and the true value. This function creates a polar scatter plot to diagnose model performance by visualizing the structure of its errors. The angle is proportional to the **true value**, and the radius represents the **forecast error**. It is a powerful tool for identifying conditional biases and heteroscedasticity. Parameters ---------- y_true : np.ndarray 1D array of true observed values. *y_preds : np.ndarray One or more 1D arrays of predicted values from different models. names : list of str, optional Display names for each of the models. If not provided, generic names like ``'Model 1'`` will be generated. title : str, default="Error vs. True Value Relationship" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap used to assign a unique color to each model's markers. s : int, default=50 The size of the scatter plot markers. alpha : float, default=0.7 The transparency of the markers. show_zero_line : bool, default=True If ``True``, draws a reference circle representing zero error. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_residual_relationship : Plot error vs. the predicted value. plot_conditional_quantiles : Visualize full conditional quantile bands. Notes ----- This plot is a novel visualization developed as part of the analytics framework in :footcite:p:`kouadiob2025`. It helps diagnose if the model's error is correlated with the true value, a key assumption in many statistical models. 1. **Error (Residual) Calculation**: For each observation :math:`i`, the error is the difference between the true and predicted value. .. math:: e_i = y_{true,i} - y_{pred,i} 2. **Angular Mapping**: The angle :math:`\theta_i` is made proportional to the true value :math:`y_{true,i}`, after sorting, to create a continuous spiral. .. math:: \theta_i \propto y_{true,i} 3. **Radial Mapping**: The radius :math:`r_i` represents the error :math:`e_i`. To handle negative error values on a polar plot, an offset is added to all radii so that the zero-error line becomes a reference circle. Examples -------- >>> import numpy as np >>> from kdiagram.plot.relationship import plot_error_relationship >>> >>> # Generate synthetic data with known flaws >>> np.random.seed(0) >>> n_samples = 200 >>> y_true = np.linspace(0, 20, n_samples)**1.5 >>> # Model has a bias that depends on the true value >>> bias = -0.1 * y_true >>> y_pred = y_true + bias + np.random.normal(0, 2, n_samples) >>> >>> # Generate the plot >>> ax = plot_error_relationship( ... y_true, ... y_pred, ... names=["My Model"], ... title="Error vs. True Value (Conditional Bias)" ... ) References ---------- .. footbibliography:: """
[docs] @validate_params( { "y_true": ["array-like"], "y_preds_quantiles": ["array-like"], "quantiles": ["array-like"], } ) def plot_conditional_quantiles( y_true: np.ndarray, y_preds_quantiles: np.ndarray, quantiles: np.ndarray, *, bands: list[int] | None = None, title: str = "Conditional Quantile Plot", figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", alpha_min: float = 0.2, alpha_max: float = 0.5, show_grid: bool = True, grid_props: dict[str, Any] | None = None, mask_radius: bool = False, savefig: str | None = None, dpi: int = 300, ): # --- Input Validation --- y_true, y_preds_quantiles = validate_yy( y_true, y_preds_quantiles, allow_2d_pred=True ) quantiles = np.asarray(quantiles) if y_preds_quantiles.shape[1] != len(quantiles): raise ValueError("Shape mismatch between predictions and quantiles.") # Sort data by y_true to ensure a smooth spiral plot sort_idx = np.argsort(y_true) y_true_sorted = y_true[sort_idx] y_preds_sorted = y_preds_quantiles[sort_idx, :] # --- Plotting Setup --- fig, ax = plt.subplots( figsize=figsize, subplot_kw={"projection": "polar"} ) # Map y_true to the angular coordinate theta = ( (y_true_sorted - y_true_sorted.min()) / (y_true_sorted.max() - y_true_sorted.min()) * 2 * np.pi ) # --- Identify Median and Bands --- median_q = 0.5 if median_q not in quantiles: warnings.warn( "Median (0.5) not found in quantiles." " No central line will be plotted.", stacklevel=2, ) median_idx = -1 else: median_idx = np.where(np.isclose(quantiles, median_q))[0][0] if bands is None: # Default to the widest possible interval min_q, max_q = np.min(quantiles), np.max(quantiles) bands = [int((max_q - min_q) * 100)] bands = sorted(bands, reverse=True) # Plot widest band first cmap_obj = get_cmap(cmap, default="viridis") alphas = np.linspace(alpha_min, alpha_max, len(bands)) colors = cmap_obj(np.linspace(0.3, 0.9, len(bands))) # --- Plot Bands --- for i, band_pct in enumerate(bands): lower_q = (100 - band_pct) / 200.0 upper_q = 1 - lower_q try: lower_idx = np.where(np.isclose(quantiles, lower_q))[0][0] upper_idx = np.where(np.isclose(quantiles, upper_q))[0][0] except IndexError: warnings.warn( f"Quantiles for {band_pct}% interval not found. Skipping.", stacklevel=2, ) continue ax.fill_between( theta, y_preds_sorted[:, lower_idx], y_preds_sorted[:, upper_idx], color=colors[i], alpha=alphas[i], label=f"{band_pct}% Interval", ) # --- Plot Median Line --- if median_idx != -1: ax.plot( theta, y_preds_sorted[:, median_idx], color="black", lw=1.5, label="Median (Q50)", ) # --- Formatting --- ax.set_title(title, fontsize=16, y=1.1) ax.set_xlabel( f"Based on {y_true.name if hasattr(y_true, 'name') else 'True Value'}" ) ax.set_ylabel("Predicted Value", labelpad=25) ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.1)) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) if mask_radius: ax.set_yticklabels([]) plt.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_conditional_quantiles.__doc__ = r""" Plots polar conditional quantile bands. This function visualizes how the predicted conditional distribution (represented by quantiles) changes as a function of the true observed value. It is a powerful tool for diagnosing heteroscedasticity, i.e., whether the forecast uncertainty is constant or changes with the magnitude of the target variable. Parameters ---------- y_true : np.ndarray 1D array of true observed values, which will be mapped to the angular coordinate. y_preds_quantiles : np.ndarray 2D array of quantile forecasts, with shape ``(n_samples, n_quantiles)``. quantiles : np.ndarray 1D array of the quantile levels corresponding to the columns of ``y_preds_quantiles``. bands : list of int, optional A list of the desired interval percentages to plot as shaded bands (e.g., ``[90, 50]`` for the 90% and 50% prediction intervals). Defaults to the widest interval available from the provided quantiles. title : str, default="Conditional Quantile Plot" The title for the plot. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap for the shaded uncertainty bands. alpha_min : float, default=0.2 The minimum alpha (transparency) for the outermost band. alpha_max : float, default=0.5 The maximum alpha for the innermost band. show_grid : bool, default=True Toggle the visibility of the polar grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. mask_radius : bool, default=False If ``True``, hide the radial tick labels. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. Notes ----- This plot is a novel visualization developed as part of the analytics framework in :footcite:p:`kouadiob2025`. It provides an intuitive view of the conditional predictive distribution. 1. **Coordinate Mapping**: The plot first sorts the data based on the true values :math:`y_{true}` to ensure a continuous spiral. The sorted true values are then mapped to the angular coordinate :math:`\theta` in the range :math:`[0, 2\pi]`. .. math:: \theta_i \propto y_{true,i}^{\text{(sorted)}} The predicted quantiles :math:`q_{i, \tau}` for each observation :math:`i` and quantile level :math:`\tau` are mapped directly to the radial coordinate :math:`r`. 2. **Band Construction**: For a given prediction interval, for example 80%, the corresponding lower (:math:`\tau=0.1`) and upper (:math:`\tau=0.9`) quantile forecasts are used to define the boundaries of a shaded band. The function can plot multiple, nested bands (e.g., 80% and 50%) to give a more complete picture of the distribution's shape. The median forecast (:math:`\tau=0.5`) is drawn as a solid central line. Examples -------- >>> import numpy as np >>> from kdiagram.plot.relationship import plot_conditional_quantiles >>> >>> # Generate synthetic data with heteroscedasticity >>> np.random.seed(0) >>> n_samples = 200 >>> y_true = np.linspace(0, 20, n_samples)**1.5 >>> quantiles = np.array([0.1, 0.25, 0.5, 0.75, 0.9]) >>> >>> # Uncertainty (interval width) increases with the true value >>> interval_width = 5 + (y_true / y_true.max()) * 15 >>> y_preds = np.zeros((n_samples, len(quantiles))) >>> y_preds[:, 2] = y_true # Median >>> y_preds[:, 1] = y_true - interval_width * 0.25 # Q25 >>> y_preds[:, 3] = y_true + interval_width * 0.25 # Q75 >>> y_preds[:, 0] = y_true - interval_width * 0.5 # Q10 >>> y_preds[:, 4] = y_true + interval_width * 0.5 # Q90 >>> >>> # Generate the plot >>> ax = plot_conditional_quantiles( ... y_true, ... y_preds, ... quantiles, ... bands=[80, 50], # Show 80% and 50% intervals ... title="Conditional Uncertainty (Heteroscedasticity)" ... ) References ---------- .. footbibliography:: """
[docs] @validate_params( { "y_true": ["array-like"], "y_pred": ["array-like"], "theta_scale": [StrOptions({"proportional", "uniform"})], "acov": [ StrOptions( {"default", "half_circle", "quarter_circle", "eighth_circle"} ) ], } ) def plot_relationship( y_true, *y_preds, names=None, title=None, theta_offset=0, theta_scale="proportional", acov="default", figsize=None, cmap="tab10", s=50, alpha=0.7, legend=True, show_grid=True, grid_props=None, color_palette=None, xlabel=None, ylabel=None, z_values=None, z_label=None, savefig=None, ): # Remove NaN values from y_true and all y_pred arrays y_true, *y_preds = drop_nan_in(y_true, *y_preds, error="raise") # Validate y_true and each y_pred to ensure consistency and continuity try: y_preds = [ validate_yy( y_true, pred, expected_type="continuous", flatten=True )[1] for pred in y_preds ] except Exception as err: raise ValueError( "Validation failed. Please check your y_pred" ) from err # Generate default model names if none are provided num_preds = len(y_preds) if names is None: names = [f"Model_{i+1}" for i in range(num_preds)] else: # Ensure names is a list names = list(names) # Ensure the length of names matches y_preds if len(names) < num_preds: names += [f"Model_{i+1}" for i in range(len(names), num_preds)] elif len(names) > num_preds: warnings.warn( f"Received {len(names)} names for {num_preds}" f" predictions. Extra names ignored.", UserWarning, stacklevel=2, ) names = names[:num_preds] # --- Color Handling --- if color_palette is None: # Generate colors from cmap if palette not given try: cmap_obj = get_cmap(cmap, default="tab10", failsafe="discrete") # Sample enough distinct colors if ( hasattr(cmap_obj, "colors") and len(cmap_obj.colors) >= num_preds ): # Use colors directly from discrete map if enough color_palette = cmap_obj.colors[:num_preds] else: color_palette = [ ( cmap_obj(i / max(1, num_preds - 1)) if num_preds > 1 else cmap_obj(0.5) ) for i in range(num_preds) ] except ValueError: warnings.warn( f"Invalid cmap '{cmap}'. Falling back to 'tab10'.", stacklevel=2, ) color_palette = plt.cm.tab10.colors # Default palette # Ensure palette has enough colors, repeat if necessary final_colors = [ color_palette[i % len(color_palette)] for i in range(num_preds) ] # Determine the angular range based on `acov` if acov == "default": angular_range = 2 * np.pi elif acov == "half_circle": angular_range = np.pi elif acov == "quarter_circle": angular_range = np.pi / 2 elif acov == "eighth_circle": angular_range = np.pi / 4 else: # This case should be caught by @validate_params, # but keep as safeguard raise ValueError( "Invalid value for `acov`. Choose from 'default'," " 'half_circle', 'quarter_circle', or 'eighth_circle'." ) # Create the polar plot fig, ax = plt.subplots( figsize=figsize or (8, 8), # Provide default here subplot_kw={"projection": "polar"}, ) # Limit the visible angular range ax.set_thetamin(0) # Start angle (in degrees) ax.set_thetamax(np.degrees(angular_range)) # End angle (in degrees) # Map `y_true` to angular coordinates (theta) # Handle potential division by zero if y_true is constant y_true_range = np.ptp(y_true) # Peak-to-peak range if theta_scale == "proportional": if y_true_range > 1e-9: # Avoid division by zero theta = angular_range * (y_true - np.min(y_true)) / y_true_range else: # Handle constant y_true case - map all to start angle? theta = np.zeros_like(y_true) warnings.warn( "y_true has zero range. Mapping all points to angle 0" " with 'proportional' scaling.", UserWarning, stacklevel=2, ) elif theta_scale == "uniform": # linspace handles len=1 case correctly theta = np.linspace(0, angular_range, len(y_true), endpoint=False) else: # This case should be caught by @validate_params raise ValueError( "`theta_scale` must be either 'proportional' or 'uniform'." ) # Apply theta offset theta += theta_offset # Plot each model's predictions for i, y_pred in enumerate(y_preds): # Ensure `y_pred` is a numpy array y_pred = np.asarray(y_pred, dtype=float) # Convert early # Normalize `y_pred` for radial coordinates # Handle potential division by zero if y_pred is constant y_pred_range = np.ptp(y_pred) if y_pred_range > 1e-9: r = (y_pred - np.min(y_pred)) / y_pred_range else: # If constant, map all to 0.5 radius (midpoint)? Or 0? Let's use 0.5 r = np.full_like(y_pred, 0.5) warnings.warn( f"Prediction series '{names[i]}' has zero range." f" Plotting all its points at normalized radius 0.5.", UserWarning, stacklevel=2, ) # Plot on the polar axis ax.scatter( theta, r, label=names[i], color=final_colors[i], s=s, alpha=alpha, edgecolor="black", ) # If z_values are provided, replace angle labels with z_values if z_values is not None: z_values = np.asarray(z_values) # Ensure numpy array if len(z_values) != len(y_true): raise ValueError( "Length of `z_values` must match the length of `y_true`." ) # Decide number of ticks, e.g., 5-10 depending on range/preference num_z_ticks = min(len(z_values), 8) # Example: max 8 ticks tick_indices = np.linspace( 0, len(z_values) - 1, num_z_ticks, dtype=int, endpoint=True ) # Get theta values corresponding to these indices theta_ticks = theta[tick_indices] # Use theta calculated earlier z_tick_labels = [ f"{z_values[ix]:.2g}" for ix in tick_indices ] # Format labels ax.set_xticks(theta_ticks) ax.set_xticklabels(z_tick_labels) # Set label for z-axis if z_label is provided if z_label: ax.text( 1.1, 0.5, z_label, transform=ax.transAxes, rotation=90, va="center", ha="left", ) # Add labels for radial and angular axes (only if z_values are not used for angles) if z_values is None: ax.set_ylabel(ylabel or "Angular Mapping (θ)", labelpad=15) # Radial label ax.set_xlabel(xlabel or "Normalized Predictions (r)", labelpad=15) # Position radial labels better ax.set_rlabel_position(22.5) # Adjust angle for radial labels ax.set_title(title or "Relationship Visualization", va="bottom", pad=20) # Add grid using helper or directly set_axis_grid(ax, show_grid, grid_props=grid_props) # Add legend if legend: ax.legend( loc="upper right", bbox_to_anchor=(1.25, 1.1) ) # Adjust position plt.tight_layout() # Adjust layout to prevent overlap # --- Save or Show --- if savefig: try: plt.savefig(savefig, bbox_inches="tight", dpi=300) print(f"Plot saved to {savefig}") except Exception as e: print(f"Error saving plot to {savefig}: {e}") else: # Warning for non-GUI backend is expected here in test envs plt.show()
plot_relationship.__doc__ = r""" Visualize the relationship between true values and one or more prediction series on a polar (circular) scatter plot. Each point uses an angular position derived from ``y_true`` and a radial position derived from the corresponding prediction. This compact view lets you compare multiple prediction series against the same truth—useful for spotting systematic deviations and patterns over a cyclic or ordered domain (e.g., phase, time-of-year). Parameters ---------- y_true : array-like of shape (n_samples,) Ground-truth (observed) values. Must be numeric, 1D, and the same length as every array in ``y_preds``. *y_preds : array-like(s) One or more prediction arrays, each with shape ``(n_samples,)`` and aligned to ``y_true``. names : list of str, optional Labels for each prediction series. If fewer names than series are provided, placeholders like ``'Model_3'`` are appended. title : str, optional Figure title. If ``None``, uses ``'Relationship Visualization'``. theta_offset : float, default=0 Constant angular shift (radians) applied after the angle mapping. theta_scale : {'proportional', 'uniform'}, default='proportional' Strategy for mapping ``y_true`` to angles: - ``'proportional'``: angle proportional to the scaled value of ``y_true`` within its range over the selected angular span. - ``'uniform'``: angles evenly spaced over the selected span, ignoring the numerical spacing in ``y_true``. acov : {'default', 'half_circle', 'quarter_circle', 'eighth_circle'}, default='default' Angular coverage (span) of the plot: - ``'default'``: :math:`2\pi` (full circle) - ``'half_circle'``: :math:`\pi` - ``'quarter_circle'``: :math:`\tfrac{\pi}{2}` - ``'eighth_circle'``: :math:`\tfrac{\pi}{4}` figsize : tuple of (float, float), optional Figure size in inches. If ``None``, a sensible default is used. cmap : str, default='tab10' Matplotlib colormap name used to generate distinct series colors. s : float, default=50 Marker size for scatter points. alpha : float, default=0.7 Alpha (transparency) for scatter points in ``[0, 1]``. legend : bool, default=True If ``True``, show a legend for the prediction series. show_grid : bool, default=True Toggle polar grid lines (delegated to ``set_axis_grid``). grid_props : dict, optional Keyword arguments forwarded to the grid helper (e.g., ``linestyle``, ``alpha``). color_palette : list of color-like, optional Explicit list of colors. If omitted, colors are derived from ``cmap``. If provided with fewer colors than series, they repeat. xlabel : str, optional Label for the radial axis. Defaults to ``'Normalized Predictions (r)'``. ylabel : str, optional Label for the angular axis. Defaults to ``'Angular Mapping (θ)'`` when ``z_values`` is not used. z_values : array-like of shape (n_samples,), optional Optional values used to label angular ticks (e.g., time, phase). If provided, a subset of positions is selected and tick labels are replaced by formatted entries from ``z_values``. z_label : str, optional Axis/legend label describing ``z_values`` (shown as text next to the angular tick labels region). savefig : str, optional Path to save the figure (with extension). If ``None``, the figure is shown instead. Returns ------- ax : matplotlib.axes.Axes The polar axes containing the visualization. Notes ----- **Angular span.** Let :math:`\Delta\theta` be the selected span: :math:`2\pi` (full), :math:`\pi`, :math:`\pi/2`, or :math:`\pi/4` depending on ``acov``. Angles are then limited to :math:`[0,\,\Delta\theta]` and shifted by ``theta_offset``. **Angle mapping.** For :math:`N=\text{len}(y_{\text{true}})` and :math:`i=0,\dots,N-1`: - Proportional mapping (range-aware): .. math:: \theta_i \;=\; \begin{cases} \dfrac{y_i - y_{\min}}{y_{\max}-y_{\min}}\,\Delta\theta, & \text{if } y_{\max}>y_{\min},\\[6pt] 0, & \text{otherwise,} \end{cases} where :math:`y_{\min}=\min_i y_i` and :math:`y_{\max}=\max_i y_i`. - Uniform mapping (index-based): .. math:: \theta_i \;=\; \frac{i}{N}\,\Delta\theta. **Radial normalization.** Each prediction series :math:`p` is scaled to :math:`[0,1]` by .. math:: r_i \;=\; \begin{cases} \dfrac{p_i - p_{\min}}{p_{\max}-p_{\min}}, & p_{\max}>p_{\min},\\[6pt] 0.5, & \text{otherwise,} \end{cases} to give comparable radii across heterogeneous series :footcite:p:`Hunter:2007`. **Data preparation.** The function first removes joint NaNs via ``drop_nan_in`` and validates each pair ``(y_true, y_pred)`` through ``validate_yy`` (continuous expectations, 1D arrays). Colors are drawn from ``cmap`` unless ``color_palette`` is supplied. Grid appearance is managed by ``set_axis_grid``. **Interpretation.** When ``theta_scale='proportional'``, nearby angles reflect similar truth values; with ``'uniform'``, angles reflect order only. Clustering by color (series) indicates systematic agreement or disagreement versus truth across the domain :footcite:p:`kouadiob2025`. Examples -------- Basic comparison over a full circle: >>> import numpy as np >>> from kdiagram.plot.relationship import plot_relationship >>> rng = np.random.default_rng(0) >>> y = rng.random(200) >>> p1 = y + rng.normal(0, 0.10, size=len(y)) >>> p2 = y + rng.normal(0, 0.20, size=len(y)) >>> ax = plot_relationship( ... y, p1, p2, ... names=["Model A", "Model B"], ... acov="default", ... title="Truth–Prediction (Full Circle)" ... ) Half-circle with custom angular tick labels (e.g., months): >>> months = np.linspace(1, 12, len(y)) >>> ax = plot_relationship( ... y, p1, ... names=["Model A"], ... theta_scale="uniform", ... acov="half_circle", ... z_values=months, ... z_label="Month", ... xlabel="Normalized Predictions (r)" ... ) See Also -------- kdiagram.plot.uncertainty.plot_temporal_uncertainty : General polar series visualization (e.g., quantiles). kdiagram.plot.uncertainty.plot_actual_vs_predicted : Side-by-side truth vs. point prediction comparison. References ---------- .. footbibliography:: """