MONet Bundle#

In this notebook, we will demonstrate how to create a MONAI Bundle supporting nnUNet experiment for training and inference. In this step-by step tutorial, we will describe how to create all the required python code and YAML configuration files needed to train and evaluate a nnUNet model using the MONAI Bundle format.

The tutorial assumes that the Spleen Dataset has been already downloaded and preprocessed as described in the MONet Bundle Tutorial Notebook.

Setup environment#

[ ]:
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import nnunetv2" || pip install -q nnunetv2

Setup imports#

[ ]:
import torch
from monai.data import Dataset, DataLoader
from monai.handlers import (
    StatsHandler,
    from_engine,
    MeanDice,
    ValidationHandler,
    LrScheduleHandler,
    CheckpointSaver,
    CheckpointLoader,
    TensorBoardStatsHandler,
    MLFlowHandler,
)
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted, Transposed, SaveImaged, LoadImaged, Decollated

import re
import pathlib
import os
import yaml
import json
from monai.bundle import ConfigParser
import monai
from pathlib import Path
from odict import odict

from monai.apps.nnunet import get_nnunet_trainer, get_nnunet_monai_predictor, convert_nnunet_to_monai_bundle, convert_monai_bundle_to_nnunet

from monai.apps.nnunet import nnUNetV2Runner

#from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
#from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
import shutil


[ ]:
os.environ["MONAI_DATA_DIRECTORY"] = "/home/maia-user/Documents/MONAI/Data"

work_dir = os.path.join(os.environ["MONAI_DATA_DIRECTORY"], "nnUNet")

nnunet_raw = os.path.join(work_dir, "nnUNet_raw_data_base")
nnunet_preprocessed = os.path.join(".", work_dir, "nnUNet_preprocessed")
nnunet_results = os.path.join(".", work_dir, "nnUNet_trained_models")

if not os.path.exists(nnunet_raw):
    os.makedirs(nnunet_raw)

if not os.path.exists(nnunet_preprocessed):
    os.makedirs(nnunet_preprocessed)

if not os.path.exists(nnunet_results):
    os.makedirs(nnunet_results)

# claim environment variable
os.environ["nnUNet_raw"] = nnunet_raw
os.environ["nnUNet_preprocessed"] = nnunet_preprocessed
os.environ["nnUNet_results"] = nnunet_results
os.environ["OMP_NUM_THREADS"] = str(1)

nnUNet Trainer#

The core component for the nnUNet MONAI Bundle is the get_nnunet_trainer function. This function is responsible for creating the nnUNet trainer object from the native nnUNetv2 implementation. From the nnUNet trainer object, we can access the training components, such as the data loaders, model, learning rate scheduler, optimizer, and loss function, and perform training and inference tasks.

[ ]:
nnunet_config = {
    "dataset_name_or_id": "009",
    "configuration": "3d_fullres",
    "trainer_class_name": "nnUNetTrainer_10epochs",
    "plans_identifier": "nnUNetPlans",
    "fold": 0,
}


nnunet_trainer = get_nnunet_trainer(**nnunet_config)

The function get_nnunet_trainer accepts the following parameters:

  • dataset_name_or_id: The dataset name or ID to be used for training and evaluation.

  • fold: The fold number for the cross-validation experiment.

  • configuration: The training configuration for the nnUNet trainer, usually 3d_fullres.

  • trainer_class_name: The nnUNet trainer class name to be used for training, e.g. nnUNetTrainer.

  • plans_identifier: The nnUNet plans identifier for the dataset, e.g. nnUNetPlans.

Train and Val Data Loaders#

[ ]:
dataset_key = "case_identifier"
[ ]:
train_dataloader = nnunet_trainer.dataloader_train
train_data = [{dataset_key: k} for k in nnunet_trainer.dataloader_train.generator._data.identifiers]
train_dataset = Dataset(data=train_data)
[ ]:
val_dataloader = nnunet_trainer.dataloader_val
val_data = [{dataset_key: k} for k in nnunet_trainer.dataloader_val.generator._data.identifiers]
val_dataset = Dataset(data=val_data)

Network, Optimizer, and Loss Function#

[ ]:
device = nnunet_trainer.device

network = nnunet_trainer.network
optimizer = nnunet_trainer.optimizer
lr_scheduler = nnunet_trainer.lr_scheduler
loss = nnunet_trainer.loss

Prepare Batch Function#

The nnUnet DataLoader returns a dictionary with the data and target keys. Since the SupervisedTrainer used in the MONAI Bundle expects the data and target to be separate tensors, we need to create a custom prepare batch function to extract the data and target tensors from the dictionary.

[ ]:
def prepare_nnunet_batch(batch, device, non_blocking):
    data = batch["data"].to(device, non_blocking=non_blocking)
    if isinstance(batch["target"], list):
        target = [i.to(device, non_blocking=non_blocking) for i in batch["target"]]
    else:
        target = batch["target"].to(device, non_blocking=non_blocking)
    return data, target
[ ]:
image, label = prepare_nnunet_batch(next(iter(train_dataloader)), device="cpu", non_blocking=True)

MONAI Supervised Trainer#

The SupervisedTrainer class from MONAI is used to train the nnUNet model. For a minimal setup, we need to provide the model, optimizer, loss function, data loaders, number of epochs and the device to run the training.

[ ]:
train_handlers = [StatsHandler(output_transform=from_engine(["loss"], first=True), tag_name="train_loss")]
[ ]:
iterations = 10
epochs = 1
[ ]:
trainer = SupervisedTrainer(
    amp=True,
    device=device,
    epoch_length=iterations,
    loss_function=loss,
    max_epochs=epochs,
    network=network,
    prepare_batch=prepare_nnunet_batch,
    optimizer=optimizer,
    train_data_loader=train_dataloader,
    train_handlers=train_handlers,
)
[ ]:
trainer.run()

Adding Validation and Validation Metrics#

For a complete training setup, we need to add the validation data loader and the validation metrics to the SupervisedTrainer. Using the MONAI class SupervisedEvaluator, we can evaluate the model on the validation data loader and calculate the validation metrics (Dice Score).

[ ]:
val_key_metric = MeanDice(output_transform=from_engine(["pred", "label"]), reduction="mean", include_background=False)

additional_metrics = {
    "Val_Dice_Per_Class": MeanDice(
        output_transform=from_engine(["pred", "label"]),
        reduction="mean_batch",
        include_background=False,
    )
}

Additionally, in order to compute the Mean Dice score over the batch, we need to apply a pos-processing transformtation to the nnUNet model output. Since MeanDice accepts y and y_preds as Batch-first tensors (BCHW[D]), we need to create a custom post-processing transform to convert the nnUNet model output to the required format.

[ ]:
num_classes = 2

