Skip to content

efficientunet++

The EfficientUnet++.

Decoder

Code credit: https://github.com/jlcsilva/segmentation_models.pytorch

kelp.nn.models.efficientunetplusplus.decoder.EfficientUnetPlusPlusDecoder

Bases: Module

EfficientUnet++ Decoder.

Source code in kelp/nn/models/efficientunetplusplus/decoder.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class EfficientUnetPlusPlusDecoder(nn.Module):
    """
    EfficientUnet++ Decoder.
    """

    def __init__(
        self,
        encoder_channels: List[int],
        decoder_channels: List[int],
        n_blocks: int = 5,
        squeeze_ratio: int = 1,
        expansion_ratio: int = 1,
    ) -> None:
        super().__init__()
        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        encoder_channels = encoder_channels[1:]  # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[::-1]  # reverse channels to start from head of encoder
        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        self.in_channels = [head_channels] + list(decoder_channels[:-1])
        self.skip_channels = list(encoder_channels[1:]) + [0]
        self.out_channels = decoder_channels

        # combine decoder keyword arguments
        kwargs = dict(squeeze_ratio=squeeze_ratio, expansion_ratio=expansion_ratio)

        blocks = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(layer_idx + 1):
                if depth_idx == 0:
                    in_ch = self.in_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
                    out_ch = self.out_channels[layer_idx]
                else:
                    out_ch = self.skip_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1 - depth_idx)
                    in_ch = self.skip_channels[layer_idx - 1]
                blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
        blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock(
            self.in_channels[-1], 0, self.out_channels[-1], **kwargs
        )
        self.blocks = nn.ModuleDict(blocks)
        self.depth = len(self.in_channels) - 1

    def forward(self, *features: Any) -> Tensor:
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder
        # start building dense connections
        dense_x = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(self.depth - layer_idx):
                if layer_idx == 0:
                    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](features[depth_idx], features[depth_idx + 1])
                    dense_x[f"x_{depth_idx}_{depth_idx}"] = output
                else:
                    dense_l_i = depth_idx + layer_idx
                    cat_features = [dense_x[f"x_{idx}_{dense_l_i}"] for idx in range(depth_idx + 1, dense_l_i + 1)]
                    cat_features = torch.cat(cat_features + [features[dense_l_i + 1]], dim=1)
                    dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[f"x_{depth_idx}_{dense_l_i}"](
                        dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features
                    )
        dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](dense_x[f"x_{0}_{self.depth - 1}"])
        return dense_x[f"x_{0}_{self.depth}"]

kelp.nn.models.efficientunetplusplus.decoder.InvertedResidual

Bases: Module

Inverted bottleneck residual block with an scSE block embedded into the residual layer, after the depth-wise convolution. By default, uses batch normalization and Hardswish activation.

