Skip to content

transforms

The augmentation transforms related classes and helpers.

kelp.nn.data.transforms.MinMaxNormalize

Bases: Module

Min-Max normalization transform that uses provided min and max per-channel values for image transformation.

Parameters:

Name Type Description Default
min_vals Tensor

A Tensor of min values per-channel.

required
max_vals Tensor

A Tensor of max values per-channel.

required
Source code in kelp/nn/data/transforms.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class MinMaxNormalize(Module):
    """
    Min-Max normalization transform that uses provided min and max per-channel values for image transformation.

    Args:
        min_vals: A Tensor of min values per-channel.
        max_vals: A Tensor of max values per-channel.
    """

    def __init__(self, min_vals: Tensor, max_vals: Tensor) -> None:
        super().__init__()
        self.mins = min_vals.view(1, -1, 1, 1)
        self.maxs = max_vals.view(1, -1, 1, 1)

    def forward(self, x: Tensor) -> Tensor:
        mins = torch.as_tensor(self.mins, device=x.device, dtype=x.dtype)
        maxs = torch.as_tensor(self.maxs, device=x.device, dtype=x.dtype)
        x = x.clamp(mins, maxs)
        x = (x - mins) / (maxs - mins + consts.data.EPS)
        return x

kelp.nn.data.transforms.PerSampleMinMaxNormalize

Bases: Module

A per-sample normalization transform that will calculate min and max per-channel on the fly.

Source code in kelp/nn/data/transforms.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class PerSampleMinMaxNormalize(Module):
    """
    A per-sample normalization transform that will calculate min and max per-channel on the fly.
    """

    def forward(self, x: Tensor) -> Tensor:
        """
        Runs the normalization transform for specified batch of images.

        Args:
            x: The batch of images.

        Returns: A batch of normalized images.

        """
        vmin = torch.amin(x, dim=(2, 3)).unsqueeze(2).unsqueeze(3)
        vmax = torch.amax(x, dim=(2, 3)).unsqueeze(2).unsqueeze(3)
        return (x - vmin) / (vmax - vmin + consts.data.EPS)

kelp.nn.data.transforms.PerSampleMinMaxNormalize.forward

Runs the normalization transform for specified batch of images.

Parameters:

Name Type Description Default
x Tensor

The batch of images.

required
Source code in kelp/nn/data/transforms.py
47
48
49
50
51
52
53
54
55
56
57
58
59
def forward(self, x: Tensor) -> Tensor:
    """
    Runs the normalization transform for specified batch of images.

    Args:
        x: The batch of images.

    Returns: A batch of normalized images.

    """
    vmin = torch.amin(x, dim=(2, 3)).unsqueeze(2).unsqueeze(3)
    vmax = torch.amax(x, dim=(2, 3)).unsqueeze(2).unsqueeze(3)
    return (x - vmin) / (vmax - vmin + consts.data.EPS)

kelp.nn.data.transforms.PerSampleQuantileNormalize

Bases: Module

A per-sample normalization transform that will calculate min and max per-channel on the fly using provided quantile values.

Parameters:

Name Type Description Default
q_low float

The lower quantile value.

required
q_high float

The upper quantile value.

required
Source code in kelp/nn/data/transforms.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class PerSampleQuantileNormalize(Module):
    """
    A per-sample normalization transform that will calculate min and max per-channel on the fly
    using provided quantile values.

    Args:
        q_low: The lower quantile value.
        q_high: The upper quantile value.

    """

    def __init__(self, q_low: float, q_high: float) -> None:
        super().__init__()
        self.q_low = q_low
        self.q_high = q_high

    def forward(self, x: Tensor) -> Tensor:
        """
        Runs the normalization transform for specified batch of images.

        Args:
            x: The batch of images.

        Returns: A batch of normalized images.

        """
        flattened_sample = x.view(x.shape[0], x.shape[1], -1)
        vmin = torch.quantile(flattened_sample, self.q_low, dim=2).unsqueeze(2).unsqueeze(3)
        vmax = torch.quantile(flattened_sample, self.q_high, dim=2).unsqueeze(2).unsqueeze(3)
        x = x.clamp(vmin, vmax)
        x = (x - vmin) / (vmax - vmin + consts.data.EPS)
        return x

kelp.nn.data.transforms.PerSampleQuantileNormalize.forward

Runs the normalization transform for specified batch of images.

Parameters:

Name Type Description Default
x Tensor

The batch of images.

required
Source code in kelp/nn/data/transforms.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def forward(self, x: Tensor) -> Tensor:
    """
    Runs the normalization transform for specified batch of images.

    Args:
        x: The batch of images.

    Returns: A batch of normalized images.

    """
    flattened_sample = x.view(x.shape[0], x.shape[1], -1)
    vmin = torch.quantile(flattened_sample, self.q_low, dim=2).unsqueeze(2).unsqueeze(3)
    vmax = torch.quantile(flattened_sample, self.q_high, dim=2).unsqueeze(2).unsqueeze(3)
    x = x.clamp(vmin, vmax)
    x = (x - vmin) / (vmax - vmin + consts.data.EPS)
    return x

kelp.nn.data.transforms.RemoveNaNs

Bases: Module

Removes NaN values from the input tensor.

Parameters:

Name Type Description Default
min_vals Tensor

The min values per-channel to use when removing NaNs and neg-Inf.

required
max_vals Tensor

The min values per-channel to use when removing positive-Inf.

required
Source code in kelp/nn/data/transforms.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class RemoveNaNs(Module):
    """
    Removes NaN values from the input tensor.

    Args:
        min_vals: The min values per-channel to use when removing NaNs and neg-Inf.
        max_vals: The min values per-channel to use when removing positive-Inf.

    """

    def __init__(self, min_vals: Tensor, max_vals: Tensor) -> None:
        super().__init__()
        self.mins = min_vals.view(1, -1, 1, 1)
        self.maxs = max_vals.view(1, -1, 1, 1)

    def forward(self, x: Tensor) -> Tensor:
        """
        Runs the transform for specified batch of images.

        Args:
            x: The batch of images.

        Returns: A batch of normalized images.

        """
        mins = torch.as_tensor(self.mins, device=x.device, dtype=x.dtype)
        maxs = torch.as_tensor(self.maxs, device=x.device, dtype=x.dtype)
        x = torch.where(torch.isnan(x), mins, x)
        x = torch.where(torch.isneginf(x), mins, x)
        x = torch.where(torch.isinf(x), maxs, x)
        return x

kelp.nn.data.transforms.RemoveNaNs.forward

Runs the transform for specified batch of images.

Parameters:

Name Type Description Default
x Tensor

The batch of images.

required
Source code in kelp/nn/data/transforms.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def forward(self, x: Tensor) -> Tensor:
    """
    Runs the transform for specified batch of images.

    Args:
        x: The batch of images.

    Returns: A batch of normalized images.

    """
    mins = torch.as_tensor(self.mins, device=x.device, dtype=x.dtype)
    maxs = torch.as_tensor(self.maxs, device=x.device, dtype=x.dtype)
    x = torch.where(torch.isnan(x), mins, x)
    x = torch.where(torch.isneginf(x), mins, x)
    x = torch.where(torch.isinf(x), maxs, x)
    return x

kelp.nn.data.transforms.RemovePadding

Bases: Module

Removes specified padding from the input tensors.

Parameters:

Name Type Description Default
image_size int

The size of the target image after padding removal.

required
padded_image_size int

The size of the padded image before padding removal.

required
args Any

Arguments passed to super class.

()
kwargs Any

Keyword arguments passed to super class.

{}
Source code in kelp/nn/data/transforms.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class RemovePadding(nn.Module):
    """
    Removes specified padding from the input tensors.

    Args:
        image_size: The size of the target image after padding removal.
        padded_image_size: The size of the padded image before padding removal.
        args: Arguments passed to super class.
        kwargs: Keyword arguments passed to super class.

    """

    def __init__(self, image_size: int, padded_image_size: int, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.padding_to_trim = (padded_image_size - image_size) // 2
        self.crop_upper_bound = image_size + self.padding_to_trim

    def forward(self, x: Tensor) -> Tensor:
        """
        Runs the transform for specified batch of images.

        Args:
            x: The batch of images.

        Returns: A batch of normalized images.

        """
        x = x.squeeze()
        x = x[self.padding_to_trim : self.crop_upper_bound, self.padding_to_trim : self.crop_upper_bound]
        return x

kelp.nn.data.transforms.RemovePadding.forward

Runs the transform for specified batch of images.

Parameters:

Name Type Description Default
x Tensor

The batch of images.

required
Source code in kelp/nn/data/transforms.py
146
147
148
149
150
151
152
153
154
155
156
157
158
def forward(self, x: Tensor) -> Tensor:
    """
    Runs the transform for specified batch of images.

    Args:
        x: The batch of images.

    Returns: A batch of normalized images.

    """
    x = x.squeeze()
    x = x[self.padding_to_trim : self.crop_upper_bound, self.padding_to_trim : self.crop_upper_bound]
    return x

kelp.nn.data.transforms.build_append_index_transforms

Build an append index transforms based on specified spectral indices.

Parameters:

Name Type Description Default
spectral_indices List[str]

A list of spectral indices to use.

required
Source code in kelp/nn/data/transforms.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def build_append_index_transforms(spectral_indices: List[str]) -> Callable[[Tensor], Tensor]:
    """
    Build an append index transforms based on specified spectral indices.

    Args:
        spectral_indices: A list of spectral indices to use.

    Returns: A callable that can be used to transform batch of images.

    """
    transforms = K.AugmentationSequential(
        AppendDEMWM(  # type: ignore
            index_dem=BAND_INDEX_LOOKUP["DEM"],
            index_qa=BAND_INDEX_LOOKUP["QA"],
        ),
        *[
            SPECTRAL_INDEX_LOOKUP[index_name](
                index_swir=BAND_INDEX_LOOKUP["SWIR"],
                index_nir=BAND_INDEX_LOOKUP["NIR"],
                index_red=BAND_INDEX_LOOKUP["R"],
                index_green=BAND_INDEX_LOOKUP["G"],
                index_blue=BAND_INDEX_LOOKUP["B"],
                index_dem=BAND_INDEX_LOOKUP["DEM"],
                index_qa=BAND_INDEX_LOOKUP["QA"],
                index_water_mask=BAND_INDEX_LOOKUP["DEMWM"],
                mask_using_qa=not index_name.endswith("WM"),
                mask_using_water_mask=not index_name.endswith("WM"),
                fill_val=torch.nan,
            )
            for index_name in spectral_indices
            if index_name != "DEMWM"
        ],
        data_keys=["input"],
    ).to(DEVICE)
    return transforms  # type: ignore[no-any-return]

kelp.nn.data.transforms.min_max_normalize

Runs min-max normalization on the input array by calculating min and max per-channel values on the fly.

Parameters:

Name Type Description Default
arr ndarray

The array to normalize.

required
Source code in kelp/nn/data/transforms.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def min_max_normalize(arr: np.ndarray) -> np.ndarray:  # type: ignore[type-arg]
    """
    Runs min-max normalization on the input array by calculating min and max per-channel values on the fly.

    Args:
        arr: The array to normalize.

    Returns: Normalized array.

    """
    vmin = np.nanmin(arr, axis=(0, 1))
    vmax = np.nanmax(arr, axis=(0, 1))
    arr = arr.clip(0, vmax)
    arr = (arr - vmin) / (vmax - vmin + consts.data.EPS)
    arr = arr.clip(0, 1)
    return arr

kelp.nn.data.transforms.quantile_min_max_normalize

Runs min-max quantile normalization on the input array by calculating min and max per-channel values on the fly.

Parameters:

Name Type Description Default
x ndarray

The array to normalize.

required
q_lower float

The lower quantile.

0.01
q_upper float

The upper quantile.

0.99
Source code in kelp/nn/data/transforms.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def quantile_min_max_normalize(
    x: np.ndarray,  # type: ignore[type-arg]
    q_lower: float = 0.01,
    q_upper: float = 0.99,
) -> np.ndarray:  # type: ignore[type-arg]
    """
    Runs min-max quantile normalization on the input array by calculating min and max per-channel values on the fly.

    Args:
        x: The array to normalize.
        q_lower: The lower quantile.
        q_upper: The upper quantile.

    Returns: Normalized array.

    """
    vmin = np.expand_dims(np.expand_dims(np.quantile(x, q=q_lower, axis=(1, 2)), 1), 2)
    vmax = np.expand_dims(np.expand_dims(np.quantile(x, q=q_upper, axis=(1, 2)), 1), 2)
    return (x - vmin) / (vmax - vmin + consts.data.EPS)  # type: ignore[no-any-return]

kelp.nn.data.transforms.resolve_normalization_stats

Resolves normalization stats based on specified bands to use.

Parameters:

Name Type Description Default
dataset_stats Dict[str, Dict[str, float]]

The full per-band dataset statistics.

required
bands_to_use List[str]

The list of band names to use.

required
Source code in kelp/nn/data/transforms.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def resolve_normalization_stats(
    dataset_stats: Dict[str, Dict[str, float]],
    bands_to_use: List[str],
) -> Tuple[BandStats, int]:
    """
    Resolves normalization stats based on specified bands to use.

    Args:
        dataset_stats: The full per-band dataset statistics.
        bands_to_use: The list of band names to use.

    Returns: A tuple of stats and the number of bands to use.

    """
    band_stats = {band: dataset_stats[band] for band in bands_to_use}
    mean = [val["mean"] for val in band_stats.values()]
    std = [val["std"] for val in band_stats.values()]
    vmin = [val["min"] for val in band_stats.values()]
    vmax = [val["max"] for val in band_stats.values()]
    q01 = [val["q01"] for val in band_stats.values()]
    q99 = [val["q99"] for val in band_stats.values()]
    stats = BandStats(
        mean=torch.tensor(mean),
        std=torch.tensor(std),
        min=torch.tensor(vmin),
        max=torch.tensor(vmax),
        q01=torch.tensor(q01),
        q99=torch.tensor(q99),
    )
    return stats, len(band_stats)

kelp.nn.data.transforms.resolve_normalization_transform

Resolves the normalization transform.

Parameters:

Name Type Description Default
band_stats BandStats

The band statistics.

required
normalization_strategy Literal['min-max', 'quantile', 'per-sample-min-max', 'per-sample-quantile', 'z-score']

The normalization strategy.

'quantile'
Source code in kelp/nn/data/transforms.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def resolve_normalization_transform(
    band_stats: BandStats,
    normalization_strategy: Literal[
        "min-max",
        "quantile",
        "per-sample-min-max",
        "per-sample-quantile",
        "z-score",
    ] = "quantile",
) -> Union[_AugmentationBase, nn.Module]:
    """
    Resolves the normalization transform.

    Args:
        band_stats: The band statistics.
        normalization_strategy: The normalization strategy.

    Returns: A normalization transform to use for the image batch.

    """
    if normalization_strategy == "z-score":
        return K.Normalize(band_stats.mean, band_stats.std)  # type: ignore[no-any-return]
    elif normalization_strategy == "min-max":
        return MinMaxNormalize(min_vals=band_stats.min, max_vals=band_stats.max)
    elif normalization_strategy == "quantile":
        return MinMaxNormalize(min_vals=band_stats.q01, max_vals=band_stats.q99)
    elif normalization_strategy == "per-sample-quantile":
        return PerSampleQuantileNormalize(q_low=0.01, q_high=0.99)
    elif normalization_strategy == "per-sample-min-max":
        return PerSampleMinMaxNormalize()
    else:
        raise ValueError(f"{normalization_strategy} is not supported!")

kelp.nn.data.transforms.resolve_resize_transform

Resolves the input image and mask resize transform.

Parameters:

Name Type Description Default
image_or_mask Literal['image', 'mask']

Indicates if the transform is for an image or a mask.

required
resize_strategy Literal['pad', 'resize']

The resize strategy to use.

'pad'
image_size int

The size of the resized image.

352
interpolation Literal['nearest', 'nearest-exact', 'bilinear', 'bicubic']

The interpolation method to use for the "resize" strategy.

'nearest'
Source code in kelp/nn/data/transforms.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def resolve_resize_transform(
    image_or_mask: Literal["image", "mask"],
    resize_strategy: Literal["pad", "resize"] = "pad",
    image_size: int = 352,
    interpolation: Literal["nearest", "nearest-exact", "bilinear", "bicubic"] = "nearest",
) -> Callable[[Tensor], Tensor]:
    """
    Resolves the input image and mask resize transform.

    Args:
        image_or_mask: Indicates if the transform is for an image or a mask.
        resize_strategy: The resize strategy to use.
        image_size: The size of the resized image.
        interpolation: The interpolation method to use for the "resize" strategy.

    Returns:

    """
    interpolation_lookup = {
        "nearest": InterpolationMode.NEAREST,
        "nearest-exact": InterpolationMode.NEAREST_EXACT,
        "bilinear": InterpolationMode.BILINEAR,
        "bicubic": InterpolationMode.BICUBIC,
    }
    if resize_strategy == "pad":
        if image_size < 352:
            raise ValueError("Invalid resize strategy. Padding is only applicable when image size is greater than 352.")
        return T.Pad(  # type: ignore[no-any-return]
            padding=[
                (image_size - consts.data.TILE_SIZE) // 2,
            ],
            fill=0,
            padding_mode="constant",
        )
    elif resize_strategy == "resize":
        return T.Resize(  # type: ignore[no-any-return]
            size=(image_size, image_size),
            interpolation=interpolation_lookup[interpolation]
            if image_or_mask == "image"
            else InterpolationMode.NEAREST,
            antialias=False,
        )
    else:
        raise ValueError(f"{resize_strategy=} is not supported!")

kelp.nn.data.transforms.resolve_transforms

Resolves batch augmentation transformations to be used based on specified configuration.

Parameters:

Name Type Description Default
spectral_indices List[str]

The list of spectral indices to use.

required
band_index_lookup Dict[str, int]

The dictionary mapping band name to index in the input tensor.

required
band_stats BandStats

The band statistics to use.

required
mask_using_qa bool

A flag indicating whether to mask spectral indices with QA band.

required
mask_using_water_mask bool

A flag indicating whether to mask spectral indices with DEM Water Mask.

required
normalization_transform Union[_AugmentationBase, Module]

A normalization transformation.

required
stage Literal['train', 'val', 'test', 'predict']

A literal indicating the stage to use. One of ["train", "val", "test", "predict"].

required
Source code in kelp/nn/data/transforms.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def resolve_transforms(
    spectral_indices: List[str],
    band_index_lookup: Dict[str, int],
    band_stats: BandStats,
    mask_using_qa: bool,
    mask_using_water_mask: bool,
    normalization_transform: Union[_AugmentationBase, nn.Module],
    stage: Literal["train", "val", "test", "predict"],
) -> K.AugmentationSequential:
    """
    Resolves batch augmentation transformations to be used based on specified configuration.

    Args:
        spectral_indices: The list of spectral indices to use.
        band_index_lookup: The dictionary mapping band name to index in the input tensor.
        band_stats: The band statistics to use.
        mask_using_qa: A flag indicating whether to mask spectral indices with QA band.
        mask_using_water_mask: A flag indicating whether to mask spectral indices with DEM Water Mask.
        normalization_transform: A normalization transformation.
        stage: A literal indicating the stage to use. One of ["train", "val", "test", "predict"].

    Returns: An instance of AugmentationSequential.

    """
    common_transforms = []

    for index_name in spectral_indices:
        common_transforms.append(
            SPECTRAL_INDEX_LOOKUP[index_name](
                index_swir=band_index_lookup.get("SWIR", -1),
                index_nir=band_index_lookup.get("NIR", -1),
                index_red=band_index_lookup.get("R", -1),
                index_green=band_index_lookup.get("G", -1),
                index_blue=band_index_lookup.get("B", -1),
                index_dem=band_index_lookup.get("DEM", -1),
                index_qa=band_index_lookup.get("QA", -1),
                index_water_mask=band_index_lookup.get("DEMWM", -1),
                mask_using_qa=False if index_name.endswith("WM") else mask_using_qa,
                mask_using_water_mask=False if index_name.endswith("WM") else mask_using_water_mask,
                fill_val=torch.nan,
            )
        )

    common_transforms.extend(
        [
            RemoveNaNs(min_vals=band_stats.min, max_vals=band_stats.max),
            normalization_transform,
        ]
    )

    if stage == "train":
        return K.AugmentationSequential(
            *common_transforms,
            K.RandomRotation(p=0.5, degrees=90),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            data_keys=["input", "mask"],
        )
    else:
        return K.AugmentationSequential(
            *common_transforms,
            data_keys=["input"] if stage == "predict" else ["input", "mask"],
        )