Source code for bluemath_tk.predictor.xwt

import logging
import warnings
from datetime import datetime, timedelta
from typing import Any, Dict

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.collections import Collection

from ..core.dask import setup_dask_client
from ..core.decorators import validate_data_xwt
from ..core.models import BlueMathModel
from ..core.pipeline import BlueMathPipeline
from ..core.plotting.colors import get_cluster_colors, get_config_variables
from ..datamining.kma import KMA
from ..datamining.pca import PCA

warnings.filterwarnings("ignore")
config_variables = get_config_variables()


def get_dynamic_estela_predictor(
    data: xr.Dataset,
    estela: xr.Dataset,
    check_interpolation: bool = True,
    verbose: bool = False,
) -> xr.Dataset:
    """
    Transform an xarray dataset of longitude, latitude, and time into one where
    each longitude, latitude value at each time is replaced by the corresponding
    time - t, where t is specified in the estela dataset.

    Parameters
    ----------ltimes = estela.where(estela.F >= 0, np.nan).traveltime.astype(int)
    estela_max_traveltime = estela_traveltimes.max().values
    for traveltime in range(estela_max_traveltime):
        data = data.w
    data : xr.Dataset
        The input dataset with dimensions longitude, latitude, and time.
    estela : xr.Dataset
        The dataset containing the F values with dimensions longitude and latitude.
    check_interpolation : bool, optional
        Whether to check if the data is interpolated. Default is True.
    verbose : bool, optional
        Whether to print verbose output. Default is False.
        If False, Dask logs are suppressed.
        If True, Dask logs are shown.

    Returns
    -------
    xr.Dataset
        The transformed dataset.
    """

    if not verbose:
        # Suppress Dask logs
        logging.getLogger("distributed").setLevel(logging.ERROR)
        logging.getLogger("distributed.client").setLevel(logging.ERROR)
        logging.getLogger("distributed.scheduler").setLevel(logging.ERROR)
        logging.getLogger("distributed.worker").setLevel(logging.ERROR)
        logging.getLogger("distributed.nanny").setLevel(logging.ERROR)
        # Also suppress bokeh and tornado logs that Dask uses
        logging.getLogger("bokeh").setLevel(logging.ERROR)
        logging.getLogger("tornado").setLevel(logging.ERROR)

    # TODO: Add customization for dask client
    _dask_client = setup_dask_client(n_workers=4, memory_limit=0.25)

    if check_interpolation:
        if (
            "longitude" not in data.dims
            or "latitude" not in data.dims
            or "time" not in data.dims
        ):
            raise ValueError("Data must have longitude, latitude, and time dimensions.")
        if "longitude" not in estela.dims or "latitude" not in estela.dims:
            raise ValueError("Estela must have longitude and latitude dimensions.")
        estela = estela.interp_like(data)  # TODO: Check NaNs interpolation
    data = data.chunk({"time": 365}).where(estela.F >= 0.0, np.nan)
    estela_traveltimes = estela.where(estela.F >= 0, np.nan).traveltime.astype(int)
    estela_max_traveltime = estela_traveltimes.max().values
    for traveltime in range(estela_max_traveltime):
        data = data.where(estela_traveltimes != traveltime, data.shift(time=traveltime))

    return data.compute()


def check_model_is_fitted(func):
    """
    Decorator to check if the model is fitted.
    """

    def wrapper(self, *args, **kwargs):
        if self.kma_bmus is None:
            raise XWTError("Fit the model before calling this property.")
        return func(self, *args, **kwargs)

    return wrapper


class XWTError(Exception):
    """Custom exception for XWT class."""

    def __init__(self, message="XWT error occurred."):
        self.message = message
        super().__init__(self.message)


