MLFlow artifacts
Parameters
All parameters from the training config are also logged as Run parameters. The training config is logged as yaml file too. A sample full training config can be found below:
accumulate_grad_batches: 1
almost_all_water_importance_factor: 0.5
architecture: unet
bands:
- R
- G
- B
- SWIR
- NIR
- QA
- DEM
batch_size: 32
benchmark: false
ce_class_weights:
- 0.4
- 0.6
ce_smooth_factor: 0.1
compile: false
compile_dynamic: false
compile_mode: default
cosine_T_mult: 2
cosine_eta_min: 1.0e-07
cv_split: 4
cyclic_base_lr: 1.0e-05
cyclic_mode: exp_range
data_dir: data/raw
dataset_stats_fp: data/processed/2023-12-31T20:30:39-stats-fill_value=nan-mask_using_qa=True-mask_using_water_mask=True.json
decision_threshold: 0.48
decoder_attention_type: null
decoder_channels:
- 256
- 128
- 64
- 32
- 16
dem_nan_pixels_pct_importance_factor: 0.25
dem_zero_pixels_pct_importance_factor: -1.0
early_stopping_patience: 10
encoder: tu-efficientnet_b5
encoder_depth: 5
encoder_weights: imagenet
epochs: 10
experiment: nn-train-exp
fast_dev_run: false
fill_missing_pixels_with_torch_nan: true
has_kelp_importance_factor: 3.0
ignore_index: null
image_size: 352
interpolation: nearest
kelp_pixels_pct_importance_factor: 0.2
limit_test_batches: null
limit_train_batches: null
limit_val_batches: null
log_every_n_steps: 50
loss: dice
lr: 0.0003
lr_scheduler: onecycle
mask_using_qa: true
mask_using_water_mask: true
metadata_fp: data/processed/train_val_test_dataset_strategy=cross_val.parquet
monitor_metric: val/dice
monitor_mode: max
normalization_strategy: quantile
num_classes: 2
num_workers: 6
objective: binary
onecycle_div_factor: 2.0
onecycle_final_div_factor: 100.0
onecycle_pct_start: 0.1
optimizer: adamw
ort: false
output_dir: mlruns
plot_n_batches: 3
precision: bf16-mixed
pretrained: true
qa_corrupted_pixels_pct_importance_factor: -1.0
qa_ok_importance_factor: 0.0
reduce_lr_on_plateau_factor: 0.95
reduce_lr_on_plateau_min_lr: 1.0e-06
reduce_lr_on_plateau_patience: 3
reduce_lr_on_plateau_threshold: 0.0001
resize_strategy: pad
sahi: false
samples_per_epoch: 10240
save_top_k: 1
seed: 42
spectral_indices:
- DEMWM
- NDVI
- ATSAVI
- AVI
- CI
- ClGreen
- GBNDVI
- GVMI
- IPVI
- KIVU
- MCARI
- MVI
- NormNIR
- PNDVI
- SABI
- WDRVI
- mCRIG
swa: false
swa_annealing_epochs: 10
swa_epoch_start: 0.5
swa_lr: 3.0e-05
tta: false
tta_merge_mode: max
use_weighted_sampler: true
val_check_interval: null
weight_decay: 0.0001
Metrics
The optimization metric (can be selected via training config and passed through command line arguments) is by default
set as val/dice
. The same metric is used for early stopping.
During the training loop following metrics are logged:
epoch
hp_metric
- logged only once at the end of training - theval/dice
score of the best modellr-AdamW
- theAdamW
part depends on actual optimizer used for traininglr-AdamW-momentum
- theAdamW
part depends on actual optimizer used for traininglr-AdamW-weight_decay
- theAdamW
part depends on actual optimizer used for trainingtrain/loss
train/dice
val/loss
val/dice
val/iou
val/iou_kelp
val/iou_background
val/accuracy
val/precision
val/f1
test/loss
test/dice
test/iou
test/iou_kelp
test/iou_background
test/accuracy
test/precision
test/f1
Images
Spectral indices
- ATSAVI
- AVI
- CI
- ClGreen
- DEMWM
- GBNDVI
- GVMI
- IPVI
- KIVU
- MCARI
- mCRIG
- MVI
- NDVI
- NormNIR
- PNDVI
- SABI
- WDRVI
Composites
- True color
- Color infrared
- Shortwave infrared
- DEM
- QA
- Ground Truth Mask
Predictions
The predictions for first plot_n_batches
in the val dataset are logged as a grid to monitor the model learning progress.
The data is logged after every epoch. Here are only predictions from a few epochs.
- Epoch #0
- Epoch #1
- Epoch #2
- Epoch #5
- Epoch #10
- Epoch #20
- Epoch #38 (best epoch)
Confusion matrix
- Normalized confusion matrix
- Full confusion matrix
Checkpoints
MLFlow logger has been configured to log top_k
best checkpoints and the last one (needed when running SWA).
The checkpoints will be available under checkpoints
and model
catalog in the run artifacts directory.