Skip to content

resunet++

The ResUNet++.

Decoder

Code credit: https://github.com/jlcsilva/segmentation_models.pytorch

kelp.nn.models.resunetplusplus.decoder.ASPP

Bases: Module

ASPP described in https://arxiv.org/pdf/1706.05587.pdf but without the concatenation of 1x1, original feature maps and global average pooling

Source code in kelp/nn/models/resunetplusplus/decoder.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ASPP(nn.Module):
    """
    ASPP described in https://arxiv.org/pdf/1706.05587.pdf but without the concatenation of 1x1,
    original feature maps and global average pooling
    """

    def __init__(self, in_channels: int, out_channels: int, rate: Tuple[int, int, int] = (6, 12, 18)) -> None:
        super(ASPP, self).__init__()

        # Dilation rates of 6, 12 and 18 for the Atrous Spatial Pyramid Pooling blocks
        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=rate[0], dilation=rate[0]),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=rate[1], dilation=rate[1]),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=rate[2], dilation=rate[2]),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )
        self.aspp_block4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )

        self.output = nn.Conv2d((len(rate) + 1) * out_channels, out_channels, kernel_size=1)
        self._init_weights()

    def forward(self, x: Tensor) -> Tensor:
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        x4 = self.aspp_block4(x)
        out = torch.cat([x1, x2, x3, x4], dim=1)

        return self.output(out)

    def _init_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

Model

Code credit: https://github.com/jlcsilva/segmentation_models.pytorch

kelp.nn.models.resunetplusplus.model.ResUnetPlusPlus

Bases: SegmentationModel

ResUnet++ is a full-convolutional neural network for image semantic segmentation. Consist of encoder and decoder parts connected with skip connections. The encoder extracts features of different spatial resolution (skip connections) which are used by decoder to define accurate segmentation mask.

Applies attention to the skip connection feature maps, based on themselves and the decoder feature maps. The skip connection feature maps are then fused with the decoder feature maps through concatenation. Uses an Atrous Spatial Pyramid Pooling (ASPP) bridge module and residual connections inside each decoder blocks.

Parameters:

Name Type Description Default
encoder_name str

Name of the classification model that will be used as an encoder (a.k.a. backbone) to extract features of different spatial resolution

'resnet34'
encoder_depth int

A number of stages used in encoder in range [3, 5]. Each stage generate features two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). Default is 5

5
encoder_weights Optional[str]

One of None (random initialization), "imagenet" (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name)

'imagenet'
decoder_channels Optional[List[int]]

List of integers which specify in_channels parameter for convolutions used in decoder. Length of the list should be the same as encoder_depth

None
decoder_use_batchnorm bool

If True, BatchNorm2d layer between Conv2D and Activation layers is used. If "inplace" InplaceABN will be used, allows to decrease memory consumption. Available options are True, False, "inplace"

True
decoder_attention_type Optional[str]

Attention module used in decoder of the model. Available options are None and scse (https://arxiv.org/abs/1808.08127).

None
in_channels int

A number of input channels for the model, default is 3 (RGB images)

3
classes int

A number of classes for output mask (or you can think as a number of channels of output mask)

1
activation Optional[Union[str, Callable[[Any], Any]]]

An activation function to apply after the final convolution layer. Available options are "sigmoid", "softmax", "logsoftmax", "tanh", "identity", callable and None. Default is None

None
aux_params Optional[Dict[str, Any]]

Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if aux_params is not None (default). Supported params: - classes (int): A number of classes - pooling (str): One of "max", "avg". Default is "avg" - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be None to return logits)

None

Returns:

Type Description

torch.nn.Module: ResUnetPlusPlus

Reference

Jha et al. 2019

Source code in kelp/nn/models/resunetplusplus/model.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class ResUnetPlusPlus(SegmentationModel):
    """ResUnet++ is a full-convolutional neural network for image semantic segmentation. Consist of *encoder*
    and *decoder* parts connected with *skip connections*. The encoder extracts features of different spatial
    resolution (skip connections) which are used by decoder to define accurate segmentation mask.

    Applies attention to the skip connection feature maps, based on themselves and the decoder feature maps.
    The skip connection feature maps are then fused with the decoder feature maps through *concatenation*.
    Uses an Atrous Spatial Pyramid Pooling (ASPP) bridge module and residual connections inside each decoder
    blocks.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a. backbone)
                to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
            Available options are **True, False, "inplace"**
        decoder_attention_type: Attention module used in decoder of the model. Available options are
            **None** and **scse** (https://arxiv.org/abs/1808.08127).
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
                **callable** and **None**.
            Default is **None**
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)

    Returns:
        ``torch.nn.Module``: ResUnetPlusPlus

    Reference:
        [Jha et al. 2019](https://arxiv.org/abs/1911.07067)
    """

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: Optional[List[int]] = None,
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, Callable[[Any], Any]]] = None,
        aux_params: Optional[Dict[str, Any]] = None,
    ) -> None:
        super().__init__()
        if decoder_channels is None:
            decoder_channels = [256, 128, 64, 32, 16]

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = ResUnetPlusPlusDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            attention_type=decoder_attention_type,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=1,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "resunet++-{}".format(encoder_name)
        self.initialize()