Train#

In order to make steering decision, we want the drone to be able to “see” the path in front of it. To achieve this we’ll train a convolutional neural net.

Imports#

from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
from tensorflow.python.distribute.tpu_strategy import TPUStrategy
import segmentation_models as sm
WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.
WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.
WARNING:root:Limited tf.compat.v2.summary API due to missing TensorBoard installation.
WARNING:root:Limited tf.summary API due to missing TensorBoard installation.
Segmentation Models: using `tf.keras` framework.
# Convenience methods
from visualise import show_two, show
from tpu import resolve_tpu_strategy, get_tpu_devices
from dataset import load_dataset, split_dataset_paths

Set up TPU#

To speed up model training, we use Google Cloud Platform’s “TPU’s”. TPU’s use a machine learning ASIC.

tpu_strategy = resolve_tpu_strategy('padnet')
INFO:tensorflow:Initializing the TPU system: padnet
INFO:tensorflow:Initializing the TPU system: padnet
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
nbro_tpu_devices = len(get_tpu_devices())
print(f"{nbro_tpu_devices} TPU devices connected!")
8 TPU devices connected!
# PARALLEL_OPS = tf.data.experimental.AUTOTUNE
PARALLEL_OPS = None

Prepare dataset#

IMAGE_SIZE = [448, 640]
GCS_PATTERN = 'gs://padnet-data/freiburg/tfr/*.tfr'
paths = tf.io.gfile.glob(GCS_PATTERN)
TRAIN_RATIO = 0.7
VALIDATION_RATIO = 0.15
train_paths, validate_paths, test_paths = split_dataset_paths(paths, TRAIN_RATIO, VALIDATION_RATIO)
print(f"Found {len(paths)} tfrs. Splitting into {len(train_paths)} training, {len(validate_paths)} validation, and {len(test_paths)} test tfrs")
Found 230 tfrs. Splitting into 161 training, 34 validation, and 35 test tfrs

Visualise dataset#

visualisation_dataset = load_dataset(train_paths, IMAGE_SIZE, PARALLEL_OPS)
sample_rgb, sample_gt = next(iter(visualisation_dataset))
show_two("RGB image", sample_rgb, "Ground truth mask", sample_gt)
../_images/train_15_0.png

Load dataset#

BATCH_SIZE = 16  # Using TPU v3-8 device => must be divisible by 8 for sharding
# Dataset generation *must* come after tpu resolution
training_dataset = load_dataset(train_paths, IMAGE_SIZE, PARALLEL_OPS) \
    .repeat() \
    .batch(BATCH_SIZE)
validation_dataset = load_dataset(validate_paths, IMAGE_SIZE, PARALLEL_OPS)\
    .batch(BATCH_SIZE)

Create model#

def create_model(tpu_strategy: TPUStrategy = None) -> tf.keras.Model:
    BACKBONE = 'efficientnetb3'
    CLASSES = ['path']
    LR = 0.0001

    n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
    activation = 'sigmoid' if n_classes == 1 else 'softmax'

    def scoped_create_model():
        dice_loss = sm.losses.DiceLoss()
        focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
        total_loss = dice_loss + (1 * focal_loss)

        optim = tf.keras.optimizers.Adam(LR)

        metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

        model: tf.keras.Model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
        model.compile(optim, total_loss, metrics)
        # model.compile(optim, total_loss)
        return model

    if tpu_strategy is not None:
        with tpu_strategy.scope():
            return scoped_create_model()
    else:
        return scoped_create_model()
model = create_model(tpu_strategy)

Train model#

steps_per_epoch = (len(train_paths) // BATCH_SIZE) * 10 # Arbitrary amount of extra augmentations per image
validation_steps = len(validate_paths) // BATCH_SIZE

print(f"With a batch size of {BATCH_SIZE}, there will be {steps_per_epoch} batches per training epoch and {validation_steps} batches per validation run.")
With a batch size of 16, there will be 100 batches per training epoch and 2 batches per validation run.
tensorboard_log_dir = "gs://padnet-data/model/" + datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
callbacks = [
    tf.keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
    tf.keras.callbacks.ReduceLROnPlateau(),
    tf.keras.callbacks.TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1),
    #early stopping
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.001, patience=3)
]
EPOCHS = 40
history = model.fit(
    x=training_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    callbacks=callbacks,
    validation_data=validation_dataset,
    validation_steps=validation_steps
)
model.save('gs://padnet-data/trained_models/no_aug.h5')

