Skip to content

sahi

SAHI inference logic.

kelp.nn.inference.sahi.inference_model

Runs inference on a batch of image tiles.

Parameters:

Name Type Description Default
x Tensor

The batch of image tiles.

required
model Module

The model to use for inference.

required
soft_labels bool

A flag indicating whether to use soft-labels.

False
tta bool

A flag indicating whether to use TTA.

False
tta_merge_mode str

The TTA merge mode.

'mean'
decision_threshold Optional[float]

An optional decision threshold to use. Will use torch.argmax by default.

None
Source code in kelp/nn/inference/sahi.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
@torch.inference_mode()
def inference_model(
    x: Tensor,
    model: nn.Module,
    soft_labels: bool = False,
    tta: bool = False,
    tta_merge_mode: str = "mean",
    decision_threshold: Optional[float] = None,
) -> Tensor:
    """
    Runs inference on a batch of image tiles.

    Args:
        x: The batch of image tiles.
        model: The model to use for inference.
        soft_labels: A flag indicating whether to use soft-labels.
        tta: A flag indicating whether to use TTA.
        tta_merge_mode: The TTA merge mode.
        decision_threshold: An optional decision threshold to use. Will use torch.argmax by default.

    Returns: A tensor with predictions.

    """
    x = x.to(model.device)
    with torch.no_grad():
        if tta:
            tta_model = ttach.SegmentationTTAWrapper(
                model=model,
                transforms=_test_time_transforms,
                merge_mode=tta_merge_mode,
            )
            y_hat = tta_model(x)
        else:
            y_hat = model(x)
        if soft_labels:
            y_hat = y_hat.sigmoid()[:, 1, :, :].float()
        elif decision_threshold is not None:
            y_hat = (y_hat.sigmoid()[:, 1, :, :] >= decision_threshold).long()  # type: ignore[attr-defined]
        else:
            y_hat = y_hat.argmax(dim=1)
    return y_hat

kelp.nn.inference.sahi.load_image

Helper function to load a satellite image and fill out missing pixels.

Parameters:

Name Type Description Default
image_path Path

The path to the image.

required
band_order List[int]

The band order to load.

required
fill_value nan

The fill value for missing pixels.

required
Source code in kelp/nn/inference/sahi.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def load_image(
    image_path: Path,
    band_order: List[int],
    fill_value: torch.nan,
) -> np.ndarray:  # type: ignore[type-arg]
    """
    Helper function to load a satellite image and fill out missing pixels.

    Args:
        image_path: The path to the image.
        band_order: The band order to load.
        fill_value: The fill value for missing pixels.

    Returns: An array with the image.

    """
    with rasterio.open(image_path) as src:
        img = src.read(band_order).astype(np.float32)
        img = np.where(img == -32768.0, fill_value, img)
    return img  # type: ignore[no-any-return]

kelp.nn.inference.sahi.merge_predictions

Merges the prediction tiles into a single image by averaging the predictions in the overlapping sections.

Parameters:

Name Type Description Default
tiles List[ndarray]

A list of tiles to merge back into one image.

required
original_shape Tuple[int, int, int]

The shape of the original image.

required
tile_size Tuple[int, int]

The tile size used to generate crops.

required
overlap int

The overlap between the tiles.

required
decision_threshold Optional[float]

An optional decision threshold.

None
Source code in kelp/nn/inference/sahi.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def merge_predictions(
    tiles: List[np.ndarray],  # type: ignore[type-arg]
    original_shape: Tuple[int, int, int],
    tile_size: Tuple[int, int],
    overlap: int,
    decision_threshold: Optional[float] = None,
) -> np.ndarray:  # type: ignore[type-arg]
    """
    Merges the prediction tiles into a single image by averaging the predictions in the overlapping sections.

    Args:
        tiles: A list of tiles to merge back into one image.
        original_shape: The shape of the original image.
        tile_size: The tile size used to generate crops.
        overlap: The overlap between the tiles.
        decision_threshold: An optional decision threshold.

    Returns: A numpy array representing merged tiles.

    """
    step = tile_size[0] - overlap
    prediction = np.zeros(original_shape, dtype=np.float32)
    counts = np.zeros(original_shape, dtype=np.float32)

    idx = 0
    for y in range(0, original_shape[0], step):
        for x in range(0, original_shape[1], step):
            h, w = prediction[y : y + tile_size[1], x : x + tile_size[0]].shape
            prediction[y : y + tile_size[1], x : x + tile_size[0]] += tiles[idx][:h, :w].astype(np.float32)
            counts[y : y + tile_size[1], x : x + tile_size[0]] += 1
            idx += 1

    # Avoid division by zero
    counts[counts == 0] = 1
    prediction /= counts

    if decision_threshold is not None:
        prediction = np.where(prediction > decision_threshold, 1, 0)

    return prediction.astype(np.int64)

kelp.nn.inference.sahi.predict_sahi

Runs SAHI on specified image list.

Parameters:

Name Type Description Default
model Module

The model to use for prediction.

required
file_paths List[Path]

The input image paths.

required
output_dir Path

The path to the output directory.

required
tile_size Tuple[int, int]

The tile size to use for SAHI.

required
overlap int

The overlap between tiles.

required
band_order List[int]

The band order.

required
resize_tf Callable[[Tensor], Tensor]

The resize transform to use for resizing the tiles.

required
input_transforms Callable[[Tensor], Tensor]

The input transform to use for input image before passing it to the model.

required
post_predict_transforms Callable[[Tensor], Tensor]

The post-predict transform to use for predictions.

required
fill_value float

The fill value for missing pixels.

0.0
soft_labels bool

A flag indicating whether to use soft-labels.

False
tta bool

A flag indicating whether to use TTA.

False
tta_merge_mode str

The TTA merge mode.

'mean'
decision_threshold Optional[float]

An optional decision threshold.

None
Source code in kelp/nn/inference/sahi.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def predict_sahi(
    model: nn.Module,
    file_paths: List[Path],
    output_dir: Path,
    tile_size: Tuple[int, int],
    overlap: int,
    band_order: List[int],
    resize_tf: Callable[[Tensor], Tensor],
    input_transforms: Callable[[Tensor], Tensor],
    post_predict_transforms: Callable[[Tensor], Tensor],
    soft_labels: bool = False,
    fill_value: float = 0.0,
    tta: bool = False,
    tta_merge_mode: str = "mean",
    decision_threshold: Optional[float] = None,
) -> None:
    """
    Runs SAHI on specified image list.

    Args:
        model: The model to use for prediction.
        file_paths: The input image paths.
        output_dir: The path to the output directory.
        tile_size: The tile size to use for SAHI.
        overlap: The overlap between tiles.
        band_order: The band order.
        resize_tf: The resize transform to use for resizing the tiles.
        input_transforms: The input transform to use for input image before passing it to the model.
        post_predict_transforms: The post-predict transform to use for predictions.
        fill_value: The fill value for missing pixels.
        soft_labels: A flag indicating whether to use soft-labels.
        tta: A flag indicating whether to use TTA.
        tta_merge_mode: The TTA merge mode.
        decision_threshold: An optional decision threshold.

    """
    model.eval()
    for file_path in tqdm(file_paths, desc="Processing files"):
        tile_id = file_path.name.split("_")[0]
        pred = process_image(
            image_path=file_path,
            model=model,
            tile_size=tile_size,
            overlap=overlap,
            input_transforms=input_transforms,
            post_predict_transforms=post_predict_transforms,
            soft_labels=soft_labels,
            tta=tta,
            tta_merge_mode=tta_merge_mode,
            decision_threshold=decision_threshold,
            band_order=band_order,
            resize_tf=resize_tf,
            fill_value=fill_value,
        )
        if soft_labels and decision_threshold is None:
            META["dtype"] = "float32"
        dest: DatasetWriter
        with rasterio.open(output_dir / f"{tile_id}_kelp.tif", "w", **META) as dest:
            dest.write(pred, 1)

kelp.nn.inference.sahi.process_image

Runs SAHI on a single image.

Parameters:

Name Type Description Default
image_path Path

The path to the image.

required
model Module

The model to use for prediction.

required
tile_size Tuple[int, int]

The tile size to use for SAHI.

required
overlap int

The overlap between tiles.

required
band_order List[int]

The band order.

required
resize_tf Callable[[Tensor], Tensor]

The resize transform to use for resizing the tiles.

required
input_transforms Callable[[Tensor], Tensor]

The input transform to use for input image before passing it to the model.

