Source code for bluemath_tk.core.plotting.scatter

from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from scipy.stats import gaussian_kde, probplot
from sklearn.metrics import mean_squared_error

from .base_plotting import DefaultStaticPlotting
from .colors import default_colors


[docs] def rmse(pred: np.ndarray, tar: np.ndarray) -> float: """ Calculate the Root Mean Square Error between predicted and target values. Parameters ---------- pred : np.ndarray Array of predicted values. tar : np.ndarray Array of target/actual values. Returns ------- float The Root Mean Square Error value. """ if len(pred) != len(tar): raise ValueError("pred and tar must have the same length") return np.sqrt(((pred - tar) ** 2).mean())
[docs] def bias(pred: np.ndarray, tar: np.ndarray) -> float: """ Calculate the bias between predicted and target values. Parameters ---------- pred : np.ndarray Array of predicted values. tar : np.ndarray Array of target/actual values. Returns ------- float The bias value (mean difference between predictions and targets). """ if len(pred) != len(tar): raise ValueError("pred and tar must have the same length") return sum(pred - tar) / len(pred)
[docs] def si(pred: np.ndarray, tar: np.ndarray) -> float: """ Calculate the Scatter Index between predicted and target values. Parameters ---------- pred : np.ndarray Array of predicted values. tar : np.ndarray Array of target/actual values. Returns ------- float The Scatter Index value. """ if len(pred) != len(tar): raise ValueError("pred and tar must have the same length") pred_mean = pred.mean() tar_mean = tar.mean() return np.sqrt(sum(((pred - pred_mean) - (tar - tar_mean)) ** 2) / (sum(tar**2)))
[docs] def density_scatter( x: np.ndarray, y: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Compute a density scatter for two arrays using gaussian KDE. Parameters ---------- x : np.ndarray X values for the scatter plot. y : np.ndarray Y values for the scatter plot. Returns ------- Tuple[np.ndarray, np.ndarray, np.ndarray] A tuple containing: - Sorted x values - Sorted y values - Density values corresponding to each point """ if len(x) != len(y): raise ValueError("x and y must have the same length") xy = np.vstack([x, y]) z = gaussian_kde(xy)(xy) idx = z.argsort() x1, y1, z = x[idx], y[idx], z[idx] return x1, y1, z
[docs] def validation_scatter( axs: Axes, x: np.ndarray, y: np.ndarray, xlabel: str, ylabel: str, title: str, cmap: str = "rainbow", ) -> None: """ Plot a density scatter and Q-Q plot for validation. Parameters ---------- axs : Axes Matplotlib axes to plot on. x : np.ndarray X values for the scatter plot. y : np.ndarray Y values for the scatter plot. xlabel : str Label for the X-axis. ylabel : str Label for the Y-axis. title : str Title for the plot. cmap : str, optional Colormap to use for the scatter plot. Default is "rainbow". """ x2, y2, z = density_scatter(x, y) # plot axs.scatter(x2, y2, c=z, s=5, cmap=cmap) # labels axs.set_xlabel(xlabel) axs.set_ylabel(ylabel) axs.set_title(title) # axis limits maxt = np.ceil(max(max(x) + 0.1, max(y) + 0.1)) axs.set_xlim(0, maxt) axs.set_ylim(0, maxt) axs.plot([0, maxt], [0, maxt], "-r") axs.set_xticks(np.linspace(0, maxt, 5)) axs.set_yticks(np.linspace(0, maxt, 5)) axs.set_aspect("equal") # qq-plot xq = probplot(x, dist="norm") yq = probplot(y, dist="norm") axs.plot(xq[0][1], yq[0][1], "o", markersize=0.5, color="k", label="Q-Q plot") # diagnostic errors props = dict( boxstyle="round", facecolor="w", edgecolor="grey", linewidth=0.8, alpha=0.5 ) mse = mean_squared_error(x2, y2) rmse_e = rmse(x2, y2) BIAS = bias(x2, y2) SI = si(x2, y2) label = "\n".join( ( r"RMSE = %.2f" % (rmse_e,), r"MSE = %.2f" % (mse,), r"BIAS = %.2f" % (BIAS,), R"SI = %.2f" % (SI,), ) ) axs.text( 0.05, 0.95, label, transform=axs.transAxes, fontsize=9, verticalalignment="top", bbox=props, )
[docs] def plot_scatters_in_triangle( dataframes: List[pd.DataFrame], data_colors: Optional[List[str]] = None, **kwargs, ) -> Tuple[Figure, np.ndarray]: """ Plot scatter plots of the dataframes with axes in a triangle arrangement. Parameters ---------- dataframes : List[pd.DataFrame] List of dataframes to plot. Each dataframe should contain the same columns. data_colors : Optional[List[str]], optional List of colors for the dataframes. If None, uses default_colors. **kwargs : dict, optional Additional keyword arguments for the scatter plot. These will be passed to matplotlib.pyplot.scatter. Common parameters include: - s : float, marker size - alpha : float, transparency - marker : str, marker style Returns ------- Tuple[Figure, np.ndarray] A tuple containing: - Figure object - 2D array of Axes objects Raises ------ ValueError If the variables in the first dataframe are not present in all other dataframes. """ if data_colors is None: data_colors = default_colors # Get the number and names of variables from the first dataframe variables_names = list(dataframes[0].columns) num_variables = len(variables_names) # Check variables names are in all dataframes for df in dataframes: if not all(v in df.columns for v in variables_names): raise ValueError( f"Variables {variables_names} are not in dataframe {df.columns}." ) # Create figure and axes default_static_plot = DefaultStaticPlotting() fig, axes = default_static_plot.get_subplots( nrows=num_variables - 1, ncols=num_variables - 1, sharex=False, sharey=False, ) if isinstance(axes, Axes): axes = np.array([[axes]]) for c1, v1 in enumerate(variables_names[1:]): for c2, v2 in enumerate(variables_names[:-1]): for idf, df in enumerate(dataframes): default_static_plot.plot_scatter( ax=axes[c2, c1], x=df[v1], y=df[v2], c=data_colors[idf], alpha=0.6, **kwargs, ) if c1 == c2: axes[c2, c1].set_xlabel(variables_names[c1 + 1]) axes[c2, c1].set_ylabel(variables_names[c2]) elif c1 > c2: axes[c2, c1].xaxis.set_ticklabels([]) axes[c2, c1].yaxis.set_ticklabels([]) else: fig.delaxes(axes[c2, c1]) return fig, axes