Skip to content

average_predictions

Inference logic for averaging predictions from multiple folds.

kelp.nn.inference.average_predictions.AveragePredictionsConfig

Bases: ConfigBase

Config for running prediction averaging logic.

Source code in kelp/nn/inference/average_predictions.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class AveragePredictionsConfig(ConfigBase):
    """Config for running prediction averaging logic."""

    predictions_dirs: List[Path]
    output_dir: Path
    decision_threshold: float = 0.5
    weights: List[float]
    preview_submission: bool = False
    test_data_dir: Optional[Path] = None
    preview_first_n: int = 10

    @model_validator(mode="before")
    def validate_cfg(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if values["preview_submission"] and values.get("test_data_dir", None) is None:
            raise ValueError("Please provide test_data_dir param if running submission preview!")
        return values

kelp.nn.inference.average_predictions.average_predictions

Average predictions given a list of directories with predictions from single models.

Parameters:

Name Type Description Default
preds_dirs List[Path]

The list of directories with predictions from single model.

required
output_dir Path

The output directory.

required
weights List[float]

The list of weights for each fold (prediction directory).

required
decision_threshold float

The final decision threshold.

0.5
Source code in kelp/nn/inference/average_predictions.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def average_predictions(
    preds_dirs: List[Path],
    output_dir: Path,
    weights: List[float],
    decision_threshold: float = 0.5,
) -> None:
    """
    Average predictions given a list of directories with predictions from single models.

    Args:
        preds_dirs: The list of directories with predictions from single model.
        output_dir: The output directory.
        weights: The list of weights for each fold (prediction directory).
        decision_threshold: The final decision threshold.

    """
    if len(weights) != len(preds_dirs):
        raise ValueError("Number of weights must match the number prediction dirs!")

    output_dir.mkdir(parents=True, exist_ok=True)
    predictions: Dict[str, Dict[str, Union[np.ndarray, float, int]]] = {}  # type: ignore[type-arg]

    for preds_dir, weight in zip(preds_dirs, weights):
        if weight == 0.0:
            _logger.info(f"Weight for {preds_dir.name} == 0.0. Skipping this fold.")
            continue
        for pred_file in tqdm(
            sorted(list(preds_dir.glob("*.tif"))),
            desc=f"Processing files for {preds_dir.name}, {weight=}",
        ):
            file_name = pred_file.name
            with rasterio.open(pred_file) as src:
                pred_array = src.read(1) * weight
                if file_name not in predictions:
                    predictions[file_name] = {
                        "data": np.zeros_like(pred_array, dtype=np.float32),
                        "count": 1,
                        "weight_sum": weight,
                    }
                predictions[file_name]["data"] += pred_array
                predictions[file_name]["count"] += 1
                predictions[file_name]["weight_sum"] += weight

    for file_name, content in tqdm(predictions.items(), desc="Saving predictions"):
        content["data"] = content["data"] / content["weight_sum"]
        content["data"] = np.where(content["data"] >= decision_threshold, 1, 0).astype(np.uint8)
        output_file = output_dir / file_name
        with rasterio.open(output_file, "w", **META) as dst:
            dst.write(content["data"].astype(rasterio.uint8), 1)  # type: ignore[union-attr]

kelp.nn.inference.average_predictions.main

Main entrypoint for averaging the predictions and creating a submission file.

Source code in kelp/nn/inference/average_predictions.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def main() -> None:
    """Main entrypoint for averaging the predictions and creating a submission file."""
    cfg = parse_args()
    now = datetime.utcnow().isoformat()
    out_dir = cfg.output_dir / now
    preds_dir = cfg.output_dir / now / "predictions"
    preds_dir.mkdir(exist_ok=False, parents=True)
    avg_preds_config = cfg.model_dump(mode="json")
    (out_dir / "predict_config.yaml").write_text(yaml.dump(avg_preds_config))
    average_predictions(
        preds_dirs=cfg.predictions_dirs,
        output_dir=preds_dir,
        weights=cfg.weights,
        decision_threshold=cfg.decision_threshold,
    )
    create_submission_tar(
        preds_dir=preds_dir,
        output_dir=out_dir,
    )
    if cfg.preview_submission:
        plot_first_n_samples(
            data_dir=cfg.test_data_dir,  # type: ignore[arg-type]
            submission_dir=out_dir,
            output_dir=out_dir / "previews",
            n=cfg.preview_first_n,
        )

kelp.nn.inference.average_predictions.parse_args

Parse command line arguments.

Returns: An instance of AveragePredictionsConfig.

Source code in kelp/nn/inference/average_predictions.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def parse_args() -> AveragePredictionsConfig:
    """
    Parse command line arguments.

    Returns: An instance of AveragePredictionsConfig.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--predictions_dirs", nargs="*", required=True)
    parser.add_argument("--weights", nargs="*", required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--decision_threshold", type=float, default=0.5)
    parser.add_argument("--preview_submission", action="store_true")
    parser.add_argument("--test_data_dir", type=str)
    parser.add_argument("--preview_first_n", type=int, default=10)
    args = parser.parse_args()
    cfg = AveragePredictionsConfig(**vars(args))
    cfg.log_self()
    cfg.output_dir.mkdir(exist_ok=True, parents=True)
    return cfg