postprocessing = Compose(
    transforms=[
        ## Extract only high-res predictions from Deep Supervision
        Lambdad(keys=["pred", "label"], func=lambda x: x[0]),
        ## Apply Softmax to the predictions
        Activationsd(keys="pred", softmax=True),
        ## Binarize the predictions
        AsDiscreted(keys="pred", threshold=0.5),
        ## Convert the labels to one-hot
        AsDiscreted(keys="label", to_onehot=num_classes),
    ]
)
[ ]:
val_handlers = [StatsHandler(iteration_log=False)]
[ ]:
val_iterations = 10
val_interval = 1
[ ]:
evaluator = SupervisedEvaluator(
    amp=True,
    device=device,
    epoch_length=val_iterations,
    network=network,
    key_val_metric={"Val_Dice": val_key_metric},
    prepare_batch=prepare_nnunet_batch,
    val_data_loader=val_dataloader,
    val_handlers=val_handlers,
    postprocessing=postprocessing,
    additional_metrics=additional_metrics,
)

And finally, we add the evaluator to the SupervisedTrainer to calculate the validation metrics during training.

[ ]:
train_handlers.append(ValidationHandler(epoch_level=True, interval=val_interval, validator=evaluator))

We can also add the MeanDice metric to the SupervisedTrainer to calculate the mean dice score over the batch during training.

[ ]:
train_key_metric = MeanDice(output_transform=from_engine(["pred", "label"]), reduction="mean", include_background=False)

additional_metrics = {
    "Train_Dice_Per_Class": MeanDice(
        output_transform=from_engine(["pred", "label"]),
        reduction="mean_batch",
        include_background=False,
    )
}
[ ]:
trainer = SupervisedTrainer(
    amp=True,
    device=device,
    epoch_length=iterations,
    loss_function=loss,
    max_epochs=epochs,
    network=network,
    prepare_batch=prepare_nnunet_batch,
    optimizer=optimizer,
    train_data_loader=train_dataloader,
    train_handlers=train_handlers,
    key_train_metric={"Train_Dice": train_key_metric},
    postprocessing=postprocessing,
    additional_metrics=additional_metrics,
)
[ ]:
trainer.run()

Learning Rate Scheduler#

One last component to add to the SupervisedTrainer, in order to replicate the training behaviour of the native nnUNet, is the learning rate scheduler.

[ ]:
train_handlers.append(LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True))
[ ]:
trainer = SupervisedTrainer(
    amp=True,
    device=device,
    epoch_length=iterations,
    loss_function=loss,
    max_epochs=epochs,
    network=network,
    prepare_batch=prepare_nnunet_batch,
    optimizer=optimizer,
    train_data_loader=train_dataloader,
    train_handlers=train_handlers,
    key_train_metric={"Train_Dice": train_key_metric},
    postprocessing=postprocessing,
    additional_metrics=additional_metrics,
)
[ ]:
trainer.run()
[ ]:
train_handlers[-1].lr_scheduler.get_last_lr()

Checkpointing#

To save the model weights during training, we can use the CheckpointSaver callback from MONAI. This callback saves the model weights after each epoch. We can later use the CheckpointLoader to load the model weights and perform inference or resume training.

[ ]:
ckpt_dir = "MONetBundle/models/fold_0"

val_handlers.append(
    CheckpointSaver(
        save_dir=ckpt_dir,
        save_dict={
            "network_weights": nnunet_trainer.network._orig_mod
            "optimizer_state": nnunet_trainer.optimizer,
            "scheduler": nnunet_trainer.lr_scheduler,
        },
        # save_final= True,
        save_interval=1,
        save_key_metric=True,
        # final_filename= "model_final.pt",
        #key_metric_filename= "model.pt",
        n_saved=1,
    )
)
[ ]:
trainer = SupervisedTrainer(
    amp=True,
    device=device,
    epoch_length=iterations,
    loss_function=loss,
    max_epochs=epochs+1,
    network=network,
    prepare_batch=prepare_nnunet_batch,
    optimizer=optimizer,
    train_data_loader=train_dataloader,
    train_handlers=train_handlers,
    key_train_metric={"Train_Dice": train_key_metric},
    postprocessing=postprocessing,
    additional_metrics=additional_metrics,
)
[ ]:
trainer.run()

Reload Checkpoint#

When resuming the training from a checkpoint, we also want to restart the training from the same epoch. To do this, we need to load the checkpoint and update the trainer.state.epoch and trainer.state.iteration parameter in the SupervisedTrainer.

[ ]:
def subfiles(folder, prefix=None, suffix=None, join=True, sort=True):
    files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
    if prefix is not None:
        files = [f for f in files if f.startswith(prefix)]
    if suffix is not None:
        files = [f for f in files if f.endswith(suffix)]
    if sort:
        files.sort()
    if join:
        files = [os.path.join(folder, f) for f in files]
    return files


def get_checkpoint(epoch, ckpt_dir):
    if epoch == "latest":

        latest_checkpoints = subfiles(ckpt_dir, prefix="checkpoint_epoch", sort=True, join=False)
        epochs = []
        for latest_checkpoint in latest_checkpoints:
            epochs.append(int(latest_checkpoint[len("checkpoint_epoch=") : -len(".pt")]))

        epochs.sort()
        if len(epochs) == 0:
            return None
        latest_epoch = epochs[-1]
        return latest_epoch
    else:
        return epoch


def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir, lr_scheduler=None):

    epoch_to_load = get_checkpoint(epoch, ckpt_dir)
    trainer.state.epoch = epoch_to_load
    trainer.state.iteration = (epoch_to_load * num_train_batches_per_epoch) + 1

    if lr_scheduler is not None:
        lr_scheduler.ctr = epoch_to_load
        lr_scheduler.step(epoch_to_load)
[ ]:
reload_checkpoint_epoch = "latest"

train_handlers.append(
    CheckpointLoader(
        load_path=os.path.join(
            ckpt_dir, "checkpoint_epoch=" + str(get_checkpoint(reload_checkpoint_epoch, ckpt_dir)) + ".pt"
        ),
        load_dict={
            "network_weights": nnunet_trainer.network._orig_mod,
            "optimizer_state": nnunet_trainer.optimizer,
            "scheduler": nnunet_trainer.lr_scheduler,
        },
        map_location=device,
    )
)

Initial nnUNet Checkpoint#

In order to provide compatibility with the native nnUNet, we need to save the nnUNet-specific configuration, together the regular MONAI checkpoint. This is done only once, before the training starts. At the end of the training, we will have a MONAI checkpoint and a nnUNet checkpoint. To be able to convert the MONAI checkpoint to a nnUNet checkpoint at any time, we can then combine the two checkpoints.

[ ]:
checkpoint = {
    "inference_allowed_mirroring_axes": nnunet_trainer.inference_allowed_mirroring_axes,
    "init_args": nnunet_trainer.my_init_kwargs,
    "trainer_name": nnunet_trainer.__class__.__name__,
}
checkpoint_filename = os.path.join(Path(ckpt_dir).parent, "nnunet_checkpoint.pth")

torch.save(checkpoint, checkpoint_filename)

MLFlow and Tensorboard Monitoring#

To monitor the training process, we can use MLFlow and Tensorboard. We can log the training metrics, hyperparameters, and model weights to MLFlow, and visualize the training metrics using Tensorboard.

