Skip to content

eda

EDA logic.

kelp.data_prep.eda.SatelliteImageStats

Bases: BaseModel

A data class for holding stats for single satellite image.

Source code in kelp/data_prep/eda.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class SatelliteImageStats(BaseModel):
    """
    A data class for holding stats for single satellite image.
    """

    tile_id: str
    aoi_id: Optional[int] = None
    split: str

    has_kelp: Optional[bool] = None
    non_kelp_pixels: Optional[int] = None
    kelp_pixels: Optional[int] = None
    kelp_pixels_pct: Optional[float] = None
    high_kelp_pixels_pct: Optional[bool] = None

    dem_nan_pixels: int
    dem_has_nans: bool
    dem_nan_pixels_pct: Optional[float] = None

    dem_zero_pixels: int
    dem_zero_pixels_pct: Optional[float] = None

    water_pixels: Optional[int] = None
    water_pixels_pct: Optional[float] = None
    almost_all_water: bool

    qa_corrupted_pixels: Optional[int] = None
    qa_ok: bool
    qa_corrupted_pixels_pct: Optional[float] = None
    high_corrupted_pixels_pct: Optional[bool] = None

kelp.data_prep.eda.build_tile_id_aoi_id_and_split_tuples

Builds a list of tile ID, AOI ID and split tuples from specified metadata dataframe.

Parameters:

Name Type Description Default
metadata DataFrame

The metadata dataframe.

required
Source code in kelp/data_prep/eda.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def build_tile_id_aoi_id_and_split_tuples(metadata: pd.DataFrame) -> List[Tuple[str, int, str]]:
    """
    Builds a list of tile ID, AOI ID and split tuples from specified metadata dataframe.

    Args:
        metadata: The metadata dataframe.

    Returns: A list of tile ID, AOI ID and split tuples.

    """
    records = []
    metadata["split"] = metadata["in_train"].apply(lambda x: "train" if x else "test")
    for _, row in tqdm(metadata.iterrows(), total=len(metadata), desc="Extracting tile_id and split tuples"):
        if row["type"] == "kelp":
            continue
        records.append((row["tile_id"], row["aoi_id"], row["split"]))
    return records

kelp.data_prep.eda.calculate_stats

Calculates statistics for single tile.

Parameters:

Name Type Description Default
tile_id_aoi_id_split_tuple Tuple[str, int, str]

A tuple with tile ID, AOI ID and split name.

required
data_dir Path

The path to the data directory.

required
Source code in kelp/data_prep/eda.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def calculate_stats(tile_id_aoi_id_split_tuple: Tuple[str, int, str], data_dir: Path) -> SatelliteImageStats:
    """
    Calculates statistics for single tile.

    Args:
        tile_id_aoi_id_split_tuple: A tuple with tile ID, AOI ID and split name.
        data_dir: The path to the data directory.

    Returns: An instance of SatelliteImageStats class.

    """
    tile_id, aoi_id, split = tile_id_aoi_id_split_tuple
    src: rasterio.DatasetReader
    with rasterio.open(data_dir / split / "images" / f"{tile_id}_satellite.tif") as src:
        input_arr = src.read()
        dem_band = input_arr[6]
        qa_band = input_arr[5]
        all_pixels = np.prod(qa_band.shape)
        dem_nan_pixels = np.where(dem_band < 0, 1, 0).sum()
        dem_zero_pixels = np.where(dem_band == 0, 1, 0).sum()
        dem_zero_pixels_pct = dem_zero_pixels / all_pixels.item()
        dem_nan_pixels_pct = dem_nan_pixels / all_pixels.item()
        water_pixels = np.where(dem_band <= 0, 1, 0).sum()
        water_pixels_pct = water_pixels / all_pixels.item()
        almost_all_water = water_pixels_pct > HIGH_DEM_ZERO_OR_NAN_PCT_THRESHOLD
        dem_has_nans = dem_nan_pixels > 0
        nan_vals = qa_band.sum()
        qa_ok = nan_vals == 0
        qa_corrupted_pixels = nan_vals.item()
        qa_corrupted_pixels_pct = nan_vals.item() / all_pixels.item()
        high_corrupted_pixels_pct = qa_corrupted_pixels_pct > HIGH_CORRUPTION_PCT_THRESHOLD

    if split != "test":
        with rasterio.open(data_dir / split / "masks" / f"{tile_id}_kelp.tif") as src:
            target_arr: np.ndarray = src.read()  # type: ignore[type-arg]
            kelp_pixels = target_arr.sum()
            non_kelp_pixels = np.prod(target_arr.shape) - kelp_pixels
            has_kelp = kelp_pixels > 0
            kelp_pixels_pct = kelp_pixels.item() / all_pixels.item()
            high_kelp_pixels_pct = kelp_pixels_pct > HIGH_KELP_PCT_THRESHOLD
    else:
        kelp_pixels = None
        has_kelp = None
        non_kelp_pixels = None
        kelp_pixels_pct = None
        high_kelp_pixels_pct = None

    return SatelliteImageStats(
        tile_id=tile_id,
        aoi_id=None if np.isnan(aoi_id) else aoi_id,
        split=split,
        has_kelp=has_kelp,
        kelp_pixels=kelp_pixels,
        kelp_pixels_pct=kelp_pixels_pct,
        high_kelp_pixels_pct=high_kelp_pixels_pct,
        non_kelp_pixels=non_kelp_pixels,
        dem_has_nans=dem_has_nans,
        dem_nan_pixels=dem_nan_pixels,
        dem_nan_pixels_pct=dem_nan_pixels_pct,
        dem_zero_pixels=dem_zero_pixels,
        dem_zero_pixels_pct=dem_zero_pixels_pct,
        water_pixels=water_pixels,
        water_pixels_pct=water_pixels_pct,
        almost_all_water=almost_all_water,
        qa_ok=qa_ok,
        qa_corrupted_pixels=qa_corrupted_pixels,
        qa_corrupted_pixels_pct=qa_corrupted_pixels_pct,
        high_corrupted_pixels_pct=high_corrupted_pixels_pct,
    )

