import importlib
import logging
import os
import pickle
import sys
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy as np
import pandas as pd
import xarray as xr
from scipy import constants
from sklearn.metrics import (
explained_variance_score,
mean_absolute_error,
mean_squared_error,
r2_score,
)
from sklearn.preprocessing import StandardScaler
from .logging import get_file_logger
from .operations import (
denormalize,
destandarize,
get_degrees_from_uv,
get_uv_components,
normalize,
standarize,
)
[docs]
class BlueMathModel(ABC):
"""
Abstract base class for handling default functionalities across the project.
"""
gravity = constants.g
@abstractmethod
def __init__(self) -> None:
self._logger: logging.Logger = None
self._exclude_attributes: List[str] = []
self.num_workers: int = 1
# [UNDER DEVELOPMENT] Below, we try to generalise parallel processing
bluemath_num_workers = os.environ.get("BLUEMATH_NUM_WORKERS", None)
omp_num_threads = os.environ.get("OMP_NUM_THREADS", None)
if bluemath_num_workers is not None:
self.logger.info(
f"Setting self.num_workers to {bluemath_num_workers} due to BLUEMATH_NUM_WORKERS. \n"
"Change it using self.set_num_processors_to_use method. \n"
"Also setting OMP_NUM_THREADS to 1, to avoid conflicts with BlueMath parallel processing."
)
self.set_num_processors_to_use(num_processors=int(bluemath_num_workers))
self.set_omp_num_threads(num_threads=1)
elif omp_num_threads is not None:
self.logger.info(
f"Changing variable OMP_NUM_THREADS from {omp_num_threads} to 1. \n"
f"And setting self.num_workers to {omp_num_threads}. \n"
"To avoid conflicts with BlueMath parallel processing."
)
self.set_omp_num_threads(num_threads=1)
self.set_num_processors_to_use(num_processors=int(omp_num_threads))
else:
self.num_workers = 1 # self.get_num_processors_available()
self.logger.info(
f"Setting self.num_workers to {self.num_workers}. "
"Change it using self.set_num_processors_to_use method."
)
def __getstate__(self):
"""Exclude certain attributes from being pickled."""
state = self.__dict__.copy()
for attr in self._exclude_attributes:
if attr in state:
del state[attr]
# Iterate through the state attributes, warning about xr.Datasets
for key, value in state.items():
if isinstance(value, xr.Dataset) or isinstance(value, xr.DataArray):
self.logger.warning(
f"Attribute {key} is an xarray Dataset / Dataarray and will be pickled!"
)
return state
@property
def logger(self) -> logging.Logger:
if self._logger is None:
self._logger = get_file_logger(name=self.__class__.__name__)
return self._logger
@logger.setter
def logger(self, value: logging.Logger) -> None:
self._logger = value
[docs]
def set_logger_name(
self, name: str, level: str = "INFO", console: bool = True
) -> None:
"""Sets the name of the logger."""
self.logger = get_file_logger(name=name, level=level, console=console)
[docs]
def save_model(self, model_path: str, exclude_attributes: List[str] = None) -> None:
"""Saves the model to a file."""
self.logger.info(f"Saving model to {model_path}")
if exclude_attributes is not None:
self._exclude_attributes += exclude_attributes
with open(model_path, "wb") as f:
pickle.dump(self, f)
[docs]
def load_model(self, model_path: str) -> "BlueMathModel":
"""Loads the model from a file."""
self.logger.info(f"Loading model from {model_path}")
with open(model_path, "rb") as f:
model = pickle.load(f)
return model
[docs]
def list_class_attributes(self) -> list:
"""
Lists the attributes of the class.
Returns
-------
list
The attributes of the class.
"""
return [
attr
for attr in dir(self)
if not callable(getattr(self, attr)) and not attr.startswith("__")
]
[docs]
def list_class_methods(self) -> list:
"""
Lists the methods of the class.
Returns
-------
list
The methods of the class.
"""
return [
attr
for attr in dir(self)
if callable(getattr(self, attr)) and not attr.startswith("__")
]
[docs]
def check_nans(
self,
data: Union[np.ndarray, pd.Series, pd.DataFrame, xr.DataArray, xr.Dataset],
replace_value: Union[float, callable] = None,
raise_error: bool = False,
) -> Union[np.ndarray, pd.Series, pd.DataFrame, xr.DataArray, xr.Dataset]:
"""
Checks for NaNs in the data and optionally replaces them.
Parameters
----------
data : np.ndarray, pd.Series, pd.DataFrame, xr.DataArray or xr.Dataset
The data to check for NaNs.
replace_value : float or callable, optional
The value to replace NaNs with. If None, NaNs will not be replaced.
If a callable is provided, it will be called and the result will be returned.
Default is None.
raise_error : bool, optional
Whether to raise an error if NaNs are found. Default is False.
Returns
-------
data : np.ndarray, pd.Series, pd.DataFrame, xr.DataArray or xr.Dataset
The data with NaNs optionally replaced.
Raises
------
ValueError
If NaNs are found and raise_error is True.
Notes
-----
- This method is intended to be used in classes that inherit from the BlueMathModel class.
- The method checks for NaNs in the data and optionally replaces them with the specified value.
"""
# If replace_value is a callable, just call and return it
if callable(replace_value):
self.logger.debug(f"Replace value is a callable. Calling {replace_value}.")
return replace_value(data)
if isinstance(data, np.ndarray):
if np.isnan(data).any():
if raise_error:
raise ValueError("Data contains NaNs.")
self.logger.warning("Data contains NaNs.")
if replace_value is not None:
data = np.nan_to_num(data, nan=replace_value)
self.logger.info(f"NaNs replaced with {replace_value}.")
elif isinstance(data, pd.Series) or isinstance(data, pd.DataFrame):
if data.isnull().values.any():
if raise_error:
raise ValueError("Data contains NaNs.")
self.logger.warning("Data contains NaNs.")
if replace_value is not None:
data.fillna(replace_value, inplace=True)
self.logger.info(f"NaNs replaced with {replace_value}.")
elif isinstance(data, xr.DataArray) or isinstance(data, xr.Dataset):
if data.isnull().any():
if raise_error:
raise ValueError("Data contains NaNs.")
self.logger.warning("Data contains NaNs.")
if replace_value is not None:
data = data.fillna(replace_value)
self.logger.info(f"NaNs replaced with {replace_value}.")
else:
self.logger.warning("Data type not supported for NaN check.")
return data
[docs]
def normalize(
self, data: Union[pd.DataFrame, xr.Dataset], custom_scale_factor: dict = {}
) -> Tuple[Union[pd.DataFrame, xr.Dataset], dict]:
"""
Normalize data to 0-1 using min max scaler approach.
More info in bluemath_tk.core.operations.normalize.
Parameters
----------
data : pd.DataFrame or xr.Dataset
The data to normalize.
custom_scale_factor : dict, optional
Custom scale factors for normalization.
Returns
-------
normalized_data : pd.DataFrame or xr.Dataset
The normalized data.
scale_factor : dict
The scale factors used for normalization.
"""
normalized_data, scale_factor = normalize(
data=data, custom_scale_factor=custom_scale_factor, logger=self.logger
)
return normalized_data, scale_factor
[docs]
def denormalize(
self, normalized_data: pd.DataFrame, scale_factor: dict
) -> pd.DataFrame:
"""
Denormalize data using provided scale_factor.
More info in bluemath_tk.core.operations.denormalize.
Parameters
----------
normalized_data : pd.DataFrame
The normalized data to denormalize.
scale_factor : dict
The scale factors used for denormalization.
Returns
-------
data : pd.DataFrame
The denormalized data.
"""
data = denormalize(normalized_data=normalized_data, scale_factor=scale_factor)
return data
[docs]
def standarize(
self,
data: Union[np.ndarray, pd.DataFrame, xr.Dataset],
scaler: StandardScaler = None,
transform: bool = False,
) -> Tuple[Union[np.ndarray, pd.DataFrame, xr.Dataset], StandardScaler]:
"""
Standarize data using StandardScaler.
More info in bluemath_tk.core.operations.standarize.
Parameters
----------
data : np.ndarray, pd.DataFrame or xr.Dataset
Input data to be standarized.
scaler : StandardScaler, optional
Scaler object to use for standarization. Default is None.
transform : bool
Whether to just transform the data. Default to False.
Returns
-------
standarized_data : np.ndarray, pd.DataFrame or xr.Dataset
Standarized data.
scaler : StandardScaler
Scaler object used for standarization.
"""
standarized_data, scaler = standarize(
data=data, scaler=scaler, transform=transform
)
return standarized_data, scaler
[docs]
def destandarize(
self,
standarized_data: Union[np.ndarray, pd.DataFrame, xr.Dataset],
scaler: StandardScaler,
) -> Union[np.ndarray, pd.DataFrame, xr.Dataset]:
"""
Destandarize data using provided scaler.
More info in bluemath_tk.core.operations.destandarize.
Parameters
----------
standarized_data : np.ndarray, pd.DataFrame or xr.Dataset
Standarized data to be destandarized.
scaler : StandardScaler
Scaler object used for standarization.
Returns
-------
data : np.ndarray, pd.DataFrame or xr.Dataset
Destandarized data.
"""
data = destandarize(standarized_data=standarized_data, scaler=scaler)
return data
[docs]
@staticmethod
def get_metrics(
data1: Union[pd.DataFrame, xr.Dataset],
data2: Union[pd.DataFrame, xr.Dataset],
) -> pd.DataFrame:
"""
Gets the metrics of the model.
Parameters
----------
data1 : pd.DataFrame or xr.Dataset
The first dataset.
data2 : pd.DataFrame or xr.Dataset
The second dataset.
Returns
-------
metrics : pd.DataFrame
The metrics of the model.
Raises
------
ValueError
If the DataFrames or Datasets have different shapes.
TypeError
If the inputs are not both DataFrames or both xarray Datasets.
"""
if isinstance(data1, pd.DataFrame) and isinstance(data2, pd.DataFrame):
if data1.shape != data2.shape:
raise ValueError("DataFrames must have the same shape")
variables = data1.columns
elif isinstance(data1, xr.Dataset) and isinstance(data2, xr.Dataset):
if sorted(list(data1.dims)) != sorted(list(data2.dims)) or sorted(
list(data1.data_vars)
) != sorted(list(data2.data_vars)):
raise ValueError(
"Datasets must have the same dimensions, coordinates and variables"
)
variables = data1.data_vars
else:
raise TypeError(
"Inputs must be either both DataFrames or both xarray Datasets"
)
metrics = {}
for var in variables:
if isinstance(data1, pd.DataFrame):
y_true = data1[var]
y_pred = data2[var]
else:
y_true = data1[var].values.reshape(-1)
y_pred = data2[var].values.reshape(-1)
metrics[var] = {
"mean_squared_error": mean_squared_error(y_true, y_pred),
"r2_score": r2_score(y_true, y_pred),
"mean_absolute_error": mean_absolute_error(y_true, y_pred),
"explained_variance_score": explained_variance_score(y_true, y_pred),
}
return pd.DataFrame(metrics).T
[docs]
@staticmethod
def get_uv_components(x_deg: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
This method calculates the u and v components for the given directional data.
Here, we assume that the directional data is in degrees,
beign 0° the North direction,
and increasing clockwise.
0° N
|
|
270° W <---------> 90° E
|
|
90° S
Parameters
----------
x_deg : np.ndarray
The directional data in degrees.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The u and v components.
"""
return get_uv_components(x_deg)
[docs]
@staticmethod
def get_degrees_from_uv(xu: np.ndarray, xv: np.ndarray) -> np.ndarray:
"""
This method calculates the degrees from the u and v components.
Here, we assume u and v represent angles between 0 and 360 degrees,
where 0° is the North direction,
and increasing clockwise.
(u=0, v=1)
|
|
(u=-1, v=0) <---------> (u=1, v=0)
|
|
(u=0, v=-1)
Parameters
----------
xu : np.ndarray
The u component.
xv : np.ndarray
The v component.
Returns
-------
np.ndarray
The degrees.
"""
return get_degrees_from_uv(xu, xv)
[docs]
def set_omp_num_threads(self, num_threads: int) -> None:
"""
Sets the number of threads for OpenMP.
Parameters
----------
num_threads : int
The number of threads.
Warning
-----
- This methos is under development.
"""
os.environ["OMP_NUM_THREADS"] = str(num_threads)
# Re-import numpy if it is already imported
if "numpy" in sys.modules:
importlib.reload(np)
[docs]
def get_num_processors_available(self) -> int:
"""
Gets the number of processors available.
Returns
-------
int
The number of processors available.
TODO
----
- Check whether available processors are used or not.
"""
return int(os.cpu_count() * 0.9)
[docs]
def set_num_processors_to_use(self, num_processors: int) -> None:
"""
Sets the number of processors to use for parallel processing.
Parameters
----------
num_processors : int
The number of processors to use.
If -1, all available processors will be used.
"""
# Retrieve the number of processors available
num_processors_available = self.get_num_processors_available()
# Check if the number of processors requested is valid
if num_processors == -1:
num_processors = num_processors_available
elif num_processors <= 0:
raise ValueError("Number of processors must be greater than 0")
elif (num_processors_available - num_processors) < 2:
self.logger.info(
"Number of processors requested leaves less than 2 processors available"
)
# Set the number of processors to use
self.num_workers = num_processors
[docs]
def parallel_execute(
self,
func: Callable,
items: List[Any],
num_workers: int,
cpu_intensive: bool = False,
**kwargs,
) -> Dict[int, Any]:
"""
Execute a function in parallel using concurrent.futures.
Parameters
----------
func : Callable
Function to execute for each item.
items : List[Any]
List of items to process.
num_workers : int
Number of parallel workers.
cpu_intensive : bool, optional
Whether the function is CPU intensive. Default is False.
**kwargs : dict
Additional keyword arguments for func.
Returns
-------
Dict[int, Any]
Dictionary with the results of the function execution.
The keys are the indices of the items in the original list.
The values are the results of the function execution.
Warnings
--------
- When using ThreadPoolExecutor, the function sometimes fails when reading / writing
to the same / different files. Might be the GIL (Global Interpreter Lock) in Python.
- cpu_intensive = True does not work with non-pickable objects (Under development).
"""
results = {}
executor_class = ProcessPoolExecutor if cpu_intensive else ThreadPoolExecutor
self.logger.info(f"Using {executor_class.__name__} for parallel execution")
with executor_class(max_workers=num_workers) as executor:
future_to_item = {
executor.submit(func, *item, **kwargs)
if isinstance(item, tuple)
else executor.submit(func, item, **kwargs): i
for i, item in enumerate(items)
}
for future in as_completed(future_to_item):
i = future_to_item[future]
try:
result = future.result()
results[i] = result
except Exception as exc:
self.logger.error(f"Job for {i} generated an exception: {exc}")
return results