Source code for kdiagram.plot.context

import warnings
from typing import Any, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pandas.plotting import autocorrelation_plot
from scipy.stats import probplot

from ..compat.matplotlib import get_cmap
from ..decorators import check_non_emptiness, isdf
from ..utils.deps import ensure_pkg
from ..utils.generic_utils import get_valid_kwargs
from ..utils.handlers import columns_manager
from ..utils.plot import set_axis_grid
from ..utils.validator import exist_features

__all__ = [
    "plot_time_series",
    "plot_scatter_correlation",
    "plot_error_autocorrelation",
    "plot_qq",
    "plot_error_pacf",
    "plot_error_distribution",
]


[docs] @isdf @check_non_emptiness(params=["df"]) def plot_time_series( df: pd.DataFrame, x_col: Optional[str] = None, actual_col: Optional[str] = None, pred_cols: Optional[list[str]] = None, names: Optional[list[str]] = None, q_lower_col: Optional[str] = None, q_upper_col: Optional[str] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, figsize: tuple[float, float] = (12, 6), cmap: str = "viridis", show_grid: bool = True, grid_props: Optional[dict[str, Any]] = None, savefig: Optional[str] = None, dpi: int = 300, ): pred_cols = columns_manager(pred_cols, empty_as_none=False) if not actual_col and not pred_cols: raise ValueError( "At least one of `actual_col` or `pred_cols` must be provided." ) required_cols = [] if x_col: required_cols.append(x_col) if actual_col: required_cols.append(actual_col) if q_lower_col: required_cols.append(q_lower_col) if q_upper_col: required_cols.append(q_upper_col) if pred_cols: required_cols.extend(pred_cols) exist_features(df, features=required_cols) # Use index if x_col is not provided x_data = df.index if x_col is None else df[x_col] # Handle names for the legend num_preds = len(pred_cols) if names and len(names) != num_preds: warnings.warn( "Length of `names` does not match `pred_cols`. Using defaults.", stacklevel=2, ) names = None if not names: names = [col for col in pred_cols] # --- Plotting Setup --- fig, ax = plt.subplots(figsize=figsize) cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, num_preds)) if num_preds > 0 else [] # --- Plot Uncertainty Band (if provided) --- if q_lower_col and q_upper_col: ax.fill_between( x_data, df[q_lower_col], df[q_upper_col], color="gray", alpha=0.2, label="Uncertainty Interval", ) # --- Plot Actual Values --- if actual_col: ax.plot( x_data, df[actual_col], color="black", linewidth=2, label="Actual" ) # --- Plot Predicted Values --- for i, pred_col in enumerate(pred_cols): ax.plot( x_data, df[pred_col], color=colors[i], linestyle="--", linewidth=1.5, label=names[i], ) # --- Formatting --- ax.set_title(title or "Time Series Forecast", fontsize=16) ax.set_xlabel(xlabel or (x_col if x_col else "Index"), fontsize=12) ax.set_ylabel(ylabel or "Value", fontsize=12) ax.legend() set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) fig.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_time_series.__doc__ = r""" Plots one or more time series from a DataFrame. This function creates a standard time series plot, which is a fundamental tool for visualizing and comparing actual observed values against one or more model forecasts over time. It serves as an essential first-look diagnostic in any forecasting workflow. More details in :ref:`Time Series Plot User Guide <ug_plot_time_series>` Parameters ---------- df : pd.DataFrame The input DataFrame containing the time series data. x_col : str, optional The name of the column to use for the x-axis (e.g., a datetime column). If ``None``, the DataFrame's index is used. actual_col : str, optional The name of the column containing the true observed values. This is typically plotted as a solid reference line. pred_cols : list of str, optional A list of one or more column names containing the point forecasts from different models. names : list of str, optional Display names for each of the prediction series, to be used in the legend. q_lower_col : str, optional The name of the column for the lower bound of a prediction interval. If provided with ``q_upper_col``, a shaded uncertainty band will be drawn. q_upper_col : str, optional The name of the column for the upper bound of a prediction interval. title : str, optional The title for the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. figsize : tuple of (float, float), default=(12, 6) The figure size in inches. cmap : str, default='viridis' The colormap used to assign unique colors to the different prediction series. show_grid : bool, default=True Toggle the visibility of the plot's grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_scatter_correlation : A Cartesian plot for correlation. plot_actual_vs_predicted : A polar plot for comparing true vs. predicted. Notes ----- This function provides a direct visualization of time-dependent variables by mapping a time-like variable to the x-axis and the series values to the y-axis. It is a foundational plot for assessing a model's ability to track trends and seasonality. Examples -------- >>> import pandas as pd >>> import numpy as np >>> from kdiagram.plot.context import plot_time_series >>> >>> # Generate synthetic time series data >>> np.random.seed(0) >>> n_samples = 100 >>> time = pd.date_range("2024-01-01", periods=n_samples, freq='D') >>> y_true = 50 + np.linspace(0, 10, n_samples) + \ ... 5 * np.sin(np.arange(n_samples) * 2 * np.pi / 15) >>> >>> y_pred = y_true + np.random.normal(0, 1.5, n_samples) >>> df = pd.DataFrame({ ... 'time': time, ... 'actual': y_true, ... 'forecast': y_pred, ... 'q10': y_pred - 3, ... 'q90': y_pred + 3, ... }) >>> >>> # Generate the plot >>> ax = plot_time_series( ... df, ... x_col='time', ... actual_col='actual', ... pred_cols=['forecast'], ... q_lower_col='q10', ... q_upper_col='q90', ... title="Forecast vs. Actuals with 80% Uncertainty" ... ) """
[docs] @isdf @check_non_emptiness(params=["df"]) def plot_scatter_correlation( df: pd.DataFrame, actual_col: str, pred_cols: list[str], names: Optional[list[str]] = None, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, figsize: tuple[float, float] = (8, 8), cmap: str = "viridis", s: int = 50, alpha: float = 0.7, show_identity_line: bool = True, show_grid: bool = True, grid_props: Optional[dict[str, Any]] = None, savefig: Optional[str] = None, dpi: int = 300, ): # --- Input Validation and Preparation --- pred_cols = columns_manager(pred_cols, empty_as_none=False) if not pred_cols: raise ValueError( "At least one prediction column (`pred_cols`) must be provided." ) required_cols = [actual_col] + list(pred_cols) exist_features(df, features=required_cols) data_to_plot = df[required_cols].dropna() actual_data = data_to_plot[actual_col] num_preds = len(pred_cols) if names and len(names) != num_preds: warnings.warn( "Length of `names` does not match `pred_cols`. Using defaults.", stacklevel=2, ) names = None if not names: names = [col for col in pred_cols] # --- Plotting Setup --- fig, ax = plt.subplots(figsize=figsize) cmap_obj = get_cmap(cmap, default="viridis") colors = cmap_obj(np.linspace(0, 1, num_preds)) if num_preds > 0 else [] # --- Plot Identity Line (y=x) --- if show_identity_line: min_val = min( actual_data.min(), data_to_plot[list(pred_cols)].min().min() ) max_val = max( actual_data.max(), data_to_plot[list(pred_cols)].max().max() ) ax.plot( [min_val, max_val], [min_val, max_val], "k--", alpha=0.7, label="Identity Line", ) # --- Plot Scatter for Each Prediction --- for i, pred_col in enumerate(pred_cols): ax.scatter( actual_data, data_to_plot[pred_col], color=colors[i], s=s, alpha=alpha, label=names[i], ) # --- Formatting --- ax.set_title(title or "Actual vs. Predicted", fontsize=16) ax.set_xlabel(xlabel or f"True Values ({actual_col})", fontsize=12) ax.set_ylabel(ylabel or "Predicted Values", fontsize=12) ax.legend() ax.axis("equal") # Ensure a square aspect ratio set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) fig.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_scatter_correlation.__doc__ = r""" Plots a scatter plot of true vs predicted values. This function creates a classic Cartesian scatter plot to visualize the relationship between true observed values and model predictions. It is an essential tool for assessing linear correlation, identifying systemic bias, and spotting outliers. For more details, refer to :ref:`Scatter Correlation Plot User Guide <ug_plot_scatter_correlation>` Parameters ---------- df : pd.DataFrame The input DataFrame containing the actual and predicted values. actual_col : str The name of the column containing the true observed values, which will be plotted on the x-axis. pred_cols : list of str A list of one or more column names containing the point forecasts from different models. names : list of str, optional Display names for each of the prediction series, to be used in the legend. title : str, optional The title for the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. figsize : tuple of (float, float), default=(8, 8) The figure size in inches. cmap : str, default='viridis' The colormap used to assign unique colors to the different prediction series. s : int, default=50 The size of the scatter plot markers. alpha : float, default=0.7 The transparency of the markers. show_identity_line : bool, default=True If ``True``, draws a dashed y=x line, which represents a perfect forecast. show_grid : bool, default=True Toggle the visibility of the plot's grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_relationship : A polar version of this plot. plot_error_relationship : A plot to diagnose error patterns. Notes ----- This plot directly visualizes the relationship between two variables by plotting each observation :math:`i` as a point :math:`(y_{true,i}, y_{pred,i})`. The primary reference is the **identity line**, defined by the equation: .. math:: y = x For a perfect forecast, every predicted value would equal its corresponding true value, and all points would fall exactly on this line. Deviations from this line represent prediction errors. Examples -------- >>> import pandas as pd >>> import numpy as np >>> from kdiagram.plot.context import plot_scatter_correlation >>> >>> # Generate synthetic data >>> np.random.seed(0) >>> n_samples = 100 >>> y_true = np.linspace(0, 50, n_samples) >>> y_pred_good = y_true + np.random.normal(0, 3, n_samples) >>> y_pred_biased = y_true * 0.8 + 5 >>> >>> df = pd.DataFrame({ ... 'actual': y_true, ... 'good_model': y_pred_good, ... 'biased_model': y_pred_biased, ... }) >>> >>> # Generate the plot >>> ax = plot_scatter_correlation( ... df, ... actual_col='actual', ... pred_cols=['good_model', 'biased_model'], ... names=['Good Model', 'Biased Model'], ... title="Actual vs. Predicted Correlation" ... ) """
[docs] @isdf @check_non_emptiness(params=["df"]) def plot_error_autocorrelation( df: pd.DataFrame, actual_col: str, pred_col: str, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, figsize: tuple[float, float] = (10, 5), show_grid: bool = True, grid_props: Optional[dict[str, Any]] = None, savefig: Optional[str] = None, dpi: int = 300, **acf_kwargs, ): # --- Input Validation and Preparation --- required_cols = [actual_col, pred_col] exist_features(df, features=required_cols) data_to_plot = df[required_cols].dropna() errors = data_to_plot[actual_col] - data_to_plot[pred_col] if len(errors) < 2: warnings.warn( "Not enough data points to plot autocorrelation.", stacklevel=2 ) return None # --- Plotting Setup --- fig, ax = plt.subplots(figsize=figsize) # --- Generate ACF Plot --- acf_kwargs = get_valid_kwargs(autocorrelation_plot, acf_kwargs) autocorrelation_plot(errors, ax=ax, **acf_kwargs) # --- Formatting --- ax.set_title(title or "Autocorrelation of Forecast Errors", fontsize=16) ax.set_xlabel(xlabel or "Lag", fontsize=12) ax.set_ylabel(ylabel or "Autocorrelation", fontsize=12) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) fig.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_error_autocorrelation.__doc__ = r""" Plots the autocorrelation of forecast errors. This function creates an Autocorrelation Function (ACF) plot of the forecast errors (residuals). It is a critical diagnostic for time series models, used to check if there is any remaining temporal structure (i.e., patterns) in the residuals. A well- specified model should have errors that are uncorrelated over time, behaving like random noise. Additional details can be found in :ref:`Error Distribution Plot User Guide <ug_plot_error_distribution>` Parameters ---------- df : pd.DataFrame The input DataFrame containing the actual and predicted values. actual_col : str The name of the column containing the true observed values. pred_col : str The name of the column containing the point forecast values. title : str, optional The title for the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. figsize : tuple of (float, float), default=(10, 5) The figure size in inches. show_grid : bool, default=True Toggle the visibility of the plot's grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. **acf_kwargs Additional keyword arguments passed directly to the underlying ``pandas.plotting.autocorrelation_plot`` function. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_error_pacf : The companion plot for partial autocorrelation. Notes ----- The Autocorrelation Function (ACF) at lag :math:`k` measures the correlation between a time series and its own past values. For a series of errors :math:`e_t`, the ACF is defined as: .. math:: \rho_k = \frac{\text{Cov}(e_t, e_{t-k})}{\text{Var}(e_t)} This plot displays the values of :math:`\rho_k` for a range of different lags :math:`k`. The plot also includes significance bands (typically at 95% and 99% confidence), which provide a threshold for determining if a correlation is statistically significant or likely due to random chance. Examples -------- >>> import pandas as pd >>> import numpy as np >>> from kdiagram.plot.context import plot_error_autocorrelation >>> >>> # Generate synthetic data with autocorrelated errors >>> np.random.seed(0) >>> n_samples = 200 >>> y_true = np.linspace(0, 50, n_samples) >>> # Errors have an AR(1) structure >>> errors = [0] >>> for _ in range(n_samples - 1): ... errors.append(0.7 * errors[-1] + np.random.normal(0, 1)) >>> y_pred = y_true + np.array(errors) >>> >>> df = pd.DataFrame({'actual': y_true, 'predicted': y_pred}) >>> >>> # Generate the plot >>> ax = plot_error_autocorrelation( ... df, ... actual_col='actual', ... pred_col='predicted', ... title="Autocorrelation of Dependent Errors" ... ) """
[docs] @isdf @check_non_emptiness(params=["df"]) def plot_qq( df: pd.DataFrame, actual_col: str, pred_col: str, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, figsize: tuple[float, float] = (7, 7), show_grid: bool = True, grid_props: Optional[dict[str, Any]] = None, savefig: Optional[str] = None, dpi: int = 300, **scatter_kwargs, ): # --- Input Validation and Preparation --- required_cols = [actual_col, pred_col] exist_features(df, features=required_cols) data_to_plot = df[required_cols].dropna() errors = data_to_plot[actual_col] - data_to_plot[pred_col] if len(errors) < 2: warnings.warn( "Not enough data points to generate a Q-Q plot.", stacklevel=2 ) return None # --- Plotting Setup --- fig, ax = plt.subplots(figsize=figsize) # --- Generate Q-Q Plot Data and Plot --- (osm, osr), (slope, intercept, r) = probplot(errors, dist="norm", plot=ax) # --- Formatting --- ax.get_lines()[0].set_markerfacecolor("#E74C3C") # Change marker color ax.get_lines()[0].set_markeredgecolor("#E74C3C") ax.get_lines()[1].set_color("#2980B9") # Change line color ax.set_title(title or "Q-Q Plot of Forecast Errors", fontsize=16) ax.set_xlabel(xlabel or "Theoretical Quantiles (Normal)", fontsize=12) ax.set_ylabel(ylabel or "Ordered Error Values", fontsize=12) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) fig.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_qq.__doc__ = r""" Generates a Quantile-Quantile (Q-Q) plot of forecast errors. This function creates a Q-Q plot, a standard graphical method for comparing a dataset's distribution to a theoretical distribution (in this case, the normal distribution). It is an essential tool for visually checking if the forecast errors are normally distributed, a key assumption for many statistical methods. More details in :ref:`Q-Q Plot User Guide <ug_plot_qq>`. Parameters ---------- df : pd.DataFrame The input DataFrame containing the actual and predicted values. actual_col : str The name of the column containing the true observed values. pred_col : str The name of the column containing the point forecast values. title : str, optional The title for the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. figsize : tuple of (float, float), default=(7, 7) The figure size in inches. show_grid : bool, default=True Toggle the visibility of the plot's grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. **scatter_kwargs Additional keyword arguments passed directly to the underlying scatter plot for the data points. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_error_distribution : A histogram/KDE plot of the same errors. scipy.stats.probplot : The underlying SciPy function used. Notes ----- A Q-Q plot is constructed by plotting the quantiles of two distributions against each other. This function compares the quantiles of the empirical distribution of the forecast errors, :math:`e_i = y_{true,i} - y_{pred,i}`, against the theoretical quantiles of a standard normal distribution, :math:`\mathcal{N}(0, 1)`. If the two distributions are identical, the resulting points will fall perfectly along the identity line :math:`y=x`. Systematic deviations from this line indicate a departure from normality. Examples -------- >>> import pandas as pd >>> import numpy as np >>> from kdiagram.plot.context import plot_qq >>> >>> # Generate synthetic data with normally distributed errors >>> np.random.seed(0) >>> n_samples = 200 >>> y_true = np.linspace(0, 50, n_samples) >>> errors = np.random.normal(0, 5, n_samples) # Normal errors >>> y_pred = y_true + errors >>> >>> df = pd.DataFrame({'actual': y_true, 'predicted': y_pred}) >>> >>> # Generate the Q-Q plot >>> ax = plot_qq( ... df, ... actual_col='actual', ... pred_col='predicted', ... title="Q-Q Plot of Normally Distributed Errors" ... ) """
[docs] @isdf @check_non_emptiness(params=["df"]) def plot_error_distribution( df: pd.DataFrame, actual_col: str, pred_col: str, title: Optional[str] = None, xlabel: Optional[str] = None, **hist_kwargs, ): """ Plots a histogram and KDE of the forecast errors. (Full docstring to be added later) """ from ..utils.hist import plot_hist_kde # --- Input Validation and Preparation --- required_cols = [actual_col, pred_col] exist_features(df, features=required_cols) data_to_plot = df[required_cols].dropna() errors = data_to_plot[actual_col] - data_to_plot[pred_col] errors.name = "Forecast Error" # Give the series a name for the plot if len(errors) < 2: warnings.warn( "Not enough data points to plot a distribution.", stacklevel=2 ) return None # --- Plotting --- # This function acts as a wrapper around the more general plot_hist_kde # We pass through any extra histogram-related keyword arguments. ax = plot_hist_kde( data=errors, title=title or "Distribution of Forecast Errors", x_label=xlabel or "Error (Actual - Predicted)", return_ax=True, **hist_kwargs, ) return ax
plot_error_distribution.__doc__ = r""" Plots a histogram and KDE of the forecast errors. This function creates a distribution plot of the forecast errors (residuals), combining a histogram with a smooth Kernel Density Estimate (KDE) curve. It is a fundamental diagnostic for checking if a model's errors are unbiased (centered at zero) and normally distributed. For more details, refer to :ref:`Error Autocorrelation (ACF) Plot User Guide <ug_plot_error_autocorrelation>` Parameters ---------- df : pd.DataFrame The input DataFrame containing the actual and predicted values. actual_col : str The name of the column containing the true observed values. pred_col : str The name of the column containing the point forecast values. title : str, optional The title for the plot. If ``None``, a default is generated. xlabel : str, optional The label for the x-axis. If ``None``, a default is generated. **hist_kwargs Additional keyword arguments passed directly to the underlying :func:`~kdiagram.utils.hist.plot_hist_kde` function (e.g., `bins`, `kde_color`, `figsize`). Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_qq : A complementary plot for checking error normality. plot_hist_kde : The general-purpose histogram utility this function wraps. :ref:`userguide_context` : The user guide for contextual plots. Notes ----- This function first calculates the forecast errors (or residuals), :math:`e_i = y_{true,i} - y_{pred,i}`. It then visualizes the distribution of these errors using two standard non-parametric methods: 1. **Histogram**: The range of errors is divided into bins, and the height of each bar represents the frequency (or density) of errors in that bin. 2. **Kernel Density Estimate (KDE)**: This provides a smooth, continuous estimate of the error's probability density function, based on foundational work in density estimation :footcite:p:`Silverman1986`. A well-behaved model should ideally produce errors that are normally distributed and centered around zero. Examples -------- >>> import pandas as pd >>> import numpy as np >>> from kdiagram.plot.context import plot_error_distribution >>> >>> # Generate synthetic data with normally distributed errors >>> np.random.seed(0) >>> n_samples = 500 >>> y_true = np.linspace(0, 50, n_samples) >>> errors = np.random.normal(0, 5, n_samples) # Normal errors >>> y_pred = y_true + errors >>> >>> df = pd.DataFrame({'actual': y_true, 'predicted': y_pred}) >>> >>> # Generate the plot >>> ax = plot_error_distribution( ... df, ... actual_col='actual', ... pred_col='predicted', ... title="Distribution of Normally-Distributed Errors", ... bins=40 ... ) References ---------- .. footbibliography:: """
[docs] @isdf @check_non_emptiness(params=["df"]) @ensure_pkg( "statsmodels", extra=( "To use PACF plots, please install" " statsmodels (`pip install statsmodels`)" ), ) def plot_error_pacf( df: pd.DataFrame, actual_col: str, pred_col: str, title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, figsize: tuple[float, float] = (10, 5), show_grid: bool = True, grid_props: Optional[dict[str, Any]] = None, savefig: Optional[str] = None, dpi: int = 300, **pacf_kwargs, ): from statsmodels.graphics.tsaplots import plot_pacf # --- Input Validation and Preparation --- required_cols = [actual_col, pred_col] exist_features(df, features=required_cols) data_to_plot = df[required_cols].dropna() errors = data_to_plot[actual_col] - data_to_plot[pred_col] n = len(errors) if n < 2: warnings.warn( "Not enough data points to plot partial autocorrelation.", stacklevel=2, ) return None # --- Plotting Setup --- fig, ax = plt.subplots(figsize=figsize) # # --- Generate PACF Plot --- # pacf_kwargs = get_valid_kwargs(plot_pacf, pacf_kwargs) # # set a stable default method unless the user provided one # pacf_kwargs.setdefault("method", "ywm") # plot_pacf(errors, ax=ax, **pacf_kwargs) pacf_kwargs = get_valid_kwargs(plot_pacf, pacf_kwargs) pacf_kwargs.setdefault("method", "ywm") # Ensure lags respects statsmodels constraint: lags < n//2 max_lags = max(1, n // 2 - 1) if ( "lags" not in pacf_kwargs or pacf_kwargs["lags"] is None or pacf_kwargs["lags"] >= n // 2 ): pacf_kwargs["lags"] = max_lags try: plot_pacf(errors, ax=ax, **pacf_kwargs) except ValueError: # Fallback once with a safer lags if a user forced something too large pacf_kwargs["lags"] = max_lags plot_pacf(errors, ax=ax, **pacf_kwargs) # --- Formatting --- ax.set_title( title or "Partial Autocorrelation of Forecast Errors", fontsize=16 ) ax.set_xlabel(xlabel or "Lag", fontsize=12) ax.set_ylabel(ylabel or "Partial Autocorrelation", fontsize=12) set_axis_grid(ax, show_grid=show_grid, grid_props=grid_props) fig.tight_layout() if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches="tight") plt.close(fig) else: plt.show() return ax
plot_error_pacf.__doc__ = r""" Plots the partial autocorrelation of forecast errors. This function creates a Partial Autocorrelation Function (PACF) plot of the forecast errors. It is a critical companion to the ACF plot, used to identify the *direct* relationship between an error and its past values, after removing the effects of the intervening lags. This plot requires the ``statsmodels`` package. Additional details can be found in :ref:`Error Partial Autocorrelation (PACF) Plot User Guide <ug_plot_error_pacf>` Parameters ---------- df : pd.DataFrame The input DataFrame containing the actual and predicted values. actual_col : str The name of the column containing the true observed values. pred_col : str The name of the column containing the point forecast values. title : str, optional The title for the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. figsize : tuple of (float, float), default=(10, 5) The figure size in inches. show_grid : bool, default=True Toggle the visibility of the plot's grid lines. grid_props : dict, optional Custom keyword arguments passed to the grid for styling. savefig : str, optional The file path to save the plot. If ``None``, the plot is displayed interactively. dpi : int, default=300 The resolution (dots per inch) for the saved figure. **pacf_kwargs Additional keyword arguments passed directly to the underlying ``statsmodels.graphics.tsaplots.plot_pacf`` function. Returns ------- ax : matplotlib.axes.Axes The Matplotlib Axes object containing the plot. See Also -------- plot_error_autocorrelation : The companion plot for autocorrelation. Notes ----- While the ACF at lag :math:`k` shows the total correlation between :math:`e_t` and :math:`e_{t-k}`, the PACF shows the **partial correlation**. It measures the correlation between :math:`e_t` and :math:`e_{t-k}` after removing the linear dependence on the intermediate observations :math:`e_{t-1}, e_{t-2}, ..., e_{t-k+1}`. This helps to isolate the direct relationship at a specific lag, making it a key tool for identifying the order of autoregressive (AR) processes in the residuals. Examples -------- >>> import pandas as pd >>> import numpy as np >>> from kdiagram.plot.context import plot_error_pacf >>> >>> # Generate synthetic data where errors have an AR(2) structure >>> np.random.seed(0) >>> n_samples = 200 >>> y_true = np.linspace(0, 50, n_samples) >>> errors = np.zeros(n_samples) >>> errors[0] = np.random.normal(0, 1) >>> errors[1] = 0.6 * errors[0] + np.random.normal(0, 1) >>> for t in range(2, n_samples): ... errors[t] = 0.6 * errors[t-1] - 0.3 * errors[t-2] + np.random.normal(0, 1) >>> y_pred = y_true - errors # Subtracting so error = actual - pred >>> >>> df = pd.DataFrame({'actual': y_true, 'predicted': y_pred}) >>> >>> # Generate the PACF plot >>> # The plot should show significant spikes at lags 1 and 2 >>> try: ... ax = plot_error_pacf( ... df, ... actual_col='actual', ... pred_col='predicted', ... title="PACF of AR(2) Errors" ... ) ... except ImportError: ... print("Skipping PACF plot: statsmodels is not installed.") """