Skip to content

datamodule

The Kelp Forest DataModule.

kelp.nn.data.datamodule.KelpForestDataModule

Bases: LightningDataModule

A LightningDataModule that handles all data-related setup for the Kelp Forest Segmentation Task.

Parameters:

Name Type Description Default
dataset_stats Dict[str, Dict[str, float]]

The per-band statistics dictionary.

required
train_images Optional[List[Path]]

The list of training images.

None
train_masks Optional[List[Path]]

The list of training masks.

None
val_images Optional[List[Path]]

The list of validation images.

None
val_masks Optional[List[Path]]

The list of validation mask.

None
test_images Optional[List[Path]]

The list of test images.

None
test_masks Optional[List[Path]]

The list of test masks.

None
predict_images Optional[List[Path]]

The list of prediction images.

None
spectral_indices Optional[List[str]]

The list of spectral indices to append to the input tensor.

None
bands Optional[List[str]]

The list of band names to use.

None
missing_pixels_fill_value float

The value to fill missing pixels with.

0.0
batch_size int

The batch size.

32
num_workers int

The number of workers to use for data loading.

0
sahi bool

Flag indicating whether we are using SAHI dataset.

False
image_size int

The size of the input image.

352
interpolation Literal['nearest', 'nearest-exact', 'bilinear', 'bicubic']

The interpolation to use when performing resize operation.

'nearest'
resize_strategy Literal['pad', 'resize']

The resize strategy to use. One of ['pad', 'resize'].

'pad'
normalization_strategy Literal['min-max', 'quantile', 'per-sample-min-max', 'per-sample-quantile', 'z-score']

The normalization strategy to use.

'quantile'
mask_using_qa bool

A flag indicating whether spectral index bands should be masked with QA band.

False
mask_using_water_mask bool

A flag indicating whether spectral index bands should be masked with DEM Water Mask.

False
use_weighted_sampler bool

A flag indicating whether to use weighted sampler.

False
samples_per_epoch int

The number of samples per epoch if using weighted sampler.

10240
image_weights Optional[List[float]]

The weights per input image for weighted sampler if using weighted sampler.

None
**kwargs Any

Extra keywords. Unused.

