Skip to content

kelp.utils

GPU

The GPU utilities.

kelp.utils.gpu.set_gpu_power_limit_if_needed

Helper function, that sets GPU power limit if RTX 3090 is used

Parameters:

Name Type Description Default
pw int

The new power limit to set. Defaults to 250W.

250
Source code in kelp/utils/gpu.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def set_gpu_power_limit_if_needed(pw: int = 250) -> None:
    """
    Helper function, that sets GPU power limit if RTX 3090 is used

    Args:
        pw: The new power limit to set. Defaults to 250W.

    """

    stream = os.popen("nvidia-smi --query-gpu=gpu_name --format=csv")
    gpu_list = stream.read()
    if "NVIDIA GeForce RTX 3090" in gpu_list:
        os.system("sudo nvidia-smi -pm 1")
        os.system(f"sudo nvidia-smi -pl {pw}")

Logging

The logging utilities.

kelp.utils.logging.get_logger

Builds a Logger instance with provided name and log level.

Parameters:

Name Type Description Default
name str

The name for the logger.

required
log_level Union[int, str]

The default log level.

INFO

Returns:

Type Description
Logger

The logger.

Source code in kelp/utils/logging.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_logger(name: str, log_level: Union[int, str] = logging.INFO) -> logging.Logger:
    """
    Builds a `Logger` instance with provided name and log level.

    Args:
        name: The name for the logger.
        log_level: The default log level.

    Returns:
        The logger.

    """

    logger = logging.getLogger(name=name)
    logger.setLevel(log_level)

    # Prevent log messages from propagating to the parent logger
    logger.propagate = False

    # Check if handlers are already set to avoid duplication
    if not logger.handlers:
        stream_handler = logging.StreamHandler()
        formatter = logging.Formatter(fmt=consts.logging.FORMAT)
        stream_handler.setFormatter(fmt=formatter)
        logger.addHandler(stream_handler)

    return logger

kelp.utils.logging.timed

This decorator prints the execution time for the decorated function.

Parameters:

Name Type Description Default
func Callable[P, T]

The function to wrap.

required

Returns:

Type Description
Callable[P, T]

Wrapper around the function.

Source code in kelp/utils/logging.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def timed(func: Callable[P, T]) -> Callable[P, T]:
    """
    This decorator prints the execution time for the decorated function.

    Args:
        func: The function to wrap.

    Returns:
        Wrapper around the function.

    """

    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        _timed_logger.info(f"{func.__qualname__} is running...")
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        _timed_logger.info(f"{func.__qualname__} ran in {(end - start):.4f}s")
        return result

    return wrapper

MLFlow

The MLFlow utilities.

kelp.utils.mlflow.get_mlflow_run_dir

Gets MLFlow run directory given the active run and output directory. Args: current_run: The current active run. output_dir: The output directory.

Returns: A path to the MLFlow run directory.

Source code in kelp/utils/mlflow.py
 8
 9
10
11
12
13
14
15
16
17
18
def get_mlflow_run_dir(current_run: ActiveRun, output_dir: Path) -> Path:
    """
    Gets MLFlow run directory given the active run and output directory.
    Args:
        current_run: The current active run.
        output_dir: The output directory.

    Returns: A path to the MLFlow run directory.

    """
    return Path(output_dir / str(current_run.info.experiment_id) / current_run.info.run_id)

Plotting

The sample plotting utilities.

kelp.utils.plotting.plot_sample

Plot a single sample of the satellite image.

Parameters:

Name Type Description Default
input_arr ndarray

The input image array. Expects all image bands to be provided.

required
target_arr Optional[ndarray]

An optional kelp mask array.

None
predictions_arr Optional[ndarray]

An optional kelp prediction array.

None
figsize Tuple[int, int]

The figure size.

(20, 4)
ndvi_cmap str

The colormap to use for the NDVI.

'RdYlGn'
dem_cmap str

The colormap to use for the DEM band.

'viridis'
qa_mask_cmap str

The colormap to use for the QA band.

'gray'
mask_cmap str

The colormap to use for the kelp mask.

CMAP
show_titles bool

A flag indicating whether the titles should be visible.

True
suptitle Optional[str]

The title for the figure.

