Source code for bluemath_tk.core.plotting.base_plotting

from abc import ABC, abstractmethod
from typing import Tuple

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import xarray as xr

from ...config.paths import PATHS
from .colors import hex_colors_land, hex_colors_water
from .satellite import get_satellite_image
from .utils import join_colormaps


[docs] class BasePlotting(ABC): """ Abstract base class for handling default plotting functionalities across the project. """ def __init__(self): pass
[docs] @abstractmethod def plot_line(self, x, y): """ Abstract method for plotting a line. Should be implemented by subclasses. """ pass
[docs] @abstractmethod def plot_scatter(self, x, y): """ Abstract method for plotting a scatter plot. Should be implemented by subclasses. """ pass
[docs] class DefaultStaticPlotting(BasePlotting): """ Concrete implementation of BasePlotting with static plotting behaviors. """ # Class-level dictionary for default settings templates = { "default": { "line": { "color": "blue", "line_style": "-", }, "scatter": { "color": "red", "size": 10, "marker": "o", }, "bathymetry": { "cmap": "albita_ocean", }, } } def __init__(self, template: str = "default") -> None: """ Initialize an instance of the DefaultStaticPlotting class. Parameters ---------- template : str The template to use for the plotting settings. Default is "default". Notes ----- - If no keyword arguments are provided, the default template is used. - If a keyword argument is provided, it will override the corresponding default setting. - Any other provided keyword arguments will be set as instance attributes. """ super().__init__() # Update instance attributes with either default template or passed-in values / template for key, value in self.templates.get(template, "default").items(): setattr(self, f"{key}_defaults", value)
[docs] def get_subplots(self, **kwargs): fig, ax = plt.subplots(**kwargs) return fig, ax
[docs] def get_subplot(self, figsize, **kwargs): fig = plt.figure(figsize=figsize) ax = fig.add_subplot(**kwargs) return fig, ax
[docs] def plot_line(self, ax: plt.Axes, **kwargs): c = kwargs.pop("c", self.line_defaults.get("color")) ls = kwargs.pop("ls", self.line_defaults.get("line_style")) ax.plot( c=c, ls=ls, **kwargs, )
[docs] def plot_scatter(self, ax: plt.Axes, **kwargs): c = kwargs.pop("c", self.scatter_defaults.get("color")) s = kwargs.pop("s", self.scatter_defaults.get("size")) marker = kwargs.pop("marker", self.scatter_defaults.get("marker")) ax.scatter( c=c, s=s, marker=marker, **kwargs, )
[docs] def plot_bathymetry( self, ax: plt.Axes, source: str, area: Tuple[float, float, float, float], **kwargs, ) -> None: """ Plot a bathymetry map from a bathymetry dataset stored in the PATHS dictionary. Parameters ---------- ax: plt.Axes The axes on which to plot the data. source: str The source of the bathymetry data. Must be a key in the PATHS dictionary. area: Tuple[float, float, float, float] The area of the bathymetry data in the format (lon_min, lon_max, lat_min, lat_max). **kwargs Additional keyword arguments passed to the xr.Dataset.plot() function. """ if source not in PATHS: raise ValueError(f"Source {source} not found in PATHS") else: bathymetry_ds = ( xr.open_dataset(PATHS[source]) .sel(lon=slice(area[0], area[1]), lat=slice(area[2], area[3])) .elevation ) cmap = kwargs.pop("cmap", self.bathymetry_defaults.get("cmap")) if cmap == "albita_ocean": cmap, norm = join_colormaps( cmap1=hex_colors_water, cmap2=hex_colors_land, value_range1=(bathymetry_ds.min(), 0.0), value_range2=(0.0, bathymetry_ds.max()), ) p = bathymetry_ds.plot(ax=ax, cmap=cmap, norm=norm, **kwargs) # Hide minor ticks on colorbar if hasattr(p, "colorbar") and p.colorbar is not None: p.colorbar.minorticks_off() else: bathymetry_ds.plot(ax=ax, cmap=cmap, **kwargs)
[docs] def plot_satellite( self, ax: plt.Axes, area: Tuple[float, float, float, float], source: str = "arcgis", **kwargs, ) -> None: """ Downloads and displays a satellite/raster map for the given bounding box. Parameters ---------- ax: plt.Axes The axes on which to plot the data. source: str The source of the satellite data. area: Tuple[float, float, float, float] The area of the satellite data. **kwargs Additional keyword arguments passed to the plotting function. """ map_img, extent = get_satellite_image( source=source, area=area, ) ax.set_extent(area) ax.imshow( map_img, extent=extent, transform=ccrs.Mercator.GOOGLE, **kwargs, )
[docs] class DefaultInteractivePlotting(BasePlotting): """ Concrete implementation of BasePlotting with interactive plotting behaviors. """ def __init__(self): super().__init__()
[docs] def plot_line(self, x, y): fig = go.Figure() fig.add_trace( go.Scatter(x=x, y=y, mode="lines", line=dict(color=self.default_line_color)) ) fig.update_layout( title="Interactive Line Plot", xaxis_title="X-axis", yaxis_title="Y-axis" ) fig.show()
[docs] def plot_scatter(self, x, y): fig = go.Figure() fig.add_trace( go.Scatter( x=x, y=y, mode="markers", marker=dict(color=self.default_scatter_color) ) ) fig.update_layout( title="Interactive Scatter Plot", xaxis_title="X-axis", yaxis_title="Y-axis" ) fig.show()
[docs] def plot_map(self, markers=None): fig = go.Figure( go.Scattermapbox( lat=[marker[0] for marker in markers] if markers else [], lon=[marker[1] for marker in markers] if markers else [], mode="markers", marker=go.scattermapbox.Marker(size=10, color="red"), ) ) fig.update_layout( mapbox=dict( style="open-street-map", center=dict( lat=self.default_map_center[0], lon=self.default_map_center[1] ), zoom=self.default_map_zoom_start, ), title="Interactive Map with Plotly", ) fig.show()