Skip to content

factories

The factory methods.

kelp.nn.models.factories.resolve_loss

Resolves the loss function based on provided arguments.

Parameters:

Name Type Description Default
loss_fn str

The loss function name.

required
objective str

The objective.

required
device device

The device.

required
num_classes int

The number of classes.

NUM_CLASSES
ce_smooth_factor float

The smoothing factor for Cross Entropy Loss.

0.0
ce_class_weights Optional[List[float]]

The class weights for Cross Entropy Loss.

None
ignore_index Optional[int]

The index to ignore.

None
Source code in kelp/nn/models/factories.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def resolve_loss(
    loss_fn: str,
    objective: str,
    device: torch.device,
    num_classes: int = consts.data.NUM_CLASSES,
    ce_smooth_factor: float = 0.0,
    ce_class_weights: Optional[List[float]] = None,
    ignore_index: Optional[int] = None,
) -> nn.Module:
    """
    Resolves the loss function based on provided arguments.

    Args:
        loss_fn: The loss function name.
        objective: The objective.
        device: The device.
        num_classes: The number of classes.
        ce_smooth_factor: The smoothing factor for Cross Entropy Loss.
        ce_class_weights: The class weights for Cross Entropy Loss.
        ignore_index: The index to ignore.

    Returns: Resolved Loss Function module.

    """
    if loss_fn not in LOSS_REGISTRY:
        raise ValueError(f"{loss_fn=} is not supported.")

    loss_kwargs: Dict[str, Any]
    if loss_fn in ["jaccard", "dice"]:
        loss_kwargs = {
            "mode": "multiclass",
            "classes": list(range(num_classes)) if objective != "binary" else None,
        }
    elif loss_fn == "ce":
        loss_kwargs = {
            "weight": torch.tensor(ce_class_weights, device=device),
            "ignore_index": ignore_index or -100,
        }
    elif loss_fn == "soft_ce":
        loss_kwargs = {
            "ignore_index": ignore_index,
            "smooth_factor": ce_smooth_factor,
        }
    elif loss_fn == "xedice":
        loss_kwargs = {
            "mode": "multiclass",
            "ce_class_weights": torch.tensor(ce_class_weights, device=device),
        }
    elif loss_fn in [
        "focal_tversky",
        "log_cosh_dice",
        "hausdorff",
        "combo",
        "soft_dice",
        "batch_soft_dice",
        "sens_spec_loss",
    ]:
        loss_kwargs = {}
    elif loss_fn == "t_loss":
        loss_kwargs = {
            "device": device,
        }
    elif loss_fn == "exp_log_loss":
        loss_kwargs = {
            "class_weights": torch.tensor(ce_class_weights, device=device),
        }
    else:
        loss_kwargs = {
            "mode": "multiclass",
            "ignore_index": ignore_index,
        }

    return LOSS_REGISTRY[loss_fn](**loss_kwargs)

kelp.nn.models.factories.resolve_lr_scheduler

Resolves the learning rate scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer.

required
num_training_steps int

The number of training steps.

required
steps_per_epoch int

The number of training steps per epoch.

required
hyperparams Dict[str, Any]

The hyperparameters.

required
Source code in kelp/nn/models/factories.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def resolve_lr_scheduler(
    optimizer: torch.optim.Optimizer,
    num_training_steps: int,
    steps_per_epoch: int,
    hyperparams: Dict[str, Any],
) -> Optional[torch.optim.lr_scheduler.LRScheduler]:
    """
    Resolves the learning rate scheduler.

    Args:
        optimizer: The optimizer.
        num_training_steps: The number of training steps.
        steps_per_epoch: The number of training steps per epoch.
        hyperparams: The hyperparameters.

    Returns: Resolved optimizer if requested, None otherwise.

    """

    if (lr_scheduler := hyperparams["lr_scheduler"]) is None:
        return None
    elif lr_scheduler == "onecycle":
        scheduler = OneCycleLR(
            optimizer,
            max_lr=hyperparams["lr"],
            total_steps=num_training_steps,
            pct_start=hyperparams["onecycle_pct_start"],
            div_factor=hyperparams["onecycle_div_factor"],
            final_div_factor=hyperparams["onecycle_final_div_factor"],
        )
    elif lr_scheduler == "cosine":
        scheduler = CosineAnnealingLR(
            optimizer=optimizer,
            T_max=hyperparams["epochs"],
            eta_min=hyperparams["cosine_eta_min"],
        )
    elif lr_scheduler == "cyclic":
        scheduler = CyclicLR(
            optimizer=optimizer,
            max_lr=hyperparams["lr"],
            base_lr=hyperparams["cyclic_base_lr"],
            step_size_up=steps_per_epoch,
            step_size_down=steps_per_epoch,
            mode=hyperparams["cyclic_mode"],
        )
    elif lr_scheduler == "cosine_with_warm_restarts":
        scheduler = CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=steps_per_epoch,
            T_mult=hyperparams["cosine_T_mult"],
            eta_min=hyperparams["cosine_eta_min"],
        )
    elif lr_scheduler == "reduce_lr_on_plateau":
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=hyperparams["reduce_lr_on_plateau_factor"],
            patience=hyperparams["reduce_lr_on_plateau_patience"],
            threshold=hyperparams["reduce_lr_on_plateau_threshold"],
            min_lr=hyperparams["reduce_lr_on_plateau_min_lr"],
            verbose=True,
        )
    else:
        raise ValueError(f"LR Scheduler: {lr_scheduler} is not supported.")
    return scheduler

kelp.nn.models.factories.resolve_model

Resolves the model based on provided parameters.

Parameters:

Name Type Description Default
architecture str

The architecture.

required
encoder str

The encoder.

required
classes int

The number of classes.

required
in_channels int

The number of input channels.

required
encoder_weights Optional[str]

Optional pre-trained encoder weights.

None
decoder_channels Optional[List[int]]

Optional decoder channels.

None
decoder_attention_type Optional[str]

Optional decoder attention type.

None
pretrained bool

A flag indicating whether to use pre-trained model weights.

False
compile bool

A flag indicating whether to compile the model using torch.compile.

False
compile_mode str

The compile mode.

'default'
compile_dynamic Optional[bool]

A flag indicating whether to use dynamic compile.

None
ort bool

A flag indicating whether to use torch ORT compilation.

False
Source code in kelp/nn/models/factories.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def resolve_model(
    architecture: str,
    encoder: str,
    classes: int,
    in_channels: int,
    encoder_weights: Optional[str] = None,
    decoder_channels: Optional[List[int]] = None,
    decoder_attention_type: Optional[str] = None,
    pretrained: bool = False,
    compile: bool = False,
    compile_mode: str = "default",
    compile_dynamic: Optional[bool] = None,
    ort: bool = False,
) -> nn.Module:
    """
    Resolves the model based on provided parameters.

    Args:
        architecture: The architecture.
        encoder: The encoder.
        classes: The number of classes.
        in_channels: The number of input channels.
        encoder_weights: Optional pre-trained encoder weights.
        decoder_channels: Optional decoder channels.
        decoder_attention_type: Optional decoder attention type.
        pretrained: A flag indicating whether to use pre-trained model weights.
        compile: A flag indicating whether to compile the model using torch.compile.
        compile_mode: The compile mode.
        compile_dynamic: A flag indicating whether to use dynamic compile.
        ort: A flag indicating whether to use torch ORT compilation.

    Returns: Resolved model.

    """
    if decoder_channels is None:
        decoder_channels = [256, 128, 64, 32, 16]

    if architecture in _MODEL_LOOKUP:
        model_kwargs = {
            "encoder_name": encoder,
            "encoder_weights": encoder_weights if pretrained else None,
            "in_channels": in_channels,
            "classes": classes,
            "encoder_depth": len(decoder_channels),
            "decoder_channels": decoder_channels,
            "decoder_attention_type": decoder_attention_type,
        }
        if "unet" not in architecture or architecture == "efficientunet++":
            model_kwargs.pop("decoder_attention_type")
        if architecture == "fcn":
            model_kwargs.pop("encoder_name")
            model_kwargs.pop("encoder_weights")
        if architecture not in ["efficinentunet++", "manet", "resunet", "resunet++", "unet", "unet++"]:
            model_kwargs.pop("decoder_channels")
            model_kwargs.pop("encoder_depth")
        model = _MODEL_LOOKUP[architecture](**model_kwargs)
    else:
        raise ValueError(f"{architecture=} is not supported.")

    if compile:
        model = torch.compile(
            model,
            mode=compile_mode,
            dynamic=compile_dynamic,
        )

    if ort:
        if module_available("torch_ort"):
            from torch_ort import ORTModule  # noqa

            model = ORTModule(model)
        else:
            raise MisconfigurationException(
                "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
            )

    return model

kelp.nn.models.factories.resolve_optimizer

Resolves the optimizer.

Parameters:

Name Type Description Default
params Iterator[Parameter]

The model parameters.

required
hyperparams Dict[str, Any]

A dictionary of hyperparameters.

required
Source code in kelp/nn/models/factories.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def resolve_optimizer(params: Iterator[Parameter], hyperparams: Dict[str, Any]) -> torch.optim.Optimizer:
    """
    Resolves the optimizer.

    Args:
        params: The model parameters.
        hyperparams: A dictionary of hyperparameters.

    Returns: Resolved optimizer.

    """
    if (optimizer := hyperparams["optimizer"]) == "adam":
        optimizer = Adam(params, lr=hyperparams["lr"], weight_decay=hyperparams["weight_decay"])
    elif optimizer == "adamw":
        optimizer = AdamW(params, lr=hyperparams["lr"], weight_decay=hyperparams["weight_decay"])
    elif optimizer == "sgd":
        optimizer = SGD(params, lr=hyperparams["lr"], weight_decay=hyperparams["weight_decay"])
    else:
        raise ValueError(f"Optimizer: {optimizer} is not supported.")
    return optimizer