[ ]:
log_dir = "MONetBundle/logs"

train_handlers.append(
    TensorBoardStatsHandler(log_dir=log_dir, output_transform=from_engine(["loss"], first=True), tag_name="train_loss")
)

val_handlers.append(TensorBoardStatsHandler(log_dir=log_dir, iteration_log=False))
[ ]:
def mlflow_transform(state_output):
    return state_output[0]["loss"]


class MLFlownnUNetHandler(MLFlowHandler):
    def __init__(self, label_dict, **kwargs):
        super(MLFlownnUNetHandler, self).__init__(**kwargs)
        self.label_dict = label_dict

    def _default_epoch_log(self, engine) -> None:
        """
        Execute epoch level log operation.
        Default to track the values from Ignite `engine.state.metrics` dict and
        track the values of specified attributes of `engine.state`.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        """
        log_dict = engine.state.metrics
        if not log_dict:
            return

        current_epoch = self.global_epoch_transform(engine.state.epoch)

        new_log_dict = {}

        for metric in log_dict:
            if type(log_dict[metric]) == torch.Tensor:
                for i, val in enumerate(log_dict[metric]):
                    new_log_dict[metric+"_{}".format(self.label_dict[list(self.label_dict.keys())[i+1]])] = val
            else:
                new_log_dict[metric] = log_dict[metric]
        self._log_metrics(new_log_dict, step=current_epoch)

        if self.state_attributes is not None:
            attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}
            self._log_metrics(attrs, step=current_epoch)
[ ]:
def create_mlflow_experiment_params(params_file, custom_params=None):
    params_dict = {}
    config_values = monai.config.deviceconfig.get_config_values()
    for k in config_values:
        params_dict[re.sub("[()]", " ", str(k))] = config_values[k]

    optional_config_values = monai.config.deviceconfig.get_optional_config_values()
    for k in optional_config_values:
        params_dict[re.sub("[()]", " ", str(k))] = optional_config_values[k]

    gpu_info = monai.config.deviceconfig.get_gpu_info()
    for k in gpu_info:
        params_dict[re.sub("[()]", " ", str(k))] = str(gpu_info[k])

    yaml_config_files = [params_file]
    # %%
    monai_config = {}
    for config_file in yaml_config_files:
        with open(config_file, "r") as file:
            monai_config.update(yaml.safe_load(file))

    monai_config["bundle_root"] = str(Path(Path(params_file).parent).parent)

    parser = ConfigParser(monai_config, globals={"os": "os", "pathlib": "pathlib", "json": "json", "ignite": "ignite"})

    parser.parse(True)

    for k in monai_config:
        params_dict[k] = parser.get_parsed_content(k, instantiate=True)

    if custom_params is not None:
        for k in custom_params:
            params_dict[k] = custom_params[k]
    return params_dict
[ ]:
%%writefile MONetBundle/mlflow_params.yaml

dataset_name_or_id: "009"
nnunet_trainer_class_name: "nnUNetTrainer"
nnunet_plans_identifier: "nnUNetPlans"

num_classes: 2
label_dict:
    0: "background"
    1: "spleen"

tracking_uri: "http://localhost:5000"
mlflow_experiment_name: "MONet_Bundle_Spleen"
mlflow_run_name: "MONet_Bundle_Spleen"



[ ]:
mlflow_experiment_name = "MONet_Bundle_Spleen"
mlflow_run_name = "MONet_Bundle_Spleen"
label_dict = {0: "background", 1: "Spleen"}
tracking_uri = "http://localhost:5000"

params_file = "MONetBundle/mlflow_params.yaml"


train_handlers.append(
    MLFlownnUNetHandler(
        dataset_dict={"train": train_dataset},
        dataset_keys=dataset_key,
        experiment_param=create_mlflow_experiment_params(params_file),
        experiment_name=mlflow_experiment_name,
        label_dict=label_dict,
        output_transform=mlflow_transform,
        run_name=mlflow_run_name,
        state_attributes=["best_metric", "best_metric_epoch"],
        tag_name="Train_Loss",
        tracking_uri=tracking_uri,
    )
)

val_handlers.append(
    MLFlownnUNetHandler(
        experiment_name=mlflow_experiment_name,
        iteration_log=False,
        label_dict=label_dict,
        output_transform=mlflow_transform,
        run_name=mlflow_run_name,
        state_attributes=["best_metric", "best_metric_epoch"],
        tracking_uri=tracking_uri,
    )
)

To start the MLFlow server, we can run the following command in the terminal:

cd MLFlow && mlflow server

To run Tensorboard, we can use the following command:

tensorboard --logdir MONetBundle/logs
[ ]:
trainer = SupervisedTrainer(
    amp=True,
    device=device,
    epoch_length=iterations,
    loss_function=loss,
    max_epochs=epochs+2,
    network=network,
    prepare_batch=prepare_nnunet_batch,
    optimizer=optimizer,
    train_data_loader=train_dataloader,
    train_handlers=train_handlers,
    key_train_metric={"Train_Dice": train_key_metric},
    postprocessing=postprocessing,
    additional_metrics=additional_metrics,
)
[ ]:
trainer.run()

Create MONAI Bundle#

[ ]:
%%bash

python -m monai.bundle init_bundle MONetBundle

mkdir -p MONetBundle/nnUNet
mkdir -p MONetBundle/src
mkdir -p MONetBundle/nnUNet/evaluator
which tree && tree MONetBundle || true
[ ]:
%%writefile MONetBundle/configs/logging.conf
[loggers]
keys=root

[handlers]
keys=consoleHandler

[formatters]
keys=fullFormatter

[logger_root]
level=INFO
handlers=consoleHandler

[handler_consoleHandler]
class=StreamHandler
level=INFO
formatter=fullFormatter
args=(sys.stdout,)

[formatter_fullFormatter]
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s

[ ]:
%%writefile MONetBundle/configs/metadata.json

{
    "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
    "version": "0.1.0",
    "changelog": {
        "0.1.0": "Initial release",
    },
    "monai_version": "1.4.0",
    "pytorch_version": "2.3.0",
    "numpy_version": "1.21.2",
    "required_packages_version": {"nnunetv2": "2.6.0"},
    "task": "Decathlon spleen segmentation with nnUNet",
    "description": "A pre-trained  nnUNet model for volumetric (3D) segmentation of the spleen from CT image",
    "authors": "Simone Bendazzoli",
    "copyright": "Copyright (c) MONAI Consortium",
    "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
    "data_type": "nifti",
    "image_classes": "single channel data, intensity scaled to [0, 1]",
    "label_classes": "single channel data, 1 is spleen, 0 is everything else",
    "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",
    "eval_metrics": {
        "mean_dice": 0.97
    },
    "intended_use": "This is an example, not to be used for diagnostic purposes",
    "references": [
        "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211."
    ],
    "network_data_format":{
        "inputs": {
            "image": {
                "type": "image",
                "format": "hounsfield",
                "modality": "CT",
                "num_channels": 1,
                "spatial_shape": ['*', '*', '*'],
                "dtype": "float32",
                "value_range": [-1024, 1024],
                "is_patch_data": false,
                "channel_def": {"0": "image"}
            }
        },
        "outputs":{
            "pred": {
                "type": "image",
                "format": "segmentation",
                "num_channels": 1,
                "spatial_shape": ['*', '*', '*'],
                "dtype": "float32",
                "value_range": [0,1],
                "is_patch_data": false,
                "channel_def": {"0": "background", "1": "spleen"}
            }
        }
    }
}
[ ]:
%%writefile MONetBundle/nnUNet/global.yaml