model.summary()
Epoch 1/40
100/100 [==============================] - 224s 1s/step - loss: 0.7952 - iou_score: 0.4034 - f1-score: 0.5326 - val_loss: 0.7753 - val_iou_score: 0.5061 - val_f1-score: 0.6616
Epoch 2/40
100/100 [==============================] - 64s 648ms/step - loss: 0.5138 - iou_score: 0.9008 - f1-score: 0.9469 - val_loss: 0.5554 - val_iou_score: 0.6984 - val_f1-score: 0.8169
Epoch 3/40
100/100 [==============================] - 65s 653ms/step - loss: 0.3543 - iou_score: 0.9370 - f1-score: 0.9673 - val_loss: 0.3496 - val_iou_score: 0.8738 - val_f1-score: 0.9315
Epoch 4/40
100/100 [==============================] - 62s 620ms/step - loss: 0.2445 - iou_score: 0.9469 - f1-score: 0.9726 - val_loss: 0.2394 - val_iou_score: 0.8913 - val_f1-score: 0.9417
Epoch 5/40
100/100 [==============================] - 46s 459ms/step - loss: 0.1690 - iou_score: 0.9550 - f1-score: 0.9769 - val_loss: 0.1775 - val_iou_score: 0.8958 - val_f1-score: 0.9442
Epoch 6/40
100/100 [==============================] - 45s 448ms/step - loss: 0.1227 - iou_score: 0.9594 - f1-score: 0.9792 - val_loss: 0.1389 - val_iou_score: 0.8952 - val_f1-score: 0.9436
Epoch 7/40
100/100 [==============================] - 63s 635ms/step - loss: 0.0923 - iou_score: 0.9648 - f1-score: 0.9820 - val_loss: 0.1221 - val_iou_score: 0.8905 - val_f1-score: 0.9410
Epoch 8/40
100/100 [==============================] - 40s 405ms/step - loss: 0.0719 - iou_score: 0.9695 - f1-score: 0.9845 - val_loss: 0.1113 - val_iou_score: 0.8865 - val_f1-score: 0.9385
Epoch 9/40
100/100 [==============================] - 64s 649ms/step - loss: 0.0605 - iou_score: 0.9693 - f1-score: 0.9844 - val_loss: 0.1037 - val_iou_score: 0.8869 - val_f1-score: 0.9387
Epoch 10/40
100/100 [==============================] - 64s 641ms/step - loss: 0.0502 - iou_score: 0.9726 - f1-score: 0.9861 - val_loss: 0.0941 - val_iou_score: 0.8939 - val_f1-score: 0.9429
Epoch 11/40
100/100 [==============================] - 65s 649ms/step - loss: 0.0433 - iou_score: 0.9741 - f1-score: 0.9869 - val_loss: 0.0965 - val_iou_score: 0.8857 - val_f1-score: 0.9381
Epoch 12/40
100/100 [==============================] - 64s 647ms/step - loss: 0.0374 - iou_score: 0.9759 - f1-score: 0.9878 - val_loss: 0.0887 - val_iou_score: 0.8954 - val_f1-score: 0.9437
Epoch 13/40
100/100 [==============================] - 62s 628ms/step - loss: 0.0335 - iou_score: 0.9767 - f1-score: 0.9882 - val_loss: 0.0908 - val_iou_score: 0.8880 - val_f1-score: 0.9394
Epoch 14/40
100/100 [==============================] - 59s 594ms/step - loss: 0.0302 - iou_score: 0.9774 - f1-score: 0.9886 - val_loss: 0.0871 - val_iou_score: 0.8931 - val_f1-score: 0.9424
Epoch 15/40
100/100 [==============================] - 55s 558ms/step - loss: 0.0273 - iou_score: 0.9791 - f1-score: 0.9894 - val_loss: 0.0860 - val_iou_score: 0.8922 - val_f1-score: 0.9419
Epoch 16/40
100/100 [==============================] - 39s 394ms/step - loss: 0.0254 - iou_score: 0.9792 - f1-score: 0.9895 - val_loss: 0.0839 - val_iou_score: 0.8971 - val_f1-score: 0.9447
Epoch 17/40
100/100 [==============================] - 62s 622ms/step - loss: 0.0233 - iou_score: 0.9801 - f1-score: 0.9900 - val_loss: 0.0856 - val_iou_score: 0.8916 - val_f1-score: 0.9415
Epoch 18/40
100/100 [==============================] - 57s 573ms/step - loss: 0.0215 - iou_score: 0.9809 - f1-score: 0.9904 - val_loss: 0.0854 - val_iou_score: 0.8930 - val_f1-score: 0.9423
Epoch 19/40
100/100 [==============================] - 62s 628ms/step - loss: 0.0201 - iou_score: 0.9815 - f1-score: 0.9907 - val_loss: 0.0851 - val_iou_score: 0.8915 - val_f1-score: 0.9415
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
<ipython-input-17-ea75cecef43e> in <module>
      7     validation_steps=validation_steps
      8 )
