Skip to content

predict

Single model prediction logic.

kelp.nn.inference.predict.PredictConfig

Bases: ConfigBase

The prediction config

Source code in kelp/nn/inference/predict.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class PredictConfig(ConfigBase):
    """The prediction config"""

    model_config = ConfigDict(protected_namespaces=())

    data_dir: Path
    dataset_stats_dir: Path
    original_training_config_fp: Path
    model_checkpoint: Path
    use_checkpoint: Literal["best", "latest"] = "best"
    run_dir: Path
    output_dir: Path
    tta: bool = False
    soft_labels: bool = False
    tta_merge_mode: str = "max"
    decision_threshold: Optional[float] = None
    precision: Optional[
        Literal[
            "16-true",
            "16-mixed",
            "bf16-true",
            "bf16-mixed",
            "32-true",
        ]
    ] = None
    sahi_tile_size: int = 128
    sahi_overlap: int = 64

    @model_validator(mode="before")
    def validate_inputs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
        run_dir = Path(data["run_dir"])
        if (run_dir / "model").exists():
            artifacts_dir = run_dir
        elif (run_dir / "artifacts").exists():
            artifacts_dir = run_dir / "artifacts"
        else:
            raise ValueError("Could not find nor model dir nor artifacts folder in the specified run_dir")

        model_checkpoint = artifacts_dir / "model"

        if (checkpoints_root := (artifacts_dir / "model" / "checkpoints")).exists():
            if data["use_checkpoint"] == "latest":
                model_checkpoint = checkpoints_root / "last" / "last.ckpt"
            else:
                for checkpoint_dir in sorted(list(checkpoints_root.iterdir())):
                    aliases = (checkpoint_dir / "aliases.txt").read_text()
                    if "'best'" in aliases:
                        model_checkpoint = checkpoints_root / checkpoint_dir.name / f"{checkpoint_dir.name}.ckpt"
                        break

        config_fp = artifacts_dir / "config.yaml"
        data["model_checkpoint"] = model_checkpoint
        data["original_training_config_fp"] = config_fp
        return data

    @property
    def training_config(self) -> TrainConfig:
        with open(self.original_training_config_fp, "r") as f:
            cfg = TrainConfig(**yaml.safe_load(f))
        cfg.data_dir = self.data_dir
        cfg.dataset_stats_fp = self.dataset_stats_dir / cfg.dataset_stats_fp.name.replace("%3A", ":")
        cfg.output_dir = self.output_dir
        if self.precision is not None:
            cfg.precision = self.precision
        return cfg

    @property
    def use_mlflow(self) -> bool:
        return self.model_checkpoint.is_dir()

kelp.nn.inference.predict.build_prediction_arg_parser

Builds a base prediction argument parser.

Returns: An instance of :argparse.ArgumentParser.