None
Source code in kelp/utils/plotting.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def plot_sample(
    input_arr: np.ndarray,  # type: ignore[type-arg]
    target_arr: Optional[np.ndarray] = None,  # type: ignore[type-arg]
    predictions_arr: Optional[np.ndarray] = None,  # type: ignore[type-arg]
    figsize: Tuple[int, int] = (20, 4),
    ndvi_cmap: str = "RdYlGn",
    dem_cmap: str = "viridis",
    qa_mask_cmap: str = "gray",
    mask_cmap: str = consts.data.CMAP,
    show_titles: bool = True,
    suptitle: Optional[str] = None,
) -> plt.Figure:
    """
    Plot a single sample of the satellite image.

    Args:
        input_arr: The input image array. Expects all image bands to be provided.
        target_arr: An optional kelp mask array.
        predictions_arr: An optional kelp prediction array.
        figsize: The figure size.
        ndvi_cmap: The colormap to use for the NDVI.
        dem_cmap: The colormap to use for the DEM band.
        qa_mask_cmap: The colormap to use for the QA band.
        mask_cmap: The colormap to use for the kelp mask.
        show_titles: A flag indicating whether the titles should be visible.
        suptitle: The title for the figure.

    Returns: A figure with plotted sample.

    """
    num_panels = 6

    if target_arr is not None:
        num_panels = num_panels + 1

    if predictions_arr is not None:
        num_panels = num_panels + 1

    tci = np.rollaxis(input_arr[[2, 3, 4]], 0, 3)
    tci = min_max_normalize(tci)
    false_color = np.rollaxis(input_arr[[1, 2, 3]], 0, 3)
    false_color = min_max_normalize(false_color)
    agriculture = np.rollaxis(input_arr[[0, 1, 2]], 0, 3)
    agriculture = min_max_normalize(agriculture)
    qa_mask = input_arr[5]
    dem = input_arr[6]
    ndvi = (input_arr[1] - input_arr[2]) / (input_arr[1] + input_arr[2] + consts.data.EPS)
    dem = min_max_normalize(dem)

    fig, axes = plt.subplots(nrows=1, ncols=num_panels, figsize=figsize, sharey=True)

    axes[0].imshow(tci)
    axes[1].imshow(false_color)
    axes[2].imshow(agriculture)
    axes[3].imshow(ndvi, cmap=ndvi_cmap, vmin=-1, vmax=1)
    axes[4].imshow(dem, cmap=dem_cmap)
    axes[5].imshow(qa_mask, cmap=qa_mask_cmap, interpolation=None)

    if target_arr is not None:
        axes[6].imshow(target_arr, cmap=mask_cmap, interpolation=None)

    if predictions_arr is not None:
        axes[7 if target_arr is not None else 6].imshow(predictions_arr, cmap=mask_cmap, interpolation=None)

    if show_titles:
        axes[0].set_xlabel("Natural Color (R, G, B)")
        axes[1].set_xlabel("Color Infrared (NIR, R, B)")
        axes[2].set_xlabel("Short Wave Infrared (SWIR, NIR, R)")
        axes[3].set_xlabel("NDVI")
        axes[4].set_xlabel("DEM")
        axes[5].set_xlabel("QA Mask")

        if target_arr is not None:
            axes[6].set_xlabel("Kelp Mask GT")

        if predictions_arr is not None:
            axes[7 if target_arr is not None else 6].set_xlabel("Prediction")

    if suptitle is not None:
        plt.suptitle(suptitle)

    plt.tight_layout()
    return fig

Serialization

The serialization utils.

kelp.utils.serialization.JsonEncoder

Bases: JSONEncoder

Custom JSON encoder that handles datatypes that are not out-of-the-box supported by the json package.

Source code in kelp/utils/serialization.py
13
14
15
16
17
18
19
20
21
22
23
24
25
class JsonEncoder(JSONEncoder):
    """
    Custom JSON encoder that handles datatypes that are not out-of-the-box supported by the `json` package.
    """

    def default(self, o: Any) -> str:
        if isinstance(o, datetime) or isinstance(o, date):
            return o.isoformat()

        if isinstance(o, Path):
            return o.as_posix()

        return super().default(o)  # type: ignore