[docs] class XWT(BlueMathModel, BlueMathPipeline): """ Xly Weather Types (XWT) class. This class implements the XWT method to identify and classify weather patterns in a dataset. The XWT method is a combination of Principal Component Analysis (PCA) and K-means clustering (KMA). Attributes ---------- steps : Dict[str, BlueMathModel] The steps of the XWT method. num_clusters : int The number of clusters. kma_bmus : pd.DataFrame The KMA best matching units (BMUs). """ def __init__(self, steps: Dict[str, BlueMathModel]) -> None: """ Initialize the XWT. Parameters ---------- steps : Dict[str, BlueMathModel] The steps of the XWT method. The steps must include a PCA and a KMA model. """ super().__init__() self.set_logger_name(name=self.__class__.__name__, level="INFO") # Save XWT attributes if steps: if ( not all(isinstance(step, BlueMathModel) for step in steps.values()) or "pca" not in steps or "kma" not in steps ): raise XWTError("The steps must include a PCA and a KMA model.") self.steps = steps self._data: xr.Dataset = None self.num_clusters: int = None self.kma_bmus: pd.DataFrame = None # Exclude attributes from being saved self._exclude_attributes = ["_data"] @property def data(self) -> xr.Dataset: return self._data @property @check_model_is_fitted def clusters_probs_df(self) -> pd.DataFrame: """ Calculate the probabilities for each XWT. """ # Calculate probabilities for each cluster clusters_probs = ( self.kma_bmus["kma_bmus"].value_counts(normalize=True).sort_index() ) return clusters_probs @property @check_model_is_fitted def clusters_monthly_probs_df(self) -> pd.DataFrame: """ Calculate the monthly probabilities for each XWT. """ # Calculate probabilities for each month monthly_probs = ( self.kma_bmus.groupby(self.kma_bmus.index.month)["kma_bmus"] .value_counts(normalize=True) .unstack() .fillna(0) ) return monthly_probs @property @check_model_is_fitted def clusters_seasonal_probs_df(self) -> pd.DataFrame: """ Calculate the seasonal probabilities for each XWT. """ # Calculate probabilities for each season # Define seasons: DJF (Dec, Jan, Feb), MAM (Mar, Apr, May), # JJA (Jun, Jul, Aug), SON (Sep, Oct, Nov) seasons = { "DJF": [12, 1, 2], "MAM": [3, 4, 5], "JJA": [6, 7, 8], "SON": [9, 10, 11], } # Add a 'season' column to the DataFrame kma_bmus_season = self.kma_bmus.copy() kma_bmus_season["season"] = kma_bmus_season.index.month.map( lambda x: next(season for season, months in seasons.items() if x in months) ) # Calculate probabilities for each season seasonal_probs = ( kma_bmus_season.groupby("season")["kma_bmus"] .value_counts(normalize=True) .unstack() .fillna(0) ) return seasonal_probs @property @check_model_is_fitted def clusters_annual_probs_df(self) -> pd.DataFrame: """ Calculate the annual probabilities for each XWT. """ # Calculate probabilities for each year annual_probs = ( self.kma_bmus.groupby(self.kma_bmus.index.year)["kma_bmus"] .value_counts(normalize=True) .unstack() .fillna(0) ) return annual_probs @property @check_model_is_fitted def clusters_perpetual_year_probs_df(self) -> pd.DataFrame: """ Calculate the perpetual year probabilities for each XWT. """ # Calculate probabilities for each natural day in the year natural_day_probs = ( self.kma_bmus.groupby(self.kma_bmus.index.dayofyear)["kma_bmus"] .value_counts(normalize=True) .unstack() .fillna(0) ) # Set index to be the datetime first day of month natural_day_probs.index = [ datetime(2000, 1, 1) + timedelta(days=i - 1) for i in natural_day_probs.index ] return natural_day_probs @property @check_model_is_fitted def get_conditioned_probabilities(self) -> pd.DataFrame: """ Calculate conditional probabilities P(X_t = j | X_{t-lag} = i) """ # Convert to numpy array if not already data = self.kma_bmus.values.flatten() # Find unique values in the data unique_values = np.unique(data) # Create empty matrix for conditional probabilities cond_probs = np.zeros((self.num_clusters, self.num_clusters)) # Count transitions for i in range(len(data) - 1): prev_idx = np.where(unique_values == data[i])[0][0] next_idx = np.where(unique_values == data[i + 1])[0][0] cond_probs[prev_idx, next_idx] += 1 # Normalize to get probabilities row_sums = cond_probs.sum(axis=1, keepdims=True) # Avoid division by zero row_sums[row_sums == 0] = 1 cond_probs = cond_probs / row_sums # Create DataFrame with labels df_cond_probs = pd.DataFrame( cond_probs, index=[f"Cluster {v}" for v in unique_values], columns=[f"Cluster {v}" for v in unique_values], ) return df_cond_probs
[docs] @validate_data_xwt def fit( self, data: xr.Dataset, fit_params: Dict[str, Dict[str, Any]] = {}, variable_to_sort_bmus: str = None, ) -> None: """ Fit the XWT model to the data. Parameters ---------- data : xr.Dataset The data to fit the model to. Must be PCA formatted. fit_params : Dict[str, Dict[str, Any]], optional The fitting parameters for the PCA and KMA models. Default is {}. variable_to_sort_bmus : str, optional The variable to sort the BMUs. Default is None. Raises ------ XWTError If the data is not PCA formatted. """ # Make a copy of the data to avoid modifying the original dataset self._data = data.copy() pca: PCA = self.steps.get("pca") if pca.pcs is None: try: _pcs_ds = pca.fit_transform( data=data, **fit_params.get("pca", {}), ) except Exception as e: raise XWTError(f"Error during PCA fitting: {e}") else: self.logger.info("PCA already fitted, skipping PCA fitting.") kma: KMA = self.steps.get("kma") self.num_clusters = kma.num_clusters # TODO: standarize PCs by first PC variance # pca.pcs_df / pca.pcs.stds.isel(n_component=0).values kma_bmus, _kma_bmus_df = kma.fit_predict( data=pca.pcs_df, **fit_params.get("kma", {}), ) self.kma_bmus = kma_bmus + 1 # TODO: Check if this is necessary!!! # Re-sort kma clusters based on variable if specified if variable_to_sort_bmus: pca.pcs["kma_bmus"] = (("time"), self.kma_bmus["kma_bmus"].values) sorted_bmus = ( pca.inverse_transform( PCs=pca.pcs.groupby("kma_bmus") .mean() .rename({"kma_bmus": pca.pca_dim_for_rows}) ) .mean(dim=pca.coords_to_stack) .sortby(variable_to_sort_bmus)[f"{pca.pca_dim_for_rows}"] .values ) sorted_bmus_mapping = dict( zip(sorted_bmus, range(1, self.num_clusters + 1)) ) self.kma_bmus.replace(sorted_bmus_mapping, inplace=True) # Add the KMA bmus to the PCs and data pca.pcs["kma_bmus"] = (("time"), self.kma_bmus["kma_bmus"].values) self.data["kma_bmus"] = (("time"), self.kma_bmus["kma_bmus"].values)
[docs] def plot_map_features( self, ax: Axes, land_color: str = cfeature.COLORS["land"] ) -> None: """ Plot map features on an axis. Parameters ---------- ax : Axes The axis to plot the map features on. land_color : str, optional The color of the land. Default is cfeature.COLORS["land"]. """ ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.BORDERS, linestyle=":") ax.add_feature(cfeature.LAND, edgecolor="black", color=land_color) ax.add_feature(cfeature.OCEAN, color="lightblue", alpha=0.5)
[docs] def plot_xwts( self, var_to_plot: str, anomaly: bool = False, map_center: tuple = None ) -> Collection: """ Plot the XWTs for a variable. Parameters ---------- var_to_plot : str The variable to plot. anomaly : bool, optional Whether to plot the anomaly of the variable. Default is False. map_center : tuple, optional The center of the map. Default is None. Returns ------- GridSpec The grid specification with the XWTs plot. """ if anomaly: data_to_plot = self.data.groupby("kma_bmus").mean()[ var_to_plot ] - self.data[var_to_plot].mean("time") else: data_to_plot = self.data.groupby("kma_bmus").mean()[var_to_plot] if self.num_clusters > 3: col_wrap = int(np.ceil(np.sqrt(self.num_clusters))) else: col_wrap = self.num_clusters # Get the configuration for the variable to plot if it exists var_to_plot_config = config_variables.get(var_to_plot, {}) # Get the cluster colors for each XWT xwts_colors = get_cluster_colors(num_clusters=self.num_clusters) # Create figure with enough space at bottom for colorbar fig = plt.figure(figsize=(15, 16)) gs = gridspec.GridSpec( col_wrap, col_wrap, wspace=0.05, hspace=0.05, ) # Plot the XWTs for the variable vmin = var_to_plot_config.get("vmin", data_to_plot.min().values) vmax = var_to_plot_config.get("vmax", data_to_plot.max().values) for i, (bmus, xwt_color) in enumerate( zip(data_to_plot.kma_bmus.values, xwts_colors) ): row = i // col_wrap col = i % col_wrap if map_center: ax = fig.add_subplot( gs[row, col], projection=ccrs.Orthographic(*map_center) ) p = data_to_plot.sel(kma_bmus=bmus).plot( ax=ax, cmap=var_to_plot_config.get("cmap", "RdBu_r"), add_colorbar=False, transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax, ) self.plot_map_features(ax=ax, land_color=xwt_color) else: ax = fig.add_subplot(gs[row, col]) p = data_to_plot.sel(kma_bmus=bmus).plot( ax=ax, cmap=var_to_plot_config.get("cmap", "RdBu_r"), add_colorbar=False, vmin=vmin, vmax=vmax, ) for border in ["top", "bottom", "left", "right"]: ax.spines[border].set_color(xwt_color) ax.set_title("") ax.text( 0.05, 0.05, int(bmus), ha="left", va="bottom", fontsize=15, fontweight="bold", color="navy", transform=ax.transAxes, ) # Add colorbar in little custom axes at the bottom cbar_ax = fig.add_axes([0.3, 0.05, 0.4, 0.02]) _cb = fig.colorbar( p, cax=cbar_ax, orientation="horizontal", label=var_to_plot_config.get("label", var_to_plot), ) return p
def _axplot_wt_probs( self, ax: Axes, wt_probs: np.ndarray, ttl: str = "", vmin: float = 0.0, vmax: float = 0.1, cmap: str = "Blues", caxis: str = "black", plot_text: bool = False, ) -> Collection: """ Axes plot WT cluster probabilities. Parameters ---------- ax : Axes The axis to plot the WT cluster probabilities on. wt_probs : np.ndarray The WT cluster probabilities. ttl : str, optional The title of the plot. Default is "". vmin : float, optional The minimum value of the colorbar. Default is 0.0. vmax : float, optional The maximum value of the colorbar. Default is 0.1. cmap : str, optional The colormap to use. Default is "Blues". caxis : str, optional The color of the axis. Default is "black". plot_text : bool, optional Whether to plot the text in each cell. Default is False. """ # cluster transition plot pc = ax.pcolor( np.flipud(wt_probs), cmap=cmap, vmin=vmin, vmax=vmax, edgecolors="k", ) # plot text in each cell if plot_text: for i in range(wt_probs.shape[0]): for j in range(wt_probs.shape[1]): ax.text( j + 0.5, wt_probs.shape[0] - 0.5 - i, f"{wt_probs[i, j]:.2f}", ha="center", va="center", fontsize=6, fontweight="bold", color="black", ) # customize axes ax.set_xticks([]) ax.set_yticks([]) ax.set_title(ttl, {"fontsize": 10, "fontweight": "bold"}) # axis color plt.setp(ax.spines.values(), color=caxis) plt.setp( [ax.get_xticklines(), ax.get_yticklines()], color=caxis, ) # axis linewidth if caxis != "black": plt.setp(ax.spines.values(), linewidth=3) return pc def _axplot_wt_hist(self, ax: Axes, ttl: str = "") -> Axes: """ Axes plot WT cluster count histogram. Parameters ---------- ax : Axes The axis to plot the WT cluster count histogram on. ttl : str, optional The title of the plot. Default is "". Returns ------- Axes The axis with the WT cluster count histogram. """ # cluster transition plot ax.hist( self.kma_bmus.values.reshape(-1), bins=np.arange(1, self.num_clusters + 2), edgecolor="k", ) # customize axes # ax.grid('y') ax.set_xticks(np.arange(1, self.num_clusters + 1) + 0.5) ax.set_xticklabels(np.arange(1, self.num_clusters + 1)) ax.set_xlim([1, self.num_clusters + 1]) ax.tick_params(axis="both", which="major", labelsize=6) ax.set_title(ttl, {"fontsize": 10, "fontweight": "bold"}) return ax
[docs] def plot_dwts_probs( self, vmax: float = 0.15, vmax_seasonality: float = 0.15, plot_text: bool = False, ) -> None: """ Plot Daily Weather Types bmus probabilities. Parameters ---------- vmax : float, optional The maximum value of the colorbar. Default is 0.15. vmax_seasonality : float, optional The maximum value of the colorbar for seasonality. Default is 0.15. plot_text : bool, optional Whether to plot the text in each cell. Default is False. Raises ------ ValueError If the kma_bmus time sampling is not daily. """ if (self.kma_bmus.index[-1] - self.kma_bmus.index[-2]) != timedelta(days=1): raise ValueError("The kma_bmus time sampling must be daily.") # Best rows cols combination if self.num_clusters > 3: n_rows = n_cols = int(np.ceil(np.sqrt(self.num_clusters))) else: n_cols = self.num_clusters n_rows = 1 # figure fig = plt.figure(figsize=(15, 9)) gs = gridspec.GridSpec(4, 7, wspace=0.10, hspace=0.25) # list all plots params l_months = [ (1, "January", gs[1, 3]), (2, "February", gs[2, 3]), (3, "March", gs[0, 4]), (4, "April", gs[1, 4]), (5, "May", gs[2, 4]), (6, "June", gs[0, 5]), (7, "July", gs[1, 5]), (8, "August", gs[2, 5]), (9, "September", gs[0, 6]), (10, "October", gs[1, 6]), (11, "November", gs[2, 6]), (12, "December", gs[0, 3]), ] l_3months = [ ([12, 1, 2], "DJF", gs[3, 3]), ([3, 4, 5], "MAM", gs[3, 4]), ([6, 7, 8], "JJA", gs[3, 5]), ([9, 10, 11], "SON", gs[3, 6]), ] # plot total probabilities c_T = self.clusters_probs_df.values C_T = c_T.reshape(n_rows, n_cols) ax_probs_T = plt.subplot(gs[:2, :2]) pc = self._axplot_wt_probs( ax_probs_T, C_T, ttl="DWT Probabilities", plot_text=plot_text ) # plot counts histogram ax_hist = plt.subplot(gs[2:, :3]) _ax_hist = self._axplot_wt_hist(ax_hist, ttl="DWT Counts") # plot probabilities by month for m_ix, m_name, m_gs in l_months: try: c_M = self.clusters_monthly_probs_df.loc[m_ix, :].values C_M = c_M.reshape(n_rows, n_cols) ax_M = plt.subplot(m_gs) self._axplot_wt_probs( ax_M, C_M, ttl=m_name, vmax=vmax, plot_text=plot_text ) except Exception as e: self.logger.error(e) # plot probabilities by 3 month sets for m_ix, m_name, m_gs in l_3months: try: c_M = self.clusters_seasonal_probs_df.loc[m_name, :].values C_M = c_M.reshape(n_rows, n_cols) ax_M = plt.subplot(m_gs) self._axplot_wt_probs( ax_M, C_M, ttl=m_name, vmax=vmax_seasonality, cmap="Greens", plot_text=plot_text, ) except Exception as e: self.logger.error(e) # add custom colorbar pp = ax_probs_T.get_position() cbar_ax = fig.add_axes([pp.x1 + 0.02, pp.y0, 0.02, pp.y1 - pp.y0]) cb = fig.colorbar(pc, cax=cbar_ax, cmap="Blues") cb.ax.tick_params(labelsize=8)
[docs] def plot_perpetual_year(self) -> Axes: """ Plot perpetual year bmus probabilities. Returns ------- Axes The plot with the perpetual year bmus probabilities. """ # Get cluster colors for stacked bar plot cluster_colors = get_cluster_colors(self.num_clusters) cluster_colors_list = [ tuple(cluster_colors[cluster, :]) for cluster in range(self.num_clusters) ] # Plot perpetual year bmus fig, ax = plt.subplots(1, figsize=(15, 5)) clusters_perpetual_year_probs_df = self.clusters_perpetual_year_probs_df clusters_perpetual_year_probs_df.plot.area( ax=ax, stacked=True, color=cluster_colors_list, legend=False, ) ax.set_ylim(0, 1) return ax