Skip to content

calculate_band_stats

Band stats calculation logic.

kelp.data_prep.calculate_band_stats.StatisticsCalculationConfig

Bases: ConfigBase

A Config class for running statistics calculations for training dataset.

Source code in kelp/data_prep/calculate_band_stats.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class StatisticsCalculationConfig(ConfigBase):
    """A Config class for running statistics calculations for training dataset."""

    data_dir: Path
    output_dir: Path
    mask_using_qa: bool = False
    mask_using_water_mask: bool = False
    fill_missing_pixels_with_torch_nan: bool = False

    @property
    def file_paths(self) -> List[Path]:
        """List of file paths with satellite images."""
        return sorted(list(self.data_dir.rglob("*_satellite.tif")))

    @property
    def fill_value(self) -> float:
        """Resolved fill value for masking corrupted pixels."""
        return torch.nan if self.fill_missing_pixels_with_torch_nan else 0.0  # type: ignore[no-any-return]

kelp.data_prep.calculate_band_stats.StatisticsCalculationConfig.file_paths: List[Path] property

List of file paths with satellite images.

kelp.data_prep.calculate_band_stats.StatisticsCalculationConfig.fill_value: float property

Resolved fill value for masking corrupted pixels.

kelp.data_prep.calculate_band_stats.calculate_band_statistics

Runs statistics calculation for specified images.

Parameters:

Name Type Description Default
image_paths List[Path]

The input image paths.

required
output_dir Path

The output directory.

required
mask_using_qa bool

A flag indicating whether the corrupted pixels should be masked using QA band.

False
mask_using_water_mask bool

A flag indicating whether the corrupted pixels should be masked using Water Mask.

False
fill_value float

The fill value to use for corrupted pixels.

