Skip to content

predict

Single model prediction logic.

kelp.xgb.inference.predict.PredictConfig

Bases: ConfigBase

XGBoost prediction config

Source code in kelp/xgb/inference/predict.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class PredictConfig(ConfigBase):
    """XGBoost prediction config"""

    model_config = ConfigDict(protected_namespaces=())

    data_dir: Path
    original_training_config_fp: Path
    model_path: Path
    run_dir: Path
    output_dir: Path

    @model_validator(mode="before")
    def validate_inputs(cls, data: Dict[str, Any]) -> Dict[str, Any]:
        run_dir = Path(data["run_dir"])
        if (run_dir / "model").exists():
            artifacts_dir = run_dir
        elif (run_dir / "artifacts").exists():
            artifacts_dir = run_dir / "artifacts"
        else:
            raise ValueError("Could not find nor model dir nor artifacts folder in the specified run_dir")
        model_checkpoint = artifacts_dir / "model"
        config_fp = artifacts_dir / "config.yaml"
        data["model_path"] = model_checkpoint
        data["original_training_config_fp"] = config_fp
        return data

    @property
    def training_config(self) -> TrainConfig:
        with open(self.original_training_config_fp, "r") as f:
            cfg = TrainConfig(**yaml.safe_load(f))
        return cfg

kelp.xgb.inference.predict.build_prediction_arg_parser

Builds the base parser for prediction steps.

Returns: An instance of :class:argparse.ArgumentParser.

Source code in kelp/xgb/inference/predict.py
58
59
60
61
62
63
64
65
66
67
68
69
def build_prediction_arg_parser() -> argparse.ArgumentParser:
    """
    Builds the base parser for prediction steps.

    Returns: An instance of :class:`argparse.ArgumentParser`.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--run_dir", type=str, required=True)
    return parser

kelp.xgb.inference.predict.main

Main entrypoint for running XGBoost inference.

Source code in kelp/xgb/inference/predict.py
174
175
176
177
178
179
180
181
182
def main() -> None:
    """Main entrypoint for running XGBoost inference."""
    cfg = parse_args()
    run_prediction(
        data_dir=cfg.data_dir,
        output_dir=cfg.output_dir,
        model_dir=cfg.model_path,
        spectral_indices=cfg.training_config.spectral_indices,
    )

kelp.xgb.inference.predict.parse_args

Parse command line arguments.

Returns: An instance of PredictConfig.

Source code in kelp/xgb/inference/predict.py
72
73
74
75
76
77
78
79
80
81
82
83
84
def parse_args() -> PredictConfig:
    """
    Parse command line arguments.

    Returns: An instance of PredictConfig.

    """
    parser = build_prediction_arg_parser()
    args = parser.parse_args()
    cfg = PredictConfig(**vars(args))
    cfg.log_self()
    cfg.output_dir.mkdir(exist_ok=True, parents=True)
    return cfg

kelp.xgb.inference.predict.predict

Runs XGBoost prediction on files in the specified input directory.

Parameters:

Name Type Description Default
input_dir Path

The input directory.

required
model XGBClassifier

The XGBoost model.

required
spectral_indices List[str]

The list of spectral indices to append to the input image.

required
output_dir Path

The output directory.

required
Source code in kelp/xgb/inference/predict.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def predict(input_dir: Path, model: XGBClassifier, spectral_indices: List[str], output_dir: Path) -> None:
    """
    Runs XGBoost prediction on files in the specified input directory.

    Args:
        input_dir: The input directory.
        model: The XGBoost model.
        spectral_indices: The list of spectral indices to append to the input image.
        output_dir: The output directory.

    """
    fps = sorted(list(input_dir.glob("*.tif")))
    transforms = build_append_index_transforms(spectral_indices)
    for fp in tqdm(fps, "Predicting"):
        tile_id = fp.name.split("_")[0]
        with rasterio.open(fp) as src:
            input_arr = src.read()
        prediction = predict_on_single_image(
            model=model,
            x=input_arr,
            transforms=transforms,
            columns=list(consts.data.ORIGINAL_BANDS) + spectral_indices,
        )
        dest: DatasetWriter
        with rasterio.open(output_dir / f"{tile_id}_kelp.tif", "w", **META) as dest:
            dest.write(prediction, 1)

kelp.xgb.inference.predict.predict_on_single_image

Runs inference on a single satellite image using specified XGBoost model.

Parameters:

Name Type Description Default
model XGBClassifier

The XGBoost model.

required
x ndarray

The array representing the satellite image.

required
transforms Callable[[Tensor], Tensor]

A set of transforms to apply to the input image.

required
columns List[str]

The column names for the input array.

required
decision_threshold float

The decision threshold.

0.5
Source code in kelp/xgb/inference/predict.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@torch.inference_mode()
def predict_on_single_image(
    model: XGBClassifier,
    x: np.ndarray,  # type: ignore[type-arg]
    transforms: Callable[[Tensor], Tensor],
    columns: List[str],
    decision_threshold: float = 0.5,
) -> np.ndarray:  # type: ignore[type-arg]
    """
    Runs inference on a single satellite image using specified XGBoost model.

    Args:
        model: The XGBoost model.
        x: The array representing the satellite image.
        transforms: A set of transforms to apply to the input image.
        columns: The column names for the input array.
        decision_threshold: The decision threshold.

    Returns: A numpy array with predicted mask.

    """
    tensor = torch.tensor(x, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    tensor = torch.flatten(transforms(tensor), start_dim=2).squeeze().T
    df = pd.DataFrame(tensor.detach().cpu().numpy(), columns=columns).replace({np.nan: -32768.0})
    prediction = model.predict_proba(df)
    prediction = np.where(prediction[:, 1] >= decision_threshold, 1, 0)
    prediction = prediction.reshape(x.shape[1], x.shape[2])
    return prediction  # type: ignore[no-any-return]

kelp.xgb.inference.predict.run_prediction

Runs the XGBoost inference on specified data directory.

Parameters:

Name Type Description Default
data_dir Path

The data directory.

required
output_dir Path

The output directory.

required
model_dir Path

The model directory.

required
spectral_indices List[str]

The spectral indices to append to the input image.

required
Source code in kelp/xgb/inference/predict.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def run_prediction(
    data_dir: Path,
    output_dir: Path,
    model_dir: Path,
    spectral_indices: List[str],
) -> None:
    """
    Runs the XGBoost inference on specified data directory.

    Args:
        data_dir: The data directory.
        output_dir: The output directory.
        model_dir: The model directory.
        spectral_indices: The spectral indices to append to the input image.

    """
    model = load_model(model_path=model_dir)
    predict(
        input_dir=data_dir,
        model=model,
        spectral_indices=spectral_indices,
        output_dir=output_dir,
    )