{}
Source code in kelp/nn/data/datamodule.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
class KelpForestDataModule(pl.LightningDataModule):
    """
    A LightningDataModule that handles all data-related setup for the Kelp Forest Segmentation Task.

    Args:
        dataset_stats: The per-band statistics dictionary.
        train_images: The list of training images.
        train_masks: The list of training masks.
        val_images: The list of validation images.
        val_masks: The list of validation mask.
        test_images: The list of test images.
        test_masks: The list of test masks.
        predict_images: The list of prediction images.
        spectral_indices: The list of spectral indices to append to the input tensor.
        bands: The list of band names to use.
        missing_pixels_fill_value: The value to fill missing pixels with.
        batch_size: The batch size.
        num_workers: The number of workers to use for data loading.
        sahi: Flag indicating whether we are using SAHI dataset.
        image_size: The size of the input image.
        interpolation: The interpolation to use when performing resize operation.
        resize_strategy: The resize strategy to use. One of ['pad', 'resize'].
        normalization_strategy: The normalization strategy to use.
        mask_using_qa: A flag indicating whether spectral index bands should be masked with QA band.
        mask_using_water_mask: A flag indicating whether spectral index bands should be masked with DEM Water Mask.
        use_weighted_sampler: A flag indicating whether to use weighted sampler.
        samples_per_epoch: The number of samples per epoch if using weighted sampler.
        image_weights: The weights per input image for weighted sampler if using weighted sampler.
        **kwargs: Extra keywords. Unused.
    """

    base_bands = [
        "SWIR",
        "NIR",
        "R",
        "G",
        "B",
        "QA",
        "DEM",
    ]

    def __init__(
        self,
        dataset_stats: Dict[str, Dict[str, float]],
        train_images: Optional[List[Path]] = None,
        train_masks: Optional[List[Path]] = None,
        val_images: Optional[List[Path]] = None,
        val_masks: Optional[List[Path]] = None,
        test_images: Optional[List[Path]] = None,
        test_masks: Optional[List[Path]] = None,
        predict_images: Optional[List[Path]] = None,
        spectral_indices: Optional[List[str]] = None,
        bands: Optional[List[str]] = None,
        missing_pixels_fill_value: float = 0.0,
        batch_size: int = 32,
        num_workers: int = 0,
        sahi: bool = False,
        image_size: int = 352,
        interpolation: Literal["nearest", "nearest-exact", "bilinear", "bicubic"] = "nearest",
        resize_strategy: Literal["pad", "resize"] = "pad",
        normalization_strategy: Literal[
            "min-max",
            "quantile",
            "per-sample-min-max",
            "per-sample-quantile",
            "z-score",
        ] = "quantile",
        mask_using_qa: bool = False,
        mask_using_water_mask: bool = False,
        use_weighted_sampler: bool = False,
        samples_per_epoch: int = 10240,
        image_weights: Optional[List[float]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()  # type: ignore[no-untyped-call]
        bands = self._guard_against_invalid_bands_config(bands)
        spectral_indices = self._guard_against_invalid_spectral_indices_config(
            bands_to_use=bands,
            spectral_indices=spectral_indices,
            mask_using_qa=mask_using_qa,
            mask_using_water_mask=mask_using_water_mask,
        )
        self.dataset_stats = dataset_stats
        self.train_images = train_images or []
        self.train_masks = train_masks or []
        self.val_images = val_images or []
        self.val_masks = val_masks or []
        self.test_images = test_images or []
        self.test_masks = test_masks or []
        self.predict_images = predict_images or []
        self.spectral_indices = spectral_indices
        self.bands = bands
        self.band_order = [self.base_bands.index(band) for band in self.bands]
        self.bands_to_use = self.bands + self.spectral_indices
        self.band_index_lookup = {band: idx for idx, band in enumerate(self.bands_to_use)}
        self.missing_pixels_fill_value = missing_pixels_fill_value
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.sahi = sahi
        self.image_size = image_size
        self.interpolation = interpolation
        self.normalization_strategy = normalization_strategy
        self.mask_using_qa = mask_using_qa
        self.mask_using_water_mask = mask_using_water_mask
        self.use_weighted_sampler = use_weighted_sampler
        self.samples_per_epoch = samples_per_epoch
        self.image_weights = image_weights or [1.0 for _ in self.train_images]
        self.band_stats, self.in_channels = resolve_normalization_stats(
            dataset_stats=dataset_stats,
            bands_to_use=self.bands_to_use,
        )
        self.normalization_transform = resolve_normalization_transform(
            band_stats=self.band_stats,
            normalization_strategy=self.normalization_strategy,
        )
        self.train_augmentations = resolve_transforms(
            spectral_indices=self.spectral_indices,
            band_index_lookup=self.band_index_lookup,
            band_stats=self.band_stats,
            mask_using_qa=self.mask_using_qa,
            mask_using_water_mask=self.mask_using_water_mask,
            normalization_transform=self.normalization_transform,
            stage="train",
        )
        self.val_augmentations = resolve_transforms(
            spectral_indices=self.spectral_indices,
            band_index_lookup=self.band_index_lookup,
            band_stats=self.band_stats,
            mask_using_qa=self.mask_using_qa,
            mask_using_water_mask=self.mask_using_water_mask,
            normalization_transform=self.normalization_transform,
            stage="val",
        )
        self.test_augmentations = resolve_transforms(
            spectral_indices=self.spectral_indices,
            band_index_lookup=self.band_index_lookup,
            band_stats=self.band_stats,
            mask_using_qa=self.mask_using_qa,
            mask_using_water_mask=self.mask_using_water_mask,
            normalization_transform=self.normalization_transform,
            stage="test",
        )
        self.predict_augmentations = resolve_transforms(
            spectral_indices=self.spectral_indices,
            band_index_lookup=self.band_index_lookup,
            band_stats=self.band_stats,
            mask_using_qa=self.mask_using_qa,
            mask_using_water_mask=self.mask_using_water_mask,
            normalization_transform=self.normalization_transform,
            stage="predict",
        )
        self.image_resize_tf = resolve_resize_transform(
            image_or_mask="image",
            resize_strategy=resize_strategy,
            image_size=image_size,
            interpolation=interpolation,
        )
        self.mask_resize_tf = resolve_resize_transform(
            image_or_mask="mask",
            resize_strategy=resize_strategy,
            image_size=image_size,
            interpolation=interpolation,
        )

    def _guard_against_invalid_bands_config(self, bands: Optional[List[str]]) -> List[str]:
        if not bands:
            return self.base_bands

        if set(bands).issubset(set(self.base_bands)):
            return bands

        raise ValueError(f"{bands=} should be a subset of {self.base_bands=}")

    def _guard_against_invalid_spectral_indices_config(
        self,
        bands_to_use: List[str],
        spectral_indices: Optional[List[str]] = None,
        mask_using_qa: bool = False,
        mask_using_water_mask: bool = False,
    ) -> List[str]:
        if not spectral_indices:
            return []

        if "DEM" not in bands_to_use and "DEMWM" in spectral_indices:
            raise ValueError(
                f"You specified 'DEMWM' as one of spectral indices but 'DEM' is not in {bands_to_use=}, "
                f"which corresponds to {bands_to_use=}"
            )

        if "QA" not in bands_to_use and mask_using_qa:
            raise ValueError(
                f"You specified {mask_using_qa=} but 'QA' is not in {bands_to_use=}, "
                f"which corresponds to {bands_to_use=}"
            )

        if mask_using_water_mask and "DEMWM" not in spectral_indices:
            raise ValueError(f"You specified {mask_using_water_mask=} but 'DEMWM' is not in {spectral_indices=}")

        return spectral_indices

    def _build_dataset(self, images: List[Path], masks: Optional[List[Path]] = None) -> KelpForestSegmentationDataset:
        ds = KelpForestSegmentationDataset(
            image_fps=images,
            mask_fps=masks,
            transforms=self._common_transforms,
            band_order=self.band_order,
            fill_value=self.missing_pixels_fill_value,
        )
        return ds

    def _apply_transform(
        self,
        transforms: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]],
        batch: Dict[str, Tensor],
    ) -> Dict[str, Tensor]:
        x = batch["image"]
        # Kornia expects masks to be floats with a channel dimension
        y = batch["mask"].float().unsqueeze(1)
        x, y = transforms(x, y)
        batch["image"] = x
        # torchmetrics expects masks to be longs without a channel dimension
        batch["mask"] = y.squeeze(1).long()
        return batch

    def _apply_predict_transform(
        self,
        transforms: Callable[[Tensor], Tensor],
        batch: Dict[str, Tensor],
    ) -> Dict[str, Tensor]:
        x = batch["image"]
        x = transforms(x)
        batch["image"] = x
        return batch

    def _common_transforms(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]:
        sample["image"] = self.image_resize_tf(sample["image"])
        if "mask" in sample:
            sample["mask"] = self.mask_resize_tf(sample["mask"].unsqueeze(0)).squeeze()
        return sample

    def on_after_batch_transfer(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
        """Apply batch augmentations after batch is transferred to the device.

        Args:
            batch: mini-batch of data
            batch_idx: batch index

        Returns:
            augmented mini-batch
        """
        if (
            hasattr(self, "trainer")
            and self.trainer is not None
            and hasattr(self.trainer, "training")
            and self.trainer.training
        ):
            batch = self._apply_transform(self.train_augmentations, batch)
        elif (
            hasattr(self, "trainer")
            and self.trainer is not None
            and hasattr(self.trainer, "predicting")
            and self.trainer.predicting
        ):
            batch = self._apply_predict_transform(self.predict_augmentations, batch)
        else:
            batch = self._apply_transform(self.val_augmentations, batch)

        return batch

    def setup(self, stage: Optional[str] = None) -> None:
        """Initialize the main ``Dataset`` objects.

        This method is called once per GPU per run.

        Args:
            stage: stage to set up
        """
        if self.train_images:
            self.train_dataset = self._build_dataset(self.train_images, self.train_masks)
        if self.val_images:
            self.val_dataset = self._build_dataset(self.val_images, self.val_masks)
        if self.test_images:
            self.test_dataset = self._build_dataset(self.test_images, self.test_masks)
        if self.predict_images:
            self.predict_dataset = self._build_dataset(self.predict_images)

    def train_dataloader(self) -> DataLoader[Any]:
        """Return a DataLoader for training.

        Returns:
            training data loader
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=WeightedRandomSampler(
                weights=self.image_weights,
                num_samples=self.samples_per_epoch,
            )
            if self.use_weighted_sampler
            else None,
            shuffle=True if not self.use_weighted_sampler else False,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        """Return a DataLoader for validation.

        Returns:
            validation data loader
        """
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self) -> DataLoader[Any]:
        """Return a DataLoader for testing.

        Returns:
            testing data loader
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def predict_dataloader(self) -> DataLoader[Any]:
        """Return a DataLoader for prediction.

        Returns:
            prediction data loader
        """
        return DataLoader(
            self.predict_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def plot_sample(self, *args: Any, **kwargs: Any) -> plt.Figure:
        """Run :meth:`kelp.nn.data.dataset.KelpForestSegmentationDataset.plot_sample`."""
        return self.val_dataset.plot_sample(*args, **kwargs)

    def plot_batch(self, *args: Any, **kwargs: Any) -> FigureGrids:
        """Run :meth:`kelp.nn.data.dataset.KelpForestSegmentationDataset.plot_batch`."""
        return self.val_dataset.plot_batch(*args, **kwargs)

    @classmethod
    def resolve_file_paths(
        cls,
        data_dir: Path,
        metadata: pd.DataFrame,
        cv_split: int,
        split: str,
        sahi: bool = False,
    ) -> Tuple[List[Path], List[Path]]:
        """
        Resolves file paths using specified metadata dataframe.

        Args:
            data_dir: The data directory.
            metadata: The metadata dataframe.
            cv_split: The CV fold to use.
            split: The split to use (train, val, test).
            sahi: A flag indicating whether SAHI dataset is used.

        Returns: A tuple with input image paths and target (mask) image paths

        """
        split_data = metadata[metadata[f"split_{cv_split}"] == split]
        img_folder = consts.data.TRAIN if split in [consts.data.TRAIN, consts.data.VAL] else consts.data.TEST
        image_paths = sorted(
            split_data.apply(
                lambda row: data_dir
                / img_folder
                / "images"
                / (
                    f"{row['tile_id']}_satellite_{row['j']}_{row['i']}.tif"
                    if sahi
                    else f"{row['tile_id']}_satellite.tif"
                ),
                axis=1,
            ).tolist()
        )
        mask_paths = sorted(
            split_data.apply(
                lambda row: data_dir
                / img_folder
                / "masks"
                / (f"{row['tile_id']}_kelp_{row['j']}_{row['i']}.tif" if sahi else f"{row['tile_id']}_kelp.tif"),
                axis=1,
            ).tolist()
        )
        return image_paths, mask_paths

    @classmethod
    def _calculate_image_weights(
        cls,
        df: pd.DataFrame,
        has_kelp_importance_factor: float = 1.0,
        kelp_pixels_pct_importance_factor: float = 1.0,
        qa_ok_importance_factor: float = 1.0,
        qa_corrupted_pixels_pct_importance_factor: float = 1.0,
        almost_all_water_importance_factor: float = -1.0,
        dem_nan_pixels_pct_importance_factor: float = -1.0,
        dem_zero_pixels_pct_importance_factor: float = -1.0,
        sahi: bool = False,
    ) -> pd.DataFrame:
        def resolve_weight(row: pd.Series) -> float:
            if row["original_split"] == "test":
                return 0.0

            has_kelp = int(row["kelp_pxls"] > 0) if sahi else int(row["has_kelp"])
            kelp_pixels_pct = row["kelp_pct"] if sahi else row["kelp_pixels_pct"]
            qa_ok = int(row["qa_ok"])
            water_pixels_pct = row["water_pixels_pct"]
            qa_corrupted_pixels_pct = row["qa_corrupted_pixels_pct"]
            dem_nan_pixels_pct = row["dem_nan_pixels_pct"]
            dem_zero_pixels_pct = row["dem_zero_pixels_pct"]

            weight = (
                has_kelp_importance_factor * has_kelp
                + kelp_pixels_pct_importance_factor * (1 - kelp_pixels_pct)
                + qa_ok_importance_factor * qa_ok
                + qa_corrupted_pixels_pct_importance_factor * (1 - qa_corrupted_pixels_pct)
                + almost_all_water_importance_factor * (1 - water_pixels_pct)
                + dem_nan_pixels_pct_importance_factor * (1 - dem_nan_pixels_pct)
                + dem_zero_pixels_pct_importance_factor * (1 - dem_zero_pixels_pct)
            )
            return weight  # type: ignore[no-any-return]

        df["weight"] = df.apply(resolve_weight, axis=1)
        min_val = df["weight"].min()
        max_val = df["weight"].max()
        df["weight"] = (df["weight"] - min_val) / (max_val - min_val + consts.data.EPS)
        return df

    @classmethod
    def _resolve_image_weights(cls, df: pd.DataFrame, image_paths: List[Path]) -> List[float]:
        tile_ids = [fp.stem.split("_")[0] for fp in image_paths]
        weights = df[df["tile_id"].isin(tile_ids)].sort_values("tile_id")["weight"].tolist()
        return weights  # type: ignore[no-any-return]

    @classmethod
    def from_metadata_file(
        cls,
        data_dir: Path,
        metadata_fp: Path,
        dataset_stats: Dict[str, Dict[str, float]],
        cv_split: int,
        has_kelp_importance_factor: float = 1.0,
        kelp_pixels_pct_importance_factor: float = 1.0,
        qa_ok_importance_factor: float = 1.0,
        almost_all_water_importance_factor: float = 1.0,
        qa_corrupted_pixels_pct_importance_factor: float = -1.0,
        dem_nan_pixels_pct_importance_factor: float = -1.0,
        dem_zero_pixels_pct_importance_factor: float = -1.0,
        sahi: bool = False,
        **kwargs: Any,
    ) -> KelpForestDataModule:
        """
        Factory method to create the KelpForestDataModule based on metadata file.

        Args:
            data_dir: The path to the data directory.
            metadata_fp: The path to the metadata file.
            dataset_stats: The per-band dataset statistics.
            cv_split: The CV fold number to use.
            has_kelp_importance_factor: The importance factor for the has_kelp flag.
            kelp_pixels_pct_importance_factor: The importance factor for the kelp_pixels_pct value.
            qa_ok_importance_factor: The importance factor for the has_kelp flag.
            almost_all_water_importance_factor: The importance factor for the almost_all_water flag.
            qa_corrupted_pixels_pct_importance_factor: The importance factor for the qa_corrupted_pixels_pct value.
            dem_nan_pixels_pct_importance_factor: The importance factor for the dem_nan_pixels_pct value.
            dem_zero_pixels_pct_importance_factor: The importance factor for the dem_zero_pixels_pct value.
            sahi: A flag indicating whether SAHI dataset is used.
            **kwargs: Other keyword arguments passed to the KelpForestDataModule constructor.

        Returns: An instance of KelpForestDataModule.

        """
        metadata = cls._calculate_image_weights(
            df=pd.read_parquet(metadata_fp),
            has_kelp_importance_factor=has_kelp_importance_factor,
            kelp_pixels_pct_importance_factor=kelp_pixels_pct_importance_factor,
            qa_ok_importance_factor=qa_ok_importance_factor,
            qa_corrupted_pixels_pct_importance_factor=qa_corrupted_pixels_pct_importance_factor,
            almost_all_water_importance_factor=almost_all_water_importance_factor,
            dem_nan_pixels_pct_importance_factor=dem_nan_pixels_pct_importance_factor,
            dem_zero_pixels_pct_importance_factor=dem_zero_pixels_pct_importance_factor,
            sahi=sahi,
        )
        train_images, train_masks = cls.resolve_file_paths(
            data_dir=data_dir, metadata=metadata, cv_split=cv_split, split=consts.data.TRAIN, sahi=sahi
        )
        val_images, val_masks = cls.resolve_file_paths(
            data_dir=data_dir, metadata=metadata, cv_split=cv_split, split=consts.data.VAL, sahi=sahi
        )
        test_images, test_masks = cls.resolve_file_paths(
            data_dir=data_dir, metadata=metadata, cv_split=cv_split, split=consts.data.VAL, sahi=sahi
        )
        image_weights = cls._resolve_image_weights(df=metadata, image_paths=train_images)
        return cls(
            train_images=train_images,
            train_masks=train_masks,
            val_images=val_images,
            val_masks=val_masks,
            test_images=test_images,
            test_masks=test_masks,
            predict_images=None,
            image_weights=image_weights,
            dataset_stats=dataset_stats,
            **kwargs,
        )

    @classmethod
    def from_folders(
        cls,
        dataset_stats: Dict[str, Dict[str, float]],
        train_data_folder: Optional[Path] = None,
        val_data_folder: Optional[Path] = None,
        test_data_folder: Optional[Path] = None,
        predict_data_folder: Optional[Path] = None,
        **kwargs: Any,
    ) -> KelpForestDataModule:
        """
        Factory method to create the KelpForestDataModule based on folder paths.

        Args:
            dataset_stats: The per-band dataset statistics.
            train_data_folder: The path to the training data folder.
            val_data_folder: The path to the val data folder.
            test_data_folder: The path to the test data folder.
            predict_data_folder: The path to the prediction data folder.
            **kwargs: Other keyword arguments passed to the KelpForestDataModule constructor.

        Returns: An instance of KelpForestDataModule.

        """
        return cls(
            train_images=sorted(list(train_data_folder.glob("images/*.tif")))
            if train_data_folder and train_data_folder.exists()
            else None,
            train_masks=sorted(list(train_data_folder.glob("masks/*.tif")))
            if train_data_folder and train_data_folder.exists()
            else None,
            val_images=sorted(list(val_data_folder.glob("images/*.tif")))
            if val_data_folder and val_data_folder.exists()
            else None,
            val_masks=sorted(list(val_data_folder.glob("masks/*.tif")))
            if val_data_folder and val_data_folder.exists()
            else None,
            test_images=sorted(list(test_data_folder.glob("images/*.tif")))
            if test_data_folder and test_data_folder.exists()
            else None,
            test_masks=sorted(list(test_data_folder.glob("masks/*.tif")))
            if test_data_folder and test_data_folder.exists()
            else None,
            predict_images=sorted(list(predict_data_folder.rglob("*.tif")))
            if predict_data_folder and predict_data_folder.exists()
            else None,
            dataset_stats=dataset_stats,
            **kwargs,
        )

    @classmethod
    def from_file_paths(
        cls,
        dataset_stats: Dict[str, Dict[str, float]],
        train_images: Optional[List[Path]] = None,
        train_masks: Optional[List[Path]] = None,
        val_images: Optional[List[Path]] = None,
        val_masks: Optional[List[Path]] = None,
        test_images: Optional[List[Path]] = None,
        test_masks: Optional[List[Path]] = None,
        predict_images: Optional[List[Path]] = None,
        spectral_indices: Optional[List[str]] = None,
        batch_size: int = 32,
        image_size: int = 352,
        num_workers: int = 0,
        **kwargs: Any,
    ) -> KelpForestDataModule:
        """
        Factory method to create the KelpForestDataModule based on file paths.

        Args:
            dataset_stats: The per-band dataset statistics.
            train_images: The list of training images.
            train_masks: The list of training masks.
            val_images: The list of validation images.
            val_masks: The list of validation mask.
            test_images: The list of test images.
            test_masks: The list of test masks.
            predict_images: The list of prediction images.
            spectral_indices: The list of spectral indices to append to the input tensor.
            batch_size: The batch size.
            num_workers: The number of workers to use for data loading.
            image_size: The size of the input image.
            **kwargs: Other keyword arguments passed to the KelpForestDataModule constructor.

        Returns: An instance of KelpForestDataModule.

        """
        return cls(
            train_images=train_images,
            train_masks=train_masks,
            val_images=val_images,
            val_masks=val_masks,
            test_images=test_images,
            test_masks=test_masks,
            predict_images=predict_images,
            dataset_stats=dataset_stats,
            spectral_indices=spectral_indices,
            batch_size=batch_size,
            image_size=image_size,
            num_workers=num_workers,
            **kwargs,
        )

kelp.nn.data.datamodule.KelpForestDataModule.from_file_paths classmethod

Factory method to create the KelpForestDataModule based on file paths.

Parameters:

Name Type Description Default
dataset_stats Dict[str, Dict[str, float]]

The per-band dataset statistics.

required
train_images Optional[List[Path]]

The list of training images.

None
train_masks Optional[List[Path]]

The list of training masks.

None
val_images Optional[List[Path]]

The list of validation images.

None
val_masks Optional[List[Path]]

The list of validation mask.

None
test_images Optional[List[Path]]

The list of test images.

None
test_masks Optional[List[Path]]

The list of test masks.

None
predict_images Optional[List[Path]]

The list of prediction images.

None
spectral_indices Optional[List[str]]

The list of spectral indices to append to the input tensor.

None
batch_size int

The batch size.

32
num_workers int

The number of workers to use for data loading.

0
image_size int

The size of the input image.

352
**kwargs Any

Other keyword arguments passed to the KelpForestDataModule constructor.

{}
Source code in kelp/nn/data/datamodule.py
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
@classmethod
def from_file_paths(
    cls,
    dataset_stats: Dict[str, Dict[str, float]],
    train_images: Optional[List[Path]] = None,
    train_masks: Optional[List[Path]] = None,
    val_images: Optional[List[Path]] = None,
    val_masks: Optional[List[Path]] = None,
    test_images: Optional[List[Path]] = None,
    test_masks: Optional[List[Path]] = None,
    predict_images: Optional[List[Path]] = None,
    spectral_indices: Optional[List[str]] = None,
    batch_size: int = 32,
    image_size: int = 352,
    num_workers: int = 0,
    **kwargs: Any,
) -> KelpForestDataModule:
    """
    Factory method to create the KelpForestDataModule based on file paths.

    Args:
        dataset_stats: The per-band dataset statistics.
        train_images: The list of training images.
        train_masks: The list of training masks.
        val_images: The list of validation images.
        val_masks: The list of validation mask.
        test_images: The list of test images.
        test_masks: The list of test masks.
        predict_images: The list of prediction images.
        spectral_indices: The list of spectral indices to append to the input tensor.
        batch_size: The batch size.
        num_workers: The number of workers to use for data loading.
        image_size: The size of the input image.
        **kwargs: Other keyword arguments passed to the KelpForestDataModule constructor.

    Returns: An instance of KelpForestDataModule.

    """
    return cls(
        train_images=train_images,
        train_masks=train_masks,
        val_images=val_images,
        val_masks=val_masks,
        test_images=test_images,
        test_masks=test_masks,
        predict_images=predict_images,
        dataset_stats=dataset_stats,
        spectral_indices=spectral_indices,
        batch_size=batch_size,
        image_size=image_size,
        num_workers=num_workers,
        **kwargs,
    )

kelp.nn.data.datamodule.KelpForestDataModule.from_folders classmethod

Factory method to create the KelpForestDataModule based on folder paths.

Parameters:

Name Type Description Default
dataset_stats Dict[str, Dict[str, float]]

The per-band dataset statistics.

required
train_data_folder Optional[Path]

The path to the training data folder.

None
val_data_folder Optional[Path]

The path to the val data folder.

None
test_data_folder Optional[Path]

The path to the test data folder.

None
predict_data_folder Optional[Path]

The path to the prediction data folder.

None
**kwargs Any

Other keyword arguments passed to the KelpForestDataModule constructor.

{}
Source code in kelp/nn/data/datamodule.py
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
@classmethod
def from_folders(
    cls,
    dataset_stats: Dict[str, Dict[str, float]],
    train_data_folder: Optional[Path] = None,
    val_data_folder: Optional[Path] = None,
    test_data_folder: Optional[Path] = None,
    predict_data_folder: Optional[Path] = None,
    **kwargs: Any,
) -> KelpForestDataModule:
    """
    Factory method to create the KelpForestDataModule based on folder paths.

    Args:
        dataset_stats: The per-band dataset statistics.
        train_data_folder: The path to the training data folder.
        val_data_folder: The path to the val data folder.
        test_data_folder: The path to the test data folder.
        predict_data_folder: The path to the prediction data folder.
        **kwargs: Other keyword arguments passed to the KelpForestDataModule constructor.

    Returns: An instance of KelpForestDataModule.

    """
    return cls(
        train_images=sorted(list(train_data_folder.glob("images/*.tif")))
        if train_data_folder and train_data_folder.exists()
        else None,
        train_masks=sorted(list(train_data_folder.glob("masks/*.tif")))
        if train_data_folder and train_data_folder.exists()
        else None,
        val_images=sorted(list(val_data_folder.glob("images/*.tif")))
        if val_data_folder and val_data_folder.exists()
        else None,
        val_masks=sorted(list(val_data_folder.glob("masks/*.tif")))
        if val_data_folder and val_data_folder.exists()
        else None,
        test_images=sorted(list(test_data_folder.glob("images/*.tif")))
        if test_data_folder and test_data_folder.exists()
        else None,
        test_masks=sorted(list(test_data_folder.glob("masks/*.tif")))
        if test_data_folder and test_data_folder.exists()
        else None,
        predict_images=sorted(list(predict_data_folder.rglob("*.tif")))
        if predict_data_folder and predict_data_folder.exists()
        else None,
        dataset_stats=dataset_stats,
        **kwargs,
    )

kelp.nn.data.datamodule.KelpForestDataModule.from_metadata_file classmethod

Factory method to create the KelpForestDataModule based on metadata file.

Parameters:

Name Type Description Default
data_dir Path

The path to the data directory.

required
metadata_fp Path

The path to the metadata file.

required
dataset_stats Dict[str, Dict[str, float]]

The per-band dataset statistics.

required
cv_split int

The CV fold number to use.

required
has_kelp_importance_factor float

The importance factor for the has_kelp flag.

1.0
kelp_pixels_pct_importance_factor float

The importance factor for the kelp_pixels_pct value.

1.0
qa_ok_importance_factor float

The importance factor for the has_kelp flag.

1.0
almost_all_water_importance_factor float

The importance factor for the almost_all_water flag.

1.0
qa_corrupted_pixels_pct_importance_factor float

The importance factor for the qa_corrupted_pixels_pct value.

-1.0
dem_nan_pixels_pct_importance_factor float

The importance factor for the dem_nan_pixels_pct value.

-1.0
dem_zero_pixels_pct_importance_factor float

The importance factor for the dem_zero_pixels_pct value.

-1.0
sahi bool

A flag indicating whether SAHI dataset is used.

False
**kwargs Any

Other keyword arguments passed to the KelpForestDataModule constructor.

{}
Source code in kelp/nn/data/datamodule.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
@classmethod
def from_metadata_file(
    cls,
    data_dir: Path,
    metadata_fp: Path,
    dataset_stats: Dict[str, Dict[str, float]],
    cv_split: int,
    has_kelp_importance_factor: float = 1.0,
    kelp_pixels_pct_importance_factor: float = 1.0,
    qa_ok_importance_factor: float = 1.0,
    almost_all_water_importance_factor: float = 1.0,
    qa_corrupted_pixels_pct_importance_factor: float = -1.0,
    dem_nan_pixels_pct_importance_factor: float = -1.0,
    dem_zero_pixels_pct_importance_factor: float = -1.0,
    sahi: bool = False,
    **kwargs: Any,
) -> KelpForestDataModule:
    """
    Factory method to create the KelpForestDataModule based on metadata file.

    Args:
        data_dir: The path to the data directory.
        metadata_fp: The path to the metadata file.
        dataset_stats: The per-band dataset statistics.
        cv_split: The CV fold number to use.
        has_kelp_importance_factor: The importance factor for the has_kelp flag.
        kelp_pixels_pct_importance_factor: The importance factor for the kelp_pixels_pct value.
        qa_ok_importance_factor: The importance factor for the has_kelp flag.
        almost_all_water_importance_factor: The importance factor for the almost_all_water flag.
        qa_corrupted_pixels_pct_importance_factor: The importance factor for the qa_corrupted_pixels_pct value.
        dem_nan_pixels_pct_importance_factor: The importance factor for the dem_nan_pixels_pct value.
        dem_zero_pixels_pct_importance_factor: The importance factor for the dem_zero_pixels_pct value.
        sahi: A flag indicating whether SAHI dataset is used.
        **kwargs: Other keyword arguments passed to the KelpForestDataModule constructor.

    Returns: An instance of KelpForestDataModule.

    """
    metadata = cls._calculate_image_weights(
        df=pd.read_parquet(metadata_fp),
        has_kelp_importance_factor=has_kelp_importance_factor,
        kelp_pixels_pct_importance_factor=kelp_pixels_pct_importance_factor,
        qa_ok_importance_factor=qa_ok_importance_factor,
        qa_corrupted_pixels_pct_importance_factor=qa_corrupted_pixels_pct_importance_factor,
        almost_all_water_importance_factor=almost_all_water_importance_factor,
        dem_nan_pixels_pct_importance_factor=dem_nan_pixels_pct_importance_factor,
        dem_zero_pixels_pct_importance_factor=dem_zero_pixels_pct_importance_factor,
        sahi=sahi,
    )
    train_images, train_masks = cls.resolve_file_paths(
        data_dir=data_dir, metadata=metadata, cv_split=cv_split, split=consts.data.TRAIN, sahi=sahi
    )
    val_images, val_masks = cls.resolve_file_paths(
        data_dir=data_dir, metadata=metadata, cv_split=cv_split, split=consts.data.VAL, sahi=sahi
    )
    test_images, test_masks = cls.resolve_file_paths(
        data_dir=data_dir, metadata=metadata, cv_split=cv_split, split=consts.data.VAL, sahi=sahi
    )
    image_weights = cls._resolve_image_weights(df=metadata, image_paths=train_images)
    return cls(
        train_images=train_images,
        train_masks=train_masks,
        val_images=val_images,
        val_masks=val_masks,
        test_images=test_images,
        test_masks=test_masks,
        predict_images=None,
        image_weights=image_weights,
        dataset_stats=dataset_stats,
        **kwargs,
    )

kelp.nn.data.datamodule.KelpForestDataModule.on_after_batch_transfer

Apply batch augmentations after batch is transferred to the device.

Parameters:

Name Type Description Default
batch Dict[str, Any]

mini-batch of data

required
batch_idx int

batch index

required

Returns:

Type Description
Dict[str, Any]

augmented mini-batch

Source code in kelp/nn/data/datamodule.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def on_after_batch_transfer(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
    """Apply batch augmentations after batch is transferred to the device.

    Args:
        batch: mini-batch of data
        batch_idx: batch index

    Returns:
        augmented mini-batch
    """
    if (
        hasattr(self, "trainer")
        and self.trainer is not None
        and hasattr(self.trainer, "training")
        and self.trainer.training
    ):
        batch = self._apply_transform(self.train_augmentations, batch)
    elif (
        hasattr(self, "trainer")
        and self.trainer is not None
        and hasattr(self.trainer, "predicting")
        and self.trainer.predicting
    ):
        batch = self._apply_predict_transform(self.predict_augmentations, batch)
    else:
        batch = self._apply_transform(self.val_augmentations, batch)

    return batch

kelp.nn.data.datamodule.KelpForestDataModule.plot_batch

Run :meth:kelp.nn.data.dataset.KelpForestSegmentationDataset.plot_batch.

Source code in kelp/nn/data/datamodule.py
378
379
380
def plot_batch(self, *args: Any, **kwargs: Any) -> FigureGrids:
    """Run :meth:`kelp.nn.data.dataset.KelpForestSegmentationDataset.plot_batch`."""
    return self.val_dataset.plot_batch(*args, **kwargs)

kelp.nn.data.datamodule.KelpForestDataModule.plot_sample

Run :meth:kelp.nn.data.dataset.KelpForestSegmentationDataset.plot_sample.

Source code in kelp/nn/data/datamodule.py
374
375
376
def plot_sample(self, *args: Any, **kwargs: Any) -> plt.Figure:
    """Run :meth:`kelp.nn.data.dataset.KelpForestSegmentationDataset.plot_sample`."""
    return self.val_dataset.plot_sample(*args, **kwargs)

kelp.nn.data.datamodule.KelpForestDataModule.predict_dataloader

Return a DataLoader for prediction.

Returns:

Type Description
DataLoader[Any]

prediction data loader

Source code in kelp/nn/data/datamodule.py
361
362
363
364
365
366
367
368
369
370
371
372
def predict_dataloader(self) -> DataLoader[Any]:
    """Return a DataLoader for prediction.

    Returns:
        prediction data loader
    """
    return DataLoader(
        self.predict_dataset,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        shuffle=False,
    )

kelp.nn.data.datamodule.KelpForestDataModule.resolve_file_paths classmethod

Resolves file paths using specified metadata dataframe.

Parameters:

Name Type Description Default
data_dir Path

The data directory.

required
metadata DataFrame

The metadata dataframe.

required
cv_split int

The CV fold to use.

required
split str

The split to use (train, val, test).

required
sahi bool

A flag indicating whether SAHI dataset is used.

False
Source code in kelp/nn/data/datamodule.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@classmethod
def resolve_file_paths(
    cls,
    data_dir: Path,
    metadata: pd.DataFrame,
    cv_split: int,
    split: str,
    sahi: bool = False,
) -> Tuple[List[Path], List[Path]]:
    """
    Resolves file paths using specified metadata dataframe.

    Args:
        data_dir: The data directory.
        metadata: The metadata dataframe.
        cv_split: The CV fold to use.
        split: The split to use (train, val, test).
        sahi: A flag indicating whether SAHI dataset is used.

    Returns: A tuple with input image paths and target (mask) image paths

    """
    split_data = metadata[metadata[f"split_{cv_split}"] == split]
    img_folder = consts.data.TRAIN if split in [consts.data.TRAIN, consts.data.VAL] else consts.data.TEST
    image_paths = sorted(
        split_data.apply(
            lambda row: data_dir
            / img_folder
            / "images"
            / (
                f"{row['tile_id']}_satellite_{row['j']}_{row['i']}.tif"
                if sahi
                else f"{row['tile_id']}_satellite.tif"
            ),
            axis=1,
        ).tolist()
    )
    mask_paths = sorted(
        split_data.apply(
            lambda row: data_dir
            / img_folder
            / "masks"
            / (f"{row['tile_id']}_kelp_{row['j']}_{row['i']}.tif" if sahi else f"{row['tile_id']}_kelp.tif"),
            axis=1,
        ).tolist()
    )
    return image_paths, mask_paths

kelp.nn.data.datamodule.KelpForestDataModule.setup

Initialize the main Dataset objects.

This method is called once per GPU per run.

Parameters:

Name Type Description Default
stage Optional[str]

stage to set up

None
Source code in kelp/nn/data/datamodule.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def setup(self, stage: Optional[str] = None) -> None:
    """Initialize the main ``Dataset`` objects.

    This method is called once per GPU per run.

    Args:
        stage: stage to set up
    """
    if self.train_images:
        self.train_dataset = self._build_dataset(self.train_images, self.train_masks)
    if self.val_images:
        self.val_dataset = self._build_dataset(self.val_images, self.val_masks)
    if self.test_images:
        self.test_dataset = self._build_dataset(self.test_images, self.test_masks)
    if self.predict_images:
        self.predict_dataset = self._build_dataset(self.predict_images)

kelp.nn.data.datamodule.KelpForestDataModule.test_dataloader

Return a DataLoader for testing.

Returns:

Type Description
DataLoader[Any]

testing data loader

Source code in kelp/nn/data/datamodule.py
348
349
350
351
352
353
354
355
356
357
358
359
def test_dataloader(self) -> DataLoader[Any]:
    """Return a DataLoader for testing.

    Returns:
        testing data loader
    """
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        shuffle=False,
    )

kelp.nn.data.datamodule.KelpForestDataModule.train_dataloader

Return a DataLoader for training.

Returns:

Type Description
DataLoader[Any]

training data loader

Source code in kelp/nn/data/datamodule.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def train_dataloader(self) -> DataLoader[Any]:
    """Return a DataLoader for training.

    Returns:
        training data loader
    """
    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        sampler=WeightedRandomSampler(
            weights=self.image_weights,
            num_samples=self.samples_per_epoch,
        )
        if self.use_weighted_sampler
        else None,
        shuffle=True if not self.use_weighted_sampler else False,
    )

kelp.nn.data.datamodule.KelpForestDataModule.val_dataloader

Return a DataLoader for validation.

Returns:

Type Description
DataLoader[Any]

validation data loader

Source code in kelp/nn/data/datamodule.py
335
336
337
338
339
340
341
342
343
344
345
346
def val_dataloader(self) -> DataLoader[Any]:
    """Return a DataLoader for validation.

    Returns:
        validation data loader
    """
    return DataLoader(
        self.val_dataset,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        shuffle=False,
    )