required
post_predict_transforms Callable[[Tensor], Tensor]

The post-predict transform to use for predictions.

required
fill_value float

The fill value for missing pixels.

0.0
soft_labels bool

A flag indicating whether to use soft-labels.

False
tta bool

A flag indicating whether to use TTA.

False
tta_merge_mode str

The TTA merge mode.

'mean'
decision_threshold Optional[float]

An optional decision threshold.

None
Source code in kelp/nn/inference/sahi.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def process_image(
    image_path: Path,
    model: nn.Module,
    tile_size: Tuple[int, int],
    overlap: int,
    band_order: List[int],
    resize_tf: Callable[[Tensor], Tensor],
    input_transforms: Callable[[Tensor], Tensor],
    post_predict_transforms: Callable[[Tensor], Tensor],
    fill_value: float = 0.0,
    soft_labels: bool = False,
    tta: bool = False,
    tta_merge_mode: str = "mean",
    decision_threshold: Optional[float] = None,
) -> np.ndarray:  # type: ignore[type-arg]
    """
    Runs SAHI on a single image.

    Args:
        image_path: The path to the image.
        model: The model to use for prediction.
        tile_size: The tile size to use for SAHI.
        overlap: The overlap between tiles.
        band_order: The band order.
        resize_tf: The resize transform to use for resizing the tiles.
        input_transforms: The input transform to use for input image before passing it to the model.
        post_predict_transforms: The post-predict transform to use for predictions.
        fill_value: The fill value for missing pixels.
        soft_labels: A flag indicating whether to use soft-labels.
        tta: A flag indicating whether to use TTA.
        tta_merge_mode: The TTA merge mode.
        decision_threshold: An optional decision threshold.

    Returns: An array with post-processed and merged tiles as final prediction.

    """
    image = load_image(
        image_path=image_path,
        fill_value=fill_value,
        band_order=band_order,
    )
    tiles = slice_image(image, tile_size, overlap)
    predictions = []
    img_batch = []
    for tile in tiles:
        x = resize_tf(torch.from_numpy(tile)).unsqueeze(0)
        img_batch.append(x)

    x = torch.cat(img_batch, dim=0).to(DEVICE)
    x = input_transforms(x)
    y_hat = inference_model(
        x=x,
        model=model,
        soft_labels=soft_labels,
        tta=tta,
        tta_merge_mode=tta_merge_mode,
        decision_threshold=decision_threshold,
    )
    prediction = post_predict_transforms(y_hat).detach().cpu().numpy()
    predictions.extend([tensor for tensor in prediction])

    merged_prediction = merge_predictions(
        tiles=predictions,
        original_shape=image.shape[1:],  # type: ignore[arg-type]
        tile_size=tile_size,
        overlap=overlap,
        decision_threshold=decision_threshold,
    )
    return merged_prediction

kelp.nn.inference.sahi.slice_image

Helper function to slice an image into smaller tiles with a given overlap.

Parameters:

Name Type Description Default
image ndarray

The image to slice.

required
tile_size Tuple[int, int]

The size of the tile.

required
overlap int

The overlap between tiles.

required
Source code in kelp/nn/inference/sahi.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def slice_image(
    image: np.ndarray,  # type: ignore[type-arg]
    tile_size: Tuple[int, int],
    overlap: int,
) -> List[np.ndarray]:  # type: ignore[type-arg]
    """
    Helper function to slice an image into smaller tiles with a given overlap.

    Args:
        image: The image to slice.
        tile_size: The size of the tile.
        overlap: The overlap between tiles.

    Returns: A list of sliced images.

    """
    tiles = []
    height, width = image.shape[1], image.shape[2]
    step = tile_size[0] - overlap

    for y in range(0, height, step):
        for x in range(0, width, step):
            tile = image[:, y : y + tile_size[1], x : x + tile_size[0]]
            # Padding the tile if it's smaller than the expected size (at edges)
            if tile.shape[1] < tile_size[1] or tile.shape[2] < tile_size[0]:
                tile = np.pad(
                    tile,
                    ((0, 0), (0, max(0, tile_size[1] - tile.shape[1])), (0, max(0, tile_size[0] - tile.shape[2]))),
                    mode="constant",
                    constant_values=0,
                )
            tiles.append(tile)
    return tiles