Skip to content

sahi_dataset_prep

SAHI (Slicing Aided Hyper Inference) dataset preparation logic.

kelp.data_prep.sahi_dataset_prep.SahiDatasetPrepConfig

Bases: ConfigBase

Config class for creating SAHI dataset

Source code in kelp/data_prep/sahi_dataset_prep.py
14
15
16
17
18
19
20
21
class SahiDatasetPrepConfig(ConfigBase):
    """Config class for creating SAHI dataset"""

    data_dir: Path
    metadata_fp: Path
    output_dir: Path
    image_size: int = 128
    stride: int = 64

kelp.data_prep.sahi_dataset_prep.generate_tiles_from_image

Generates small tiles from the input image using specified tile size and stride.

Parameters:

Name Type Description Default
data_dir Path

The path to the data directory.

required
tile_id str

The tile ID.

required
tile_size Tuple[int, int]

The tile size in pixels.

required
stride Tuple[int, int]

The tile stride in pixels.

required
output_dir Path

The output directory.

required
Source code in kelp/data_prep/sahi_dataset_prep.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def generate_tiles_from_image(
    data_dir: Path,
    tile_id: str,
    tile_size: Tuple[int, int],
    stride: Tuple[int, int],
    output_dir: Path,
) -> List[Tuple[int, int, float, float]]:
    """
    Generates small tiles from the input image using specified tile size and stride.

    Args:
        data_dir: The path to the data directory.
        tile_id: The tile ID.
        tile_size: The tile size in pixels.
        stride: The tile stride in pixels.
        output_dir: The output directory.

    Returns: A list of tuples with the tile coordinates and stats about kelp pixel number and kelp pixel percentage.

    """
    records: List[Tuple[int, int, float, float]] = []

    with rasterio.open(data_dir / "images" / f"{tile_id}_satellite.tif") as src:
        for j in range(0, src.height, stride[1]):
            for i in range(0, src.width, stride[0]):
                window = Window(i, j, *tile_size)
                data = src.read(window=window)

                # Check if the tile is smaller than expected
                if data.shape[1] < tile_size[0] or data.shape[2] < tile_size[1]:
                    # Pad the data to match the expected tile size
                    padded_data = np.full((src.count, *tile_size), -32768, dtype=data.dtype)
                    padded_data[:, : data.shape[1], : data.shape[2]] = data
                    data = padded_data

                # Save the tile
                output_tile_path = output_dir / "images" / f"{tile_id}_satellite_{i}_{j}.tif"
                with rasterio.open(
                    output_tile_path,
                    "w",
                    driver="GTiff",
                    height=tile_size[1],
                    width=tile_size[0],
                    count=src.count,
                    dtype=data.dtype,
                    crs=src.crs,
                    transform=src.window_transform(window),
                ) as dst:
                    dst.write(data)

    with rasterio.open(data_dir / "masks" / f"{tile_id}_kelp.tif") as src:
        for j in range(0, src.height, stride[1]):
            for i in range(0, src.width, stride[0]):
                window = Window(i, j, *tile_size)
                data = src.read(window=window)

                # Check if the tile is smaller than expected
                if data.shape[1] < tile_size[0] or data.shape[2] < tile_size[1]:
                    # Pad the data to match the expected tile size
                    padded_data = np.full((src.count, *tile_size), -32768, dtype=data.dtype)
                    padded_data[:, : data.shape[1], : data.shape[2]] = data
                    data = padded_data

                # Save the tile
                output_tile_path = output_dir / "masks" / f"{tile_id}_kelp_{i}_{j}.tif"
                with rasterio.open(
                    output_tile_path,
                    "w",
                    driver="GTiff",
                    height=tile_size[1],
                    width=tile_size[0],
                    count=src.count,
                    dtype=data.dtype,
                    crs=src.crs,
                    transform=src.window_transform(window),
                ) as dst:
                    dst.write(data)

                kelp_pct: float = data.sum() / np.prod([tile_size[1], tile_size[0]])
                kelp_pxls: float = data.sum()
                records.append((i, j, kelp_pxls, kelp_pct))

    return records

kelp.data_prep.sahi_dataset_prep.main

Main entrypoint for generating SAHI dataset.

Source code in kelp/data_prep/sahi_dataset_prep.py
162
163
164
165
def main() -> None:
    """Main entrypoint for generating SAHI dataset."""
    cfg = parse_args()
    prep_sahi_dataset(**cfg.model_dump())

kelp.data_prep.sahi_dataset_prep.parse_args

Parse command line arguments.

Returns: An instance of SahiDatasetPrepConfig.

Source code in kelp/data_prep/sahi_dataset_prep.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def parse_args() -> SahiDatasetPrepConfig:
    """
    Parse command line arguments.

    Returns: An instance of SahiDatasetPrepConfig.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--metadata_fp", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--image_size", type=int, default=128)
    parser.add_argument("--stride", type=int, default=64)
    args = parser.parse_args()
    cfg = SahiDatasetPrepConfig(**vars(args))
    cfg.log_self()
    return cfg

kelp.data_prep.sahi_dataset_prep.prep_sahi_dataset

Runs data preparation for SAHI model training.

Parameters:

Name Type Description Default
data_dir Path

The path to the data directory.

required
metadata_fp Path

The path to the metadata parquet file.

required
output_dir Path

The path to the output directory.

required
image_size int

The image size to use for tiles.

required
stride int

The stride to use for overlap between tiles.

required
Source code in kelp/data_prep/sahi_dataset_prep.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def prep_sahi_dataset(data_dir: Path, metadata_fp: Path, output_dir: Path, image_size: int, stride: int) -> None:
    """
    Runs data preparation for SAHI model training.

    Args:
        data_dir: The path to the data directory.
        metadata_fp: The path to the metadata parquet file.
        output_dir: The path to the output directory.
        image_size: The image size to use for tiles.
        stride: The stride to use for overlap between tiles.

    """
    df = pd.read_parquet(metadata_fp)
    df = df[df["original_split"] == "train"]
    records: List[Tuple[Any, ...]] = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing files"):
        out_dir_images = output_dir / "images"
        out_dir_masks = output_dir / "masks"
        out_dir_images.mkdir(exist_ok=True, parents=True)
        out_dir_masks.mkdir(exist_ok=True, parents=True)
        sub_records = generate_tiles_from_image(
            data_dir=data_dir,
            tile_id=row["tile_id"],
            output_dir=output_dir,
            tile_size=(image_size, image_size),
            stride=(stride, stride),
        )
        for j, i, kelp_pxls, kelp_pct in sub_records:
            records.append((row["tile_id"], j, i, kelp_pxls, kelp_pct))
    results_df = pd.DataFrame(records, columns=["tile_id", "j", "i", "kelp_pxls", "kelp_pct"])
    results_df = df.merge(results_df, how="inner", left_on="tile_id", right_on="tile_id")
    results_df.to_parquet(output_dir / "sahi_train_val_test_dataset.parquet")