Source code for bluemath_tk.deeplearning.unet

"""
Unet module.

This module is will try to generalize models like the ones in:
https://github.com/oaeen/wind2iwp
"""

import torch.nn as nn

from .layers import DoubleConv, DoubleConv3D, Down, Down3D, OutConv, OutConv3D, Up, Up3D


[docs] class UNet(nn.Module): """ U-Net architecture for 2D image segmentation/regression. Parameters ---------- n_channels : int Number of input channels. n_classes : int Number of output channels/classes. base_channels : int, optional Base number of channels. Default is 64. bilinear : bool, optional Whether to use bilinear upsampling. Default is True. """ def __init__(self, n_channels=1, n_classes=1, base_channels=64, bilinear=True): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, base_channels) self.down1 = Down(base_channels, base_channels * 2) self.down2 = Down(base_channels * 2, base_channels * 4) self.down3 = Down(base_channels * 4, base_channels * 8) factor = 2 if bilinear else 1 self.down4 = Down(base_channels * 8, base_channels * 16 // factor) self.up1 = Up(base_channels * 16, base_channels * 8 // factor, bilinear) self.up2 = Up(base_channels * 8, base_channels * 4 // factor, bilinear) self.up3 = Up(base_channels * 4, base_channels * 2 // factor, bilinear) self.up4 = Up(base_channels * 2, base_channels, bilinear) self.outc = OutConv(base_channels, n_classes)
[docs] def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits
[docs] class UNet3D(nn.Module): """ U-Net architecture for 3D volumetric data. Parameters ---------- n_channels : int Number of input channels. n_classes : int Number of output channels/classes. base_channels : int, optional Base number of channels. Default is 16. trilinear : bool, optional Whether to use trilinear upsampling. Default is True. """ def __init__(self, n_channels=1, n_classes=1, base_channels=16, trilinear=True): super(UNet3D, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.trilinear = trilinear self.inc = DoubleConv3D(n_channels, base_channels) self.down1 = Down3D(base_channels, base_channels * 2) self.down2 = Down3D(base_channels * 2, base_channels * 4) self.down3 = Down3D(base_channels * 4, base_channels * 8) factor = 2 if trilinear else 1 self.down4 = Down3D(base_channels * 8, base_channels * 16 // factor) self.up1 = Up3D(base_channels * 16, base_channels * 8 // factor, trilinear) self.up2 = Up3D(base_channels * 8, base_channels * 4 // factor, trilinear) self.up3 = Up3D(base_channels * 4, base_channels * 2 // factor, trilinear) self.up4 = Up3D(base_channels * 2, base_channels, trilinear) self.outc = OutConv3D(base_channels, n_classes)
[docs] def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits