Source code for bluemath_tk.deeplearning.regularizers

"""
Project: BlueMath_tk
Sub-Module: deeplearning.regularizers
Author: GeoOcean Research Group, Universidad de Cantabria
Repository: https://github.com/GeoOcean/BlueMath_tk.git
Status: Under development (Working)

Regularization functions for PyTorch models.
"""

import torch


[docs] def orthogonal_regularizer(W: torch.Tensor, strength: float = 1e-3) -> torch.Tensor: """ Weight orthogonality regularizer. Encourages the weight matrix W to be orthogonal by penalizing deviations of W^T W from the identity matrix. Parameters ---------- W : torch.Tensor Weight matrix, shape (out_features, in_features). strength : float, optional Strength of the orthogonality penalty, by default 1e-3. Returns ------- torch.Tensor Scalar penalty value. Examples -------- >>> import torch >>> from bluemath_tk.deeplearning.regularizers import orthogonal_regularizer >>> W = torch.randn(20, 128) >>> penalty = orthogonal_regularizer(W, strength=1e-3) """ # W: shape (out_features, in_features). We want W^T W ≈ I_k # For orthogonality, we typically want W W^T ≈ I (for square) or W^T W ≈ I # Assuming W is (k, in_dim), we want W^T W ≈ I_k WT_W = torch.matmul(W, W.t()) # (k, k) I_k = torch.eye(WT_W.size(0), device=WT_W.device, dtype=WT_W.dtype) return strength * torch.sum((WT_W - I_k) ** 2)
[docs] def l2_regularizer(parameters, strength: float = 1e-4) -> torch.Tensor: """ L2 regularization (weight decay). Parameters ---------- parameters : iterable of torch.Tensor Model parameters to regularize. strength : float, optional Strength of the L2 penalty, by default 1e-4. Returns ------- torch.Tensor Scalar penalty value. Examples -------- >>> import torch.nn as nn >>> from bluemath_tk.deeplearning.regularizers import l2_regularizer >>> model = nn.Linear(10, 5) >>> penalty = l2_regularizer(model.parameters(), strength=1e-4) """ l2_loss = 0.0 for param in parameters: l2_loss += torch.sum(param**2) return strength * l2_loss
[docs] def l1_regularizer(parameters, strength: float = 1e-4) -> torch.Tensor: """ L1 regularization (sparsity). Parameters ---------- parameters : iterable of torch.Tensor Model parameters to regularize. strength : float, optional Strength of the L1 penalty, by default 1e-4. Returns ------- torch.Tensor Scalar penalty value. Examples -------- >>> import torch.nn as nn >>> from bluemath_tk.deeplearning.regularizers import l1_regularizer >>> model = nn.Linear(10, 5) >>> penalty = l1_regularizer(model.parameters(), strength=1e-4) """ l1_loss = 0.0 for param in parameters: l1_loss += torch.sum(torch.abs(param)) return strength * l1_loss