0
Source code in kelp/data_prep/calculate_band_stats.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@torch.inference_mode()
def calculate_band_statistics(
    image_paths: List[Path],
    output_dir: Path,
    mask_using_qa: bool = False,
    mask_using_water_mask: bool = False,
    fill_value: float = 0,
) -> Dict[str, Dict[str, float]]:
    """
    Runs statistics calculation for specified images.

    Args:
        image_paths: The input image paths.
        output_dir: The output directory.
        mask_using_qa: A flag indicating whether the corrupted pixels should be masked using QA band.
        mask_using_water_mask: A flag indicating whether the corrupted pixels should be masked using Water Mask.
        fill_value: The fill value to use for corrupted pixels.

    Returns: A dictionary with per band statistics.

    """
    # Move computations to GPU if available
    transform = K.AugmentationSequential(
        AppendDEMWM(  # type: ignore
            index_dem=BAND_INDEX_LOOKUP["DEM"],
            index_qa=BAND_INDEX_LOOKUP["QA"],
        ),
        *[
            append_index_transform(
                index_swir=BAND_INDEX_LOOKUP["SWIR"],
                index_nir=BAND_INDEX_LOOKUP["NIR"],
                index_red=BAND_INDEX_LOOKUP["R"],
                index_green=BAND_INDEX_LOOKUP["G"],
                index_blue=BAND_INDEX_LOOKUP["B"],
                index_dem=BAND_INDEX_LOOKUP["DEM"],
                index_qa=BAND_INDEX_LOOKUP["QA"],
                index_water_mask=BAND_INDEX_LOOKUP["DEMWM"],
                mask_using_qa=False if index_name.endswith("WM") else mask_using_qa,
                mask_using_water_mask=False if index_name.endswith("WM") else mask_using_water_mask,
                fill_val=torch.nan,
            )
            for index_name, append_index_transform in SPECTRAL_INDEX_LOOKUP.items()
            if index_name != "DEMWM"
        ],
        data_keys=["input"],
    ).to(DEVICE)

    # Initialize statistics arrays
    band_names = BASE_BANDS + [index_name for index_name in SPECTRAL_INDEX_LOOKUP.keys() if index_name != "DEMWM"]
    num_bands = len(band_names)
    min_per_band = torch.full((num_bands,), float("inf")).to(DEVICE)
    max_per_band = torch.full((num_bands,), float("-inf")).to(DEVICE)
    sum_per_band = torch.zeros(num_bands).to(DEVICE)
    sum_sq_per_band = torch.zeros(num_bands).to(DEVICE)
    q01_items = []
    q99_items = []
    total_pixels = 0

    for image_path in tqdm(image_paths, desc="Calculating band statistics"):
        # Open the image and convert to numpy array
        src: rasterio.DatasetReader
        with rasterio.open(image_path) as src:
            image_arr = src.read()
            # Convert image to PyTorch tensor
            image = torch.from_numpy(image_arr).float().to(DEVICE).unsqueeze(0)
            # Mask missing pixels
            image = torch.where(image == -32768.0, fill_value, image)

        image = transform(image).squeeze()

        # Assuming the image has shape (num_bands, height, width)
        if image.shape[0] != num_bands:
            raise ValueError(f"Image at {image_path} does not have {num_bands} bands")

        # Update min and max
        current_image_min = torch.amin(image, dim=(1, 2))
        current_image_min = torch.where(torch.isnan(current_image_min), min_per_band, current_image_min)
        current_image_max = torch.amax(image, dim=(1, 2))
        current_image_max = torch.where(torch.isnan(current_image_max), max_per_band, current_image_max)
        min_per_band = torch.minimum(min_per_band, current_image_min)
        max_per_band = torch.maximum(max_per_band, current_image_max)

        # Update sum and sum of squares for mean and std calculation
        sum_per_band += torch.nansum(image, dim=(1, 2))
        sum_sq_per_band += torch.nansum(image**2, dim=(1, 2))

        # Update total pixel count
        total_pixels += image.shape[1] * image.shape[2]

        # Append quantile values
        q01_per_band = torch.nanquantile(image.view(image.shape[0], -1), 0.01, dim=1)
        q99_per_band = torch.nanquantile(image.view(image.shape[0], -1), 0.99, dim=1)
        q01_items.append(q01_per_band)
        q99_items.append(q99_per_band)

    # Calculate mean and standard deviation
    mean_per_band = sum_per_band / total_pixels
    std_per_band = torch.sqrt(sum_sq_per_band / total_pixels - mean_per_band**2)
    mean_q01_per_band = torch.nanmean(torch.stack(q01_items), dim=0)
    mean_q99_per_band = torch.nanmean(torch.stack(q99_items), dim=0)

    stats = {
        band_name: {
            "mean": mean_per_band[idx].item(),
            "std": std_per_band[idx].item(),
            "min": min_per_band[idx].item(),
            "max": max_per_band[idx].item(),
            "q01": mean_q01_per_band[idx].item(),
            "q99": mean_q99_per_band[idx].item(),
        }
        for idx, band_name in enumerate(band_names)
    }

    # Adjust stats for binary band
    for band, band_stats in stats.items():
        if band.endswith("WM") or band == "QA":
            band_stats["min"] = 0.0
            band_stats["max"] = 1.0
            band_stats["mean"] = 0.0
            band_stats["std"] = 1.0
            band_stats["q01"] = 0.0
            band_stats["q99"] = 1.0

    stats_str = json.dumps(stats, indent=4)
    _logger.info("Per band statistics calculated. Review and adjust!")
    _logger.info(stats_str)
    now = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")
    (output_dir / f"{now}-stats-{fill_value=}-{mask_using_qa=}-{mask_using_water_mask=}.json").write_text(stats_str)

    return stats

kelp.data_prep.calculate_band_stats.main

Main entry point for band statistics calculation.

Source code in kelp/data_prep/calculate_band_stats.py
218
219
220
221
222
223
224
225
226
227
228
229
def main() -> None:
    """
    Main entry point for band statistics calculation.
    """
    cfg = parse_args()
    calculate_band_statistics(
        image_paths=cfg.file_paths,
        output_dir=cfg.output_dir,
        mask_using_qa=cfg.mask_using_qa,
        mask_using_water_mask=cfg.mask_using_water_mask,
        fill_value=cfg.fill_value,
    )

kelp.data_prep.calculate_band_stats.parse_args

Parse command line arguments.

Returns: An instance of StatisticsCalculationConfig.

Source code in kelp/data_prep/calculate_band_stats.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def parse_args() -> StatisticsCalculationConfig:
    """
    Parse command line arguments.

    Returns: An instance of StatisticsCalculationConfig.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--mask_using_qa",
        action="store_true",
    )
    parser.add_argument(
        "--mask_using_water_mask",
        action="store_true",
    )
    parser.add_argument(
        "--fill_missing_pixels_with_torch_nan",
        action="store_true",
    )
    args = parser.parse_args()
    cfg = StatisticsCalculationConfig(**vars(args))
    cfg.log_self()
    return cfg