Skip to content

train

The XGBoost training logic.

kelp.xgb.training.train.calculate_metrics

Calculates metrics for given model and its predictions.

Parameters:

Name Type Description Default
model XGBClassifier

The XGBClassifier model.

required
x DataFrame

The input dataframe.

required
y_true Series

The ground truth series.

required
y_pred ndarray

The prediction series.

required
prefix str

A prefix to use for metrics logging.

required
Source code in kelp/xgb/training/train.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@torch.inference_mode()
@timed
def calculate_metrics(
    model: XGBClassifier,
    x: pd.DataFrame,
    y_true: pd.Series,
    y_pred: np.ndarray,  # type: ignore[type-arg]
    prefix: str,
) -> Dict[str, float]:
    """
    Calculates metrics for given model and its predictions.

    Args:
        model: The XGBClassifier model.
        x: The input dataframe.
        y_true: The ground truth series.
        y_pred: The prediction series.
        prefix: A prefix to use for metrics logging.

    Returns: A dictionary with metric names and the metric values.

    """
    metrics = MetricCollection(
        metrics={
            "dice": Dice(num_classes=2, average="macro"),
            "iou": JaccardIndex(task="binary"),
            "accuracy": Accuracy(task="binary"),
            "recall": Recall(task="binary", average="macro"),
            "precision": Precision(task="binary", average="macro"),
            "f1": F1Score(task="binary", average="macro"),
            "auroc": AUROC(task="binary"),
        },
        prefix=f"{prefix}/",
    ).to(DEVICE)
    metrics(
        torch.tensor(y_pred, device=DEVICE, dtype=torch.int32),
        torch.tensor(y_true.values, device=DEVICE, dtype=torch.int32),
    )
    metrics_dict = metrics.compute()
    for name, value in metrics_dict.items():
        metrics_dict[name] = value.item()
    if hasattr(model, "predict_proba"):
        y_pred_prob = model.predict_proba(x)
        loss = log_loss(y_true, y_pred_prob)
        metrics_dict[f"{prefix}/log_loss"] = loss
    _logger.info(f"{prefix.upper()} metrics: {json.dumps(metrics_dict, indent=4)}")
    return metrics_dict  # type: ignore[no-any-return]

kelp.xgb.training.train.eval_model

Evaluates the XGBoost model.

Parameters:

Name Type Description Default
model XGBClassifier

The XGBoost model.

required
x DataFrame

The validation data.

required
y_true Series

The ground truth labels.

required
prefix str

The prefix for the metrics and plots.

required
seed int

The seed for reproducibility.

SEED
explain_model bool

A flag indicating whether to run model feature importance calculation.

False
Source code in kelp/xgb/training/train.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@timed
def eval_model(
    model: XGBClassifier,
    x: pd.DataFrame,
    y_true: pd.Series,
    prefix: str,
    seed: int = consts.reproducibility.SEED,
    explain_model: bool = False,
) -> None:
    """
    Evaluates the XGBoost model.

    Args:
        model: The XGBoost model.
        x: The validation data.
        y_true: The ground truth labels.
        prefix: The prefix for the metrics and plots.
        seed: The seed for reproducibility.
        explain_model: A flag indicating whether to run model feature importance calculation.

    """
    _logger.info(f"Running model eval for {prefix} split")
    y_pred = model.predict(x)
    metrics = calculate_metrics(model=model, x=x, y_true=y_true, y_pred=y_pred, prefix=prefix)
    mlflow.log_metrics(metrics)
    log_confusion_matrix(y_true=y_true, y_pred=y_pred, prefix=prefix, normalize=False)
    log_confusion_matrix(y_true=y_true, y_pred=y_pred, prefix=prefix, normalize=True)
    log_precision_recall_curve(y_true=y_true, y_pred=y_pred, prefix=prefix)
    log_roc_curve(y_true=y_true, y_pred=y_pred, prefix=prefix)
    if prefix == "test" and explain_model:  # calculate feature importance only once
        log_model_feature_importance(model=model, feature_names=x.columns.values)
        log_permutation_feature_importance(model=model, x=x, y_true=y_true, seed=seed)

kelp.xgb.training.train.fit_model

Runs the training.

Parameters:

Name Type Description Default
model XGBClassifier

The model to be trained.

required
x DataFrame

The training dataset.

required
y_true Series

The training labels.

required
Source code in kelp/xgb/training/train.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
@timed
def fit_model(
    model: XGBClassifier,
    x: pd.DataFrame,
    y_true: pd.Series,
) -> XGBClassifier:
    """
    Runs the training.

    Args:
        model: The model to be trained.
        x: The training dataset.
        y_true: The training labels.

    Returns: Fitted model.

    """
    model.fit(x, y_true)
    return model