----> 9 model.save('gs://padnet-data/model/no_aug.h5')
     10 
     11 model.summary()

/opt/padnet/env/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
   1999     """
   2000     # pylint: enable=line-too-long
-> 2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
   2002                     signatures, options, save_traces)
   2003 

/opt/padnet/env/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
    151           'to the Tensorflow SavedModel format (by setting save_format="tf") '
    152           'or using `save_weights`.')
--> 153     hdf5_format.save_model_to_hdf5(
    154         model, filepath, overwrite, include_optimizer)
    155   else:

/opt/padnet/env/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
    106       gfile.MakeDirs(dirpath)
    107 
--> 108     f = h5py.File(filepath, mode='w')
    109     opened_new_file = True
    110   else:

/opt/padnet/env/lib/python3.8/site-packages/h5py/_hl/files.py in __init__(self, name, mode, driver, libver, userblock_size, swmr, rdcc_nslots, rdcc_nbytes, rdcc_w0, track_order, **kwds)
    404             with phil:
    405                 fapl = make_fapl(driver, libver, rdcc_nslots, rdcc_nbytes, rdcc_w0, **kwds)
--> 406                 fid = make_fid(name, mode, userblock_size,
    407                                fapl, fcpl=make_fcpl(track_order=track_order),
    408                                swmr=swmr)

/opt/padnet/env/lib/python3.8/site-packages/h5py/_hl/files.py in make_fid(name, mode, userblock_size, fapl, fcpl, swmr)
    177         fid = h5f.create(name, h5f.ACC_EXCL, fapl=fapl, fcpl=fcpl)
    178     elif mode == 'w':
--> 179         fid = h5f.create(name, h5f.ACC_TRUNC, fapl=fapl, fcpl=fcpl)
    180     elif mode == 'a':
    181         # Open in append mode (read/write).

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()

h5py/h5f.pyx in h5py.h5f.create()

OSError: Unable to create file (unable to open file: name = 'gs://padnet-data/model/no_aug.h5', errno = 2, error message = 'No such file or directory', flags = 13, o_flags = 242)

Metrics#

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()

Test model#

test_dataset = load_dataset(test_paths, IMAGE_SIZE, PARALLEL_OPS)
test_lst = list(test_dataset)
NBRO_TEST_CASES = 5
test_sample_rgbs, test_sample_gts = zip(*test_lst[:NBRO_TEST_CASES])
predictions = model(np.asarray(test_sample_rgbs))
for test_sample_rgb, prediction, test_sample_gt in zip(test_sample_rgbs, predictions, test_sample_gts):
    show({
        "RGB image": test_sample_rgb,
        "Prediction": prediction,
        "Ground truth mask": test_sample_gt
    })