kelp.data_prep.eda.extract_stats

Runs stats extraction from images in specified directory in parallel using Dask.

Parameters:

Name Type Description Default
data_dir Path

The path to the directory.

required
records List[Tuple[str, int, str]]

The list of tuples with tile ID, AOI ID and split per image.

required
Source code in kelp/data_prep/eda.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
@timed
def extract_stats(data_dir: Path, records: List[Tuple[str, int, str]]) -> List[SatelliteImageStats]:
    """
    Runs stats extraction from images in specified directory in parallel using Dask.

    Args:
        data_dir: The path to the directory.
        records: The list of tuples with tile ID, AOI ID and split per image.

    Returns: A list of SatelliteImageStats instances.

    """
    return (  # type: ignore[no-any-return]
        dask.bag.from_sequence(records).map(calculate_stats, data_dir=data_dir).compute()
    )

kelp.data_prep.eda.main

Main entry point for performing EDA.

Source code in kelp/data_prep/eda.py
325
326
327
328
329
330
331
332
333
334
335
336
def main() -> None:
    """Main entry point for performing EDA."""
    cfg = parse_args()
    metadata = pd.read_parquet(cfg.metadata_fp)
    cfg.output_dir.mkdir(exist_ok=True, parents=True)
    records = build_tile_id_aoi_id_and_split_tuples(metadata)

    with distributed.LocalCluster(n_workers=8, threads_per_worker=1) as cluster, distributed.Client(cluster) as client:
        _logger.info(f"Running dask cluster dashboard on {client.dashboard_link}")
        stats_records = extract_stats(cfg.data_dir, records)
        stats_df = pd.DataFrame([record.model_dump() for record in stats_records])
        plot_stats(stats_df, output_dir=cfg.output_dir)

kelp.data_prep.eda.plot_stats

Plots statistics about the training dataset.

Parameters:

Name Type Description Default
df DataFrame

The dataframe with image statistics.

required
output_dir Path

The output directory, where plots will be saved.

