Skip to content

config

The XGBoost training config.

kelp.xgb.training.cfg.TrainConfig

Bases: ConfigBase

The XGBoost training configuration.

Source code in kelp/xgb/training/cfg.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class TrainConfig(ConfigBase):
    """The XGBoost training configuration."""

    dataset_fp: Path
    train_data_dir: Path
    output_dir: Path
    spectral_indices: List[str]
    sample_size: float = 1.0
    seed: int = consts.reproducibility.SEED
    plot_n_samples: int = 10
    experiment: str = "train-xgb-clf-exp"
    explain_model: bool = False

    xgb_learning_rate: float = 0.1
    xgb_n_estimators: int = 1000

    @field_validator("spectral_indices", mode="before")
    def validate_spectral_indices(cls, value: Union[str, Optional[List[str]]] = None) -> List[str]:
        if not value:
            return ["DEMWM", "NDVI"]

        if value == "all" or value == ALL_INDICES:
            return ALL_INDICES

        indices = value if isinstance(value, list) else [index.strip() for index in value.split(",")]

        if "DEMWM" in indices:
            _logger.warning("DEMWM is automatically added during training. No need to add it twice.")
            indices.remove("DEMWM")

        if "NDVI" in indices:
            _logger.warning("NDVI is automatically added during training. No need to add it twice.")
            indices.remove("NDVI")

        unknown_indices = set(indices).difference(list(SPECTRAL_INDEX_LOOKUP.keys()))
        if unknown_indices:
            raise ValueError(
                f"Unknown spectral indices were provided: {', '.join(unknown_indices)}. "
                f"Please provide comma separated indices. Valid choices are: {', '.join(SPECTRAL_INDEX_LOOKUP.keys())}."
            )

        return ["DEMWM", "NDVI"] + indices

    @property
    def resolved_experiment_name(self) -> str:
        return os.environ.get("MLFLOW_EXPERIMENT_NAME", self.experiment)

    @property
    def run_id_from_context(self) -> Optional[str]:
        return os.environ.get("MLFLOW_RUN_ID", None)

    @property
    def tags(self) -> Dict[str, Any]:
        return {"trained_at": datetime.utcnow().isoformat()}

    @property
    def columns_to_load(self) -> List[str]:
        return self.model_input_columns + ["label", "tile_id", "split"]

    @property
    def model_input_columns(self) -> List[str]:
        return list(consts.data.ORIGINAL_BANDS) + self.spectral_indices

    @property
    def random_forest_model_params(self) -> Dict[str, Any]:
        return {
            "random_state": self.seed,
        }

    @property
    def gradient_boosting_tree_model_params(self) -> Dict[str, Any]:
        return {
            "random_state": self.seed,
        }

    @property
    def catboost_model_params(self) -> Dict[str, Any]:
        return {
            "random_state": self.seed,
        }

    @property
    def xgboost_model_params(self) -> Dict[str, Any]:
        return {
            "n_estimators": self.xgb_n_estimators,
            "learning_rate": self.xgb_learning_rate,
            "random_state": self.seed,
            "device": "cuda",
            "n_jobs": -1,
        }

    @property
    def lightgbm_model_params(self) -> Dict[str, Any]:
        return {
            "random_state": self.seed,
        }