Skip to content

train

The segmentation neural network training logic.

kelp.nn.training.train.main

Main entrypoint for model training.

Source code in kelp/nn/training/train.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def main() -> None:
    """Main entrypoint for model training."""
    cfg = parse_args()
    set_gpu_power_limit_if_needed()

    mlflow.set_experiment(cfg.resolved_experiment_name)
    mlflow.pytorch.autolog()
    run = mlflow.start_run(run_id=cfg.run_id_from_context)

    with run:
        pl.seed_everything(cfg.seed, workers=True)
        mlflow.log_dict(cfg.model_dump(mode="json"), artifact_file="config.yaml")
        mlflow.log_params(cfg.model_dump(mode="json"))
        mlflow_run_dir = get_mlflow_run_dir(current_run=run, output_dir=cfg.output_dir)
        datamodule = KelpForestDataModule.from_metadata_file(**cfg.data_module_kwargs)
        segmentation_task = KelpForestSegmentationTask(in_channels=datamodule.in_channels, **cfg.model_kwargs)
        trainer = pl.Trainer(
            logger=make_loggers(
                experiment=cfg.resolved_experiment_name,
                tags=cfg.tags,
            ),
            callbacks=make_callbacks(
                output_dir=mlflow_run_dir / "artifacts" / "checkpoints",
                **cfg.callbacks_kwargs,
            ),
            **cfg.trainer_kwargs,
        )
        trainer.fit(model=segmentation_task, datamodule=datamodule)

        # Don't log hp_metric if debugging
        if not cfg.fast_dev_run:
            best_score = (
                trainer.checkpoint_callback.best_model_score.detach().cpu().item()  # type: ignore[attr-defined]
            )
            trainer.logger.log_metrics(metrics={"hp_metric": best_score})

        trainer.test(model=segmentation_task, datamodule=datamodule)

kelp.nn.training.train.make_callbacks

A factory method for creating lightning callbacks.

Parameters:

Name Type Description Default
output_dir Path

The output directory.

required
early_stopping_patience int

The early stopping patience in epochs.

3
save_top_k int

The number of top model checkpoints to save.

1
monitor_metric str

The metric to monitor for early stopping.

'val/dice'
monitor_mode str

The mode to monitor for early stopping.

'max'
swa bool

A flag indicating whether to use SWA (Stochastic Weight Averaging).

False
swa_lr float

The final learning rate for SWA annealing.

3e-05
swa_epoch_start float

The percentage of all training epochs when to start the SWA.

0.5
swa_annealing_epochs int

The number of epochs to run the SWA for.

10
Source code in kelp/nn/training/train.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def make_callbacks(
    output_dir: Path,
    early_stopping_patience: int = 3,
    save_top_k: int = 1,
    monitor_metric: str = "val/dice",
    monitor_mode: str = "max",
    swa: bool = False,
    swa_lr: float = 3e-5,
    swa_epoch_start: float = 0.5,
    swa_annealing_epochs: int = 10,
) -> List[Callback]:
    """
    A factory method for creating lightning callbacks.

    Args:
        output_dir: The output directory.
        early_stopping_patience: The early stopping patience in epochs.
        save_top_k: The number of top model checkpoints to save.
        monitor_metric: The metric to monitor for early stopping.
        monitor_mode: The mode to monitor for early stopping.
        swa: A flag indicating whether to use SWA (Stochastic Weight Averaging).
        swa_lr: The final learning rate for SWA annealing.
        swa_epoch_start: The percentage of all training epochs when to start the SWA.
        swa_annealing_epochs: The number of epochs to run the SWA for.

    Returns: A list of lightning callbacks.

    """
    early_stopping = EarlyStopping(
        monitor=monitor_metric,
        patience=early_stopping_patience,
        verbose=True,
        mode=monitor_mode,
    )
    lr_monitor = LearningRateMonitor(logging_interval="step", log_momentum=True, log_weight_decay=True)
    sanitized_monitor_metric = monitor_metric.replace("/", "_")
    filename_str = "kelp-epoch={epoch:02d}-" f"{sanitized_monitor_metric}=" f"{{{monitor_metric}:.3f}}"
    checkpoint = ModelCheckpoint(
        monitor=monitor_metric,
        mode=monitor_mode,
        verbose=True,
        save_top_k=save_top_k,
        dirpath=output_dir,
        auto_insert_metric_name=False,
        filename=filename_str,
        save_last=True,
    )
    callbacks = [early_stopping, lr_monitor, checkpoint]
    if swa:
        callbacks.append(
            StochasticWeightAveraging(
                swa_lrs=swa_lr,
                swa_epoch_start=swa_epoch_start,
                annealing_epochs=swa_annealing_epochs,
            ),
        )
    return callbacks

kelp.nn.training.train.make_loggers

Factory method for creating lightning loggers.

Parameters:

Name Type Description Default
experiment str

The experiment name.

required
tags Dict[str, Any]

The experiment tags.

required
Source code in kelp/nn/training/train.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def make_loggers(
    experiment: str,
    tags: Dict[str, Any],
) -> List[Logger]:
    """
    Factory method for creating lightning loggers.

    Args:
        experiment: The experiment name.
        tags: The experiment tags.

    Returns: List of lightning loggers.

    """
    mlflow_logger = MLFlowLogger(
        experiment_name=experiment,
        run_id=mlflow.active_run().info.run_id,
        log_model=True,
        tags=tags,
    )
    return [mlflow_logger]