iterations: $@nnunet_trainer.num_iterations_per_epoch
device: $@nnunet_trainer.device
epochs: $@nnunet_trainer.num_epochs

fold_id: 0

bundle_root: .
ckpt_dir: "$@bundle_root + '/models/fold_'+str(@fold_id)"
[ ]:
%%writefile MONetBundle/nnUNet/params.yaml


dataset_name_or_id: "100"
nnunet_trainer_class_name: "nnUNetTrainer"
nnunet_plans_identifier: "nnUNetPlans"
nnunet_configuration: "3d_fullres"

num_classes: 2
label_dict:
    0: "background"
    1: "class1"

tracking_uri: "http://localhost:5000"
mlflow_experiment_name: "nnUNet_Bundle"
mlflow_run_name: "nnUNet_Bundle"
log_dir: "$@bundle_root + '/logs'"
[ ]:
%%writefile MONetBundle/nnUNet/imports.yaml

imports:
- $import glob
- $import os
- $import ignite
- $import torch
- $import shutil
- $import json
- $import src
- $import nnunetv2
- $import src.mlflow
- $import src.trainer
- $from pathlib import Path
[ ]:
%%writefile MONetBundle/nnUNet/run.yaml

run:
- "$torch.save(@checkpoint,@checkpoint_filename)"
- "$shutil.copy(Path(@nnunet_model_folder).joinpath('dataset.json'), @bundle_root+'/models/dataset.json')"
- "$shutil.copy(Path(@nnunet_model_folder).joinpath('plans.json'), @bundle_root+'/models/plans.json')"
- "$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})"
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@train#trainer.run()

initialize:
- $monai.utils.set_determinism(seed=123)
[ ]:
%%writefile MONetBundle/nnUNet/nnunet_trainer.yaml

nnunet_trainer:
  _target_ : get_nnunet_trainer
  dataset_name_or_id: "@dataset_name_or_id"
  configuration: "@nnunet_configuration"
  fold: '@fold_id'
  trainer_class_name: "@nnunet_trainer_class_name"
  plans_identifier: "@nnunet_plans_identifier"

loss: $@nnunet_trainer.loss
lr_scheduler: $@nnunet_trainer.lr_scheduler

network: $@nnunet_trainer.network

optimizer: $@nnunet_trainer.optimizer

checkpoint:
  init_args: '$@nnunet_trainer.my_init_kwargs'
  trainer_name: '$@nnunet_trainer.__class__.__name__'
  inference_allowed_mirroring_axes: '$@nnunet_trainer.inference_allowed_mirroring_axes'

checkpoint_filename: "$@bundle_root+'/models/nnunet_checkpoint.pth'"

dataset_name: "$nnunetv2.utilities.dataset_name_id_conversion.maybe_convert_to_dataset_name(@dataset_name_or_id)"
nnunet_model_folder: "$os.path.join(os.environ['nnUNet_results'], @dataset_name, @nnunet_trainer_class_name+'__'+@nnunet_plans_identifier+'__'+@nnunet_configuration)"
[ ]:
%%writefile MONetBundle/nnUNet/train_metrics.yaml

train_key_metric:
  Train_Dice:
    _target_: "MeanDice"
    include_background: False
    output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
    reduction: "mean"

train_additional_metrics:
  Train_Dice_per_class:
    _target_: "MeanDice"
    include_background: False
    output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
    reduction: "mean_batch"
[ ]:
%%writefile MONetBundle/nnUNet/train_postprocessing.yaml

train_postprocessing:
  _target_: "Compose"
  transforms:
  - _target_: Lambdad
    keys:
      - "pred"
      - "label"
    func: "$lambda x: x[0]"
  - _target_: Activationsd
    keys:
      - "pred"
    softmax: True
  - _target_: AsDiscreted
    keys:
      - "pred"
    threshold: 0.5
  - _target_: AsDiscreted
    keys:
      - "label"
    to_onehot: "@num_classes"

train_postprocessing_region_based:
  _target_: "Compose"
  transforms:
  - _target_: Lambdad
    keys:
      - "pred"
      - "label"
    func: "$lambda x: x[0]"
  - _target_: Activationsd
    keys:
      - "pred"
    sigmoid: True
  - _target_: AsDiscreted
    keys:
      - "pred"
    threshold: 0.5
[ ]:
%%writefile MONetBundle/nnUNet/train.yaml

dataset_key: "case_identifier"
train:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  dataloader: $@nnunet_trainer.dataloader_train
  train_data: "$[{'@dataset_key':k} for k in @nnunet_trainer.dataloader_train.generator._data.identifiers]"
  train_dataset:
    _target_: Dataset
    data: "@train#train_data"
  inferer:
    _target_: SimpleInferer
  trainer:
    _target_: SupervisedTrainer
    amp: true
    device: '@device'
    additional_metrics: "@train_additional_metrics"
    epoch_length: "@iterations"
    inferer: '@train#inferer'
    key_train_metric: '@train_key_metric'
    loss_function: '@loss'
    max_epochs: '@epochs'
    network: '@network'
    prepare_batch: "$src.trainer.prepare_nnunet_batch"
    optimizer: '@optimizer'
    postprocessing: '@train_postprocessing'
    train_data_loader: '@train#dataloader'
    train_handlers: '@train_handlers#handlers'
[ ]:
%%writefile MONetBundle/nnUNet/train_handlers.yaml

train_handlers:
  handlers:
  - _target_: "$src.mlflow.MLFlownnUNetHandler"
    label_dict: "@label_dict"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.mlflow.mlflow_transform"
    dataset_dict:
        train: "@train#train_dataset"
    dataset_keys: '@dataset_key'
    state_attributes:
    - "iteration"
    - "epoch"
    tag_name: 'Train_Loss'
    experiment_param: "$src.mlflow.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')"
    #artifacts=None
    optimizer_param_names: 'lr'
    #close_on_complete: False
  - _target_: LrScheduleHandler
    lr_scheduler: '@lr_scheduler'
    print_lr: true
  - _target_: ValidationHandler
    epoch_level: true
    interval: '@val_interval'
    validator: '@validate#evaluator'
  #- _target_: StatsHandler
  #  output_transform: $monai.handlers.from_engine(['loss'], first=True)
  #  tag_name: train_loss
  - _target_: TensorBoardStatsHandler
    log_dir: '@log_dir'
    output_transform: $monai.handlers.from_engine(['loss'], first=True)
    tag_name: train_loss
[ ]:
%%writefile MONetBundle/configs/train_resume.yaml

