Source code for bluemath_tk.waves.calibration

from typing import Tuple, Union

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib as mpl
import numpy as np
import pandas as pd
import statsmodels.api as sm
import xarray as xr
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from ..core.decorators import validate_data_calval
from ..core.models import BlueMathModel
from ..core.plotting.scatter import density_scatter, validation_scatter


[docs] def get_matching_times_between_arrays( times1: np.ndarray, times2: np.ndarray, max_time_diff: int, ) -> Tuple[np.ndarray, np.ndarray]: """ Finds matching time indices between two arrays of timestamps. For each time in `times1`, finds the closest time in `times2` that is within `max_time_diff` hours. Returns the indices of matching times in both arrays. Parameters ---------- times1 : np.ndarray First array of timestamps (reference times, e.g., from model data). times2 : np.ndarray Second array of timestamps (e.g., from satellite or validation data). max_time_diff : int Maximum time difference in hours for considering times as matching. Returns ------- Tuple[np.ndarray, np.ndarray] Two arrays containing the indices of matching times: - First array: indices in times1 that have matches - Second array: corresponding indices in times2 that match Example ------- >>> idx1, idx2 = get_matching_times_between_arrays( ... model_df.index.values, ... sat_df.index.values, ... max_time_diff=2, ... ) """ indices1 = np.array([], dtype=int) indices2 = np.array([], dtype=int) for i in range(len(times1)): # Find minimum time difference for current time1 time_diffs = np.abs(times2 - times1[i]) min_diff = np.min(time_diffs) # If minimum difference is within threshold, record the indices if min_diff < np.timedelta64(max_time_diff, "h"): min_index = np.argmin(time_diffs) indices1 = np.append(indices1, i) indices2 = np.append(indices2, min_index) return indices1, indices2
[docs] def process_imos_satellite_data( satellite_df: pd.DataFrame, ini_lat: float, end_lat: float, ini_lon: float, end_lon: float, depth_threshold: float = -200, ) -> pd.DataFrame: """ Processes IMOS satellite data for calibration. This function filters and processes IMOS satellite altimeter data to be used as reference data for calibration (e.g., as `data_to_calibrate` in CalVal.fit). Parameters ---------- satellite_df : pd.DataFrame IMOS satellite data. Must contain columns: - 'LATITUDE' (float): Latitude in decimal degrees - 'LONGITUDE' (float): Longitude in decimal degrees - 'SWH_KU_quality_control' (float): Quality control flag for Ku-band - 'SWH_KA_quality_control' (float): Quality control flag for Ka-band - 'SWH_KU_CAL' (float): Calibrated significant wave height (Ku-band) - 'SWH_KA_CAL' (float): Calibrated significant wave height (Ka-band) - 'BOT_DEPTH' (float): Bathymetry (negative values for ocean) ini_lat : float Minimum latitude (southern boundary) for filtering. end_lat : float Maximum latitude (northern boundary) for filtering. ini_lon : float Minimum longitude (western boundary) for filtering. end_lon : float Maximum longitude (eastern boundary) for filtering. depth_threshold : float, optional Only include points with BOT_DEPTH < depth_threshold. Default is -200. Returns ------- pd.DataFrame Filtered and processed satellite data, suitable for use as `data_to_calibrate` in CalVal.fit. Includes a new column 'Hs_CAL' (combination of Ku-band and Ka-band calibrated significant wave heights). Notes ----- - The returned DataFrame can be used directly as the `data_to_calibrate` argument in CalVal.fit. """ # Filter satellite data by coordinates satellite_df = satellite_df[ (satellite_df.LATITUDE > ini_lat) & (satellite_df.LATITUDE < end_lat) & (satellite_df.LONGITUDE > ini_lon) & (satellite_df.LONGITUDE < end_lon) & (satellite_df.BOT_DEPTH < depth_threshold) ] # Process quality control wave_height_qlt = np.nansum( np.concatenate( ( satellite_df["SWH_KU_quality_control"].values[:, np.newaxis], satellite_df["SWH_KA_quality_control"].values[:, np.newaxis], ), axis=1, ), axis=1, ) good_qlt = np.where(wave_height_qlt < 1.5) # Process wave heights satellite_df["Hs_CAL"] = np.nansum( np.concatenate( ( satellite_df["SWH_KU_CAL"].values[:, np.newaxis], satellite_df["SWH_KA_CAL"].values[:, np.newaxis], ), axis=1, ), axis=1, ) return satellite_df.iloc[good_qlt]
[docs] class CalVal(BlueMathModel): """ Calibrates wave data using reference data. This class provides a framework for calibrating wave model outputs (e.g., hindcast or reanalysis) using reference data (e.g., satellite or buoy observations). It supports directionally-dependent calibration for both sea and swell components. Attributes ---------- direction_bin_size : int Size of directional bins in degrees. direction_bins : np.ndarray Array of bin edges for directions. calibration_model : sm.OLS The calibration model, more details in `statsmodels.api.OLS`. calibrated_data : pd.DataFrame DataFrame with columns ['Hs', 'Hs_CORR', 'Hs_CAL'] after calibration. The time domain is the same as the model data. calibration_params : dict Dictionary with 'sea_correction' and 'swell_correction' correction coefficients. """ direction_bin_size: int = 22.5 direction_bins: np.ndarray = np.arange( direction_bin_size, 360.5, direction_bin_size ) def __init__(self) -> None: """ Initialize the CalVal class. """ super().__init__() self.set_logger_name(name="CalVal", level="INFO", console=True) # Save input data self._data: pd.DataFrame = None self._data_longitude: float = None self._data_latitude: float = None self._data_to_calibrate: pd.DataFrame = None self._max_time_diff: int = None # Initialize calibration results self._data_to_fit: Tuple[pd.DataFrame, pd.DataFrame] = (None, None) self._calibration_model: sm.OLS = None self._calibrated_data: pd.DataFrame = None self._calibration_params: pd.Series = None # Exclude large attributes from model saving self._exclude_attributes += [ "_data", "_data_to_calibrate", "_data_to_fit", ] @property def calibration_model(self) -> sm.OLS: """Returns the calibration model.""" if self._calibration_model is None: raise ValueError( "Calibration model is not available. Please run the fit method first." ) return self._calibration_model @property def calibrated_data(self) -> pd.DataFrame: """Returns the calibrated data.""" if self._calibrated_data is None: raise ValueError( "Calibrated data is not available. Please run the fit method first." ) return self._calibrated_data @property def calibration_params(self) -> pd.Series: """Returns the calibration parameters.""" if self._calibration_params is None: raise ValueError( "Calibration parameters are not available. Please run the fit method first." ) return self._calibration_params def _plot_data_domains(self) -> Tuple[Figure, Axes]: """ Plots the domains of the data points. Returns ------- Tuple[Figure, Axes] A tuple containing the figure and axes objects. """ fig, ax = plt.subplots( figsize=(10, 10), subplot_kw={ "projection": ccrs.PlateCarree(central_longitude=self._data_longitude) }, ) land_10m = cfeature.NaturalEarthFeature( "physical", "land", "10m", edgecolor="face", facecolor=cfeature.COLORS["land"], ) # Plot calibration data ax.scatter( self._data_to_calibrate.LONGITUDE, self._data_to_calibrate.LATITUDE, s=0.01, c="k", transform=ccrs.PlateCarree(), ) # Plot main data point ax.scatter( self._data_longitude, self._data_latitude, s=50, c="red", zorder=10, transform=ccrs.PlateCarree(), ) # Set plot extent ax.set_extent( [ self._data_longitude - 2, self._data_longitude + 2, self._data_latitude - 2, self._data_latitude + 2, ] ) ax.set_facecolor("lightblue") ax.add_feature(land_10m) return fig, ax def _create_vec_direc(self, waves: np.ndarray, direcs: np.ndarray) -> np.ndarray: """ Creates a vector of wave heights for each directional bin. Parameters ---------- waves : np.ndarray Wave heights. direcs : np.ndarray Wave directions in degrees. Returns ------- np.ndarray Matrix of wave heights for each directional bin. """ data = np.zeros((len(waves), len(self.direction_bins))) for i in range(len(waves)): if direcs[i] < 0: direcs[i] = direcs[i] + 360 if direcs[i] > 0 and waves[i] > 0: # Handle direction = 360° case by mapping to the first bin (0-22.5°) if direcs[i] >= 360: bin_idx = 0 else: bin_idx = int(direcs[i] / self.direction_bin_size) data[i, bin_idx] = waves[i] return data @staticmethod def _get_nparts(data: pd.DataFrame) -> int: """ Gets the number of parts in the wave data. Parameters ---------- data : pd.DataFrame Wave data. Returns ------- int The number of parts in the wave data. """ return len([col for col in data.columns if col.startswith("Hswell")]) def _get_joined_sea_swell_data(self, data: pd.DataFrame) -> np.ndarray: """ Joins the sea and swell data. Parameters ---------- data : pd.DataFrame Wave data. Returns ------- np.ndarray The joined sea and swell matrix. """ # Process sea waves Hsea = self._create_vec_direc(data["Hsea"], data["Dirsea"]) ** 2 # Process swells Hs_swells = np.zeros(Hsea.shape) for part in range(1, self._get_nparts(data) + 1): Hs_swells += ( self._create_vec_direc(data[f"Hswell{part}"], data[f"Dirswell{part}"]) ) ** 2 # Combine sea and swell matrices sea_swell_matrix = np.concatenate([Hsea, Hs_swells], axis=1) return sea_swell_matrix
[docs] @validate_data_calval def fit( self, data: pd.DataFrame, data_longitude: float, data_latitude: float, data_to_calibrate: pd.DataFrame, max_time_diff: int = 2, ) -> None: """ Calibrate the model data using reference (calibration) data. This method matches the model data and calibration data in time, constructs directionally-binned sea and swell matrices, and fits a linear regression to obtain correction coefficients for each direction bin. Parameters ---------- data : pd.DataFrame Model data to calibrate. Must contain columns: - 'Hs' (float): Significant wave height - 'Hsea' (float): Sea component significant wave height - 'Dirsea' (float): Sea component mean direction (degrees) - 'Hswell1', 'Dirswell1', ... (float): Swell components (at least one required) The index must be datetime-like. data_longitude : float Longitude of the model location (used for plotting and filtering). data_latitude : float Latitude of the model location (used for plotting and filtering). data_to_calibrate : pd.DataFrame Reference data for calibration. Must contain column: - 'Hs_CAL' (float): Calibrated significant wave height (e.g., from satellite) The index must be datetime-like. max_time_diff : int, optional Maximum time difference (in hours) allowed when matching model and calibration data. Default is 2. Notes ----- - After calling this method, the calibration parameters are stored in `self.calibration_params` and the calibrated data is available in `self.calibrated_data`. - The calibration is directionally dependent, meaning it uses different correction coefficients for different wave directions. - The coefficients with p-values greater than 0.05 or negative values are set to 1.0, indicating no correction is applied for those directions. """ self.logger.info("Starting calibration fit procedure.") # Save input data self._data = data.copy() self._data_longitude = data_longitude self._data_latitude = data_latitude self._data_to_calibrate = data_to_calibrate.copy() self._max_time_diff = max_time_diff # Plot data domains self.logger.info("Plotting data domains.") self._plot_data_domains() # Construct matrices for calibration self.logger.info("Matching times and constructing matrices for calibration.") # Get matching times times_data_to_fit, times_data_to_calibrate = get_matching_times_between_arrays( self._data.index.values, self._data_to_calibrate.index.values, max_time_diff=self._max_time_diff, ) self._data_to_fit = ( self._data.iloc[times_data_to_fit], self._data_to_calibrate.iloc[times_data_to_calibrate], ) # Get joined sea and swell data sea_swell_matrix = self._get_joined_sea_swell_data(self._data_to_fit[0]) # Perform calibration self.logger.info("Fitting OLS regression for calibration.") X = sm.add_constant(sea_swell_matrix) self._calibration_model = sm.OLS(self._data_to_fit[1]["Hs_CAL"] ** 2, X) calibrated_model_results = self._calibration_model.fit() # Get significant correction coefficients significant_model_params = [ model_param if calibrated_model_results.pvalues[imp] < 0.05 and model_param > 0 else 1.0 for imp, model_param in enumerate(calibrated_model_results.params) ] # Save sea and swell correction coefficients self._calibration_params = { "sea_correction": { ip: param for ip, param in enumerate( np.sqrt(significant_model_params[: len(self.direction_bins)]) ) }, "swell_correction": { ip: param for ip, param in enumerate( np.sqrt(significant_model_params[len(self.direction_bins) :]) ) }, } # Save calibrated data to be used in plot_calibration_results() self._calibrated_data = self.correct(self._data_to_fit[0]) self._calibrated_data["Hs_CAL"] = self._data_to_fit[1]["Hs_CAL"].values self.logger.info("Calibration fit procedure completed.")
[docs] def correct( self, data: Union[pd.DataFrame, xr.Dataset] ) -> Union[pd.DataFrame, xr.Dataset]: """ Apply the calibration correction to new data. Parameters ---------- data : pd.DataFrame or xr.Dataset Data to correct. If DataFrame, must contain columns: - 'Hs', 'Hsea', 'Dirsea', 'Hswell1', 'Dirswell1', ... If xarray.Dataset, must have variable 'efth' and dimension 'part'. Returns ------- pd.DataFrame or xr.Dataset Corrected data. For DataFrame, returns columns ['Hs', 'Hs_CORR'] (original and corrected SWH). For Dataset, adds variables 'corr_coeffs' and 'corr_efth'. Notes ----- - The correction is directionally dependent and uses the coefficients obtained from `fit`. """ if self._calibration_params is None: raise ValueError( "Calibration parameters are not available. Run fit() first." ) if isinstance(data, xr.Dataset): self.logger.info( "Input is xarray.Dataset. Applying correction to spectra data." ) corrected_data = data.copy() # Copy data to avoid modifying original data peak_directions = corrected_data.spec.stats(["dp"]).load() correction_coeffs = np.ones(peak_directions.dp.shape) for n_part in peak_directions.part: if n_part == 0: correction_coeffs[n_part, :] = np.array( [ self.calibration_params["sea_correction"][ int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0 # TODO: Check if this with Javi ] for peak_direction in peak_directions.isel( part=n_part ).dp.values ] ) else: correction_coeffs[n_part, :] = np.array( [ self.calibration_params["swell_correction"][ int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0 # TODO: Check if this with Javi ] for peak_direction in peak_directions.isel( part=n_part ).dp.values ] ) corrected_data["corr_coeffs"] = (("part", "time"), correction_coeffs) corrected_data["corr_efth"] = ( corrected_data.efth * corrected_data.corr_coeffs ) self.logger.info("Spectra correction complete.") return corrected_data elif isinstance(data, pd.DataFrame): self.logger.info( "Input is pandas.DataFrame. Applying correction to wave data." ) corrected_data = data.copy() corrected_data["Hsea"] = ( corrected_data["Hsea"] ** 2 * np.array( [ self.calibration_params["sea_correction"][ int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0 ] for peak_direction in corrected_data["Dirsea"] ] ) ** 2 ) corrected_data["Hs_CORR"] = corrected_data["Hsea"] for n_part in range(1, self._get_nparts(corrected_data) + 1): corrected_data[f"Hswell{n_part}"] = ( corrected_data[f"Hswell{n_part}"] ** 2 * np.array( [ self.calibration_params["swell_correction"][ int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0 ] for peak_direction in corrected_data[f"Dirswell{n_part}"] ] ) ** 2 ) corrected_data["Hs_CORR"] += corrected_data[f"Hswell{n_part}"] corrected_data["Hs_CORR"] = np.sqrt(corrected_data["Hs_CORR"]) self.logger.info("Wave data correction complete.") return corrected_data[["Hs", "Hs_CORR"]]
[docs] def plot_calibration_results(self) -> Tuple[Figure, list]: """ Plot the calibration results, including: - Pie charts of correction coefficients for sea and swell - Scatter plots of model vs. reference (before and after correction) - Polar density plots of sea and swell wave climate Returns ------- Tuple[Figure, list] The matplotlib Figure and a list of Axes objects for all subplots. """ self.logger.info("Plotting calibration results.") fig = plt.figure(figsize=(10, 15)) gs = fig.add_gridspec(8, 2, wspace=0.4, hspace=0.7) # Create subplots with proper projections ax1 = fig.add_subplot(gs[:2, 0]) # Sea correction pie ax2 = fig.add_subplot(gs[:2, 1]) # Swell correction pie ax1_cbar = fig.add_subplot(gs[2, 0]) # Sea correction colorbar ax2_cbar = fig.add_subplot(gs[2, 1]) # Swell correction colorbar ax3 = fig.add_subplot(gs[3:5, 0]) # No correction scatter ax4 = fig.add_subplot(gs[3:5, 1]) # With correction scatter ax5 = fig.add_subplot(gs[6:, 0], projection="polar") # Sea climate ax6 = fig.add_subplot(gs[6:, 1], projection="polar") # Swell climate # Plot sea correction pie chart sea_norm = 0.35 # Normalization factor for sea correction sea_fracs = np.repeat(10, len(self.calibration_params["sea_correction"])) sea_norm = mpl.colors.Normalize(1 - sea_norm, 1 + sea_norm) sea_cmap = mpl.cm.get_cmap( "bwr", len(self.calibration_params["sea_correction"]) ) sea_colors = sea_cmap( sea_norm(list(self.calibration_params["sea_correction"].values())) ) ax1.pie( sea_fracs, labels=None, colors=sea_colors, startangle=90, counterclock=False, radius=1.2, ) ax1.set_title("SEA $Correction$", fontweight="bold") # Add colorbar for sea correction below the pie chart, shrink it _sea_cbar = mpl.colorbar.ColorbarBase( ax1_cbar, cmap=sea_cmap, norm=sea_norm, orientation="horizontal", label="Correction Factor", ) box = ax1_cbar.get_position() ax1_cbar.set_position( [ box.x0 + 0.15 * box.width, box.y0 + 0.3 * box.height, 0.7 * box.width, 0.4 * box.height, ] ) ax1_cbar.set_frame_on(False) ax1_cbar.tick_params( left=False, right=False, labelleft=False, labelbottom=True, bottom=True ) # Plot swell correction pie chart swell_norm = 0.35 # Normalization factor for swell correction swell_fracs = np.repeat(10, len(self.calibration_params["swell_correction"])) swell_norm = mpl.colors.Normalize(1 - swell_norm, 1 + swell_norm) swell_cmap = mpl.cm.get_cmap( "bwr", len(self.calibration_params["swell_correction"]) ) swell_colors = swell_cmap( swell_norm(list(self.calibration_params["swell_correction"].values())) ) ax2.pie( swell_fracs, labels=None, colors=swell_colors, startangle=90, counterclock=False, radius=1.2, ) ax2.set_title("SWELL $Correction$", fontweight="bold") # Add colorbar for swell correction below the pie chart, shrink it _swell_cbar = mpl.colorbar.ColorbarBase( ax2_cbar, cmap=swell_cmap, norm=swell_norm, orientation="horizontal", label="Correction Factor", ) box = ax2_cbar.get_position() ax2_cbar.set_position( [ box.x0 + 0.15 * box.width, box.y0 + 0.3 * box.height, 0.7 * box.width, 0.4 * box.height, ] ) ax2_cbar.set_frame_on(False) ax2_cbar.tick_params( left=False, right=False, labelleft=False, labelbottom=True, bottom=True ) # Plot no correction scatter validation_scatter( axs=ax3, x=self.calibrated_data["Hs"].values, y=self.calibrated_data["Hs_CAL"].values, xlabel="Hindcast", ylabel="Satellite", title="No Correction", ) # Plot with correction scatter validation_scatter( axs=ax4, x=self.calibrated_data["Hs_CORR"].values, y=self.calibrated_data["Hs_CAL"].values, xlabel="Hindcast", ylabel="Satellite", title="With Correction", ) # Plot sea wave climate sea_dirs = self._data["Dirsea"].iloc[::10] * np.pi / 180 sea_heights = self._data["Hsea"].iloc[::10] # Filter out NaN and infinite values valid_mask = np.isfinite(sea_dirs) & np.isfinite(sea_heights) sea_dirs_valid = sea_dirs[valid_mask] sea_heights_valid = sea_heights[valid_mask] if len(sea_dirs_valid) > 0: x, y, z = density_scatter(sea_dirs_valid, sea_heights_valid) ax5.scatter(x, y, c=z, s=3, cmap="jet") ax5.set_theta_zero_location("N", offset=0) ax5.set_xticklabels(["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) ax5.xaxis.grid(True, color="lavender", linestyle="-") ax5.yaxis.grid(True, color="lavender", linestyle="-") ax5.set_theta_direction(-1) ax5.set_xlabel("$\u03b8_{m}$ ($\degree$)") ax5.set_ylabel("$H_{s}$ (m)", labelpad=20) ax5.set_title("SEA $Wave$ $Climate$", pad=35, fontweight="bold") # Plot swell wave climate swell_dirs = self._data["Dirswell1"].iloc[::10] * np.pi / 180 swell_heights = self._data["Hswell1"].iloc[::10] # Filter out NaN and infinite values valid_mask = np.isfinite(swell_dirs) & np.isfinite(swell_heights) swell_dirs_valid = swell_dirs[valid_mask] swell_heights_valid = swell_heights[valid_mask] if len(swell_dirs_valid) > 0: x, y, z = density_scatter(swell_dirs_valid, swell_heights_valid) ax6.scatter(x, y, c=z, s=3, cmap="jet") ax6.set_theta_zero_location("N", offset=0) ax6.set_xticklabels(["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) ax6.xaxis.grid(True, color="lavender", linestyle="-") ax6.yaxis.grid(True, color="lavender", linestyle="-") ax6.set_theta_direction(-1) ax6.set_xlabel("$\u03b8_{m}$ ($\degree$)") ax6.set_ylabel("$H_{s}$ (m)", labelpad=20) ax6.set_title("SWELL 1 $Wave$ $Climate$", pad=35, fontweight="bold") return fig, [ax1, ax2, ax1_cbar, ax2_cbar, ax3, ax4, ax5, ax6]
[docs] def validate_calibration( self, data_to_validate: pd.DataFrame ) -> Tuple[Figure, list]: """ Validate the calibration using independent validation data. This method compares the original and corrected model data to the validation data, both as time series and with scatter plots. Parameters ---------- data_to_validate : pd.DataFrame Validation data. Must contain column: - 'Hs_VAL' (float): Validation significant wave height (e.g., from buoy) The index must be datetime-like. Returns ------- Tuple[Figure, list] The matplotlib Figure and a list of Axes objects: [time series axis, scatter (no correction), scatter (corrected)]. """ if "Hs_VAL" not in data_to_validate.columns: raise ValueError("Validation data is missing required column: 'Hs_VAL'") data_corr = self.correct(data=self._data) data_times, data_to_validate_times = get_matching_times_between_arrays( times1=data_corr.index, times2=data_to_validate.index, max_time_diff=1, ) # Create figure with a 2-row, 2-column grid, top row spans both columns fig = plt.figure(figsize=(12, 8)) gs = fig.add_gridspec(2, 2, height_ratios=[2, 3], hspace=0.4, wspace=0.3) # Top row: time series plot (spans both columns) ax_ts = fig.add_subplot(gs[0, :]) t = data_corr.index[data_times] ax_ts.plot( t, data_to_validate["Hs_VAL"].iloc[data_to_validate_times], label="Validation", color="k", lw=1.5, ) ax_ts.plot( t, data_corr["Hs"].iloc[data_times], label="Model (No Correction)", color="tab:blue", alpha=0.7, ) ax_ts.plot( t, data_corr["Hs_CORR"].iloc[data_times], label="Model (Corrected)", color="tab:orange", alpha=0.7, ) ax_ts.set_ylabel("$H_s$ (m)") ax_ts.set_xlabel("Time") ax_ts.set_title("Time Series Comparison") ax_ts.legend(loc="upper right") ax_ts.grid(True, linestyle=":", alpha=0.5) # Bottom row: scatter plots ax_sc1 = fig.add_subplot(gs[1, 0]) ax_sc2 = fig.add_subplot(gs[1, 1]) validation_scatter( axs=ax_sc1, x=data_corr["Hs"].iloc[data_times].values, y=data_to_validate["Hs_VAL"].iloc[data_to_validate_times].values, xlabel="Model (No Correction)", ylabel="Validation", title="No Correction", ) validation_scatter( axs=ax_sc2, x=data_corr["Hs_CORR"].iloc[data_times].values, y=data_to_validate["Hs_VAL"].iloc[data_to_validate_times].values, xlabel="Model (Corrected)", ylabel="Validation", title="With Correction", ) return fig, [ax_ts, ax_sc1, ax_sc2]