Source code in kelp/nn/models/efficientunetplusplus/decoder.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class InvertedResidual(nn.Module):
    """
    Inverted bottleneck residual block with an scSE block embedded into the residual layer, after the
    depth-wise convolution. By default, uses batch normalization and Hardswish activation.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        expansion_ratio: int = 1,
        squeeze_ratio: int = 1,
        activation: Optional[nn.Module] = None,
        normalization: Optional[Type[nn.Module]] = None,
    ) -> None:
        super().__init__()
        if activation is None:
            activation = nn.Hardswish(True)
        if normalization is None:
            normalization = nn.BatchNorm2d

        self.same_shape = in_channels == out_channels
        self.mid_channels = expansion_ratio * in_channels
        self.block = nn.Sequential(
            PointWiseConv2d(in_channels, self.mid_channels),
            normalization(self.mid_channels),
            activation,
            DepthWiseConv2d(self.mid_channels, kernel_size=kernel_size, stride=stride),
            normalization(self.mid_channels),
            activation,
            # md.sSEModule(self.mid_channels),
            SCSEModule(self.mid_channels, reduction=squeeze_ratio),
            # md.SEModule(self.mid_channels, reduction = squeeze_ratio),
            PointWiseConv2d(self.mid_channels, out_channels),
            normalization(out_channels),
        )

        if not self.same_shape:
            # 1x1 convolution used to match the number of channels in the skip feature maps with that
            # of the residual feature maps
            self.skip_conv = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
                normalization(out_channels),
            )

    def forward(self, x: Tensor) -> Tensor:
        residual = self.block(x)
        if not self.same_shape:
            x = self.skip_conv(x)
        return x + residual

Model

Code credit: https://github.com/jlcsilva/segmentation_models.pytorch

kelp.nn.models.efficientunetplusplus.model.EfficientUnetPlusPlus

Bases: SegmentationModel

The EfficientUNet++ is a fully convolutional neural network for ordinary and medical image semantic segmentation. Consists of an encoder and a decoder, connected by skip connections. The encoder extracts features of different spatial resolutions, which are fed to the decoder through skip connections. The decoder combines its own feature maps with the ones from skip connections to produce accurate segmentations masks. The EfficientUNet++ decoder architecture is based on the UNet++, a model composed of nested U-Net-like decoder sub-networks. To increase performance and computational efficiency, the EfficientUNet++ replaces the UNet++'s blocks with inverted residual blocks with depthwise convolutions and embedded spatial and channel attention mechanisms. Synergizes well with EfficientNet encoders. Due to their efficient visual representations (i.e., using few channels to represent extracted features), EfficientNet encoders require few computation from the decoder.

Parameters:

Name Type Description Default
encoder_name str

Name of the classification model that will be used as an encoder (a.k.a backbone) to extract features of different spatial resolution

'timm-efficientnet-b0'
encoder_depth int

A number of stages used in encoder in range [3, 5]. Each stage generate features two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). Default is 5

5
encoder_weights Optional[str]

One of None (random initialization), "imagenet" (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name)

'imagenet'
decoder_channels Optional[List[int]]

List of integers which specify in_channels parameter for convolutions used in decoder. Length of the list should be the same as encoder_depth

None
in_channels int

A number of input channels for the model, default is 3 (RGB images)

3
classes int

A number of classes for output mask (or you can think as a number of channels of output mask)

1
activation Optional[Union[str, Callable[[Any], Any]]]

An activation function to apply after the final convolution layer. Available options are "sigmoid", "softmax", "logsoftmax", "tanh", "identity", callable and None. Default is None

None
aux_params Optional[Dict[str, Any]]

Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if aux_params is not None (default). Supported params: - classes (int): A number of classes - pooling (str): One of "max", "avg". Default is "avg" - dropout (float): Dropout factor in [0, 1) - activation (str): An activation function to apply "sigmoid"/"softmax" (could be None to return logits)

None
Reference

Silva et al. 2021

Source code in kelp/nn/models/efficientunetplusplus/model.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class EfficientUnetPlusPlus(SegmentationModel):
    """The EfficientUNet++ is a fully convolutional neural network for ordinary and medical image semantic segmentation.
    Consists of an *encoder* and a *decoder*, connected by *skip connections*. The encoder extracts features of
    different spatial resolutions, which are fed to the decoder through skip connections. The decoder combines its
    own feature maps with the ones from skip connections to produce accurate segmentations masks.  The EfficientUNet++
    decoder architecture is based on the UNet++, a model composed of nested U-Net-like decoder sub-networks. To
    increase performance and computational efficiency, the EfficientUNet++ replaces the UNet++'s blocks with
    inverted residual blocks with depthwise convolutions and embedded spatial and channel attention mechanisms.
    Synergizes well with EfficientNet encoders. Due to their efficient visual representations (i.e., using few channels
    to represent extracted features), EfficientNet encoders require few computation from the decoder.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
                **callable** and **None**.
            Default is **None**
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)

    Reference:
        [Silva et al. 2021](https://arxiv.org/abs/2106.11447)

    """

    def __init__(
        self,
        encoder_name: str = "timm-efficientnet-b0",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_channels: Optional[List[int]] = None,
        squeeze_ratio: int = 1,
        expansion_ratio: int = 1,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, Callable[[Any], Any]]] = None,
        aux_params: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()
        if decoder_channels is None:
            decoder_channels = [256, 128, 64, 32, 16]

        self.classes = classes

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )

        self.decoder = EfficientUnetPlusPlusDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            squeeze_ratio=squeeze_ratio,
            expansion_ratio=expansion_ratio,
        )

        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "EfficientUNet++-{}".format(encoder_name)
        self.initialize()