run:
- '$src.trainer.reload_checkpoint(@train#trainer,@reload_checkpoint_epoch,@iterations,@ckpt_dir)'
- "$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})"
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@train#trainer.run()

train_handlers:
  handlers:
  - _target_: "$src.mlflow.MLFlownnUNetHandler"
    label_dict: "@label_dict"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.mlflow.mlflow_transform"
    dataset_dict:
        train: "@train#train_dataset"
    dataset_keys: '@dataset_key'
    state_attributes:
    - "iteration"
    - "epoch"
    tag_name: 'Train_Loss'
    experiment_param: "$src.mlflow.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')"
    #artifacts=None
    optimizer_param_names: 'lr'
    #close_on_complete: False
  - _target_: LrScheduleHandler
    lr_scheduler: '@lr_scheduler'
    print_lr: true
  - _target_: ValidationHandler
    epoch_level: true
    interval: '@val_interval'
    validator: '@validate#evaluator'
  #- _target_: StatsHandler
  #  output_transform: $monai.handlers.from_engine(['loss'], first=True)
  #  tag_name: train_loss
  - _target_: TensorBoardStatsHandler
    log_dir: '@log_dir'
    output_transform: $monai.handlers.from_engine(['loss'], first=True)
    tag_name: train_loss
  - _target_: CheckpointLoader
    load_dict:
      network_weights: '$@nnunet_trainer.network'
      optimizer_state: '$@nnunet_trainer.optimizer'
      scheduler: '$@nnunet_trainer.lr_scheduler'
    load_path: '$@ckpt_dir+"/checkpoint_epoch="+str(src.trainer.get_checkpoint(@reload_checkpoint_epoch, @ckpt_dir))+".pt"'
    map_location: '@device'
[ ]:
%%writefile MONetBundle/nnUNet/val_metrics.yaml

val_key_metric:
  Val_Dice:
    _target_: "MeanDice"
    output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
    reduction: "mean"
    include_background: False

val_additional_metrics:
  Val_Dice_per_class:
    _target_: "MeanDice"
    output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
    reduction: "mean_batch"
    include_background: False
[ ]:
%%writefile MONetBundle/nnUNet/val_handlers.yaml

val_handlers:
  handlers:
  - _target_: StatsHandler
    iteration_log: false
  - _target_: TensorBoardStatsHandler
    iteration_log: false
    log_dir: '@log_dir'
  - _target_: "$src.mlflow.MLFlownnUNetHandler"
    label_dict: "@label_dict"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.mlflow.mlflow_transform"
    iteration_log: False
    state_attributes:
    - "best_metric"
    - "best_metric_epoch"
  - _target_: "CheckpointSaver"
    save_dir: "@ckpt_dir"
    save_interval: 1
    n_saved: 1
    save_key_metric: true
    save_dict:
      network_weights: '$@nnunet_trainer.network'
      optimizer_state: '$@nnunet_trainer.optimizer'
      scheduler: '$@nnunet_trainer.lr_scheduler'
[ ]:
%%writefile MONetBundle/nnUNet/validate.yaml

val_interval: 1
validate:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  dataloader: $@nnunet_trainer.dataloader_val
  evaluator:
    _target_: SupervisedEvaluator
    additional_metrics: '@val_additional_metrics'
    amp: true
    epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch
    device: '@device'
    inferer: '@validate#inferer'
    key_val_metric: '@val_key_metric'
    network: '@network'
    postprocessing: '@train_postprocessing'
    val_data_loader: '@validate#dataloader'
    val_handlers: '@val_handlers#handlers'
    prepare_batch: "$src.trainer.prepare_nnunet_batch"
  inferer:
    _target_: SimpleInferer

[ ]:
%%writefile MONetBundle/nnUNet/evaluator/evaluator.yaml

run:
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@validate#evaluator.run()

initialize:
- "$setattr(torch.backends.cudnn, 'benchmark', True)"

Adding Python Utility Scripts#

We finally add the MLFlow and Training utility scripts to the MONAI Bundle.

[ ]:
%%writefile MONetBundle/src/__init__.py


[ ]:
%%writefile MONetBundle/src/mlflow.py

import re
from monai.handlers import  MLFlowHandler
import yaml
from monai.bundle import ConfigParser
from pathlib import Path
import monai
import torch

def mlflow_transform(state_output):
    """
    Extracts the 'loss' value from the first element of the state_output list.

    Parameters
    ----------
    state_output : list of dict
        A list where each element is a dictionary containing various metrics, including 'loss'.

    Returns
    -------
    float
        The 'loss' value from the first element of the state_output list.
    """
    return state_output[0]['loss']

class MLFlownnUNetHandler(MLFlowHandler):
    """
    A handler for logging nnUNet metrics to MLFlow.
    Parameters
    ----------
    label_dict : dict
        A dictionary mapping label indices to label names.
    **kwargs : dict
        Additional keyword arguments passed to the parent class.
    """
    def __init__(self, label_dict, **kwargs):
        super(MLFlownnUNetHandler, self).__init__(**kwargs)
        self.label_dict = label_dict

    def _default_epoch_log(self, engine) -> None:
        """
        Logs the metrics and state attributes at the end of each epoch.

        Parameters
        ----------
        engine : Engine
            The engine object that contains the state and metrics to be logged.

        Returns
        -------
        None
        """
        log_dict = engine.state.metrics
        if not log_dict:
            return

        current_epoch = self.global_epoch_transform(engine.state.epoch)

        new_log_dict = {}

        for metric in log_dict:
            if type(log_dict[metric]) == torch.Tensor:
                for i,val in enumerate(log_dict[metric]):
                    new_log_dict[metric+"_{}".format(self.label_dict[list(self.label_dict.keys())[i+1]])] = val
            else:
                new_log_dict[metric] = log_dict[metric]
        self._log_metrics(new_log_dict, step=current_epoch)

        if self.state_attributes is not None:
            attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}
            self._log_metrics(attrs, step=current_epoch)

def create_mlflow_experiment_params(params_file, custom_params=None):
    """
    Create a dictionary of parameters for an MLflow experiment.

    This function reads configuration values from MONAI, GPU information, and a YAML configuration file,
    and combines them into a single dictionary. Optionally, custom parameters can also be added to the dictionary.

    Parameters
    ----------
    params_file : str
        Path to the YAML configuration file.
    custom_params : dict, optional
        A dictionary of custom parameters to be added to the final parameters dictionary (default is None).

    Returns
    -------
    dict
        A dictionary containing all the combined parameters.
    """
    params_dict = {}
    config_values = monai.config.deviceconfig.get_config_values()
    for k in config_values:
        params_dict[re.sub("[()]"," ",str(k))] = config_values[k]

    optional_config_values = monai.config.deviceconfig.get_optional_config_values()
    for k in optional_config_values:
        params_dict[re.sub("[()]"," ",str(k))] = optional_config_values[k]

    gpu_info = monai.config.deviceconfig.get_gpu_info()
    for k in gpu_info:
        params_dict[re.sub("[()]"," ",str(k))] = str(gpu_info[k])

    yaml_config_files = [params_file]
    # %%
    monai_config = {}
    for config_file in yaml_config_files:
        with open(config_file, 'r') as file:
            monai_config.update(yaml.safe_load(file))

    monai_config["bundle_root"] = str(Path(Path(params_file).parent).parent)

    parser = ConfigParser(monai_config, globals={"os": "os",
                                                 "pathlib": "pathlib",
                                                 "json": "json",
                                                 "ignite": "ignite"
                                                 })

    parser.parse(True)

    for k in monai_config:
        params_dict[k] = parser.get_parsed_content(k,instantiate=True)

    if custom_params is not None:
        for k in custom_params:
            params_dict[k] = custom_params[k]
    return params_dict

