Skip to content

train_val_test_split

Train, validation and test dataset split logic.

kelp.data_prep.train_val_test_split.TrainTestSplitConfig

Bases: ConfigBase

A config for generating train and test splits.

Source code in kelp/data_prep/train_val_test_split.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class TrainTestSplitConfig(ConfigBase):
    """A config for generating train and test splits."""

    dataset_metadata_fp: Path
    stratification_columns: List[str]
    split_strategy: Literal["cross_val", "random"] = "cross_val"
    random_split_train_size: float = 0.95
    seed: int = consts.reproducibility.SEED
    splits: int = 5
    output_dir: Path

    @field_validator("stratification_columns", mode="before")
    def validate_stratification_columns(cls, val: str) -> List[str]:
        return [s.strip() for s in val.split(",")]

kelp.data_prep.train_val_test_split.filter_data

Filters dataset by removing images with high kelp pixel percentage.

Parameters:

Name Type Description Default
df DataFrame

The dataset metadata dataframe.

required
Source code in kelp/data_prep/train_val_test_split.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@timed
def filter_data(df: pd.DataFrame) -> pd.DataFrame:
    """
    Filters dataset by removing images with high kelp pixel percentage.

    Args:
        df: The dataset metadata dataframe.

    Returns: A pandas dataframe with filtered data.

    """
    df = df[df["high_kelp_pixels_pct"].isin([False, None])]
    return df

kelp.data_prep.train_val_test_split.k_fold_split

Runs Stratified K-Fold Cross Validation split on dataset.

Parameters:

Name Type Description Default
df DataFrame

The dataset metadata dataframe.

required
splits int

The number of splits to perform.

5
seed int

The seed for reproducibility.

SEED
Source code in kelp/data_prep/train_val_test_split.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@timed
def k_fold_split(df: pd.DataFrame, splits: int = 5, seed: int = consts.reproducibility.SEED) -> pd.DataFrame:
    """
    Runs Stratified K-Fold Cross Validation split on dataset.

    Args:
        df: The dataset metadata dataframe.
        splits: The number of splits to perform.
        seed: The seed for reproducibility.

    Returns: A dataframe with extra columns indicating to which splits the record belongs.

    """
    skf = StratifiedKFold(n_splits=splits, shuffle=True, random_state=seed)

    for i in range(splits):
        df[f"split_{i}"] = "train"

    for i, (_, val_idx) in enumerate(skf.split(df, df["stratification"])):
        df.loc[val_idx, f"split_{i}"] = "val"

    return df

kelp.data_prep.train_val_test_split.load_data

Loads dataset metadata parquet file.

Parameters:

Name Type Description Default
fp Path

The path to the metadata parquet file.

required
Source code in kelp/data_prep/train_val_test_split.py
83
84
85
86
87
88
89
90
91
92
93
94
@timed
def load_data(fp: Path) -> pd.DataFrame:
    """
    Loads dataset metadata parquet file.

    Args:
        fp: The path to the metadata parquet file.

    Returns: A pandas dataframe with dataset metadata.

    """
    return pd.read_parquet(fp).rename(columns={"split": "original_split"})

kelp.data_prep.train_val_test_split.main

Main entry point for running train/val/test dataset split.

Source code in kelp/data_prep/train_val_test_split.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
@timed
def main() -> None:
    """Main entry point for running train/val/test dataset split."""
    cfg = parse_args()
    (
        load_data(cfg.dataset_metadata_fp)
        .pipe(filter_data)
        .pipe(make_stratification_column, stratification_columns=cfg.stratification_columns)
        .pipe(
            split_dataset,
            split_strategy=cfg.split_strategy,
            splits=cfg.splits,
            random_split_train_size=cfg.random_split_train_size,
            seed=cfg.seed,
        )
        .pipe(save_data, output_path=cfg.output_dir / f"train_val_test_dataset_strategy={cfg.split_strategy}.parquet")
    )

kelp.data_prep.train_val_test_split.make_stratification_column

Creates a stratification column from dataset metadata and specified metadata columns.

Parameters:

Name Type Description Default
df DataFrame

The dataset metadata dataframe.

required
stratification_columns List[str]

The metadata columns to use for the stratification.

required
Source code in kelp/data_prep/train_val_test_split.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@timed
def make_stratification_column(df: pd.DataFrame, stratification_columns: List[str]) -> pd.DataFrame:
    """
    Creates a stratification column from dataset metadata and specified metadata columns.

    Args:
        df: The dataset metadata dataframe.
        stratification_columns: The metadata columns to use for the stratification.

    Returns: The same pandas dataframe with appended stratification column.

    """

    def make_stratification_key(series: pd.Series) -> str:
        vals = [f"{col}={str(series[col])}" for col in stratification_columns]
        return "-".join(vals)

    df["stratification"] = df.apply(lambda row: make_stratification_key(row), axis=1).astype("category")

    return df

kelp.data_prep.train_val_test_split.parse_args

Parse command line arguments.

Returns: An instance of TrainTestSplitConfig.

Source code in kelp/data_prep/train_val_test_split.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def parse_args() -> TrainTestSplitConfig:
    """
    Parse command line arguments.

    Returns: An instance of TrainTestSplitConfig.

    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_metadata_fp",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--stratification_columns",
        type=str,
        default="has_kelp,almost_all_water,qa_ok,high_corrupted_pixels_pct",
    )
    parser.add_argument(
        "--split_strategy",
        type=str,
        choices=["cross_val", "random"],
        default="cross_val",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=consts.reproducibility.SEED,
    )
    parser.add_argument(
        "--splits",
        type=int,
        default=10,
    )
    parser.add_argument(
        "--random_split_train_size",
        type=float,
        default=0.95,
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
    )
    args = parser.parse_args()
    cfg = TrainTestSplitConfig(**vars(args))
    cfg.output_dir.mkdir(exist_ok=True, parents=True)
    cfg.log_self()
    return cfg

kelp.data_prep.train_val_test_split.run_cross_val_split

Runs Stratified K-Fold Cross Validation split on training samples. The test samples will be marked as test split.

Parameters:

Name Type Description Default
train_samples DataFrame

The dataframe with training samples.

required
test_samples DataFrame

The dataframe with test samples.

required
splits int

The number of splits to perform.

5
seed int

The seed for reproducibility.

SEED
Source code in kelp/data_prep/train_val_test_split.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
@timed
def run_cross_val_split(
    train_samples: pd.DataFrame,
    test_samples: pd.DataFrame,
    splits: int = 5,
    seed: int = consts.reproducibility.SEED,
) -> pd.DataFrame:
    """
    Runs Stratified K-Fold Cross Validation split on training samples. The test samples will be marked as test split.

    Args:
        train_samples: The dataframe with training samples.
        test_samples: The dataframe with test samples.
        splits: The number of splits to perform.
        seed: The seed for reproducibility.

    Returns: A dataframe with merged training and test samples.

    """
    results = []
    for aoi_id, frame in train_samples[["aoi_id", "stratification"]].groupby("aoi_id"):
        results.append((aoi_id, frame["stratification"].value_counts().reset_index().iloc[0]["stratification"]))
    results_df = pd.DataFrame(results, columns=["aoi_id", "stratification"])
    train_val_samples = k_fold_split(results_df, splits=splits, seed=seed)
    train_samples = train_samples.drop("stratification", axis=1)
    train_samples = train_samples.merge(train_val_samples, how="inner", left_on="aoi_id", right_on="aoi_id")
    for split in range(splits):
        test_samples[f"split_{split}"] = "test"
    return train_samples

kelp.data_prep.train_val_test_split.run_random_split

Runs random split on train_samples. The test samples will be marked as test split.

Parameters:

Name Type Description Default
train_samples DataFrame

The dataframe with training samples.

required
test_samples DataFrame

The dataframe with test samples.

required
random_split_train_size float

The size of training split as a fraction of the whole dataset.

0.95
seed int

The seed for reproducibility.

SEED
Source code in kelp/data_prep/train_val_test_split.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
@timed
def run_random_split(
    train_samples: pd.DataFrame,
    test_samples: pd.DataFrame,
    random_split_train_size: float = 0.95,
    seed: int = consts.reproducibility.SEED,
) -> pd.DataFrame:
    """
    Runs random split on train_samples. The test samples will be marked as test split.

    Args:
        train_samples: The dataframe with training samples.
        test_samples: The dataframe with test samples.
        random_split_train_size: The size of training split as a fraction of the whole dataset.
        seed: The seed for reproducibility.

    Returns: A dataframe with training and test samples.

    """
    X_train, X_val = train_test_split(
        train_samples,
        train_size=random_split_train_size,
        stratify=train_samples["stratification"],
        shuffle=True,
        random_state=seed,
    )
    X_train["split_0"] = "train"
    X_val["split_0"] = "val"
    test_samples["split_0"] = "test"
    return pd.concat([X_train, X_val])

kelp.data_prep.train_val_test_split.save_data

Saves the specified dataframe under specified output path.

Parameters:

Name Type Description Default
df DataFrame

The dataframe to save.

required
output_path Path

The path to save the dataframe under.

required
Source code in kelp/data_prep/train_val_test_split.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
@timed
def save_data(df: pd.DataFrame, output_path: Path) -> pd.DataFrame:
    """
    Saves the specified dataframe under specified output path.

    Args:
        df: The dataframe to save.
        output_path: The path to save the dataframe under.

    Returns: The same dataframe as input.

    """
    df.to_parquet(output_path, index=False)
    return df

kelp.data_prep.train_val_test_split.split_dataset

Performs dataset split into training, validation and test sets using specified split strategy.

Parameters:

Name Type Description Default
df DataFrame

The metadata dataframe containing the training and test records.

required
split_strategy Literal['cross_val', 'random']

The strategy to use.

'cross_val'
random_split_train_size float

The size of training split as a fraction of the whole dateset.

0.95
splits int

The number of CV splits.

5
seed int

The seed for reproducibility.

SEED
Source code in kelp/data_prep/train_val_test_split.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
@timed
def split_dataset(
    df: pd.DataFrame,
    split_strategy: Literal["cross_val", "random"] = "cross_val",
    random_split_train_size: float = 0.95,
    splits: int = 5,
    seed: int = consts.reproducibility.SEED,
) -> pd.DataFrame:
    """
    Performs dataset split into training, validation and test sets using specified split strategy.

    Args:
        df: The metadata dataframe containing the training and test records.
        split_strategy: The strategy to use.
        random_split_train_size: The size of training split as a fraction of the whole dateset.
        splits: The number of CV splits.
        seed: The seed for reproducibility.

    Returns: A dataframe with training, validation and test splits.

    """
    train_samples = df[df["original_split"] == "train"].copy()
    test_samples = df[df["original_split"] == "test"].copy()
    if split_strategy == "cross_val":
        train_samples = run_cross_val_split(
            train_samples=train_samples,
            test_samples=test_samples,
            splits=splits,
            seed=seed,
        )
    elif split_strategy == "random":
        train_samples = run_random_split(
            train_samples=train_samples,
            test_samples=test_samples,
            random_split_train_size=random_split_train_size,
            seed=seed,
        )
    else:
        raise ValueError(f"{split_strategy=} is not supported")
    return pd.concat([train_samples, test_samples])