kelp.xgb.training.train.load_data

Loads the training data.

Parameters:

Name Type Description Default
df DataFrame

The input dataframe with pixel-level values.

required
sample_size float

The random sample size to use for quicker training times.

1.0
seed int

The seed for reproducibility.

SEED
Source code in kelp/xgb/training/train.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@timed
def load_data(
    df: pd.DataFrame,
    sample_size: float = 1.0,
    seed: int = consts.reproducibility.SEED,
) -> Tuple[pd.DataFrame, pd.Series, pd.DataFrame, pd.Series, pd.DataFrame, pd.Series]:
    """
    Loads the training data.

    Args:
        df: The input dataframe with pixel-level values.
        sample_size: The random sample size to use for quicker training times.
        seed: The seed for reproducibility.

    Returns: A tuple containing features and labels in following order: X_train, y_train, X_val, y_val, X_test, y_test

    """
    X_train = df[df["split"] == "train"]
    X_val = df[df["split"] == "val"]
    X_test = df[df["split"] == "test"]

    if sample_size != 1.0:
        X_train = X_train.sample(frac=sample_size, random_state=seed)

    y_train = X_train["label"]
    y_val = X_val["label"]
    y_test = X_test["label"]

    X_train = X_train.drop(["label", "split"], axis=1)
    X_val = X_val.drop(["label", "split"], axis=1)
    X_test = X_test.drop(["label", "split"], axis=1)

    return X_train, y_train, X_val, y_val, X_test, y_test

kelp.xgb.training.train.log_confusion_matrix

Logs confusion matrix to MLFlow.

Parameters:

Name Type Description Default
y_true Series

A pandas Series with ground truth values.

required
y_pred ndarray

A pandas array with prediction values.

required
prefix str

The prefix to use when logging confusion matrix.

required
normalize bool

A flag indicating whether to normalize the confusion matrix.

False
Source code in kelp/xgb/training/train.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@timed
def log_confusion_matrix(
    y_true: pd.Series,
    y_pred: np.ndarray,  # type: ignore[type-arg]
    prefix: str,
    normalize: bool = False,
) -> None:
    """
    Logs confusion matrix to MLFlow.

    Args:
        y_true: A pandas Series with ground truth values.
        y_pred: A pandas array with prediction values.
        prefix: The prefix to use when logging confusion matrix.
        normalize: A flag indicating whether to normalize the confusion matrix.

    """
    cmd = ConfusionMatrixDisplay.from_predictions(
        y_true=y_true,
        y_pred=y_pred,
        display_labels=consts.data.CLASSES,
        cmap="Blues",
        normalize="true" if normalize else None,
    )
    cmd.ax_.set_title("Normalized confusion matrix" if normalize else "Confusion matrix")
    plt.tight_layout()
    fname = "normalized_confusion_matrix" if normalize else "confusion_matrix"
    mlflow.log_figure(figure=cmd.figure_, artifact_file=f"images/{prefix}/{fname}.png")
    plt.close()

kelp.xgb.training.train.log_model_feature_importance

Logs the feature importance to MLFlow.

Parameters:

Name Type Description Default
model XGBClassifier

The XGBClassifier model.

required
feature_names ndarray

The names of the features.

required
Source code in kelp/xgb/training/train.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@timed
def log_model_feature_importance(
    model: XGBClassifier,
    feature_names: np.ndarray,  # type: ignore[type-arg]
) -> None:
    """
    Logs the feature importance to MLFlow.

    Args:
        model: The XGBClassifier model.
        feature_names: The names of the features.

    """
    sorted_idx = model.feature_importances_.argsort()
    fig, ax = plt.subplots(figsize=(8, 0.2 * len(feature_names)))
    ax.barh(feature_names[sorted_idx], model.feature_importances_[sorted_idx])
    ax.set_title("XGB Feature importances")
    ax.set_xlabel("Feature")
    ax.set_ylabel("XGB Feature Importance")
    fig.tight_layout()
    mlflow.log_figure(figure=fig, artifact_file="images/feature_importances_xgb.png")
    plt.close(fig)

kelp.xgb.training.train.log_permutation_feature_importance

Logs the permutation feature importance to MLFlow.

Parameters:

Name Type Description Default
model XGBClassifier

The XGBClassifier model.

required
x DataFrame

The input data.

required
y_true Series

The ground truth data.

required
seed int

The seed to use for reproducibility.

SEED
n_repeats int

The number of repeats.

