"""
Project: BlueMath_tk
Sub-Module: deeplearning.layers
Author: GeoOcean Research Group, Universidad de Cantabria
Repository: https://github.com/GeoOcean/BlueMath_tk.git
Status: Under development (Working)
Custom PyTorch layers for deep learning models.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class DoubleConv(nn.Module):
"""
Double convolution block: (convolution => [BN] => activation) * 2
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
mid_channels : int, optional
Number of intermediate channels. If None, uses out_channels.
activation : nn.Module, optional
Activation function. Default is SiLU.
"""
def __init__(self, in_channels, out_channels, mid_channels=None, activation=None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
if activation is None:
activation = nn.SiLU(inplace=True)
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
activation,
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
activation,
)
[docs]
def forward(self, x):
return self.double_conv(x)
[docs]
class DoubleConv3D(nn.Module):
"""
Double 3D convolution block: (convolution => [BN] => activation) * 2
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
mid_channels : int, optional
Number of intermediate channels. If None, uses out_channels.
activation : nn.Module, optional
Activation function. Default is SiLU.
"""
def __init__(self, in_channels, out_channels, mid_channels=None, activation=None):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
if activation is None:
activation = nn.SiLU(inplace=True)
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(mid_channels),
activation,
nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
activation,
)
[docs]
def forward(self, x):
return self.double_conv(x)
[docs]
class TripleConv(nn.Module):
"""
Triple convolution with separable spatial convolutions.
Uses (1, kernel_size) and (kernel_size, 1) convolutions then combines.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
mid_channels : int, optional
Number of intermediate channels. If None, uses out_channels.
kernel_size : int, optional
Kernel size for separable convolutions. Must be 3, 5, or 7. Default is 7.
"""
def __init__(self, in_channels, out_channels, mid_channels=None, kernel_size=7):
super().__init__()
if mid_channels is None:
mid_channels = out_channels
assert kernel_size in (3, 5, 7), "kernel size must be 3, 5, or 7"
padding = kernel_size // 2
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels,
mid_channels,
kernel_size=(1, kernel_size),
padding=(0, padding),
),
nn.BatchNorm2d(mid_channels),
nn.SiLU(inplace=True),
)
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels,
mid_channels,
kernel_size=(kernel_size, 1),
padding=(padding, 0),
),
nn.BatchNorm2d(mid_channels),
nn.SiLU(inplace=True),
)
self.conv_out = nn.Sequential(
nn.Conv2d(mid_channels * 2, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.SiLU(inplace=True),
)
[docs]
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x = self.conv_out(torch.cat([x1, x2], dim=1))
return x
[docs]
class Down(nn.Module):
"""
Downscaling with maxpool then double conv.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
)
[docs]
def forward(self, x):
return self.maxpool_conv(x)
[docs]
class Down3D(nn.Module):
"""
Downscaling with maxpool then double 3D conv.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool3d(2), DoubleConv3D(in_channels, out_channels)
)
[docs]
def forward(self, x):
return self.maxpool_conv(x)
[docs]
class Up(nn.Module):
"""
Upscaling then double conv.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
bilinear : bool, optional
Whether to use bilinear upsampling. Default is True.
"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(
in_channels, out_channels, mid_channels=in_channels // 2
)
else:
self.up = nn.ConvTranspose2d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = DoubleConv(in_channels, out_channels)
[docs]
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
[docs]
class Up3D(nn.Module):
"""
Upscaling then double 3D conv.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
trilinear : bool, optional
Whether to use trilinear upsampling. Default is True.
"""
def __init__(self, in_channels, out_channels, trilinear=True):
super().__init__()
if trilinear:
self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
self.conv = DoubleConv3D(
in_channels, out_channels, mid_channels=in_channels // 2
)
else:
self.up = nn.ConvTranspose3d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = DoubleConv3D(in_channels, out_channels)
[docs]
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CDHW
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]
x1 = F.pad(
x1,
[
diffX // 2,
diffX - diffX // 2,
diffY // 2,
diffY - diffY // 2,
diffZ // 2,
diffZ - diffZ // 2,
],
)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
[docs]
class OutConv(nn.Module):
"""
Output convolution layer.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
[docs]
def forward(self, x):
return self.conv(x)
[docs]
class OutConv3D(nn.Module):
"""
Output 3D convolution layer.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
"""
def __init__(self, in_channels, out_channels):
super(OutConv3D, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
[docs]
def forward(self, x):
return self.conv(x)
[docs]
class LatentDecorr(nn.Module):
"""
Latent pass-through layer that adds covariance decorrelation loss.
This layer encourages the latent representations to be decorrelated
by penalizing off-diagonal elements of the covariance matrix.
Parameters
----------
strength : float, optional
Strength of the decorrelation penalty, by default 1e-2.
"""
def __init__(self, strength: float = 1e-2):
super().__init__()
self.strength = strength
[docs]
def forward(self, z):
"""
Forward pass with decorrelation loss.
Parameters
----------
z : torch.Tensor
Latent representations, shape (batch, k).
Returns
-------
torch.Tensor
Unchanged latent representations.
"""
# z: (batch, k)
zc = z - z.mean(dim=0, keepdim=True) # center
B = zc.size(0)
cov = torch.matmul(zc.t(), zc) / (B - 1.0) # (k, k)
diag_mask = torch.eye(cov.size(0), device=cov.device, dtype=cov.dtype)
offdiag = cov * (1 - diag_mask) # zero diag
loss = self.strength * torch.sum(offdiag**2)
# Add loss to computation graph
z = z + 0 * loss # Trick to add loss to graph without changing z
# Store current loss for retrieval during training
self._loss = loss
return z
[docs]
class PositionalEmbedding(nn.Module):
"""
Learnable positional embedding layer for transformer models.
Parameters
----------
n_tokens : int
Number of tokens/patches.
d_model : int
Model dimension.
"""
def __init__(self, n_tokens: int, d_model: int):
super().__init__()
self.pos = nn.Parameter(torch.zeros(1, n_tokens, d_model))
[docs]
def forward(self, x):
"""
Forward pass.
Parameters
----------
x : torch.Tensor
Input tokens, shape (B, N, D).
Returns
-------
torch.Tensor
Tokens with positional embeddings added.
"""
return x + self.pos
[docs]
class Patchify(nn.Module):
"""
Patchify layer that splits images into patches.
Parameters
----------
patch_size : int
Size of each patch (patch_size x patch_size).
"""
def __init__(self, patch_size: int):
super().__init__()
self.p = patch_size
[docs]
def forward(self, x):
"""
Forward pass.
Parameters
----------
x : torch.Tensor
Input images, shape (B, C, H, W), where H and W are multiples of p.
Returns
-------
torch.Tensor
Patches, shape (B, N, p*p*C).
"""
B, C, H, W = x.shape
p = self.p
Hp, Wp = H // p, W // p
# Unfold into patches
x = x.unfold(2, p, p).unfold(3, p, p) # (B, C, Hp, Wp, p, p)
x = x.contiguous().view(B, C, Hp, Wp, p * p)
x = x.permute(0, 2, 3, 1, 4).contiguous() # (B, Hp, Wp, C, p*p)
x = x.view(B, Hp * Wp, C * p * p) # (B, N, p*p*C)
return x
[docs]
class Unpatchify(nn.Module):
"""
Unpatchify layer that reconstructs images from patches.
Parameters
----------
patch_size : int
Size of each patch.
Hp : int
Number of patches in height dimension.
Wp : int
Number of patches in width dimension.
C : int
Number of channels.
"""
def __init__(self, patch_size: int, Hp: int, Wp: int, C: int):
super().__init__()
self.p = patch_size
self.Hp = Hp
self.Wp = Wp
self.C = C
[docs]
def forward(self, tokens):
"""
Forward pass.
Parameters
----------
tokens : torch.Tensor
Patches, shape (B, N=Hp*Wp, p*p*C).
Returns
-------
torch.Tensor
Reconstructed images, shape (B, C, H, W).
"""
p, Hp, Wp, C = self.p, self.Hp, self.Wp, self.C
B = tokens.size(0)
x = tokens.view(B, Hp, Wp, C, p, p)
x = x.permute(0, 3, 1, 4, 2, 5).contiguous() # (B, C, Hp, p, Wp, p)
x = x.view(B, C, Hp * p, Wp * p) # (B, C, H, W)
return x
[docs]
class TimePositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding for temporal sequences.
Adds 1D sinusoidal time positions to per-timestep embeddings.
"""
[docs]
def forward(self, x):
"""
Forward pass.
Parameters
----------
x : torch.Tensor
Input sequences, shape (B, L, D).
Returns
-------
torch.Tensor
Sequences with positional encodings added.
"""
# x: (B, L, D)
B, L, D = x.shape
device = x.device
pos = torch.arange(L, device=device, dtype=torch.float32).unsqueeze(1) # (L, 1)
i = torch.arange(D, device=device, dtype=torch.float32).unsqueeze(0) # (1, D)
angles = pos / (10000.0 ** (2 * (i // 2) / D))
pe = torch.zeros(L, D, device=device)
pe[:, 0::2] = torch.sin(angles[:, 0::2])
pe[:, 1::2] = torch.cos(angles[:, 1::2])
pe = pe.unsqueeze(0) # (1, L, D)
return x + pe
[docs]
class ConvLSTMCell(nn.Module):
"""
ConvLSTM Cell implementation.
Parameters
----------
input_dim : int
Number of channels of input tensor.
hidden_dim : int
Number of channels of hidden state.
kernel_size : int or tuple, optional
Size of the convolutional kernel. Default is 3.
bias : bool, optional
Whether to add bias. Default is True.
"""
def __init__(self, input_dim, hidden_dim, kernel_size=3, bias=True):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = (
kernel_size
if isinstance(kernel_size, tuple)
else (kernel_size, kernel_size)
)
self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
self.bias = bias
self.conv = nn.Conv2d(
in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias,
)
[docs]
def forward(self, input_tensor, cur_state):
h_cur, c_cur = cur_state
combined = torch.cat(
[input_tensor, h_cur], dim=1
) # concatenate along channel axis
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
[docs]
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (
torch.zeros(
batch_size,
self.hidden_dim,
height,
width,
device=self.conv.weight.device,
),
torch.zeros(
batch_size,
self.hidden_dim,
height,
width,
device=self.conv.weight.device,
),
)
[docs]
class ConvLSTM(nn.Module):
"""
ConvLSTM module.
Parameters
----------
input_dim : int
Number of channels of input tensor.
hidden_dim : int or list
Number of channels of hidden state(s).
kernel_size : int or tuple, optional
Size of the convolutional kernel. Default is 3.
num_layers : int, optional
Number of ConvLSTM layers. Default is 1.
batch_first : bool, optional
If True, input and output tensors are provided as (batch, seq, channel, height, width).
Default is False.
bias : bool, optional
Whether to add bias. Default is True.
return_all_layers : bool, optional
If True, returns all layers' outputs. Default is False.
"""
def __init__(
self,
input_dim,
hidden_dim,
kernel_size=3,
num_layers=1,
batch_first=False,
bias=True,
return_all_layers=False,
):
super().__init__()
self._check_kernel_size_consistency(kernel_size)
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError("Inconsistent list length.")
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(
ConvLSTMCell(
input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias,
)
)
self.cell_list = nn.ModuleList(cell_list)
[docs]
def forward(self, input_tensor, hidden_state=None):
if not self.batch_first:
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
b, _, _, h, w = input_tensor.size()
if hidden_state is not None:
raise NotImplementedError()
else:
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h, c = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
h, c = self.cell_list[layer_idx](
input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
)
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h, c])
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
return init_states
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (
isinstance(kernel_size, tuple)
or (
isinstance(kernel_size, list)
and all([isinstance(elem, tuple) for elem in kernel_size])
)
):
raise ValueError("`kernel_size` must be tuple or list of tuples")
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
[docs]
class LinearSelfAttention(nn.Module):
"""
Softmax-free, Performer-style linear attention on the time axis.
Provides O(B * L * D * H) scaling, good for large sequence lengths.
Parameters
----------
d_model : int
Model dimension.
num_heads : int, optional
Number of attention heads, by default 4.
"""
def __init__(self, d_model: int, num_heads: int = 4):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.Wq = nn.Linear(d_model, d_model, bias=False)
self.Wk = nn.Linear(d_model, d_model, bias=False)
self.Wv = nn.Linear(d_model, d_model, bias=False)
self.Wo = nn.Linear(d_model, d_model, bias=False)
def _phi(self, x):
"""Positive feature map (ELU + 1) as in linear transformers."""
return F.elu(x) + 1.0
[docs]
def forward(self, x):
"""
Forward pass.
Parameters
----------
x : torch.Tensor
Input sequences, shape (B, L, D).
Returns
-------
torch.Tensor
Output sequences, shape (B, L, D).
"""
# x: (B, L, D)
B, L, D = x.shape
Q = self.Wq(x)
K = self.Wk(x)
V = self.Wv(x) # (B, L, D)
# split heads
Qh = Q.view(B, L, self.num_heads, self.d_head).transpose(1, 2) # (B, H, L, Dh)
Kh = K.view(B, L, self.num_heads, self.d_head).transpose(1, 2) # (B, H, L, Dh)
Vh = V.view(B, L, self.num_heads, self.d_head).transpose(1, 2) # (B, H, L, Dh)
Qf, Kf = self._phi(Qh), self._phi(Kh) # (B, H, L, Dh)
# Precompute Kf^T Vh and Kf^T 1 for normalization (linear time in L)
Kv = torch.einsum("bhlm,bhln->bhmn", Kf, Vh) # (B, H, Dh, Dh)
K1 = torch.einsum(
"bhlm,bhl->bhm",
Kf,
torch.ones((B, self.num_heads, L), device=Kf.device, dtype=Kf.dtype),
) # (B, H, Dh)
# Numerator: Qf @ (Kf^T V)
num = torch.einsum("bhlm,bhmn->bhln", Qf, Kv) # (B, H, L, Dh)
# Denominator: Qf @ (Kf^T 1) (broadcast over Dh)
den = torch.einsum("bhlm,bhm->bhl", Qf, K1) # (B, H, L)
den = (den + 1e-6).unsqueeze(-1) # (B, H, L, 1)
out = num / den # (B, H, L, Dh)
# merge heads
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model) # (B, L, D)
return self.Wo(out)