[ ]:
%%writefile MONetBundle/src/trainer.py

import os

def subfiles(directory, prefix=None, suffix=None, join=True, sort=True):
    """
    List files in a directory with optional filtering by prefix and/or suffix.

    Parameters
    ----------
    directory : str
        The path to the directory to list files from.
    prefix : str, optional
        If specified, only files starting with this prefix will be included.
    suffix : str, optional
        If specified, only files ending with this suffix will be included.
    join : bool, optional
        If True, the directory path will be joined with the filenames. Default is True.
    sort : bool, optional
        If True, the list of files will be sorted. Default is True.

    Returns
    -------
    list of str
        A list of filenames (with full paths if `join` is True) that match the specified criteria.
    """


    files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    if prefix is not None:
        files = [f for f in files if f.startswith(prefix)]
    if suffix is not None:
        files = [f for f in files if f.endswith(suffix)]
    if join:
        files = [os.path.join(directory, f) for f in files]
    if sort:
        files.sort()
    return files

def prepare_nnunet_batch(batch, device, non_blocking):
    """
    Prepares a batch of data and targets for nnU-Net training by transferring them to the specified device.

    Parameters
    ----------
    batch : dict
        A dictionary containing the data and target tensors. The key "data" corresponds to the input data tensor,
        and the key "target" corresponds to the target tensor or a list of target tensors.
    device : torch.device
        The device to which the data and target tensors should be transferred (e.g., 'cuda' or 'cpu').
    non_blocking : bool
        If True, allows non-blocking data transfer to the device.

    Returns
    -------
    tuple
        A tuple containing the data tensor and the target tensor(s) after being transferred to the specified device.
    """
    data = batch["data"].to(device, non_blocking=non_blocking)
    if isinstance(batch["target"], list):
        target = [i.to(device, non_blocking=non_blocking) for i in batch["target"]]
    else:
        target = batch["target"].to(device, non_blocking=non_blocking)
    return data, target

def get_checkpoint(epoch, ckpt_dir):
    """
    Retrieves the checkpoint for a given epoch from the checkpoint directory.

    Parameters
    ----------
    epoch : int or str
        The epoch number to retrieve. If 'latest', the function will return the latest checkpoint.
    ckpt_dir : str
        The directory where checkpoints are stored.

    Returns
    -------
    int
        The epoch number of the checkpoint to be retrieved. If 'latest', returns the latest epoch number.
    """
    if epoch == "latest":

        latest_checkpoints = subfiles(ckpt_dir, prefix="checkpoint_epoch", sort=True,
                                      join=False)
        epochs = []
        for latest_checkpoint in latest_checkpoints:
            epochs.append(int(latest_checkpoint[len("checkpoint_epoch="):-len(".pt")]))

        epochs.sort()
        latest_epoch = epochs[-1]
        return latest_epoch
    else:
        return epoch

def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir, lr_scheduler=None):
    """
    Reloads the checkpoint for a given epoch and updates the trainer's state.

    Parameters
    ----------
    trainer : object
        The trainer object whose state needs to be updated.
    epoch : int
        The epoch number to load the checkpoint from.
    num_train_batches_per_epoch : int
        The number of training batches per epoch.
    ckpt_dir : str
        The directory where the checkpoints are stored.
    lr_scheduler : object, optional
        The learning rate scheduler to be updated (default is None).

    Returns
    -------
    None
    """

    epoch_to_load = get_checkpoint(epoch, ckpt_dir)
    trainer.state.epoch = epoch_to_load
    trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1
    if lr_scheduler is not None:
        lr_scheduler.ctr = epoch_to_load
        lr_scheduler.step(epoch_to_load)

[ ]:
def create_config(config_folder, output_file):
    config_files = [f.path for f in os.scandir(config_folder) if f.path.endswith(".yaml")]
    config = {}
    for config_file in config_files:
        with open(config_file, "r") as file:
            config.update(yaml.safe_load(file))

    if output_file.endswith(".yaml"):
        with open(output_file, "w") as file:
            yaml.dump(config, file)
    if output_file.endswith(".json"):
        with open(output_file, "w") as file:
            json.dump(config, file)

    return config
[ ]:
import os
import yaml
config = create_config("MONetBundle/nnUNet", "MONetBundle/configs/train.yaml")
[ ]:
%%bash

export nnUNet_results=/home/maia-user/Data/nnUNet/nnUNet_trained_models
export nnUNet_raw=/home/maia-user/Data/nnUNet/nnUNet_raw_data_base
export nnUNet_preprocessed=/home/maia-user/Data/nnUNet/nnUNet_preprocessed
export nnUNet_def_n_proc=2
export nnUNet_n_proc_DA=2
export BUNDLE_ROOT=MONetBundle
export PYTHONPATH=$PYTHONPATH:$BUNDLE_ROOT

python -m monai.bundle run \
    --bundle_root $BUNDLE_ROOT \
    --reload_checkpoint_epoch "latest" \
    --iterations 10 \
    --epochs 10 \
    --config_file $BUNDLE_ROOT/configs/train.yaml


# Option to resume training
#--config_file "['$BUNDLE_ROOT/configs/train.yaml','$BUNDLE_ROOT/configs/train_resume.yaml']"
#
# Log to Local MLFlow
#--tracking_uri mlruns

Inference#

After training the nnUNet model, we can then perform inference on new data. We use a ModelnnUNetWrapper as a wrapper around the nnUNet model to perform inference from the MONAI Bundle. In this way, the nnUNet preprocessing, inference and postprocessing steps are handled by the ModelnnUNetWrapper, with the Bundle blocks only needing to handle the input data loading and sending to the nnUnet block and the nnUNet prediction postprocessing.

The ModelnnUNetWrapper receives as input the data dictionary loaded by the DataLoader, and returns the model predictions as a MetaTensor.

To get the ModelnnUNetWrapper object, we can use the get_nnunet_monai_predictor function, which receives the following parameters:

  • model_folder: The path to the nnUNet model folder.

  • model_name: [Optional] The name of the model to be loaded. If not provided, the function will load the checkpoint named model.pt.

[ ]:
# To Select the lastest checkpoint

from MONetBundle.src.trainer import get_checkpoint

ckpt_epoch = get_checkpoint("latest", "MONetBundle/models/fold_0")
[ ]:
nnunet_config = {
    "model_folder": "MONetBundle/models/fold_0",
}

