import numpy as np
import pandas as pd
import monai
from fastai.vision.all import *
import segmentation_models_pytorch as smp
Glomerulus Segmentation
Training an Unet++ with fastai to segment glomeruli.
Model Training Notebook
Here we train a Unet++ architecture with a pretrained efficientnet-b4 backbone from the awesome segmentation_models_pytorch library. Data loading, transforming, the actual training and, finally, exporting the newly trained model, we use fastai.
Here’s a running example of the model trained with this notebook.
Goal
Segmenting glomeruli (an intricate structure in the kidney’s cortex in which the blood filtration happen), i.e. turning this
to this
Data
I am using training data from the HuBMAP Challenge hosted on kaggle and a few dozen images downloaded from the Human Protein Atlas I annotated myself. (If you’re interested in how to do this, here a blogpost I wrote)
= "kidney"
ORGAN = 8 # Reduce this, if you run out of cuda memory
TRAIN_BATCH_SIZE = 8 # This is the batch size that will be used for training
EFFECTIVE_BATCH_SIZE = 512
IMAGE_SIZE = 3e-4
LR = 60
EPOCHS = f"smp_{IMAGE_SIZE}_{ORGAN}_added_data"
MODEL_NAME = Path("../data/") DATA_PATH
# Reproducibility
= 93
TESTSET_SEED = 43 TRAIN_VAL_SEED
= pd.read_csv(DATA_PATH/"train.csv") # This is the training set from the competition
df = L([*get_image_files(DATA_PATH/"test_images"), *get_image_files(DATA_PATH/"train_images")]) # List of all competition images
fns = [] # This will be a column in the dataframe, containing the filenames
fn_col for _, r in df.iterrows(): fn_col.append([fn for fn in fns if str(r["id"]) == fn.stem][0])
"fnames"] = fn_col
df["is_organ"] = df.organ.apply(lambda o: o==ORGAN)
df[= df[df.is_organ] # Only keep images with the organ we are interested in
df assert df.organ.unique()[0] == ORGAN
= df.drop(columns="organ data_source is_organ tissue_thickness pixel_size sex age".split()).copy() df
# These are the images I added and that are annotated by me
= get_image_files(DATA_PATH/"add_images/")
add_images = get_image_files(DATA_PATH/"segs/") add_images_masks
# The masks have the same name as the images, but with "_mask" appended
= [p.name[:-9]+".png" for p in add_images_masks]
masks
# Delete images without masks
= [p for p in add_images if p.name not in masks]
images_to_delete for p in images_to_delete: p.unlink()
# This will contain the masks in the same order as the images
= []
sorted_masks for i in add_images:
for p in add_images_masks if i.stem == p.stem[:-5]][0]) sorted_masks.append([p
# Combine the competition data with the added data
= pd.DataFrame({
add_df "fnames": add_images,
"segmentation": sorted_masks,
"is_add": [True]*len(add_images)
})"is_add"] = df.id.apply(lambda p: False)
df[= pd.concat([df, add_df]) combined_df
# Setting aside a random testset
= int(0.1 * len(combined_df))
cut = np.arange(len(combined_df))
ind # Always create the same testset
np.random.seed(TESTSET_SEED)
np.random.shuffle(ind)= ind[:cut]
test_ind = ind[cut:]
train_valid_ind = combined_df.iloc[test_ind,:].copy()
test_df = combined_df.iloc[train_valid_ind,:].copy() train_df
The masks of the competition data are in run-length encoding, that’s why we need the following function. It converts the run-length encoding to a numpy array which we can use for training.
# From: https://www.kaggle.com/code/paulorzp/run-length-encode-and-decode/script
def rle_decode(mask_rle, shape):
'''
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
'''
= mask_rle.split()
s = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts, lengths -= 1
starts = starts + lengths
ends = np.zeros(shape[0]*shape[1], dtype=np.uint8)
img for lo, hi in zip(starts, ends):
= 1
img[lo:hi] return np.reshape(img, shape)
= ["Background", "FTU"] # FTU = functional tissue unit CODES
def x_getter(r): return r["fnames"]
def y_getter(r):
# My additional annotations are saved as pngs, so I need to differ between the two
if r["is_add"]:
= np.array(load_image(r["segmentation"]), dtype=np.uint8)
im = (im.mean(axis=-1) < 125).astype(np.uint8)
im return im
= r["rle"]
rle = (int(r["img_height"]), int(r["img_width"]))
shape return rle_decode(rle, shape).T
= aug_transforms(
btfms =1.2,
mult=True,
do_flip=True,
flip_vert=45.0,
max_rotate=1.,
min_zoom=1.5,
max_zoom=0.3,
max_lighting=0.3,
max_warp=(IMAGE_SIZE, IMAGE_SIZE),
size=0.5
p_affine# Data augmentation
) = DataBlock(blocks=(ImageBlock, MaskBlock(CODES)),
dblock =x_getter,
get_x=y_getter,
get_y=RandomSplitter(seed=TRAIN_VAL_SEED),
splitter=[Resize((IMAGE_SIZE, IMAGE_SIZE))],
item_tfms=btfms)
batch_tfms= dblock.dataloaders(train_df, Path(".."), bs=TRAIN_BATCH_SIZE) dls
dls.train.show_batch()
dls.valid.show_batch()
= [
cbs
GradientAccumulation(EFFECTIVE_BATCH_SIZE),=MODEL_NAME),
SaveModelCallback(fname ]
= smp.UnetPlusPlus(
model ="efficientnet-b4",
encoder_name="imagenet",
encoder_weights=3,
in_channels=2,
classes )
# Splitting model's into 2 groups to use fastai's differential learning rates
def splitter(model):
= L(model.encoder.parameters())
enc_params = L(model.decoder.parameters())
dec_params = L(model.segmentation_head.parameters())
sg_params = L([*dec_params, *sg_params])
untrained_params return L([enc_params, untrained_params])
= Learner(
learn
dls,
model, =cbs,
cbs=splitter,
splitter=[Dice(), JaccardCoeff(), RocAucBinary()]) metrics
learn.fit_flat_cos(EPOCHS, LR)
epoch | train_loss | valid_loss | dice | jaccard_coeff | roc_auc_score | time |
---|---|---|---|---|---|---|
0 | 0.727570 | 0.603148 | 0.085859 | 0.044855 | 0.322773 | 00:12 |
1 | 0.604307 | 0.516570 | 0.096767 | 0.050843 | 0.427407 | 00:11 |
2 | 0.508117 | 0.379248 | 0.522587 | 0.353718 | 0.314975 | 00:11 |
3 | 0.424914 | 0.302621 | 0.584612 | 0.413040 | 0.085511 | 00:11 |
4 | 0.353784 | 0.207627 | 0.712757 | 0.553709 | 0.128719 | 00:11 |
5 | 0.296790 | 0.142182 | 0.830598 | 0.710276 | 0.556589 | 00:11 |
6 | 0.250124 | 0.107157 | 0.849849 | 0.738903 | 0.781829 | 00:11 |
7 | 0.211806 | 0.088498 | 0.857634 | 0.750752 | 0.867223 | 00:11 |
8 | 0.180824 | 0.070990 | 0.859145 | 0.753071 | 0.888895 | 00:11 |
9 | 0.154632 | 0.061644 | 0.856697 | 0.749318 | 0.929199 | 00:11 |
10 | 0.133044 | 0.052796 | 0.868353 | 0.767336 | 0.953055 | 00:11 |
11 | 0.115005 | 0.046287 | 0.871848 | 0.772811 | 0.964127 | 00:11 |
12 | 0.099967 | 0.040762 | 0.883804 | 0.791800 | 0.971401 | 00:11 |
13 | 0.087097 | 0.037157 | 0.889235 | 0.800560 | 0.976527 | 00:11 |
14 | 0.076725 | 0.033152 | 0.899692 | 0.817673 | 0.979948 | 00:12 |
15 | 0.068125 | 0.031111 | 0.902548 | 0.822404 | 0.982294 | 00:11 |
16 | 0.060929 | 0.027760 | 0.910408 | 0.835549 | 0.984009 | 00:11 |
17 | 0.054637 | 0.027974 | 0.894529 | 0.809184 | 0.983233 | 00:11 |
18 | 0.049165 | 0.024334 | 0.911780 | 0.837864 | 0.986022 | 00:11 |
19 | 0.044277 | 0.024281 | 0.905008 | 0.826498 | 0.983716 | 00:11 |
20 | 0.040233 | 0.022132 | 0.913338 | 0.840498 | 0.986142 | 00:11 |
21 | 0.036737 | 0.021514 | 0.913032 | 0.839981 | 0.982759 | 00:12 |
22 | 0.033632 | 0.020441 | 0.919236 | 0.850543 | 0.988929 | 00:11 |
23 | 0.031428 | 0.021564 | 0.900877 | 0.819632 | 0.976978 | 00:11 |
24 | 0.029265 | 0.018416 | 0.922131 | 0.855513 | 0.990599 | 00:11 |
25 | 0.027243 | 0.017509 | 0.924026 | 0.858781 | 0.990082 | 00:11 |
26 | 0.025496 | 0.018210 | 0.916627 | 0.846087 | 0.983598 | 00:11 |
27 | 0.024097 | 0.017901 | 0.917531 | 0.847628 | 0.987033 | 00:11 |
28 | 0.022979 | 0.016894 | 0.920493 | 0.852697 | 0.986602 | 00:11 |
29 | 0.021918 | 0.018380 | 0.916335 | 0.845588 | 0.989573 | 00:11 |
30 | 0.020899 | 0.016375 | 0.920465 | 0.852650 | 0.988758 | 00:11 |
31 | 0.019700 | 0.014662 | 0.929562 | 0.868394 | 0.988694 | 00:11 |
32 | 0.018992 | 0.018643 | 0.906843 | 0.829563 | 0.989722 | 00:11 |
33 | 0.018747 | 0.016715 | 0.918087 | 0.848578 | 0.986600 | 00:11 |
34 | 0.018041 | 0.015872 | 0.918797 | 0.849792 | 0.982712 | 00:11 |
35 | 0.017152 | 0.015625 | 0.922364 | 0.855914 | 0.979253 | 00:11 |
36 | 0.016359 | 0.014718 | 0.925542 | 0.861403 | 0.977279 | 00:11 |
37 | 0.015855 | 0.012931 | 0.932580 | 0.873676 | 0.987032 | 00:11 |
38 | 0.015509 | 0.012442 | 0.932764 | 0.873999 | 0.990109 | 00:11 |
39 | 0.015077 | 0.011627 | 0.936922 | 0.881329 | 0.989619 | 00:11 |
40 | 0.014742 | 0.011964 | 0.934999 | 0.877933 | 0.987285 | 00:11 |
41 | 0.014460 | 0.011784 | 0.937937 | 0.883128 | 0.990435 | 00:11 |
42 | 0.014327 | 0.013164 | 0.928118 | 0.865877 | 0.985723 | 00:11 |
43 | 0.014337 | 0.012003 | 0.934332 | 0.876757 | 0.987117 | 00:11 |
44 | 0.014002 | 0.012138 | 0.935304 | 0.878470 | 0.988768 | 00:11 |
45 | 0.013552 | 0.011391 | 0.936339 | 0.880298 | 0.987371 | 00:11 |
46 | 0.013122 | 0.012408 | 0.929143 | 0.867663 | 0.983431 | 00:11 |
47 | 0.012708 | 0.011765 | 0.932473 | 0.873489 | 0.990839 | 00:12 |
48 | 0.012281 | 0.010892 | 0.938805 | 0.884668 | 0.987427 | 00:11 |
49 | 0.011792 | 0.010427 | 0.941303 | 0.889115 | 0.986761 | 00:11 |
50 | 0.011579 | 0.011140 | 0.938671 | 0.884429 | 0.982990 | 00:11 |
51 | 0.011452 | 0.010431 | 0.945001 | 0.895737 | 0.985848 | 00:11 |
52 | 0.011334 | 0.010210 | 0.943139 | 0.892396 | 0.987449 | 00:11 |
53 | 0.011300 | 0.011088 | 0.937227 | 0.881870 | 0.987982 | 00:11 |
54 | 0.011329 | 0.010670 | 0.937806 | 0.882895 | 0.988177 | 00:11 |
55 | 0.011342 | 0.010610 | 0.937321 | 0.882035 | 0.989792 | 00:11 |
56 | 0.011146 | 0.010468 | 0.938884 | 0.884808 | 0.988132 | 00:11 |
57 | 0.011024 | 0.010291 | 0.940346 | 0.887408 | 0.988785 | 00:11 |
58 | 0.010618 | 0.010369 | 0.939683 | 0.886228 | 0.988485 | 00:11 |
59 | 0.010563 | 0.010375 | 0.939610 | 0.886099 | 0.988654 | 00:11 |
Better model found at epoch 0 with valid_loss value: 0.6031481027603149.
Better model found at epoch 1 with valid_loss value: 0.516569972038269.
Better model found at epoch 2 with valid_loss value: 0.3792479634284973.
Better model found at epoch 3 with valid_loss value: 0.302621066570282.
Better model found at epoch 4 with valid_loss value: 0.20762743055820465.
Better model found at epoch 5 with valid_loss value: 0.14218230545520782.
Better model found at epoch 6 with valid_loss value: 0.10715709626674652.
Better model found at epoch 7 with valid_loss value: 0.0884983241558075.
Better model found at epoch 8 with valid_loss value: 0.07099024951457977.
Better model found at epoch 9 with valid_loss value: 0.06164400279521942.
Better model found at epoch 10 with valid_loss value: 0.0527963824570179.
Better model found at epoch 11 with valid_loss value: 0.046286918222904205.
Better model found at epoch 12 with valid_loss value: 0.04076218977570534.
Better model found at epoch 13 with valid_loss value: 0.037156544625759125.
Better model found at epoch 14 with valid_loss value: 0.0331517718732357.
Better model found at epoch 15 with valid_loss value: 0.031111164018511772.
Better model found at epoch 16 with valid_loss value: 0.027760174125432968.
Better model found at epoch 18 with valid_loss value: 0.024333978071808815.
Better model found at epoch 19 with valid_loss value: 0.024281086400151253.
Better model found at epoch 20 with valid_loss value: 0.022132011130452156.
Better model found at epoch 21 with valid_loss value: 0.021513625979423523.
Better model found at epoch 22 with valid_loss value: 0.020441483706235886.
Better model found at epoch 24 with valid_loss value: 0.018416451290249825.
Better model found at epoch 25 with valid_loss value: 0.017508765682578087.
Better model found at epoch 28 with valid_loss value: 0.016893696039915085.
Better model found at epoch 30 with valid_loss value: 0.016375476494431496.
Better model found at epoch 31 with valid_loss value: 0.0146623644977808.
Better model found at epoch 37 with valid_loss value: 0.012930507771670818.
Better model found at epoch 38 with valid_loss value: 0.01244184747338295.
Better model found at epoch 39 with valid_loss value: 0.01162666454911232.
Better model found at epoch 45 with valid_loss value: 0.011391145177185535.
Better model found at epoch 48 with valid_loss value: 0.010892268270254135.
Better model found at epoch 49 with valid_loss value: 0.010426941327750683.
Better model found at epoch 52 with valid_loss value: 0.010209585539996624.
learn.load(MODEL_NAME)
Saved filed doesn't contain an optimizer state.
<fastai.learner.Learner at 0x7f2dd019eac0>
learn.show_results()
# Create a dataloader from the testset
= dls.test_dl(test_df, with_labels=True)
test_dl = monai.metrics.DiceMetric(include_background=False, reduction="mean") dice_func
# This function steps through the different thresholds and returns the best one
def get_best_threshold(learn, dl, metric_func, n_steps=17):
"""
Tests `n_steps` different thresholds.
Return the best threshold and the corresonding score.
"""
= torch.linspace(0.1, 0.9, n_steps)
thresholds = []
results
= learn.get_preds(dl=dl, with_input=False, with_targs=True, act=partial(F.softmax, dim=1))
res
for t in thresholds:
0][:,1]>t).unsqueeze(1), res[-1].unsqueeze(1))
metric_func((res[= metric_func.aggregate().item()
metric
metric_func.reset()round(t.detach().cpu().item(), ndigits=3), metric))
results.append((
return sorted(results, key=lambda tpl: tpl[1], reverse=True)[0]
= get_best_threshold(learn, dls.valid, dice_func)
best_threshold, _ best_threshold
0.6
And now we test the model on the training, validation and test set with the best threshold.
def test_model(learn, dl, metric_func, threshold=0.5):
= learn.get_preds(dl=dl, with_input=False, with_targs=True, act=partial(F.softmax, dim=1))
res 0][:,1]>threshold).unsqueeze(1), res[-1].unsqueeze(1))
metric_func((res[= metric_func.aggregate().item()
metric
metric_func.reset()return metric
= test_model(learn, dls.train, dice_func, threshold=best_threshold)
train_dice = test_model(learn, dls.valid, dice_func, threshold=best_threshold)
valid_dice = test_model(learn, test_dl, dice_func, threshold=best_threshold)
test_dice
for s, d in zip(("Training Dice:", "Valid Dice", "Test Dice"), (train_dice, valid_dice, test_dice)):
print(s, d)
Training Dice: 0.9252251982688904
Valid Dice 0.9271063208580017
Test Dice 0.9414731860160828
# Save and export the model
= f"unetpp_b4_th{int(best_threshold*100)}_d{str(test_dice)[2:6]}"
BEST_MODEL_NAME
learn.save(BEST_MODEL_NAME)+".pkl") learn.export(BEST_MODEL_NAME
A live version of this model is deployed on a huggingface space.