kdiagram.plot.context.plot_scatter_correlation

kdiagram.plot.context.plot_scatter_correlation(df, actual_col, pred_cols, names=None, title=None, xlabel=None, ylabel=None, figsize=(8, 8), cmap='viridis', s=50, alpha=0.7, show_identity_line=True, show_grid=True, grid_props=None, savefig=None, dpi=300)[source]

Plots a scatter plot of true vs predicted values.

This function creates a classic Cartesian scatter plot to visualize the relationship between true observed values and model predictions. It is an essential tool for assessing linear correlation, identifying systemic bias, and spotting outliers.

For more details, refer to Scatter Correlation Plot User Guide

Parameters:
dfpd.DataFrame

The input DataFrame containing the actual and predicted values.

actual_colstr

The name of the column containing the true observed values, which will be plotted on the x-axis.

pred_colslist of str

A list of one or more column names containing the point forecasts from different models.

nameslist of str, optional

Display names for each of the prediction series, to be used in the legend.

titlestr, optional

The title for the plot.

xlabelstr, optional

The label for the x-axis.

ylabelstr, optional

The label for the y-axis.

figsizetuple of (float, float), default=(8, 8)

The figure size in inches.

cmapstr, default=’viridis’

The colormap used to assign unique colors to the different prediction series.

sint, default=50

The size of the scatter plot markers.

alphafloat, default=0.7

The transparency of the markers.

show_identity_linebool, default=True

If True, draws a dashed y=x line, which represents a perfect forecast.

show_gridbool, default=True

Toggle the visibility of the plot’s grid lines.

grid_propsdict, optional

Custom keyword arguments passed to the grid for styling.

savefigstr, optional

The file path to save the plot. If None, the plot is displayed interactively.

dpiint, default=300

The resolution (dots per inch) for the saved figure.

Returns:
axmatplotlib.axes.Axes

The Matplotlib Axes object containing the plot.

Parameters:

See also

plot_relationship

A polar version of this plot.

plot_error_relationship

A plot to diagnose error patterns.

Notes

This plot directly visualizes the relationship between two variables by plotting each observation \(i\) as a point \((y_{true,i}, y_{pred,i})\).

The primary reference is the identity line, defined by the equation:

(1)\[y = x\]

For a perfect forecast, every predicted value would equal its corresponding true value, and all points would fall exactly on this line. Deviations from this line represent prediction errors.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from kdiagram.plot.context import plot_scatter_correlation
>>>
>>> # Generate synthetic data
>>> np.random.seed(0)
>>> n_samples = 100
>>> y_true = np.linspace(0, 50, n_samples)
>>> y_pred_good = y_true + np.random.normal(0, 3, n_samples)
>>> y_pred_biased = y_true * 0.8 + 5
>>>
>>> df = pd.DataFrame({
...     'actual': y_true,
...     'good_model': y_pred_good,
...     'biased_model': y_pred_biased,
... })
>>>
>>> # Generate the plot
>>> ax = plot_scatter_correlation(
...     df,
...     actual_col='actual',
...     pred_cols=['good_model', 'biased_model'],
...     names=['Good Model', 'Biased Model'],
...     title="Actual vs. Predicted Correlation"
... )