monai_predictor = get_nnunet_monai_predictor(**nnunet_config, model_name=f"checkpoint_epoch={ckpt_epoch}.pt")

Test Data Preparation#

The Bundle accepts the test dataset in the following format:

Dataset
├── Case1
│   └── Case1.nii.gz
├── Case2
│   └── Case2.nii.gz
└── Case3
    └── Case3.nii.gz
[ ]:
%%bash

mkdir -p nnUNetBundle/test_input/spleen_1
mkdir -p nnUNetBundle/test_output

cp /home/maia-user/Documents/MONAI/Data/Task09_Spleen/imagesTs/spleen_1.nii.gz nnUNetBundle/test_input/spleen_1
[ ]:
%%bash

tree nnUNetBundle/test_input

Data Loading#

[ ]:
def get_subfolder_dataset(data_dir, modality_conf):
    data_list = []
    for f in os.scandir(data_dir):

        if f.is_dir():
            subject_dict = {
                key: str(pathlib.Path(f.path).joinpath(f.name + modality_conf[key]["suffix"])) for key in modality_conf
            }
            data_list.append(subject_dict)
    return data_list
[ ]:
modalities = {
    "image": {"suffix": ".nii.gz"},
}

data = get_subfolder_dataset("MONetBundle/test_input", modalities)
[ ]:
preprocessing = LoadImaged(keys=["image"], ensure_channel_first=True, image_only=False)


test_dataset = Dataset(data, transform=preprocessing)

test_loader = DataLoader(test_dataset, batch_size=1)

Test ModelnnUNetWrapper#

To test the ModelnnUNetWrapper, we can provide a test case to the ModelnnUNetWrapper and extract the model predictions returned by the wrapper.

[ ]:
batch = next(iter(test_loader))

pred = monai_predictor(batch["image"])

Postprocessing and Save Predictions#

After obtaining the model predictions, we can apply postprocessing transformations to the predictions and save the results to disk.

The Transposed transform is required to unify the axis order convention between MONAI and nnUNet. The nnUNet model uses the zyx axis order, while MONAI uses the xyz axis order.

[ ]:
postprocessing = Compose(
    [
        #Decollated(keys=None, detach=True),
        #Transposed(keys="pred", indices=[0, 3, 2, 1]),
        SaveImaged(
            keys="pred",
            output_dir="nnUNetBundle/test_output",
            output_postfix="prediction",
            meta_keys="image_meta_dict",
        ),
    ]
)
[ ]:
postprocessing({"pred": pred})

Evaluator#

Combining everything together, we can create an Evaluator that encapsulates the data loading, model inference, postprocessing, and evaluation steps. The Evaluator can be used to evaluate the model on the test dataset .

[ ]:
validator = SupervisedEvaluator(
    val_data_loader=test_loader, device="cuda:0", network=monai_predictor, postprocessing=postprocessing
)
[ ]:
validator.run()
[ ]:
%%writefile MONetBundle/configs/inference.yaml

imports:
  - $import json
  - $from pathlib import Path
  - $import os
  - $from ignite.contrib.handlers.tqdm_logger import ProgressBar
  - $import shutil
  - $import src
  - $import src.dataset


output_dir: "."
bundle_root: "."
data_list_file : "."
data_dir: "."

fold_id: 0
model_name: "model.pt"
prediction_suffix: "prediction"


modality_conf:
  image:
    suffix: ".nii.gz"

test_data_list: "$src.dataset.get_subfolder_dataset(@data_dir,@modality_conf)"
#test_data_list: "$monai.data.load_decathlon_datalist(@data_list_file, is_segmentation=True, data_list_key='testing', base_dir=@data_dir)"
image_modality_keys: "$list(@modality_conf.keys())"
image_key: "image"
image_suffix: "@image_key"

preprocessing:
  _target_: Compose
  transforms:
  - _target_: LoadImaged
    keys: "image"
    ensure_channel_first: True
    image_only: False

test_dataset:
  _target_: Dataset
  data: "$@test_data_list"
  transform: "@preprocessing"

test_loader:
  _target_: DataLoader
  dataset: "@test_dataset"
  batch_size: 1


device: "$torch.device('cuda')"

nnunet_trainer_class_name: nnUNetTrainer
nnunet_config_ckpt:
plans:
dataset_json:

nnunet_config_dict:
  model_folder: "$@bundle_root + '/models/fold_'+str(@fold_id)"
  model_name: "@model_name"
  nnunet_config: "@nnunet_config_ckpt"
  plans: "@plans"
  dataset_json: "@dataset_json"

network_def: "$monai.apps.nnunet.nnunet_bundle.get_nnunet_monai_predictor(**@nnunet_config_dict)"

postprocessing:
  _target_: "Compose"
  transforms:
    #- _target_: Transposed
    #  keys: "pred"
    #  indices:
    #  - 0
    #  - 3
    #  - 2
    #  - 1
    - _target_: SaveImaged
      keys: "pred"
      resample: False
      output_postfix: "@prediction_suffix"
      output_dir: "@output_dir"
      meta_keys: "image_meta_dict"


testing:
  dataloader: "$@test_loader"
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  test_inferer: "$@inferer"

inferer:
  _target_: "SimpleInferer"

validator:
  _target_: "SupervisedEvaluator"
  postprocessing: "$@postprocessing"
  device: "$@device"
  inferer: "$@testing#test_inferer"
  val_data_loader: "$@testing#dataloader"
  network: "@network_def"
  val_handlers:
  - _target_: "CheckpointLoader"
    load_path: "$@bundle_root+'/models/fold_'+str(@fold_id)+'/'+@model_name"
    load_dict:
      network_weights: '$@network_def.network_weights'
run:
  - "$@testing#pbar.attach(@validator)"
  - "$@validator.run()"

nnunet_config_ckpt:
  trainer_name: "@nnunet_trainer_class_name"
  inference_allowed_mirroring_axes:
  - 0
  - 1
  - 2
  configuration: "3d_fullres"
[ ]:
%%writefile MONetBundle/src/dataset.py

import pathlib
import os

def get_subfolder_dataset(data_dir,modality_conf):
    data_list = []
    for f in os.scandir(data_dir):

        if f.is_dir():
            subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}
            data_list.append(subject_dict)
    return data_list
[ ]:
%%bash

export BUNDLE_ROOT=MONetBundle
export PYTHONPATH=$PYTHONPATH:$BUNDLE_ROOT

python -m monai.bundle run \
    --config-file $BUNDLE_ROOT/configs/inference.yaml \
    --bundle-root $BUNDLE_ROOT \
    --data-dir $HOME/Data/Samples/NIFTI/Spleen \
    --output-dir $HOME/Data/Samples/NIFTI/Spleen_pred \
    --model-name "checkpoint_epoch=1000.pt" \
    --logging-file $BUNDLE_ROOT/configs/logging.conf

Utilities#

MONAI Bundle to nnUNet Conversion#