10
Source code in kelp/xgb/training/train.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
@timed
def log_permutation_feature_importance(
    model: XGBClassifier,
    x: pd.DataFrame,
    y_true: pd.Series,
    seed: int = consts.reproducibility.SEED,
    n_repeats: int = 10,
) -> None:
    """
    Logs the permutation feature importance to MLFlow.

    Args:
        model: The XGBClassifier model.
        x: The input data.
        y_true: The ground truth data.
        seed: The seed to use for reproducibility.
        n_repeats: The number of repeats.

    """
    result = permutation_importance(model, x, y_true, n_repeats=n_repeats, random_state=seed, n_jobs=4)
    forest_importances = pd.Series(result.importances_mean, index=x.columns.tolist())
    fig, ax = plt.subplots(figsize=(0.2 * len(x.columns), 8))
    forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
    ax.set_title("Feature importances using permutation on full model")
    ax.set_xlabel("Feature")
    ax.set_ylabel("Mean accuracy decrease")
    fig.tight_layout()
    mlflow.log_figure(figure=fig, artifact_file="images/feature_importances_pi.png")
    plt.close(fig)

kelp.xgb.training.train.log_precision_recall_curve

Logs the precision and recall curve plot to MLFlow.

Parameters:

Name Type Description Default
y_true Series

The ground truth.

required
y_pred ndarray

The predicted values.

required
prefix str

The prefix to use when logging precision and recall curves.

required
Source code in kelp/xgb/training/train.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@timed
def log_precision_recall_curve(
    y_true: pd.Series,
    y_pred: np.ndarray,  # type: ignore[type-arg]
    prefix: str,
) -> None:
    """
    Logs the precision and recall curve plot to MLFlow.

    Args:
        y_true: The ground truth.
        y_pred: The predicted values.
        prefix: The prefix to use when logging precision and recall curves.

    """
    prd = PrecisionRecallDisplay.from_predictions(
        y_true=y_true,
        y_pred=y_pred,
    )
    prd.ax_.set_title("Precision recall curve")
    plt.tight_layout()
    mlflow.log_figure(figure=prd.figure_, artifact_file=f"images/{prefix}/precision_recall_curve.png")
    plt.close()

kelp.xgb.training.train.log_roc_curve

Logs the ROC curve to MLFlow.

Parameters:

Name Type Description Default
y_true Series

The ground truth.

required
y_pred ndarray

The predicted values.

required
prefix str

The prefix to use when logging ROC curve plot.

required
Source code in kelp/xgb/training/train.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@timed
def log_roc_curve(
    y_true: pd.Series,
    y_pred: np.ndarray,  # type: ignore[type-arg]
    prefix: str,
) -> None:
    """
    Logs the ROC curve to MLFlow.

    Args:
        y_true: The ground truth.
        y_pred: The predicted values.
        prefix: The prefix to use when logging ROC curve plot.

    """
    rc = RocCurveDisplay.from_predictions(
        y_true=y_true,
        y_pred=y_pred,
    )
    rc.ax_.set_title("ROC curve")
    plt.tight_layout()
    mlflow.log_figure(figure=rc.figure_, artifact_file=f"images/{prefix}/roc_curve.png")
    plt.close()

kelp.xgb.training.train.log_sample_predictions

Logs sample predictions to MLFlow.

Parameters:

Name Type Description Default
train_data_dir Path

The training data directory.

required
metadata DataFrame

The metadata dataframe.

required
model XGBClassifier

The XGBClassifier model.

required
spectral_indices List[str]

The spectral indices to append to the input image before prediction.

required
sample_size int

The number of samples to plot.

10
seed int

The seed for reproducibility.

SEED
Source code in kelp/xgb/training/train.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
@timed
def log_sample_predictions(
    train_data_dir: Path,
    metadata: pd.DataFrame,
    model: XGBClassifier,
    spectral_indices: List[str],
    sample_size: int = 10,
    seed: int = consts.reproducibility.SEED,
) -> None:
    """
    Logs sample predictions to MLFlow.

    Args:
        train_data_dir: The training data directory.
        metadata: The metadata dataframe.
        model: The XGBClassifier model.
        spectral_indices: The spectral indices to append to the input image before prediction.
        sample_size: The number of samples to plot.
        seed: The seed for reproducibility.

    """
    sample_to_plot = metadata.sample(n=sample_size, random_state=seed)
    tile_ids = sample_to_plot["tile_id"].tolist()
    transforms = build_append_index_transforms(spectral_indices)
    for tile in tqdm(tile_ids, desc="Plotting sample predictions"):
        with rasterio.open(train_data_dir / "images" / f"{tile}_satellite.tif") as src:
            input_arr = src.read()
        with rasterio.open(train_data_dir / "masks" / f"{tile}_kelp.tif") as src:
            mask_arr = src.read(1)
        prediction = predict_on_single_image(
            model=model, x=input_arr, transforms=transforms, columns=list(consts.data.ORIGINAL_BANDS) + spectral_indices
        )
        input_arr = min_max_normalize(input_arr)
        fig = plot_sample(input_arr=input_arr, target_arr=mask_arr, predictions_arr=prediction, suptitle=tile)
        mlflow.log_figure(fig, artifact_file=f"images/predictions/{tile}.png")
        plt.close(fig)

kelp.xgb.training.train.main

Main entrypoint for training XGBClassifier.

Source code in kelp/xgb/training/train.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def main() -> None:
    """Main entrypoint for training XGBClassifier."""
    cfg = parse_args()
    mlflow.xgboost.autolog(model_format="json")
    mlflow.set_experiment(cfg.resolved_experiment_name)
    run = mlflow.start_run(run_id=cfg.run_id_from_context)
    with run:
        mlflow.log_dict(cfg.model_dump(mode="json"), artifact_file="config.yaml")
        mlflow.log_params(cfg.model_dump(mode="json"))
        _ = get_mlflow_run_dir(current_run=run, output_dir=cfg.output_dir)
        model = XGBClassifier(**cfg.xgboost_model_params)
        run_training(
            train_data_dir=cfg.train_data_dir,
            dataset_fp=cfg.dataset_fp,
            columns_to_load=cfg.columns_to_load,
            model=model,
            spectral_indices=cfg.spectral_indices,
            sample_size=cfg.sample_size,
            plot_n_samples=cfg.plot_n_samples,
            seed=cfg.seed,
            explain_model=cfg.explain_model,
        )

kelp.xgb.training.train.min_max_normalize

Runs min-max quantile normalization on the input array.

Parameters:

Name Type Description Default
x ndarray

The input array.

required
Source code in kelp/xgb/training/train.py
319
320
321
322
323
324
325
326
327
328
329
330
331
def min_max_normalize(x: np.ndarray) -> np.ndarray:  # type: ignore[type-arg]
    """
    Runs min-max quantile normalization on the input array.

    Args:
        x: The input array.

    Returns: Normalized array.

    """
    vmin = np.expand_dims(np.expand_dims(np.quantile(x, q=0.01, axis=(1, 2)), 1), 2)
    vmax = np.expand_dims(np.expand_dims(np.quantile(x, q=0.99, axis=(1, 2)), 1), 2)
    return (x - vmin) / (vmax - vmin + consts.data.EPS)  # type: ignore[no-any-return]

kelp.xgb.training.train.run_training

Runs XGBoost model training.

Parameters:

Name Type Description Default
train_data_dir Path

The path to the training data.

required
dataset_fp Path

The path to the training dataset parquet file.

required
columns_to_load List[str]

The columns to load from the metadata dataset.

required
model XGBClassifier

The model to train.

required
spectral_indices List[str]

The spectral indices to append to the input records.

required
sample_size float

The fraction of samples to use for training.

1.0
plot_n_samples int

The number of samples to plot.

10
seed int

The seed for reproducibility.

SEED
explain_model bool

A flag indicating whether to run model feature importance calculation.

False
Source code in kelp/xgb/training/train.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
@timed
def run_training(
    train_data_dir: Path,
    dataset_fp: Path,
    columns_to_load: List[str],
    model: XGBClassifier,
    spectral_indices: List[str],
    sample_size: float = 1.0,
    plot_n_samples: int = 10,
    seed: int = consts.reproducibility.SEED,
    explain_model: bool = False,
) -> XGBClassifier:
    """
    Runs XGBoost model training.

    Args:
        train_data_dir: The path to the training data.
        dataset_fp: The path to the training dataset parquet file.
        columns_to_load: The columns to load from the metadata dataset.
        model: The model to train.
        spectral_indices: The spectral indices to append to the input records.
        sample_size: The fraction of samples to use for training.
        plot_n_samples: The number of samples to plot.
        seed: The seed for reproducibility.
        explain_model: A flag indicating whether to run model feature importance calculation.

    Returns: A fitted XGBClassifier.

    """
    metadata = pd.read_parquet(dataset_fp, columns=columns_to_load)
    X_train, y_train, X_val, y_val, X_test, y_test = load_data(
        df=metadata.drop(["tile_id"], axis=1, errors="ignore"),
        sample_size=sample_size,
        seed=seed,
    )
    model = fit_model(model, X_train, y_train)
    if plot_n_samples > 0:
        log_sample_predictions(
            train_data_dir=train_data_dir,
            metadata=metadata,
            model=model,
            spectral_indices=spectral_indices,
            sample_size=plot_n_samples,
            seed=seed,
        )
    eval_model(model, X_val, y_val, prefix="val", seed=seed, explain_model=explain_model)
    eval_model(model, X_test, y_test, prefix="test", seed=seed, explain_model=explain_model)
    return model