required
Source code in kelp/data_prep/eda.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@timed
def plot_stats(df: pd.DataFrame, output_dir: Path) -> None:
    """
    Plots statistics about the training dataset.

    Args:
        df: The dataframe with image statistics.
        output_dir: The output directory, where plots will be saved.

    """
    out_dir = output_dir / "stats"
    out_dir.mkdir(exist_ok=True, parents=True)

    df = df.replace({None: np.nan})
    df.to_parquet(out_dir / "dataset_stats.parquet", index=False)

    # Descriptive statistics for numerical columns
    desc_stats = df.describe()
    desc_stats.reset_index(names="stats").to_parquet(out_dir / "desc_stats.parquet")

    # Distribution of data in the train and test splits
    split_distribution = df["split"].value_counts()
    split_distribution.to_frame(name="value_count").to_parquet(out_dir / "split_distribution.parquet")

    # Summary
    _logger.info("desc_stats:")
    _logger.info(desc_stats)
    _logger.info("split_distribution")
    _logger.info(split_distribution)

    # Quality Analysis - Proportion of tiles with QA issues
    qa_issues_proportion = df["qa_ok"].value_counts(normalize=True)
    qa_issues_proportion.to_frame(name="value_count").to_parquet(out_dir / "qa_issues_proportion.parquet")

    # Kelp Presence Analysis - Balance between kelp and non-kelp tiles
    kelp_presence_proportion = df["has_kelp"].value_counts(normalize=True)
    kelp_presence_proportion.to_frame(name="value_count").to_parquet(out_dir / "kelp_presence_proportion.parquet")

    # Results
    _logger.info("qa_issues_proportion:")
    _logger.info(qa_issues_proportion)
    _logger.info("kelp_presence_proportion:")
    _logger.info(kelp_presence_proportion)

    # Correlation analysis
    correlation_matrix = df[["kelp_pixels", "non_kelp_pixels", "qa_corrupted_pixels", "dem_nan_pixels"]].corr()

    # Visualization of the correlation matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f")
    plt.title("Correlation Matrix")
    plt.savefig(out_dir / "corr_matrix.png")
    plt.close()

    # Additional Visualizations
    # Distribution of kelp pixels
    plt.figure(figsize=(10, 6))
    sns.histplot(df["kelp_pixels"], bins=50, kde=True)
    plt.title("Distribution of Kelp pixels")
    plt.xlabel("Number of Kelp pixels")
    plt.ylabel("Frequency")
    plt.savefig(out_dir / "kelp_pixels_distribution.png")
    plt.close()

    # Distribution of dem_nan_pixels
    plt.figure(figsize=(10, 6))
    sns.histplot(df["dem_nan_pixels"], bins=50, kde=True)
    plt.title("Distribution of DEM NaN pixels")
    plt.xlabel("Number of DEM NaN pixels")
    plt.ylabel("Frequency")
    plt.savefig(out_dir / "dem_nan_pixels_distribution.png")
    plt.close()

    # Image count per split (train/test)
    plt.figure(figsize=(8, 6))
    sns.countplot(x="split", data=df)
    plt.title("Image count per split")
    plt.xlabel("Split")
    plt.ylabel("Count")
    plt.savefig(out_dir / "splits.png")
    plt.close()

    # Image count with and without kelp forest class
    plt.figure(figsize=(8, 6))
    sns.countplot(x="has_kelp", data=df)
    plt.title("Image count with and without Kelp Forest")
    plt.xlabel("Has Kelp")
    plt.ylabel("Count")
    plt.savefig(out_dir / "has_kelp.png")
    plt.close()

    # Image count with and without QA issues
    plt.figure(figsize=(8, 6))
    sns.countplot(x="qa_ok", data=df)
    plt.title("Image count with and without QA issues")
    plt.xlabel("QA OK")
    plt.ylabel("Count")
    plt.savefig(out_dir / "qa_ok.png")
    plt.close()

    # Image count with and without NaN values in DEM band
    plt.figure(figsize=(8, 6))
    sns.countplot(x="high_corrupted_pixels_pct", data=df)
    plt.title("Image count with and without high corrupted pixel percentage")
    plt.xlabel("High corrupted pixel percentage")
    plt.ylabel("Count")
    plt.savefig(out_dir / "qa_corrupted_pixels_pct.png")
    plt.close()

    # Image count with and without NaN values in DEM band
    plt.figure(figsize=(8, 6))
    sns.countplot(x="dem_has_nans", data=df)
    plt.title("Image Count with and Without NaN Values in DEM Band")
    plt.xlabel("DEM Has NaNs")
    plt.ylabel("Count")
    plt.savefig(out_dir / "dem_has_nans.png")
    plt.close()

    # Image count with and without NaN values in DEM band
    plt.figure(figsize=(8, 6))
    sns.countplot(x="high_kelp_pixels_pct", data=df)
    plt.title("Image count with and without high percent of Kelp pixels")
    plt.xlabel("Mask high kelp pixel percentage")
    plt.ylabel("Count")
    plt.savefig(out_dir / "high_kelp_pixels_pct.png")
    plt.close()

    # Image count with and without NaN values in DEM band
    df.groupby("aoi_id").size()
    plt.figure(figsize=(8, 6))
    sns.countplot(x="aoi_id", data=df)
    plt.title("Image count with and without high percent of Kelp pixels")
    plt.xlabel("Mask high kelp pixel percentage")
    plt.ylabel("Count")
    plt.savefig(out_dir / "kelp_high_pct.png")
    plt.close()

    # Image count per AOI
    counts = df.groupby("aoi_id").size().reset_index().rename(columns={0: "count"})
    plt.figure(figsize=(10, 6))
    sns.histplot(counts["count"], bins=35, kde=True)
    plt.title("Distribution of images per AOI")
    plt.xlabel("Number of images per AOI")
    plt.ylabel("Frequency")
    plt.savefig(out_dir / "aoi_images_distribution.png")
    plt.close()

    # Images per AOI without AOIs that have single image
    counts = counts[counts["count"] > 1]
    plt.figure(figsize=(10, 6))
    sns.histplot(counts["count"], bins=35, kde=True)
    plt.title("Distribution of images per AOI (without singles)")
    plt.xlabel("Number of images per AOI")
    plt.ylabel("Frequency")
    plt.savefig(out_dir / "aoi_images_distribution_filtered.png")
    plt.close()