To convert a MONAI Bundle to a nnUNet Bundle, we need to combine the MONAI checkpoint with the nnUNet checkpoint. This is done by loading the MONAI checkpoint and the nnUNet checkpoint, and updating the nnUNet model weights with the MONAI model weights.

[ ]:
os.environ["nnUNet_results"] = "MONAI/Data/nnUNet/nnUNet_trained_models"
os.environ["nnUNet_raw"] = "MONAI/Data/nnUNet/nnUNet_raw_data_base"
os.environ["nnUNet_preprocessed"] = "MONAI/Data/nnUNet/nnUNet_preprocessed"

nnunet_config = {
    "dataset_name_or_id": "009",
    "nnunet_trainer": "nnUNetTrainer",
}

convert_monai_bundle_to_nnunet(nnunet_config, "MONetBundle")

Testing the nnUNet Model#

We now test the nnUNet model by performing inference on the test dataset and evaluating the model predictions.

[ ]:
root_dir = "MONAI/Data"
nnunet_root_dir = os.path.join(root_dir, "nnUNet")

os.makedirs(nnunet_root_dir, exist_ok=True)

data_src_cfg = os.path.join(nnunet_root_dir, "data_src_cfg.yaml")
data_src = {
    "modality": "CT",
    "dataset_name_or_id": "09",
    "datalist": os.path.join(root_dir, "Task09_Spleen/msd_task09_spleen_folds.json"),
    "dataroot": os.path.join(root_dir, "Task09_Spleen"),
}

ConfigParser.export_config_file(data_src, data_src_cfg)

runner = nnUNetV2Runner(input_config=data_src_cfg, trainer_class_name="nnUNetTrainer", work_dir=nnunet_root_dir)
[ ]:
runner.train_single_model(config="3d_fullres", fold=0, val="")
[ ]:
runner.find_best_configuration(configs=["3d_fullres"], folds=[0], allow_ensembling=False, num_processes=1)
[ ]:
runner.predict_ensemble_postprocessing(folds=[0], run_ensemble=False, run_postprocessing=False)

nnUNet to MONAI Bundle Conversion#

To convert a nnUNet trained Model to a MONAI Bundle, we need to separate the MONAI checkpoint from the nnUNet checkpoint. This is done by loading the nnUNet checkpoint and the MONAI checkpoint, and updating the MONAI model weights with the nnUNet model weights.

[ ]:
os.environ["nnUNet_results"] = "MONAI/Data/nnUNet/nnUNet_trained_models"
os.environ["nnUNet_raw"] = "MONAI/Data/nnUNet/nnUNet_raw_data_base"
os.environ["nnUNet_preprocessed"] = "MONAI/Data/nnUNet/nnUNet_preprocessed"

nnunet_config = {
    "dataset_name_or_id": "009",
    "nnunet_trainer": "nnUNetTrainer_10epochs",
}

bundle_root = "MONetBundle"

convert_nnunet_to_monai_bundle(nnunet_config, bundle_root, 0)

Integration with NVFlare#

At the beginning of the NVFLare ScatterAndGather workflow, the server creates and distributes the global model, to be used by the clients for local training. When using nnUNet in FL, the global model needs to match the chosen nnUNet model architecture. For this reason, we adapt the nnUNet MONAI Bundle on the server side to be able to create the global model and distribute it to the clients, starting from the nnUNet plans and dataset files, produced during the nnUNet plan_and_preprocessing phase.

In train.yaml:

network: $@nnunet_trainer.network._orig_mod
network_def_fl:
  _target_: $monai.apps.nnunet.nnunet_bundle.get_network_from_nnunet_plans
  plans_file: "$@bundle_root+'/models/plans.json'"
  dataset_file: "$@bundle_root+'/models/dataset.json'"
  configuration: '@nnunet_configuration'

Integration with MONAI Deploy#

When using the nnUNet MONAI Bundle with MONAI Deploy, we need to specify where to load the checkpoint weights in the nnuNet network definition.

In inference.yaml:

network_def_predictor: "$@network_def.network_weights"

Integration with MONAI Label#

displayable_configs:
  dataset_name_or_id: '@dataset_name_or_id'
  fold_id: '@fold_id'
  mlflow_run_name: '@mlflow_run_name'
  nnunet_configuration: '@nnunet_configuration'
  nnunet_plans_identifier: '@nnunet_plans_identifier'
  nnunet_trainer_class_name: '@nnunet_trainer_class_name'
  num_classes: '@num_classes'
  region_class_order: ''
  tracking_experiment_name: '@mlflow_experiment_name'
  tracking_uri: '@tracking_uri'
  modality_list: 'CT'
  Label_0: '@label_dict.0'
  iterations: '@iterations'

Prepare the Bundle for Packaging#

To prepare the Bundle for packaging, we need to create a metadata.json file that describes the Bundle and its contents. The metadata.json file should follow the official MONAI Bundle format and include the following fields:

{
    "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
    "version": "0.1.0",
    "changelog": {
        "0.1.0": "Initial release",
    },
    "monai_version": "1.4.0",
    "pytorch_version": "2.3.0",
    "numpy_version": "1.21.2",
    "required_packages_version": {"nnunetv2": "2.6.0"},
    "task": "Decathlon spleen segmentation with nnUNet",
    "description": "A pre-trained  nnUNet model for volumetric (3D) segmentation of the spleen from CT image",
    "authors": "Simone Bendazzoli",
    "copyright": "Copyright (c) MONAI Consortium",
    "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
    "data_type": "nifti",
    "image_classes": "single channel data, intensity scaled to [0, 1]",
    "label_classes": "single channel data, 1 is spleen, 0 is everything else",
    "pred_classes": "2 channels OneHot data, channel 1 is spleen, channel 0 is background",
    "eval_metrics": {
        "mean_dice": 0.97
    },
    "intended_use": "This is an example, not to be used for diagnostic purposes",
    "references": [
        "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211."
    ],
    "network_data_format":{
        "inputs": {
            "image": {
                "type": "image",
                "format": "hounsfield",
                "modality": "CT",
                "num_channels": 1,
                "spatial_shape": ["*", "*", "*"],
                "dtype": "float32",
                "value_range": [-1024, 1024],
                "is_patch_data": false,
                "channel_def": {"0": "image"}
            }
        },
        "outputs":{
            "pred": {
                "type": "image",
                "format": "segmentation",
                "num_channels": 1,
                "spatial_shape": ["*", "*", "*"],
                "dtype": "float32",
                "value_range": [0,1],
                "is_patch_data": false,
                "channel_def": {"0": "background", "1": "spleen"}
            }
        }
    }
}

For more details on the MONAI Bundle format, please refer to the MONAI Bundle documentation.

Generate TorchScript#

To convert the MONAI Bundle checkpoints to the TorchScript format, you can use the convert_ckpt_to_ts.py script. This script takes the MONAI Bundle checkpoint and converts it to the TorchScript format, which can be used for inference in production environments. The script accepts the following parameters:

python convert_ckpt_to_ts.py --bundle_root <path_to_bundle> --checkpoint_name <checkpoint_name> --nnunet_trainer_name <trainer_name> --fold <fold_number>