Source code in kelp/nn/inference/predict.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def build_prediction_arg_parser() -> argparse.ArgumentParser:
    """
    Builds a base prediction argument parser.

    Returns: An instance of :argparse.ArgumentParser.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--dataset_stats_dir", type=str, required=True)
    parser.add_argument("--run_dir", type=str, required=True)
    parser.add_argument("--use_checkpoint", choices=["latest", "best"], type=str, default="best")
    parser.add_argument("--tta", action="store_true")
    parser.add_argument("--soft_labels", action="store_true")
    parser.add_argument("--tta_merge_mode", type=str, default="max")
    parser.add_argument("--decision_threshold", type=float)
    parser.add_argument("--sahi_tile_size", type=int, default=128)
    parser.add_argument("--sahi_overlap", type=int, default=64)
    parser.add_argument(
        "--precision",
        type=str,
        choices=[
            "16-true",
            "16-mixed",
            "bf16-true",
            "bf16-mixed",
            "32-true",
        ],
    )
    return parser

kelp.nn.inference.predict.main

Main entry point for performing model prediction. Will automatically use SAHI if model was trained with this flag.

Source code in kelp/nn/inference/predict.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def main() -> None:
    """
    Main entry point for performing model prediction.
    Will automatically use SAHI if model was trained with this flag.
    """
    cfg = parse_args()
    (cfg.output_dir / "predict_config.yaml").write_text(yaml.dump(cfg.model_dump(mode="json")))
    if cfg.training_config.sahi:
        run_sahi_prediction(
            data_dir=cfg.data_dir,
            output_dir=cfg.output_dir,
            model_checkpoint=cfg.model_checkpoint,
            use_mlflow=cfg.use_mlflow,
            train_cfg=cfg.training_config,
            tta=cfg.tta,
            soft_labels=cfg.soft_labels,
            tta_merge_mode=cfg.tta_merge_mode,
            decision_threshold=cfg.decision_threshold,
            sahi_tile_size=cfg.sahi_tile_size,
            sahi_overlap=cfg.sahi_overlap,
        )
    else:
        run_prediction(
            data_dir=cfg.data_dir,
            output_dir=cfg.output_dir,
            model_checkpoint=cfg.model_checkpoint,
            use_mlflow=cfg.use_mlflow,
            train_cfg=cfg.training_config,
            tta=cfg.tta,
            soft_labels=cfg.soft_labels,
            tta_merge_mode=cfg.tta_merge_mode,
            decision_threshold=cfg.decision_threshold,
        )

kelp.nn.inference.predict.parse_args

Parse command line arguments.

Returns: An instance of PredictConfig.

Source code in kelp/nn/inference/predict.py
145
146
147
148
149
150
151
152
153
154
155
156
157
def parse_args() -> PredictConfig:
    """
    Parse command line arguments.

    Returns: An instance of PredictConfig.

    """
    parser = build_prediction_arg_parser()
    args = parser.parse_args()
    cfg = PredictConfig(**vars(args))
    cfg.log_self()
    cfg.output_dir.mkdir(exist_ok=True, parents=True)
    return cfg

kelp.nn.inference.predict.predict

Runs prediction using specified datamodule and model.

Parameters:

Name Type Description Default
dm LightningDataModule

The datamodule to use for prediction.

required
model LightningModule

The model.

required
train_cfg TrainConfig

The original training configuration.

required
output_dir Path

The output directory.

required
resize_tf Callable[[Tensor], Tensor]

The resize transform for post-prediction adjustment.

required
Source code in kelp/nn/inference/predict.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
@torch.inference_mode()
def predict(
    dm: pl.LightningDataModule,
    model: pl.LightningModule,
    train_cfg: TrainConfig,
    output_dir: Path,
    resize_tf: Callable[[Tensor], Tensor],
) -> None:
    """
    Runs prediction using specified datamodule and model.

    Args:
        dm: The datamodule to use for prediction.
        model: The model.
        train_cfg: The original training configuration.
        output_dir: The output directory.
        resize_tf: The resize transform for post-prediction adjustment.

    """
    with torch.no_grad():
        trainer = pl.Trainer(**train_cfg.trainer_kwargs, logger=False)
        preds: List[Dict[str, Union[Tensor, str]]] = trainer.predict(model=model, datamodule=dm)
        for prediction_batch in tqdm(preds, "Saving prediction batches"):
            individual_samples = unbind_samples(prediction_batch)
            for sample in individual_samples:
                tile_id = sample["tile_id"]
                prediction = sample["prediction"]
                if model.hyperparams.get("soft_labels", False):
                    META["dtype"] = "float32"
                dest: DatasetWriter
                with rasterio.open(output_dir / f"{tile_id}_kelp.tif", "w", **META) as dest:
                    prediction_arr = resize_tf(prediction.unsqueeze(0)).detach().cpu().numpy().squeeze()
                    dest.write(prediction_arr, 1)

kelp.nn.inference.predict.resolve_post_predict_resize_transform

Resolves the post-predict resize transform.

Parameters:

Name Type Description Default
resize_strategy Literal['resize', 'pad']

The resize strategy.

required
source_image_size int

The source image size.

required
target_image_size int

The target image size.

required
Source code in kelp/nn/inference/predict.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def resolve_post_predict_resize_transform(
    resize_strategy: Literal["resize", "pad"],
    source_image_size: int,
    target_image_size: int,
) -> Callable[[Tensor], Tensor]:
    """
    Resolves the post-predict resize transform.

    Args:
        resize_strategy: The resize strategy.
        source_image_size: The source image size.
        target_image_size: The target image size.

    Returns: The transform to be called on predictions.

    """
    if resize_strategy == "resize":
        resize_tf = T.Resize(
            size=(target_image_size, target_image_size),
            interpolation=InterpolationMode.NEAREST,
            antialias=False,
        )
    elif resize_strategy == "pad":
        resize_tf = RemovePadding(image_size=target_image_size, padded_image_size=source_image_size)
    else:
        raise ValueError(f"{resize_strategy=} is not supported")
    return resize_tf  # type: ignore[no-any-return]

kelp.nn.inference.predict.run_prediction

Runs the prediction logic for a single model checkpoint.

Parameters:

Name Type Description Default
data_dir Path

The path to the data directory.

required
output_dir Path

The path to the output directory.

required
model_checkpoint Path

The model checkpoint.

required
use_mlflow bool

A flag indicating whether to use MLflow to load the model.

required
train_cfg TrainConfig

The original training config used to train the model.

required
tta bool

A flag indicating whether to use TTA for prediction.

False
soft_labels bool

A flag indicating whether to use soft labels for prediction.

False
tta_merge_mode str

The TTA merge mode.

'max'
decision_threshold Optional[float]

An optional decision threshold for prediction. torch.argmax will be used by default.

None
Source code in kelp/nn/inference/predict.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def run_prediction(
    data_dir: Path,
    output_dir: Path,
    model_checkpoint: Path,
    use_mlflow: bool,
    train_cfg: TrainConfig,
    tta: bool = False,
    soft_labels: bool = False,
    tta_merge_mode: str = "max",
    decision_threshold: Optional[float] = None,
) -> None:
    """
    Runs the prediction logic for a single model checkpoint.

    Args:
        data_dir: The path to the data directory.
        output_dir: The path to the output directory.
        model_checkpoint: The model checkpoint.
        use_mlflow: A flag indicating whether to use MLflow to load the model.
        train_cfg: The original training config used to train the model.
        tta: A flag indicating whether to use TTA for prediction.
        soft_labels: A flag indicating whether to use soft labels for prediction.
        tta_merge_mode: The TTA merge mode.
        decision_threshold: An optional decision threshold for prediction. torch.argmax will be used by default.

    """
    dm = KelpForestDataModule.from_folders(predict_data_folder=data_dir, **train_cfg.data_module_kwargs)
    model = load_model(
        model_path=model_checkpoint,
        use_mlflow=use_mlflow,
        tta=tta,
        soft_labels=soft_labels,
        tta_merge_mode=tta_merge_mode,
        decision_threshold=decision_threshold,
    )
    resize_tf = resolve_post_predict_resize_transform(
        resize_strategy=train_cfg.resize_strategy,
        source_image_size=train_cfg.image_size,
        target_image_size=consts.data.TILE_SIZE,
    )
    predict(
        dm=dm,
        model=model,
        train_cfg=train_cfg,
        output_dir=output_dir,
        resize_tf=resize_tf,
    )

kelp.nn.inference.predict.run_sahi_prediction

Runs SAHI (Sliced Aided Hyper Inference) using specified model checkpoint.

Parameters:

Name Type Description Default
data_dir Path

The path to the data directory.

required
output_dir Path

The path to the output directory.

required
model_checkpoint Path

The model checkpoint.

required
use_mlflow bool

A flag indicating whether to use MLflow to load the model.

required
train_cfg TrainConfig

The original training config used to train the model.

required
tta bool

A flag indicating whether to use TTA for prediction.

False
soft_labels bool

A flag indicating whether to use soft labels for prediction.

False
tta_merge_mode str

The TTA merge mode.

'max'
decision_threshold Optional[float]

An optional decision threshold for prediction. torch.argmax will be used by default.

None
sahi_tile_size int

The size of the tiles to use when performing SAHI.

128
sahi_overlap int

The size of the overlap between tiles to use when performing SAHI

64
Source code in kelp/nn/inference/predict.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def run_sahi_prediction(
    data_dir: Path,
    output_dir: Path,
    model_checkpoint: Path,
    use_mlflow: bool,
    train_cfg: TrainConfig,
    tta: bool = False,
    soft_labels: bool = False,
    tta_merge_mode: str = "max",
    decision_threshold: Optional[float] = None,
    sahi_tile_size: int = 128,
    sahi_overlap: int = 64,
) -> None:
    """
    Runs SAHI (Sliced Aided Hyper Inference) using specified model checkpoint.

    Args:
        data_dir: The path to the data directory.
        output_dir: The path to the output directory.
        model_checkpoint: The model checkpoint.
        use_mlflow: A flag indicating whether to use MLflow to load the model.
        train_cfg: The original training config used to train the model.
        tta: A flag indicating whether to use TTA for prediction.
        soft_labels: A flag indicating whether to use soft labels for prediction.
        tta_merge_mode: The TTA merge mode.
        decision_threshold: An optional decision threshold for prediction. torch.argmax will be used by default.
        sahi_tile_size: The size of the tiles to use when performing SAHI.
        sahi_overlap: The size of the overlap between tiles to use when performing SAHI

    """
    model = load_model(
        model_path=model_checkpoint,
        use_mlflow=use_mlflow,
        tta=tta,
        soft_labels=soft_labels,
        tta_merge_mode=tta_merge_mode,
        decision_threshold=decision_threshold,
    )
    band_order = [consts.data.ORIGINAL_BANDS.index(band) + 1 for band in train_cfg.bands]
    bands_to_use = train_cfg.bands + train_cfg.spectral_indices
    band_index_lookup = {band: idx for idx, band in enumerate(bands_to_use)}
    band_stats, in_channels = resolve_normalization_stats(
        dataset_stats=train_cfg.dataset_stats,
        bands_to_use=bands_to_use,
    )
    normalization_tf = resolve_normalization_transform(
        band_stats=band_stats,
        normalization_strategy=train_cfg.normalization_strategy,
    )
    predict_sahi(
        file_paths=sorted(list(data_dir.glob("*.tif"))),
        model=model,
        tta=tta,
        soft_labels=soft_labels,
        tta_merge_mode=tta_merge_mode,
        decision_threshold=decision_threshold,
        output_dir=output_dir,
        overlap=sahi_overlap,
        tile_size=(sahi_tile_size, sahi_tile_size),
        band_order=band_order,
        fill_value=train_cfg.fill_value,
        resize_tf=resolve_resize_transform(
            resize_strategy=train_cfg.resize_strategy,
            interpolation=train_cfg.interpolation,
            image_size=train_cfg.image_size,
            image_or_mask="image",
        ),
        input_transforms=resolve_transforms(
            normalization_transform=normalization_tf,
            spectral_indices=train_cfg.spectral_indices,
            band_stats=band_stats,
            band_index_lookup=band_index_lookup,
            mask_using_qa=train_cfg.mask_using_qa,
            mask_using_water_mask=train_cfg.mask_using_water_mask,
            stage="predict",
        ),
        post_predict_transforms=T.Resize(
            size=(sahi_tile_size, sahi_tile_size),
            interpolation=InterpolationMode.NEAREST,
            antialias=False,
        ),
    )