Maintaining proper wheel quality is incredibly important for good road safety. For organisations owning larger fleets of vehicles, taking them off the road for maintenance can be costly. Being able to better plan when maintenance is needed, and on what parts of the vehicle through smarter predictions can significantly reduce expenses.
In this project we collected images of tyres (~300 in total) to test the hypothesis on whether it's possible to predict the tyre tread depth on single camera input images using deep learning techniques. Despite the small dataset, we chose a design that can easily be scaled up to millions of samples with virtually no changes to the code.
Although the dataset is very small, this early-stage experiment seems to indicate that it is possible to predict the tread depth of tyres using deep learning techniques. More data should however be collected to verify the results with greater confidence. Another, plausible more important aspect, is whether it's easier to take a photo of sufficient quality of a tyre rather than just using current methods available on the market.
import os
import re
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dropout, Dense, MaxPooling2D, Conv2D, SeparableConv2D
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
# Mount Google Drive to access the files stored there
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
Start by loading a batch of images and visualise them so we get a feeling for what we are working with.
# Specify image size, batch size and data folder
image_size = (800, 380)
batch_size = 16
data_dir = '/content/drive/MyDrive/Tyres/Data_for_model/'
# Data generator without augmentation
data_gen_wo_aug = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
batch_size=batch_size,
image_size=image_size,
seed=9119,
)
Found 315 files belonging to 4 classes.
plt.figure(figsize=(18, 7))
for images, labels in data_gen_wo_aug.take(1):
for i in range(16):
plt.subplot(2, 8, i+1)
plt.imshow(images[i].numpy().astype('uint32'))
plt.title(int(labels[i]))
plt.axis('off')
The images look good. Since we have so little data I will do data augmentation directly, even though it introduces complexity at an early stage.
# Create folders for inspecting the augmented images
augmented_dir = '/content/drive/MyDrive/Tyres/Augmented_example_images/'
train_inspection_dir = augmented_dir + 'train_inspection'
validation_inspection_dir = augmented_dir + 'validation_inspection'
try:
os.makedirs(train_inspection_dir)
os.makedirs(validation_inspection_dir)
except FileExistsError:
pass
# Create an image data generator with augmentation on the train set
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=45,
brightness_range=[0.8, 1.2],
zoom_range=0.1,
shear_range=0.1,
fill_mode='constant',
horizontal_flip=True,
vertical_flip=True,
rescale=1./255,
validation_split=0.2, # Use the last 20% of images as validation set.
)
# No augmentation on the validation set
validation_gen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
validation_split=0.2, # Use the last 20% of images as validation set.
)
# Use the generator with augmentation on the train set
train_generator = train_gen.flow_from_directory(
data_dir,
target_size=image_size,
batch_size=batch_size,
class_mode='categorical', # with categorical labels we later have to use categorical_crossentropy loss
seed=124,
#save_to_dir=train_inspection_dir, # Save images for inspection (these should be augmented)
subset='training', # This is the training set
)
# No augmentation on the validation set
validation_generator = validation_gen.flow_from_directory(
data_dir,
target_size=image_size,
batch_size=batch_size,
class_mode='categorical',
seed=124,
#save_to_dir=validation_inspection_dir, # Save images for inspection (these should NOT be augmented)
subset='validation' # validation set
)
Found 252 images belonging to 4 classes. Found 63 images belonging to 4 classes.
# Take a quick look at what the train_generator outputs
for train_data_batch, train_label_batch in train_generator:
print("Train data batch size:", train_data_batch.shape)
print("Train label batch size:", train_label_batch.shape)
break
# And the validation_generator
for val_data_batch, val_label_batch in validation_generator:
print("Validation data batch size:", val_data_batch.shape)
print("Validation label batch size:", val_label_batch.shape)
break
Train data batch size: (16, 800, 380, 3) Train label batch size: (16, 4) Validation data batch size: (16, 800, 380, 3) Validation label batch size: (16, 4)
They should be equal in terms of shape - and they are.
# Sanity check with augmented images
plt.figure(figsize=(18, 7))
for images, labels in train_generator:
for i in range(16):
plt.subplot(2, 8, i+1)
plt.imshow(images[i])
plt.title(int(np.argmax(labels[i])))
plt.axis('off')
break
Conversely, there should be NO augmentation on the validation set.
# Sanity check with NO augmented images
plt.figure(figsize=(18, 7))
for images, labels in validation_generator:
for i in range(16):
plt.subplot(2, 8, i+1)
plt.imshow(images[i])
plt.title(int(np.argmax(labels[i])))
plt.axis('off')
break
Check class distributions.
# Check the class distribution
ax = sns.countplot(x=train_generator.classes)
ax.set_xticklabels(['very bad (0)', 'bad (1)', 'good (2)', 'very good (3)']);
plt.title('Class counts')
plt.show()
# Print out shares
n = len(train_generator.classes)
shares = {f'n{i}': round(len(train_generator.classes[train_generator.classes == i])/n, 4) for i in range(4)}
print(f"\nShares of total\nClass 0: {shares['n0']}\nClass 1: {shares['n1']}\nClass 2: {shares['n2']}\nClass 3: {shares['n3']}")
Shares of total Class 0: 0.0476 Class 1: 0.0159 Class 2: 0.1111 Class 3: 0.8254
We have a significant class imbalance problem here where 82.5% of the data belong to the larger class 3. The smallest class 1 only make up for 1.5% of the total.
Start by training a simple model from scratch. Add data augmentation but no Dropout. We will be using the largest class (82.5%) as baseline for the model to beat.
# Define metrics
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.CategoricalAccuracy(name='accuracy'),
keras.metrics.AUC(name='auc'),
keras.metrics.AUC(name='prc', curve='PR'), # precision-recall curve
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall')
]
# Create a simple CNN architecture
input_tensor = keras.Input(shape=image_size + (3,), name='input')
x = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(input_tensor)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu')(x)
x = layers.GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
output = Dense(4, activation='softmax', name='output')(x)
model = keras.Model(inputs=[input_tensor], outputs=[output], name='simple_model')
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(lr=1e-4),
metrics=METRICS)
model.summary()
Model: "simple_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) [(None, 800, 380, 3)] 0 _________________________________________________________________ conv2d_26 (Conv2D) (None, 798, 378, 32) 896 _________________________________________________________________ max_pooling2d_21 (MaxPooling (None, 399, 189, 32) 0 _________________________________________________________________ conv2d_27 (Conv2D) (None, 397, 187, 64) 18496 _________________________________________________________________ max_pooling2d_22 (MaxPooling (None, 198, 93, 64) 0 _________________________________________________________________ conv2d_28 (Conv2D) (None, 196, 91, 128) 73856 _________________________________________________________________ max_pooling2d_23 (MaxPooling (None, 98, 45, 128) 0 _________________________________________________________________ conv2d_29 (Conv2D) (None, 96, 43, 128) 147584 _________________________________________________________________ global_average_pooling2d_6 ( (None, 128) 0 _________________________________________________________________ dense_6 (Dense) (None, 512) 66048 _________________________________________________________________ output (Dense) (None, 4) 2052 ================================================================= Total params: 308,932 Trainable params: 308,932 Non-trainable params: 0 _________________________________________________________________
print(input_tensor.shape)
print(output.shape)
(None, 800, 380, 3) (None, 4)
%%time
# Add callback for storing the best model
model_dir = '/content/drive/MyDrive/Tyres/Models/'
try:
os.mkdir(model_dir)
except FileExistsError:
pass
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=model_dir+'model_simple_no_dropout.hdf5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
)
]
# Train the model
history = model.fit(train_generator,
epochs=30,
callbacks=callbacks,
validation_data=validation_generator,
workers=2)
Epoch 1/30 16/16 [==============================] - 22s 1s/step - loss: 1.2280 - tp: 15.1765 - fp: 4.0000 - tn: 415.2941 - fn: 124.5882 - accuracy: 0.6956 - auc: 0.8364 - prc: 0.7036 - precision: 0.3243 - recall: 0.0653 - val_loss: 0.6662 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8891 - val_prc: 0.6978 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 2/30 16/16 [==============================] - 19s 1s/step - loss: 0.7154 - tp: 115.2941 - fp: 26.3529 - tn: 398.5882 - fn: 26.3529 - accuracy: 0.8017 - auc: 0.8882 - prc: 0.7383 - precision: 0.8017 - recall: 0.8017 - val_loss: 0.6699 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8971 - val_prc: 0.6862 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 3/30 16/16 [==============================] - 18s 1s/step - loss: 0.6129 - tp: 116.4706 - fp: 23.5294 - tn: 396.4706 - fn: 23.5294 - accuracy: 0.8441 - auc: 0.9063 - prc: 0.7507 - precision: 0.8441 - recall: 0.8441 - val_loss: 0.6382 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8973 - val_prc: 0.6914 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 4/30 16/16 [==============================] - 18s 998ms/step - loss: 0.6124 - tp: 117.2353 - fp: 23.9412 - tn: 399.5882 - fn: 23.9412 - accuracy: 0.8437 - auc: 0.9087 - prc: 0.7776 - precision: 0.8437 - recall: 0.8437 - val_loss: 0.6470 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8953 - val_prc: 0.7052 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 5/30 16/16 [==============================] - 19s 1s/step - loss: 0.5205 - tp: 119.8824 - fp: 21.2941 - tn: 402.2353 - fn: 21.2941 - accuracy: 0.8687 - auc: 0.9322 - prc: 0.8055 - precision: 0.8687 - recall: 0.8687 - val_loss: 0.6244 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8964 - val_prc: 0.6942 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 6/30 16/16 [==============================] - 19s 1s/step - loss: 0.5385 - tp: 120.5882 - fp: 20.8235 - tn: 403.4118 - fn: 20.8235 - accuracy: 0.8771 - auc: 0.9235 - prc: 0.8209 - precision: 0.8771 - recall: 0.8771 - val_loss: 0.6247 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8985 - val_prc: 0.7054 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 7/30 16/16 [==============================] - 19s 1s/step - loss: 0.5629 - tp: 115.7647 - fp: 23.2941 - tn: 393.8824 - fn: 23.2941 - accuracy: 0.8579 - auc: 0.9255 - prc: 0.8065 - precision: 0.8579 - recall: 0.8579 - val_loss: 0.6405 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8993 - val_prc: 0.6962 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 8/30 16/16 [==============================] - 19s 1s/step - loss: 0.6374 - tp: 113.6471 - fp: 26.1176 - tn: 393.1765 - fn: 26.1176 - accuracy: 0.8145 - auc: 0.9148 - prc: 0.7714 - precision: 0.8145 - recall: 0.8145 - val_loss: 0.6390 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8950 - val_prc: 0.7033 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 9/30 16/16 [==============================] - 19s 1s/step - loss: 0.6358 - tp: 115.4118 - fp: 25.2941 - tn: 396.8235 - fn: 25.2941 - accuracy: 0.8255 - auc: 0.9110 - prc: 0.7763 - precision: 0.8255 - recall: 0.8255 - val_loss: 0.6373 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8961 - val_prc: 0.6955 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 10/30 16/16 [==============================] - 19s 1s/step - loss: 0.6578 - tp: 117.0000 - fp: 24.8824 - tn: 400.7647 - fn: 24.8824 - accuracy: 0.8136 - auc: 0.9081 - prc: 0.7445 - precision: 0.8136 - recall: 0.8136 - val_loss: 0.6245 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8972 - val_prc: 0.6967 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 11/30 16/16 [==============================] - 19s 1s/step - loss: 0.6629 - tp: 115.7059 - fp: 25.9412 - tn: 399.0000 - fn: 25.9412 - accuracy: 0.8120 - auc: 0.9061 - prc: 0.7697 - precision: 0.8120 - recall: 0.8120 - val_loss: 0.6516 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8998 - val_prc: 0.7136 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 12/30 16/16 [==============================] - 19s 1s/step - loss: 0.6963 - tp: 114.2941 - fp: 27.5882 - tn: 398.0588 - fn: 27.5882 - accuracy: 0.7941 - auc: 0.9002 - prc: 0.7211 - precision: 0.7941 - recall: 0.7941 - val_loss: 0.6455 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8994 - val_prc: 0.7145 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 13/30 16/16 [==============================] - 19s 1s/step - loss: 0.6930 - tp: 115.4118 - fp: 25.5294 - tn: 397.2941 - fn: 25.5294 - accuracy: 0.8083 - auc: 0.8984 - prc: 0.7234 - precision: 0.8083 - recall: 0.8083 - val_loss: 0.6300 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8978 - val_prc: 0.7142 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 14/30 16/16 [==============================] - 19s 1s/step - loss: 0.6446 - tp: 115.1176 - fp: 23.9412 - tn: 393.2353 - fn: 23.9412 - accuracy: 0.8208 - auc: 0.9074 - prc: 0.7691 - precision: 0.8208 - recall: 0.8208 - val_loss: 0.6224 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9012 - val_prc: 0.7097 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 15/30 16/16 [==============================] - 19s 1s/step - loss: 0.6331 - tp: 117.0000 - fp: 24.4118 - tn: 399.8235 - fn: 24.4118 - accuracy: 0.8346 - auc: 0.9040 - prc: 0.7451 - precision: 0.8346 - recall: 0.8346 - val_loss: 0.6353 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9009 - val_prc: 0.7223 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 16/30 16/16 [==============================] - 19s 1s/step - loss: 0.6279 - tp: 115.0588 - fp: 24.2353 - tn: 393.6471 - fn: 24.2353 - accuracy: 0.8287 - auc: 0.9099 - prc: 0.7597 - precision: 0.8287 - recall: 0.8287 - val_loss: 0.6285 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8984 - val_prc: 0.7067 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 17/30 16/16 [==============================] - 19s 1s/step - loss: 0.6522 - tp: 115.0000 - fp: 25.9412 - tn: 396.8824 - fn: 25.9412 - accuracy: 0.8179 - auc: 0.9116 - prc: 0.7720 - precision: 0.8179 - recall: 0.8179 - val_loss: 0.6390 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9004 - val_prc: 0.7179 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 18/30 16/16 [==============================] - 19s 1s/step - loss: 0.6255 - tp: 117.8235 - fp: 24.0588 - tn: 401.5882 - fn: 24.0588 - accuracy: 0.8280 - auc: 0.9105 - prc: 0.7501 - precision: 0.8280 - recall: 0.8280 - val_loss: 0.6242 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9041 - val_prc: 0.7321 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 19/30 16/16 [==============================] - 19s 1s/step - loss: 0.6493 - tp: 113.5882 - fp: 25.9412 - tn: 392.6471 - fn: 25.9412 - accuracy: 0.8171 - auc: 0.9086 - prc: 0.7813 - precision: 0.8171 - recall: 0.8171 - val_loss: 0.6290 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9005 - val_prc: 0.7330 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 20/30 16/16 [==============================] - 19s 1s/step - loss: 0.5647 - tp: 119.9412 - fp: 22.1765 - tn: 404.1765 - fn: 22.1765 - accuracy: 0.8349 - auc: 0.9363 - prc: 0.8445 - precision: 0.8349 - recall: 0.8349 - val_loss: 0.6316 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9016 - val_prc: 0.7294 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 21/30 16/16 [==============================] - 19s 1s/step - loss: 0.6287 - tp: 114.4706 - fp: 25.0588 - tn: 393.5294 - fn: 25.0588 - accuracy: 0.8265 - auc: 0.9167 - prc: 0.7837 - precision: 0.8265 - recall: 0.8265 - val_loss: 0.6363 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9009 - val_prc: 0.7179 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 22/30 16/16 [==============================] - 19s 1s/step - loss: 0.6417 - tp: 114.4706 - fp: 25.2941 - tn: 394.0000 - fn: 25.2941 - accuracy: 0.8117 - auc: 0.9126 - prc: 0.7629 - precision: 0.8117 - recall: 0.8117 - val_loss: 0.6285 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8990 - val_prc: 0.7263 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 23/30 16/16 [==============================] - 19s 1s/step - loss: 0.5865 - tp: 117.0000 - fp: 23.9412 - tn: 398.8824 - fn: 23.9412 - accuracy: 0.8297 - auc: 0.9253 - prc: 0.7954 - precision: 0.8297 - recall: 0.8297 - val_loss: 0.6228 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9000 - val_prc: 0.7296 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 24/30 16/16 [==============================] - 19s 1s/step - loss: 0.5975 - tp: 118.1765 - fp: 23.9412 - tn: 402.4118 - fn: 23.9412 - accuracy: 0.8375 - auc: 0.9185 - prc: 0.7802 - precision: 0.8375 - recall: 0.8375 - val_loss: 0.6229 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9024 - val_prc: 0.7439 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 25/30 16/16 [==============================] - 19s 1s/step - loss: 0.5984 - tp: 116.1765 - fp: 24.7647 - tn: 398.0588 - fn: 24.7647 - accuracy: 0.8422 - auc: 0.9203 - prc: 0.7816 - precision: 0.8422 - recall: 0.8422 - val_loss: 0.6243 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9032 - val_prc: 0.7428 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 26/30 16/16 [==============================] - 19s 1s/step - loss: 0.6881 - tp: 112.4706 - fp: 26.8235 - tn: 391.0588 - fn: 26.8235 - accuracy: 0.8038 - auc: 0.8924 - prc: 0.6925 - precision: 0.8038 - recall: 0.8038 - val_loss: 0.6641 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.8996 - val_prc: 0.7272 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 27/30 16/16 [==============================] - 19s 1s/step - loss: 0.6151 - tp: 117.7647 - fp: 23.6471 - tn: 400.5882 - fn: 23.6471 - accuracy: 0.8338 - auc: 0.9111 - prc: 0.7620 - precision: 0.8338 - recall: 0.8338 - val_loss: 0.6195 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9020 - val_prc: 0.7547 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 28/30 16/16 [==============================] - 19s 1s/step - loss: 0.6171 - tp: 115.6471 - fp: 24.1176 - tn: 395.1765 - fn: 24.1176 - accuracy: 0.8308 - auc: 0.9149 - prc: 0.7619 - precision: 0.8308 - recall: 0.8308 - val_loss: 0.6198 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9066 - val_prc: 0.7564 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 29/30 16/16 [==============================] - 19s 1s/step - loss: 0.6851 - tp: 113.2941 - fp: 26.9412 - tn: 393.7647 - fn: 26.9412 - accuracy: 0.8016 - auc: 0.8997 - prc: 0.7345 - precision: 0.8016 - recall: 0.8016 - val_loss: 0.6353 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9033 - val_prc: 0.7583 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 30/30 16/16 [==============================] - 19s 1s/step - loss: 0.6465 - tp: 115.5882 - fp: 23.9412 - tn: 394.6471 - fn: 23.9412 - accuracy: 0.8273 - auc: 0.9063 - prc: 0.7568 - precision: 0.8273 - recall: 0.8273 - val_loss: 0.6194 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9061 - val_prc: 0.7617 - val_precision: 0.8254 - val_recall: 0.8254 CPU times: user 17min 19s, sys: 15.3 s, total: 17min 35s Wall time: 9min 29s
# Function for smoothing out the loss and accuracy plots
def smooth(points, factor=0.8):
smoothed_points = []
for point in points:
if smoothed_points:
previous = smoothed_points[-1]
smoothed_points.append(previous * factor + point * (1 - factor))
else:
smoothed_points.append(point)
return smoothed_points
def plot_history(hist_dict, soft=False):
epochs = range(1, len(hist_dict['accuracy']) + 1)
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(epochs, smooth(hist_dict['loss']) if soft else hist_dict['loss'],
'bo', label='Train loss')
plt.plot(epochs, smooth(hist_dict['val_loss']) if soft else hist_dict['val_loss'],
'b', label='Val loss')
plt.title('Loss')
plt.ylabel('Loss')
plt.legend()
plt.subplot(2, 2, 2)
plt.plot(epochs, smooth(hist_dict['accuracy']) if soft else hist_dict['accuracy'],
'bo', label='Train accuracy')
plt.plot(epochs, smooth(hist_dict['val_accuracy']) if soft else hist_dict['val_accuracy'],
'b', label='Val accuracy')
plt.title('Accuracy')
plt.ylabel('Accuracy')
plt.legend()
plt.subplot(2, 2, 3)
plt.plot(epochs, smooth(hist_dict['auc']) if soft else hist_dict['auc'],
'bo', label='Train AUC')
plt.plot(epochs, smooth(hist_dict['val_auc']) if soft else hist_dict['val_auc'],
'b', label='Val AUC')
plt.title('AUC')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.legend()
plt.subplot(2, 2, 4)
plt.plot(epochs, smooth(hist_dict['precision']) if soft else hist_dict['precision'],
'bo', label='Train precision')
plt.plot(epochs, smooth(hist_dict['val_precision']) if soft else hist_dict['val_precision'],
'b', label='Val precision')
plt.title('Precision')
plt.xlabel('Epoch')
plt.ylabel('Precision')
plt.legend()
plt.tight_layout()
plt.show()
plot_history(history.history)
Well, the model seem to learn something at least as the train accuracy increases. However, due to the small network it stagnates very early on. The validation accuracy is always the same at 82.5%, which is equivalent to the largest class. The model is thus seemingly only predicting the largest class. Let us experiment with a larger network.
# Create a larger CNN architecture
input_tensor = keras.Input(shape=image_size + (3,), name='input')
x = SeparableConv2D(filters=32, kernel_size=(3, 3), activation='relu')(input_tensor)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = SeparableConv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = SeparableConv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = SeparableConv2D(128, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = SeparableConv2D(256, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = SeparableConv2D(512, (3, 3), activation='relu')(x)
x = layers.GlobalAveragePooling2D()(x)
x = Dropout(0.4)(x)
x = Dense(512, activation='relu')(x)
output = Dense(4, activation='softmax', name='output')(x)
model = keras.Model(inputs=[input_tensor], outputs=[output], name='model_large_1')
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(lr=1e-4),
metrics=METRICS)
model.summary()
Model: "model_large_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) [(None, 800, 380, 3)] 0 _________________________________________________________________ separable_conv2d_15 (Separab (None, 798, 378, 32) 155 _________________________________________________________________ max_pooling2d_24 (MaxPooling (None, 399, 189, 32) 0 _________________________________________________________________ separable_conv2d_16 (Separab (None, 397, 187, 64) 2400 _________________________________________________________________ max_pooling2d_25 (MaxPooling (None, 198, 93, 64) 0 _________________________________________________________________ separable_conv2d_17 (Separab (None, 196, 91, 64) 4736 _________________________________________________________________ max_pooling2d_26 (MaxPooling (None, 98, 45, 64) 0 _________________________________________________________________ separable_conv2d_18 (Separab (None, 96, 43, 128) 8896 _________________________________________________________________ max_pooling2d_27 (MaxPooling (None, 48, 21, 128) 0 _________________________________________________________________ separable_conv2d_19 (Separab (None, 46, 19, 256) 34176 _________________________________________________________________ max_pooling2d_28 (MaxPooling (None, 23, 9, 256) 0 _________________________________________________________________ separable_conv2d_20 (Separab (None, 21, 7, 512) 133888 _________________________________________________________________ global_average_pooling2d_7 ( (None, 512) 0 _________________________________________________________________ dropout_2 (Dropout) (None, 512) 0 _________________________________________________________________ dense_7 (Dense) (None, 512) 262656 _________________________________________________________________ output (Dense) (None, 4) 2052 ================================================================= Total params: 448,959 Trainable params: 448,959 Non-trainable params: 0 _________________________________________________________________
%%time
# Add callbacks for storing the best model
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=model_dir+'model_large_1.hdf5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
)
]
# Train the model
history = model.fit(train_generator,
epochs=30,
callbacks=callbacks,
validation_data=validation_generator,
workers=2)
Epoch 1/30 16/16 [==============================] - 25s 1s/step - loss: 1.3717 - tp: 52.0000 - fp: 11.0000 - tn: 605.0588 - fn: 153.3529 - accuracy: 0.7651 - auc: 0.8587 - prc: 0.7152 - precision: 0.8254 - recall: 0.3016 - val_loss: 1.3319 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.8836 - val_prc: 0.7533 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/30 16/16 [==============================] - 19s 1s/step - loss: 1.3192 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 417.8824 - fn: 139.2941 - accuracy: 0.8191 - auc: 0.8941 - prc: 0.7550 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 1.2709 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9048 - val_prc: 0.7635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/30 16/16 [==============================] - 20s 1s/step - loss: 1.2554 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 427.7647 - fn: 142.5882 - accuracy: 0.8145 - auc: 0.9005 - prc: 0.7517 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 1.1924 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 4/30 16/16 [==============================] - 20s 1s/step - loss: 1.1750 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 425.6471 - fn: 141.8824 - accuracy: 0.8135 - auc: 0.8972 - prc: 0.7493 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 1.0915 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 5/30 16/16 [==============================] - 20s 1s/step - loss: 1.0826 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 419.2941 - fn: 139.7647 - accuracy: 0.7945 - auc: 0.8954 - prc: 0.7581 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_loss: 0.9720 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 6/30 16/16 [==============================] - 20s 1s/step - loss: 0.9505 - tp: 11.7059 - fp: 2.5882 - tn: 416.0000 - fn: 127.8235 - accuracy: 0.8119 - auc: 0.9123 - prc: 0.7822 - precision: 0.4068 - recall: 0.0503 - val_loss: 0.8458 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 7/30 16/16 [==============================] - 19s 1s/step - loss: 0.8224 - tp: 114.1176 - fp: 23.4706 - tn: 397.9412 - fn: 26.3529 - accuracy: 0.8244 - auc: 0.9021 - prc: 0.7713 - precision: 0.8181 - recall: 0.7935 - val_loss: 0.7367 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9159 - val_prc: 0.7790 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 8/30 16/16 [==============================] - 19s 1s/step - loss: 0.7579 - tp: 115.3529 - fp: 25.8235 - tn: 397.7059 - fn: 25.8235 - accuracy: 0.8004 - auc: 0.8987 - prc: 0.7318 - precision: 0.8004 - recall: 0.8004 - val_loss: 0.6645 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9129 - val_prc: 0.7543 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 9/30 16/16 [==============================] - 20s 1s/step - loss: 0.6462 - tp: 115.7647 - fp: 24.0000 - tn: 395.2941 - fn: 24.0000 - accuracy: 0.8281 - auc: 0.9219 - prc: 0.8093 - precision: 0.8281 - recall: 0.8281 - val_loss: 0.6367 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9137 - val_prc: 0.7591 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 10/30 16/16 [==============================] - 20s 1s/step - loss: 0.7539 - tp: 112.8824 - fp: 27.8235 - tn: 394.2941 - fn: 27.8235 - accuracy: 0.7768 - auc: 0.8850 - prc: 0.7245 - precision: 0.7768 - recall: 0.7768 - val_loss: 0.6290 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9174 - val_prc: 0.7826 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 11/30 16/16 [==============================] - 19s 1s/step - loss: 0.7511 - tp: 114.0588 - fp: 27.5882 - tn: 397.3529 - fn: 27.5882 - accuracy: 0.7778 - auc: 0.8887 - prc: 0.7141 - precision: 0.7778 - recall: 0.7778 - val_loss: 0.6268 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9118 - val_prc: 0.7675 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 12/30 16/16 [==============================] - 19s 1s/step - loss: 0.6393 - tp: 114.7647 - fp: 25.0000 - tn: 394.2941 - fn: 25.0000 - accuracy: 0.8181 - auc: 0.9162 - prc: 0.7757 - precision: 0.8181 - recall: 0.8181 - val_loss: 0.6243 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9148 - val_prc: 0.7711 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 13/30 16/16 [==============================] - 20s 1s/step - loss: 0.5817 - tp: 117.0000 - fp: 23.4706 - tn: 397.9412 - fn: 23.4706 - accuracy: 0.8425 - auc: 0.9216 - prc: 0.7784 - precision: 0.8425 - recall: 0.8425 - val_loss: 0.6225 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9155 - val_prc: 0.7566 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 14/30 16/16 [==============================] - 19s 1s/step - loss: 0.6838 - tp: 114.1765 - fp: 26.5294 - tn: 395.5882 - fn: 26.5294 - accuracy: 0.8088 - auc: 0.8879 - prc: 0.6936 - precision: 0.8088 - recall: 0.8088 - val_loss: 0.6211 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9098 - val_prc: 0.7447 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 15/30 16/16 [==============================] - 20s 1s/step - loss: 0.5532 - tp: 117.1176 - fp: 22.4118 - tn: 396.1765 - fn: 22.4118 - accuracy: 0.8505 - auc: 0.9280 - prc: 0.8181 - precision: 0.8505 - recall: 0.8505 - val_loss: 0.6198 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9166 - val_prc: 0.7781 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 16/30 16/16 [==============================] - 19s 1s/step - loss: 0.6091 - tp: 114.2941 - fp: 24.7647 - tn: 392.4118 - fn: 24.7647 - accuracy: 0.8290 - auc: 0.9183 - prc: 0.7834 - precision: 0.8290 - recall: 0.8290 - val_loss: 0.6184 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 17/30 16/16 [==============================] - 19s 1s/step - loss: 0.5825 - tp: 118.0000 - fp: 23.8824 - tn: 401.7647 - fn: 23.8824 - accuracy: 0.8392 - auc: 0.9223 - prc: 0.7870 - precision: 0.8392 - recall: 0.8392 - val_loss: 0.6175 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9149 - val_prc: 0.7694 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 18/30 16/16 [==============================] - 19s 1s/step - loss: 0.6388 - tp: 115.6471 - fp: 25.7647 - tn: 398.4706 - fn: 25.7647 - accuracy: 0.8143 - auc: 0.9077 - prc: 0.7287 - precision: 0.8143 - recall: 0.8143 - val_loss: 0.6168 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9146 - val_prc: 0.7811 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 19/30 16/16 [==============================] - 19s 1s/step - loss: 0.6573 - tp: 113.2941 - fp: 26.4706 - tn: 392.8235 - fn: 26.4706 - accuracy: 0.8109 - auc: 0.8998 - prc: 0.7006 - precision: 0.8109 - recall: 0.8109 - val_loss: 0.6165 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9169 - val_prc: 0.7802 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 20/30 16/16 [==============================] - 19s 1s/step - loss: 0.5391 - tp: 118.7647 - fp: 21.4706 - tn: 399.2353 - fn: 21.4706 - accuracy: 0.8605 - auc: 0.9239 - prc: 0.7736 - precision: 0.8605 - recall: 0.8605 - val_loss: 0.6157 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9166 - val_prc: 0.7830 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 21/30 16/16 [==============================] - 19s 1s/step - loss: 0.6666 - tp: 116.8824 - fp: 25.2353 - tn: 401.1176 - fn: 25.2353 - accuracy: 0.8087 - auc: 0.8981 - prc: 0.7288 - precision: 0.8087 - recall: 0.8087 - val_loss: 0.6152 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9166 - val_prc: 0.7781 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 22/30 16/16 [==============================] - 19s 1s/step - loss: 0.6131 - tp: 115.2941 - fp: 24.4706 - tn: 394.8235 - fn: 24.4706 - accuracy: 0.8248 - auc: 0.9185 - prc: 0.7650 - precision: 0.8248 - recall: 0.8248 - val_loss: 0.6148 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9159 - val_prc: 0.7790 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 23/30 16/16 [==============================] - 19s 1s/step - loss: 0.6546 - tp: 116.0588 - fp: 25.8235 - tn: 399.8235 - fn: 25.8235 - accuracy: 0.8137 - auc: 0.9038 - prc: 0.7400 - precision: 0.8137 - recall: 0.8137 - val_loss: 0.6146 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9203 - val_prc: 0.7742 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 24/30 16/16 [==============================] - 19s 1s/step - loss: 0.6019 - tp: 117.8235 - fp: 23.5882 - tn: 400.6471 - fn: 23.5882 - accuracy: 0.8282 - auc: 0.9176 - prc: 0.7700 - precision: 0.8282 - recall: 0.8282 - val_loss: 0.6143 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9116 - val_prc: 0.7468 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 25/30 16/16 [==============================] - 19s 1s/step - loss: 0.6617 - tp: 115.7059 - fp: 24.2941 - tn: 395.7059 - fn: 24.2941 - accuracy: 0.8044 - auc: 0.9093 - prc: 0.7657 - precision: 0.8044 - recall: 0.8044 - val_loss: 0.6141 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9111 - val_prc: 0.7754 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 26/30 16/16 [==============================] - 19s 1s/step - loss: 0.5337 - tp: 120.2353 - fp: 22.1176 - tn: 404.9412 - fn: 22.1176 - accuracy: 0.8602 - auc: 0.9267 - prc: 0.7815 - precision: 0.8602 - recall: 0.8602 - val_loss: 0.6141 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9137 - val_prc: 0.7719 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 27/30 16/16 [==============================] - 19s 1s/step - loss: 0.5885 - tp: 116.7059 - fp: 24.7059 - tn: 399.5294 - fn: 24.7059 - accuracy: 0.8250 - auc: 0.9300 - prc: 0.8037 - precision: 0.8250 - recall: 0.8250 - val_loss: 0.6138 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9131 - val_prc: 0.7583 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 28/30 16/16 [==============================] - 19s 1s/step - loss: 0.6289 - tp: 117.2353 - fp: 24.6471 - tn: 401.0000 - fn: 24.6471 - accuracy: 0.8281 - auc: 0.8990 - prc: 0.7127 - precision: 0.8281 - recall: 0.8281 - val_loss: 0.6137 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9155 - val_prc: 0.7704 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 29/30 16/16 [==============================] - 19s 1s/step - loss: 0.6621 - tp: 113.8824 - fp: 25.1765 - tn: 392.0000 - fn: 25.1765 - accuracy: 0.8193 - auc: 0.8989 - prc: 0.7282 - precision: 0.8193 - recall: 0.8193 - val_loss: 0.6137 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9119 - val_prc: 0.7418 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 30/30 16/16 [==============================] - 19s 1s/step - loss: 0.7252 - tp: 114.6471 - fp: 27.7059 - tn: 399.3529 - fn: 27.7059 - accuracy: 0.7913 - auc: 0.8838 - prc: 0.7136 - precision: 0.7913 - recall: 0.7913 - val_loss: 0.6141 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9113 - val_prc: 0.7463 - val_precision: 0.8254 - val_recall: 0.8254 CPU times: user 17min 51s, sys: 18.9 s, total: 18min 10s Wall time: 9min 47s
plot_history(history.history)
So far, this doesn't look promising. Let's try an even larger model before taking another approach.
# Even larger model architecture
input_tensor = keras.Input(shape=image_size + (3,), name='input')
# Entry block
x = layers.Conv2D(32, 3, strides=2, padding="same")(input_tensor)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
for size in [128, 256, 512, 728]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
output = layers.Dense(4, activation='softmax', name='output')(x)
model = keras.Model(input_tensor, output, name='model_large_2')
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(lr=1e-4),
metrics=METRICS)
model.summary()
Model: "model_large_2" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input (InputLayer) [(None, 800, 380, 3) 0 __________________________________________________________________________________________________ conv2d_30 (Conv2D) (None, 400, 190, 32) 896 input[0][0] __________________________________________________________________________________________________ batch_normalization_15 (BatchNo (None, 400, 190, 32) 128 conv2d_30[0][0] __________________________________________________________________________________________________ activation_11 (Activation) (None, 400, 190, 32) 0 batch_normalization_15[0][0] __________________________________________________________________________________________________ conv2d_31 (Conv2D) (None, 400, 190, 64) 18496 activation_11[0][0] __________________________________________________________________________________________________ batch_normalization_16 (BatchNo (None, 400, 190, 64) 256 conv2d_31[0][0] __________________________________________________________________________________________________ activation_12 (Activation) (None, 400, 190, 64) 0 batch_normalization_16[0][0] __________________________________________________________________________________________________ activation_13 (Activation) (None, 400, 190, 64) 0 activation_12[0][0] __________________________________________________________________________________________________ separable_conv2d_21 (SeparableC (None, 400, 190, 128 8896 activation_13[0][0] __________________________________________________________________________________________________ batch_normalization_17 (BatchNo (None, 400, 190, 128 512 separable_conv2d_21[0][0] __________________________________________________________________________________________________ activation_14 (Activation) (None, 400, 190, 128 0 batch_normalization_17[0][0] __________________________________________________________________________________________________ separable_conv2d_22 (SeparableC (None, 400, 190, 128 17664 activation_14[0][0] __________________________________________________________________________________________________ batch_normalization_18 (BatchNo (None, 400, 190, 128 512 separable_conv2d_22[0][0] __________________________________________________________________________________________________ max_pooling2d_29 (MaxPooling2D) (None, 200, 95, 128) 0 batch_normalization_18[0][0] __________________________________________________________________________________________________ conv2d_32 (Conv2D) (None, 200, 95, 128) 8320 activation_12[0][0] __________________________________________________________________________________________________ add_16 (Add) (None, 200, 95, 128) 0 max_pooling2d_29[0][0] conv2d_32[0][0] __________________________________________________________________________________________________ activation_15 (Activation) (None, 200, 95, 128) 0 add_16[0][0] __________________________________________________________________________________________________ separable_conv2d_23 (SeparableC (None, 200, 95, 256) 34176 activation_15[0][0] __________________________________________________________________________________________________ batch_normalization_19 (BatchNo (None, 200, 95, 256) 1024 separable_conv2d_23[0][0] __________________________________________________________________________________________________ activation_16 (Activation) (None, 200, 95, 256) 0 batch_normalization_19[0][0] __________________________________________________________________________________________________ separable_conv2d_24 (SeparableC (None, 200, 95, 256) 68096 activation_16[0][0] __________________________________________________________________________________________________ batch_normalization_20 (BatchNo (None, 200, 95, 256) 1024 separable_conv2d_24[0][0] __________________________________________________________________________________________________ max_pooling2d_30 (MaxPooling2D) (None, 100, 48, 256) 0 batch_normalization_20[0][0] __________________________________________________________________________________________________ conv2d_33 (Conv2D) (None, 100, 48, 256) 33024 add_16[0][0] __________________________________________________________________________________________________ add_17 (Add) (None, 100, 48, 256) 0 max_pooling2d_30[0][0] conv2d_33[0][0] __________________________________________________________________________________________________ activation_17 (Activation) (None, 100, 48, 256) 0 add_17[0][0] __________________________________________________________________________________________________ separable_conv2d_25 (SeparableC (None, 100, 48, 512) 133888 activation_17[0][0] __________________________________________________________________________________________________ batch_normalization_21 (BatchNo (None, 100, 48, 512) 2048 separable_conv2d_25[0][0] __________________________________________________________________________________________________ activation_18 (Activation) (None, 100, 48, 512) 0 batch_normalization_21[0][0] __________________________________________________________________________________________________ separable_conv2d_26 (SeparableC (None, 100, 48, 512) 267264 activation_18[0][0] __________________________________________________________________________________________________ batch_normalization_22 (BatchNo (None, 100, 48, 512) 2048 separable_conv2d_26[0][0] __________________________________________________________________________________________________ max_pooling2d_31 (MaxPooling2D) (None, 50, 24, 512) 0 batch_normalization_22[0][0] __________________________________________________________________________________________________ conv2d_34 (Conv2D) (None, 50, 24, 512) 131584 add_17[0][0] __________________________________________________________________________________________________ add_18 (Add) (None, 50, 24, 512) 0 max_pooling2d_31[0][0] conv2d_34[0][0] __________________________________________________________________________________________________ activation_19 (Activation) (None, 50, 24, 512) 0 add_18[0][0] __________________________________________________________________________________________________ separable_conv2d_27 (SeparableC (None, 50, 24, 728) 378072 activation_19[0][0] __________________________________________________________________________________________________ batch_normalization_23 (BatchNo (None, 50, 24, 728) 2912 separable_conv2d_27[0][0] __________________________________________________________________________________________________ activation_20 (Activation) (None, 50, 24, 728) 0 batch_normalization_23[0][0] __________________________________________________________________________________________________ separable_conv2d_28 (SeparableC (None, 50, 24, 728) 537264 activation_20[0][0] __________________________________________________________________________________________________ batch_normalization_24 (BatchNo (None, 50, 24, 728) 2912 separable_conv2d_28[0][0] __________________________________________________________________________________________________ max_pooling2d_32 (MaxPooling2D) (None, 25, 12, 728) 0 batch_normalization_24[0][0] __________________________________________________________________________________________________ conv2d_35 (Conv2D) (None, 25, 12, 728) 373464 add_18[0][0] __________________________________________________________________________________________________ add_19 (Add) (None, 25, 12, 728) 0 max_pooling2d_32[0][0] conv2d_35[0][0] __________________________________________________________________________________________________ separable_conv2d_29 (SeparableC (None, 25, 12, 1024) 753048 add_19[0][0] __________________________________________________________________________________________________ batch_normalization_25 (BatchNo (None, 25, 12, 1024) 4096 separable_conv2d_29[0][0] __________________________________________________________________________________________________ activation_21 (Activation) (None, 25, 12, 1024) 0 batch_normalization_25[0][0] __________________________________________________________________________________________________ global_average_pooling2d_8 (Glo (None, 1024) 0 activation_21[0][0] __________________________________________________________________________________________________ dropout_3 (Dropout) (None, 1024) 0 global_average_pooling2d_8[0][0] __________________________________________________________________________________________________ output (Dense) (None, 4) 4100 dropout_3[0][0] ================================================================================================== Total params: 2,785,724 Trainable params: 2,776,988 Non-trainable params: 8,736 __________________________________________________________________________________________________
%%time
# Add callbacks for storing the best model
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=model_dir+'model_large_2.hdf5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
)
]
# Train the model
history = model.fit(train_generator,
epochs=30,
callbacks=callbacks,
validation_data=validation_generator,
workers=2)
Epoch 1/30 16/16 [==============================] - 29s 1s/step - loss: 1.4153 - tp: 79.8824 - fp: 27.1765 - tn: 587.4706 - fn: 125.0000 - accuracy: 0.5680 - auc: 0.7801 - prc: 0.6000 - precision: 0.7479 - recall: 0.4157 - val_loss: 1.3544 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.8783 - val_prc: 0.7501 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 2/30 16/16 [==============================] - 23s 1s/step - loss: 0.8285 - tp: 88.8235 - fp: 18.7059 - tn: 404.1176 - fn: 52.1176 - accuracy: 0.7587 - auc: 0.8878 - prc: 0.7566 - precision: 0.8321 - recall: 0.6180 - val_loss: 1.3216 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9048 - val_prc: 0.7635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 3/30 16/16 [==============================] - 24s 1s/step - loss: 0.7183 - tp: 106.8824 - fp: 22.1176 - tn: 404.9412 - fn: 35.4706 - accuracy: 0.7852 - auc: 0.9014 - prc: 0.7830 - precision: 0.8411 - recall: 0.7470 - val_loss: 1.2922 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9048 - val_prc: 0.7635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 4/30 16/16 [==============================] - 24s 1s/step - loss: 0.6788 - tp: 109.2941 - fp: 23.7647 - tn: 401.8824 - fn: 32.5882 - accuracy: 0.7680 - auc: 0.9124 - prc: 0.7961 - precision: 0.8114 - recall: 0.7462 - val_loss: 1.2587 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9048 - val_prc: 0.7635 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 5/30 16/16 [==============================] - 24s 1s/step - loss: 0.5779 - tp: 112.4706 - fp: 18.8235 - tn: 401.1765 - fn: 27.5294 - accuracy: 0.8309 - auc: 0.9367 - prc: 0.8545 - precision: 0.8572 - recall: 0.8058 - val_loss: 1.2260 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 6/30 16/16 [==============================] - 24s 1s/step - loss: 0.5694 - tp: 111.0588 - fp: 21.1176 - tn: 397.4706 - fn: 28.4706 - accuracy: 0.8070 - auc: 0.9375 - prc: 0.8640 - precision: 0.8278 - recall: 0.7865 - val_loss: 1.1946 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9185 - val_prc: 0.7731 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 7/30 16/16 [==============================] - 24s 1s/step - loss: 0.5439 - tp: 115.5882 - fp: 18.2353 - tn: 403.8824 - fn: 25.1176 - accuracy: 0.8361 - auc: 0.9442 - prc: 0.8597 - precision: 0.8571 - recall: 0.8112 - val_loss: 1.1651 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 8/30 16/16 [==============================] - 22s 1s/step - loss: 0.5334 - tp: 109.1765 - fp: 19.5882 - tn: 399.7059 - fn: 30.5882 - accuracy: 0.7872 - auc: 0.9465 - prc: 0.8585 - precision: 0.8551 - recall: 0.7766 - val_loss: 1.1299 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 9/30 16/16 [==============================] - 23s 1s/step - loss: 0.4351 - tp: 115.6471 - fp: 19.0588 - tn: 402.3529 - fn: 24.8235 - accuracy: 0.8468 - auc: 0.9680 - prc: 0.9147 - precision: 0.8608 - recall: 0.8220 - val_loss: 1.1052 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9154 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 10/30 16/16 [==============================] - 23s 1s/step - loss: 0.4305 - tp: 116.0000 - fp: 22.2941 - tn: 400.5294 - fn: 24.9412 - accuracy: 0.8480 - auc: 0.9679 - prc: 0.9124 - precision: 0.8484 - recall: 0.8349 - val_loss: 1.0817 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 11/30 16/16 [==============================] - 23s 1s/step - loss: 0.4467 - tp: 117.1176 - fp: 19.6471 - tn: 404.5882 - fn: 24.2941 - accuracy: 0.8709 - auc: 0.9626 - prc: 0.8997 - precision: 0.8682 - recall: 0.8450 - val_loss: 1.0466 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9153 - val_prc: 0.7717 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 12/30 16/16 [==============================] - 24s 1s/step - loss: 0.4571 - tp: 113.4118 - fp: 19.7059 - tn: 403.1176 - fn: 27.5294 - accuracy: 0.8400 - auc: 0.9603 - prc: 0.8989 - precision: 0.8598 - recall: 0.8084 - val_loss: 1.0144 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9164 - val_prc: 0.7721 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 13/30 16/16 [==============================] - 24s 1s/step - loss: 0.4625 - tp: 111.4706 - fp: 24.1176 - tn: 396.5882 - fn: 28.7647 - accuracy: 0.8070 - auc: 0.9637 - prc: 0.9056 - precision: 0.8160 - recall: 0.7868 - val_loss: 0.9859 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9266 - val_prc: 0.8108 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 14/30 16/16 [==============================] - 23s 1s/step - loss: 0.4203 - tp: 112.9412 - fp: 21.4118 - tn: 402.1176 - fn: 28.2353 - accuracy: 0.8178 - auc: 0.9694 - prc: 0.9214 - precision: 0.8442 - recall: 0.7995 - val_loss: 0.9678 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9129 - val_prc: 0.7543 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 15/30 16/16 [==============================] - 24s 1s/step - loss: 0.4087 - tp: 115.4706 - fp: 21.0588 - tn: 405.2941 - fn: 26.6471 - accuracy: 0.8283 - auc: 0.9713 - prc: 0.9217 - precision: 0.8425 - recall: 0.8134 - val_loss: 0.9534 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9187 - val_prc: 0.7828 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 16/30 16/16 [==============================] - 24s 1s/step - loss: 0.4513 - tp: 112.8235 - fp: 19.2353 - tn: 401.4706 - fn: 27.4118 - accuracy: 0.8305 - auc: 0.9641 - prc: 0.9079 - precision: 0.8574 - recall: 0.7947 - val_loss: 0.9154 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9229 - val_prc: 0.7899 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 17/30 16/16 [==============================] - 23s 1s/step - loss: 0.3313 - tp: 119.9412 - fp: 13.6471 - tn: 405.6471 - fn: 19.8235 - accuracy: 0.8884 - auc: 0.9794 - prc: 0.9443 - precision: 0.9018 - recall: 0.8649 - val_loss: 0.8903 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 189.0000 - val_fn: 63.0000 - val_accuracy: 0.8254 - val_auc: 0.9331 - val_prc: 0.8177 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 Epoch 18/30 16/16 [==============================] - 25s 1s/step - loss: 0.3282 - tp: 120.2941 - fp: 13.6471 - tn: 410.5882 - fn: 21.1176 - accuracy: 0.8746 - auc: 0.9805 - prc: 0.9478 - precision: 0.9131 - recall: 0.8630 - val_loss: 0.8569 - val_tp: 44.0000 - val_fp: 8.0000 - val_tn: 181.0000 - val_fn: 19.0000 - val_accuracy: 0.8254 - val_auc: 0.9281 - val_prc: 0.7666 - val_precision: 0.8462 - val_recall: 0.6984 Epoch 19/30 16/16 [==============================] - 24s 1s/step - loss: 0.3100 - tp: 121.7647 - fp: 16.3529 - tn: 402.9412 - fn: 18.0000 - accuracy: 0.8897 - auc: 0.9825 - prc: 0.9514 - precision: 0.8924 - recall: 0.8853 - val_loss: 0.8387 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9320 - val_prc: 0.8209 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 20/30 16/16 [==============================] - 24s 1s/step - loss: 0.3024 - tp: 120.6471 - fp: 14.3529 - tn: 406.3529 - fn: 19.5882 - accuracy: 0.8970 - auc: 0.9830 - prc: 0.9532 - precision: 0.9060 - recall: 0.8710 - val_loss: 0.8100 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9332 - val_prc: 0.8269 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 21/30 16/16 [==============================] - 23s 1s/step - loss: 0.3891 - tp: 118.8824 - fp: 18.8235 - tn: 407.5294 - fn: 23.2353 - accuracy: 0.8505 - auc: 0.9732 - prc: 0.9299 - precision: 0.8673 - recall: 0.8396 - val_loss: 0.7845 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9257 - val_prc: 0.7899 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 22/30 16/16 [==============================] - 23s 1s/step - loss: 0.3135 - tp: 118.1176 - fp: 16.6471 - tn: 406.8824 - fn: 23.0588 - accuracy: 0.8658 - auc: 0.9834 - prc: 0.9548 - precision: 0.8885 - recall: 0.8399 - val_loss: 0.7369 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9272 - val_prc: 0.7711 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 23/30 16/16 [==============================] - 24s 1s/step - loss: 0.2995 - tp: 122.8235 - fp: 15.4706 - tn: 408.0588 - fn: 18.3529 - accuracy: 0.8752 - auc: 0.9845 - prc: 0.9573 - precision: 0.8789 - recall: 0.8635 - val_loss: 0.7126 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9347 - val_prc: 0.8127 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 24/30 16/16 [==============================] - 24s 1s/step - loss: 0.3292 - tp: 118.1176 - fp: 18.6471 - tn: 399.2353 - fn: 21.1765 - accuracy: 0.8519 - auc: 0.9798 - prc: 0.9438 - precision: 0.8584 - recall: 0.8446 - val_loss: 0.6734 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9265 - val_prc: 0.7867 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 25/30 16/16 [==============================] - 24s 1s/step - loss: 0.2937 - tp: 124.1176 - fp: 14.1176 - tn: 412.9412 - fn: 18.2353 - accuracy: 0.8800 - auc: 0.9850 - prc: 0.9597 - precision: 0.8857 - recall: 0.8594 - val_loss: 0.6560 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9209 - val_prc: 0.7666 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 26/30 16/16 [==============================] - 25s 1s/step - loss: 0.2618 - tp: 125.1765 - fp: 10.7059 - tn: 410.7059 - fn: 15.2941 - accuracy: 0.9042 - auc: 0.9884 - prc: 0.9688 - precision: 0.9194 - recall: 0.8969 - val_loss: 0.6371 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9313 - val_prc: 0.8097 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 27/30 16/16 [==============================] - 23s 1s/step - loss: 0.2541 - tp: 124.4706 - fp: 12.5882 - tn: 407.4118 - fn: 15.5294 - accuracy: 0.9067 - auc: 0.9874 - prc: 0.9660 - precision: 0.9139 - recall: 0.8981 - val_loss: 0.6164 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9255 - val_prc: 0.7598 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 28/30 16/16 [==============================] - 24s 1s/step - loss: 0.2317 - tp: 126.1765 - fp: 9.6471 - tn: 410.3529 - fn: 13.8235 - accuracy: 0.9296 - auc: 0.9889 - prc: 0.9725 - precision: 0.9416 - recall: 0.9212 - val_loss: 0.6101 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9326 - val_prc: 0.8148 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 29/30 16/16 [==============================] - 24s 1s/step - loss: 0.2400 - tp: 126.8235 - fp: 9.5294 - tn: 409.7647 - fn: 12.9412 - accuracy: 0.9131 - auc: 0.9898 - prc: 0.9725 - precision: 0.9284 - recall: 0.9081 - val_loss: 0.5974 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9266 - val_prc: 0.7786 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 30/30 16/16 [==============================] - 25s 1s/step - loss: 0.3008 - tp: 120.4118 - fp: 15.2941 - tn: 404.7059 - fn: 19.5882 - accuracy: 0.8648 - auc: 0.9832 - prc: 0.9551 - precision: 0.8850 - recall: 0.8529 - val_loss: 0.5907 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9340 - val_prc: 0.8392 - val_precision: 0.8254 - val_recall: 0.8254 CPU times: user 20min 1s, sys: 30.2 s, total: 20min 31s Wall time: 12min 2s
plot_history(history.history)
!nvidia-smi
Fri Apr 23 10:36:17 2021 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 465.19.01 Driver Version: 460.32.03 CUDA Version: 11.2 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 68C P0 31W / 70W | 14828MiB / 15109MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| +-----------------------------------------------------------------------------+
We can confirm that the model consistently predicts everything as belonging to class 3.
def plot_cm(labels, predictions):
cm = tf.math.confusion_matrix(labels=labels,
predictions=predictions)
plt.figure(figsize=(4,4))
sns.heatmap(cm, annot=True, cbar=False, fmt='d')
plt.title('Confusion Matrix')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.show();
# Make predictions on the validation set (which is not ideal! But that's what we have.)
best_model = keras.models.load_model(model_dir+'model_large_2.hdf5') # Load the best performing model
predictions = best_model.predict(validation_generator)
# Plot confusion matrix
plot_cm(validation_generator.classes,
np.argmax(predictions, axis=1))
print(f"Predictions: \n{np.argmax(predictions, axis=1)}")
print(f"\nTrue labels: \n{validation_generator.classes}")
Predictions: [3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3] True labels: [0 0 0 1 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]
Use an Xception architecture pre-trained on Imagenet.
It's necessary to freeze the convolution base of a pre-trained model as Xception in order to train a randomly initiated classifier on top. Conversely, it's highly recommended to fine-tune the top layers of a convolutional base once the classifier on top has already been trained. If the classifier hasn't been trained, then the error signal propagating through the network during training will be too large, and the representations previously learned by the layers being fine-tuned will be destroyed. For that reason, the steps moving forward will be:
# Instantiate the Xception convolutional base
conv_base = keras.applications.Xception(weights='imagenet',
include_top=False,)
conv_base.summary()
Model: "xception" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_2 (InputLayer) [(None, None, None, 0 __________________________________________________________________________________________________ block1_conv1 (Conv2D) (None, None, None, 3 864 input_2[0][0] __________________________________________________________________________________________________ block1_conv1_bn (BatchNormaliza (None, None, None, 3 128 block1_conv1[0][0] __________________________________________________________________________________________________ block1_conv1_act (Activation) (None, None, None, 3 0 block1_conv1_bn[0][0] __________________________________________________________________________________________________ block1_conv2 (Conv2D) (None, None, None, 6 18432 block1_conv1_act[0][0] __________________________________________________________________________________________________ block1_conv2_bn (BatchNormaliza (None, None, None, 6 256 block1_conv2[0][0] __________________________________________________________________________________________________ block1_conv2_act (Activation) (None, None, None, 6 0 block1_conv2_bn[0][0] __________________________________________________________________________________________________ block2_sepconv1 (SeparableConv2 (None, None, None, 1 8768 block1_conv2_act[0][0] __________________________________________________________________________________________________ block2_sepconv1_bn (BatchNormal (None, None, None, 1 512 block2_sepconv1[0][0] __________________________________________________________________________________________________ block2_sepconv2_act (Activation (None, None, None, 1 0 block2_sepconv1_bn[0][0] __________________________________________________________________________________________________ block2_sepconv2 (SeparableConv2 (None, None, None, 1 17536 block2_sepconv2_act[0][0] __________________________________________________________________________________________________ block2_sepconv2_bn (BatchNormal (None, None, None, 1 512 block2_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_36 (Conv2D) (None, None, None, 1 8192 block1_conv2_act[0][0] __________________________________________________________________________________________________ block2_pool (MaxPooling2D) (None, None, None, 1 0 block2_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_26 (BatchNo (None, None, None, 1 512 conv2d_36[0][0] __________________________________________________________________________________________________ add_20 (Add) (None, None, None, 1 0 block2_pool[0][0] batch_normalization_26[0][0] __________________________________________________________________________________________________ block3_sepconv1_act (Activation (None, None, None, 1 0 add_20[0][0] __________________________________________________________________________________________________ block3_sepconv1 (SeparableConv2 (None, None, None, 2 33920 block3_sepconv1_act[0][0] __________________________________________________________________________________________________ block3_sepconv1_bn (BatchNormal (None, None, None, 2 1024 block3_sepconv1[0][0] __________________________________________________________________________________________________ block3_sepconv2_act (Activation (None, None, None, 2 0 block3_sepconv1_bn[0][0] __________________________________________________________________________________________________ block3_sepconv2 (SeparableConv2 (None, None, None, 2 67840 block3_sepconv2_act[0][0] __________________________________________________________________________________________________ block3_sepconv2_bn (BatchNormal (None, None, None, 2 1024 block3_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_37 (Conv2D) (None, None, None, 2 32768 add_20[0][0] __________________________________________________________________________________________________ block3_pool (MaxPooling2D) (None, None, None, 2 0 block3_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_27 (BatchNo (None, None, None, 2 1024 conv2d_37[0][0] __________________________________________________________________________________________________ add_21 (Add) (None, None, None, 2 0 block3_pool[0][0] batch_normalization_27[0][0] __________________________________________________________________________________________________ block4_sepconv1_act (Activation (None, None, None, 2 0 add_21[0][0] __________________________________________________________________________________________________ block4_sepconv1 (SeparableConv2 (None, None, None, 7 188672 block4_sepconv1_act[0][0] __________________________________________________________________________________________________ block4_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block4_sepconv1[0][0] __________________________________________________________________________________________________ block4_sepconv2_act (Activation (None, None, None, 7 0 block4_sepconv1_bn[0][0] __________________________________________________________________________________________________ block4_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block4_sepconv2_act[0][0] __________________________________________________________________________________________________ block4_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block4_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_38 (Conv2D) (None, None, None, 7 186368 add_21[0][0] __________________________________________________________________________________________________ block4_pool (MaxPooling2D) (None, None, None, 7 0 block4_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_28 (BatchNo (None, None, None, 7 2912 conv2d_38[0][0] __________________________________________________________________________________________________ add_22 (Add) (None, None, None, 7 0 block4_pool[0][0] batch_normalization_28[0][0] __________________________________________________________________________________________________ block5_sepconv1_act (Activation (None, None, None, 7 0 add_22[0][0] __________________________________________________________________________________________________ block5_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block5_sepconv1_act[0][0] __________________________________________________________________________________________________ block5_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block5_sepconv1[0][0] __________________________________________________________________________________________________ block5_sepconv2_act (Activation (None, None, None, 7 0 block5_sepconv1_bn[0][0] __________________________________________________________________________________________________ block5_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block5_sepconv2_act[0][0] __________________________________________________________________________________________________ block5_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block5_sepconv2[0][0] __________________________________________________________________________________________________ block5_sepconv3_act (Activation (None, None, None, 7 0 block5_sepconv2_bn[0][0] __________________________________________________________________________________________________ block5_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block5_sepconv3_act[0][0] __________________________________________________________________________________________________ block5_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block5_sepconv3[0][0] __________________________________________________________________________________________________ add_23 (Add) (None, None, None, 7 0 block5_sepconv3_bn[0][0] add_22[0][0] __________________________________________________________________________________________________ block6_sepconv1_act (Activation (None, None, None, 7 0 add_23[0][0] __________________________________________________________________________________________________ block6_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block6_sepconv1_act[0][0] __________________________________________________________________________________________________ block6_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block6_sepconv1[0][0] __________________________________________________________________________________________________ block6_sepconv2_act (Activation (None, None, None, 7 0 block6_sepconv1_bn[0][0] __________________________________________________________________________________________________ block6_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block6_sepconv2_act[0][0] __________________________________________________________________________________________________ block6_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block6_sepconv2[0][0] __________________________________________________________________________________________________ block6_sepconv3_act (Activation (None, None, None, 7 0 block6_sepconv2_bn[0][0] __________________________________________________________________________________________________ block6_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block6_sepconv3_act[0][0] __________________________________________________________________________________________________ block6_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block6_sepconv3[0][0] __________________________________________________________________________________________________ add_24 (Add) (None, None, None, 7 0 block6_sepconv3_bn[0][0] add_23[0][0] __________________________________________________________________________________________________ block7_sepconv1_act (Activation (None, None, None, 7 0 add_24[0][0] __________________________________________________________________________________________________ block7_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block7_sepconv1_act[0][0] __________________________________________________________________________________________________ block7_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block7_sepconv1[0][0] __________________________________________________________________________________________________ block7_sepconv2_act (Activation (None, None, None, 7 0 block7_sepconv1_bn[0][0] __________________________________________________________________________________________________ block7_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block7_sepconv2_act[0][0] __________________________________________________________________________________________________ block7_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block7_sepconv2[0][0] __________________________________________________________________________________________________ block7_sepconv3_act (Activation (None, None, None, 7 0 block7_sepconv2_bn[0][0] __________________________________________________________________________________________________ block7_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block7_sepconv3_act[0][0] __________________________________________________________________________________________________ block7_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block7_sepconv3[0][0] __________________________________________________________________________________________________ add_25 (Add) (None, None, None, 7 0 block7_sepconv3_bn[0][0] add_24[0][0] __________________________________________________________________________________________________ block8_sepconv1_act (Activation (None, None, None, 7 0 add_25[0][0] __________________________________________________________________________________________________ block8_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block8_sepconv1_act[0][0] __________________________________________________________________________________________________ block8_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block8_sepconv1[0][0] __________________________________________________________________________________________________ block8_sepconv2_act (Activation (None, None, None, 7 0 block8_sepconv1_bn[0][0] __________________________________________________________________________________________________ block8_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block8_sepconv2_act[0][0] __________________________________________________________________________________________________ block8_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block8_sepconv2[0][0] __________________________________________________________________________________________________ block8_sepconv3_act (Activation (None, None, None, 7 0 block8_sepconv2_bn[0][0] __________________________________________________________________________________________________ block8_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block8_sepconv3_act[0][0] __________________________________________________________________________________________________ block8_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block8_sepconv3[0][0] __________________________________________________________________________________________________ add_26 (Add) (None, None, None, 7 0 block8_sepconv3_bn[0][0] add_25[0][0] __________________________________________________________________________________________________ block9_sepconv1_act (Activation (None, None, None, 7 0 add_26[0][0] __________________________________________________________________________________________________ block9_sepconv1 (SeparableConv2 (None, None, None, 7 536536 block9_sepconv1_act[0][0] __________________________________________________________________________________________________ block9_sepconv1_bn (BatchNormal (None, None, None, 7 2912 block9_sepconv1[0][0] __________________________________________________________________________________________________ block9_sepconv2_act (Activation (None, None, None, 7 0 block9_sepconv1_bn[0][0] __________________________________________________________________________________________________ block9_sepconv2 (SeparableConv2 (None, None, None, 7 536536 block9_sepconv2_act[0][0] __________________________________________________________________________________________________ block9_sepconv2_bn (BatchNormal (None, None, None, 7 2912 block9_sepconv2[0][0] __________________________________________________________________________________________________ block9_sepconv3_act (Activation (None, None, None, 7 0 block9_sepconv2_bn[0][0] __________________________________________________________________________________________________ block9_sepconv3 (SeparableConv2 (None, None, None, 7 536536 block9_sepconv3_act[0][0] __________________________________________________________________________________________________ block9_sepconv3_bn (BatchNormal (None, None, None, 7 2912 block9_sepconv3[0][0] __________________________________________________________________________________________________ add_27 (Add) (None, None, None, 7 0 block9_sepconv3_bn[0][0] add_26[0][0] __________________________________________________________________________________________________ block10_sepconv1_act (Activatio (None, None, None, 7 0 add_27[0][0] __________________________________________________________________________________________________ block10_sepconv1 (SeparableConv (None, None, None, 7 536536 block10_sepconv1_act[0][0] __________________________________________________________________________________________________ block10_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block10_sepconv1[0][0] __________________________________________________________________________________________________ block10_sepconv2_act (Activatio (None, None, None, 7 0 block10_sepconv1_bn[0][0] __________________________________________________________________________________________________ block10_sepconv2 (SeparableConv (None, None, None, 7 536536 block10_sepconv2_act[0][0] __________________________________________________________________________________________________ block10_sepconv2_bn (BatchNorma (None, None, None, 7 2912 block10_sepconv2[0][0] __________________________________________________________________________________________________ block10_sepconv3_act (Activatio (None, None, None, 7 0 block10_sepconv2_bn[0][0] __________________________________________________________________________________________________ block10_sepconv3 (SeparableConv (None, None, None, 7 536536 block10_sepconv3_act[0][0] __________________________________________________________________________________________________ block10_sepconv3_bn (BatchNorma (None, None, None, 7 2912 block10_sepconv3[0][0] __________________________________________________________________________________________________ add_28 (Add) (None, None, None, 7 0 block10_sepconv3_bn[0][0] add_27[0][0] __________________________________________________________________________________________________ block11_sepconv1_act (Activatio (None, None, None, 7 0 add_28[0][0] __________________________________________________________________________________________________ block11_sepconv1 (SeparableConv (None, None, None, 7 536536 block11_sepconv1_act[0][0] __________________________________________________________________________________________________ block11_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block11_sepconv1[0][0] __________________________________________________________________________________________________ block11_sepconv2_act (Activatio (None, None, None, 7 0 block11_sepconv1_bn[0][0] __________________________________________________________________________________________________ block11_sepconv2 (SeparableConv (None, None, None, 7 536536 block11_sepconv2_act[0][0] __________________________________________________________________________________________________ block11_sepconv2_bn (BatchNorma (None, None, None, 7 2912 block11_sepconv2[0][0] __________________________________________________________________________________________________ block11_sepconv3_act (Activatio (None, None, None, 7 0 block11_sepconv2_bn[0][0] __________________________________________________________________________________________________ block11_sepconv3 (SeparableConv (None, None, None, 7 536536 block11_sepconv3_act[0][0] __________________________________________________________________________________________________ block11_sepconv3_bn (BatchNorma (None, None, None, 7 2912 block11_sepconv3[0][0] __________________________________________________________________________________________________ add_29 (Add) (None, None, None, 7 0 block11_sepconv3_bn[0][0] add_28[0][0] __________________________________________________________________________________________________ block12_sepconv1_act (Activatio (None, None, None, 7 0 add_29[0][0] __________________________________________________________________________________________________ block12_sepconv1 (SeparableConv (None, None, None, 7 536536 block12_sepconv1_act[0][0] __________________________________________________________________________________________________ block12_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block12_sepconv1[0][0] __________________________________________________________________________________________________ block12_sepconv2_act (Activatio (None, None, None, 7 0 block12_sepconv1_bn[0][0] __________________________________________________________________________________________________ block12_sepconv2 (SeparableConv (None, None, None, 7 536536 block12_sepconv2_act[0][0] __________________________________________________________________________________________________ block12_sepconv2_bn (BatchNorma (None, None, None, 7 2912 block12_sepconv2[0][0] __________________________________________________________________________________________________ block12_sepconv3_act (Activatio (None, None, None, 7 0 block12_sepconv2_bn[0][0] __________________________________________________________________________________________________ block12_sepconv3 (SeparableConv (None, None, None, 7 536536 block12_sepconv3_act[0][0] __________________________________________________________________________________________________ block12_sepconv3_bn (BatchNorma (None, None, None, 7 2912 block12_sepconv3[0][0] __________________________________________________________________________________________________ add_30 (Add) (None, None, None, 7 0 block12_sepconv3_bn[0][0] add_29[0][0] __________________________________________________________________________________________________ block13_sepconv1_act (Activatio (None, None, None, 7 0 add_30[0][0] __________________________________________________________________________________________________ block13_sepconv1 (SeparableConv (None, None, None, 7 536536 block13_sepconv1_act[0][0] __________________________________________________________________________________________________ block13_sepconv1_bn (BatchNorma (None, None, None, 7 2912 block13_sepconv1[0][0] __________________________________________________________________________________________________ block13_sepconv2_act (Activatio (None, None, None, 7 0 block13_sepconv1_bn[0][0] __________________________________________________________________________________________________ block13_sepconv2 (SeparableConv (None, None, None, 1 752024 block13_sepconv2_act[0][0] __________________________________________________________________________________________________ block13_sepconv2_bn (BatchNorma (None, None, None, 1 4096 block13_sepconv2[0][0] __________________________________________________________________________________________________ conv2d_39 (Conv2D) (None, None, None, 1 745472 add_30[0][0] __________________________________________________________________________________________________ block13_pool (MaxPooling2D) (None, None, None, 1 0 block13_sepconv2_bn[0][0] __________________________________________________________________________________________________ batch_normalization_29 (BatchNo (None, None, None, 1 4096 conv2d_39[0][0] __________________________________________________________________________________________________ add_31 (Add) (None, None, None, 1 0 block13_pool[0][0] batch_normalization_29[0][0] __________________________________________________________________________________________________ block14_sepconv1 (SeparableConv (None, None, None, 1 1582080 add_31[0][0] __________________________________________________________________________________________________ block14_sepconv1_bn (BatchNorma (None, None, None, 1 6144 block14_sepconv1[0][0] __________________________________________________________________________________________________ block14_sepconv1_act (Activatio (None, None, None, 1 0 block14_sepconv1_bn[0][0] __________________________________________________________________________________________________ block14_sepconv2 (SeparableConv (None, None, None, 2 3159552 block14_sepconv1_act[0][0] __________________________________________________________________________________________________ block14_sepconv2_bn (BatchNorma (None, None, None, 2 8192 block14_sepconv2[0][0] __________________________________________________________________________________________________ block14_sepconv2_act (Activatio (None, None, None, 2 0 block14_sepconv2_bn[0][0] ================================================================================================== Total params: 20,861,480 Trainable params: 20,806,952 Non-trainable params: 54,528 __________________________________________________________________________________________________
# Add a densely connected classifier on top of the convolutional base
input_tensor = keras.Input(shape=(image_size) + (3,), name='input')
x = conv_base(input_tensor)
x = layers.Flatten()(x)
x = Dense(128, activation='relu')(x)
output = Dense(4, activation='softmax', name='output')(x)
model = keras.Model(inputs=input_tensor, outputs=output, name='Xception_1')
model.summary()
Model: "Xception_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) [(None, 800, 380, 3)] 0 _________________________________________________________________ xception (Functional) (None, None, None, 2048) 20861480 _________________________________________________________________ flatten_1 (Flatten) (None, 614400) 0 _________________________________________________________________ dense_8 (Dense) (None, 128) 78643328 _________________________________________________________________ output (Dense) (None, 4) 516 ================================================================= Total params: 99,505,324 Trainable params: 99,450,796 Non-trainable params: 54,528 _________________________________________________________________
Freeze the convolutional base before compiling and training the model. Very important!
print("Number of trainable parameters BEFORE freezing the conv base:", len(model.trainable_weights))
# Freeze conv base
conv_base.trainable = False
print("Number of trainable parameters AFTER freezing the conv base:", len(model.trainable_weights))
Number of trainable parameters BEFORE freezing the conv base: 158 Number of trainable parameters AFTER freezing the conv base: 4
# Compile the model and take and make sure there a fewer trainable parameters now
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(lr=1e-5), # Lower learning rate
metrics=METRICS)
model.summary()
Model: "Xception_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) [(None, 800, 380, 3)] 0 _________________________________________________________________ xception (Functional) (None, None, None, 2048) 20861480 _________________________________________________________________ flatten_1 (Flatten) (None, 614400) 0 _________________________________________________________________ dense_8 (Dense) (None, 128) 78643328 _________________________________________________________________ output (Dense) (None, 4) 516 ================================================================= Total params: 99,505,324 Trainable params: 78,643,844 Non-trainable params: 20,861,480 _________________________________________________________________
%%time
# Add callbacks for storing the best model as well as early stopping
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=model_dir+'Xception_1.hdf5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
),
tf.keras.callbacks.EarlyStopping(
monitor='val_accuracy',
patience=10, # Stop training if the val_accuracy doesn't improve over 10 epochs.
mode='max',
restore_best_weights=True,
)
]
# Train the model using the same train and validation generators as before
history = model.fit(train_generator,
epochs=100,
callbacks=callbacks,
validation_data=validation_generator,
workers=2)
Epoch 1/100 16/16 [==============================] - 27s 1s/step - loss: 1.0027 - tp: 159.9412 - fp: 28.2941 - tn: 586.3529 - fn: 44.9412 - accuracy: 0.7813 - auc: 0.9201 - prc: 0.8182 - precision: 0.8485 - recall: 0.7681 - val_loss: 0.5554 - val_tp: 53.0000 - val_fp: 10.0000 - val_tn: 179.0000 - val_fn: 10.0000 - val_accuracy: 0.8413 - val_auc: 0.9679 - val_prc: 0.9181 - val_precision: 0.8413 - val_recall: 0.8413 Epoch 2/100 16/16 [==============================] - 22s 1s/step - loss: 0.3473 - tp: 119.6471 - fp: 15.8235 - tn: 407.0000 - fn: 21.2941 - accuracy: 0.8465 - auc: 0.9787 - prc: 0.9446 - precision: 0.8662 - recall: 0.8334 - val_loss: 0.6555 - val_tp: 52.0000 - val_fp: 11.0000 - val_tn: 178.0000 - val_fn: 11.0000 - val_accuracy: 0.8254 - val_auc: 0.9641 - val_prc: 0.9124 - val_precision: 0.8254 - val_recall: 0.8254 Epoch 3/100 16/16 [==============================] - 21s 1s/step - loss: 0.2404 - tp: 128.5294 - fp: 10.5294 - tn: 416.5294 - fn: 13.8235 - accuracy: 0.9030 - auc: 0.9897 - prc: 0.9727 - precision: 0.9197 - recall: 0.8967 - val_loss: 0.2860 - val_tp: 57.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 6.0000 - val_accuracy: 0.9048 - val_auc: 0.9797 - val_prc: 0.9448 - val_precision: 0.9194 - val_recall: 0.9048 Epoch 4/100 16/16 [==============================] - 22s 1s/step - loss: 0.1956 - tp: 133.0588 - fp: 7.7059 - tn: 419.3529 - fn: 9.2941 - accuracy: 0.9396 - auc: 0.9919 - prc: 0.9781 - precision: 0.9457 - recall: 0.9387 - val_loss: 0.2418 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 5.0000 - val_accuracy: 0.9206 - val_auc: 0.9847 - val_prc: 0.9611 - val_precision: 0.9206 - val_recall: 0.9206 Epoch 5/100 16/16 [==============================] - 22s 1s/step - loss: 0.1196 - tp: 135.5882 - fp: 4.4706 - tn: 417.6471 - fn: 5.1176 - accuracy: 0.9628 - auc: 0.9974 - prc: 0.9926 - precision: 0.9654 - recall: 0.9623 - val_loss: 0.2830 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 5.0000 - val_accuracy: 0.9206 - val_auc: 0.9846 - val_prc: 0.9599 - val_precision: 0.9206 - val_recall: 0.9206 Epoch 6/100 16/16 [==============================] - 20s 1s/step - loss: 0.1395 - tp: 136.4706 - fp: 3.8235 - tn: 418.2941 - fn: 4.2353 - accuracy: 0.9687 - auc: 0.9931 - prc: 0.9810 - precision: 0.9706 - recall: 0.9687 - val_loss: 0.2343 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9862 - val_prc: 0.9654 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 7/100 16/16 [==============================] - 22s 1s/step - loss: 0.0728 - tp: 136.5294 - fp: 2.5882 - tn: 416.7059 - fn: 3.2353 - accuracy: 0.9805 - auc: 0.9990 - prc: 0.9972 - precision: 0.9804 - recall: 0.9767 - val_loss: 0.2216 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9867 - val_prc: 0.9666 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 8/100 16/16 [==============================] - 19s 1s/step - loss: 0.0916 - tp: 138.3529 - fp: 2.8235 - tn: 420.7059 - fn: 2.8235 - accuracy: 0.9821 - auc: 0.9966 - prc: 0.9909 - precision: 0.9821 - recall: 0.9821 - val_loss: 0.3509 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 5.0000 - val_accuracy: 0.9206 - val_auc: 0.9858 - val_prc: 0.9647 - val_precision: 0.9206 - val_recall: 0.9206 Epoch 9/100 16/16 [==============================] - 22s 1s/step - loss: 0.0597 - tp: 136.8235 - fp: 4.0588 - tn: 421.5882 - fn: 5.0588 - accuracy: 0.9723 - auc: 0.9993 - prc: 0.9981 - precision: 0.9721 - recall: 0.9597 - val_loss: 0.2990 - val_tp: 59.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9866 - val_prc: 0.9668 - val_precision: 0.9516 - val_recall: 0.9365 Epoch 10/100 16/16 [==============================] - 21s 1s/step - loss: 0.1130 - tp: 139.4118 - fp: 2.9412 - tn: 424.8235 - fn: 3.1765 - accuracy: 0.9751 - auc: 0.9968 - prc: 0.9909 - precision: 0.9751 - recall: 0.9741 - val_loss: 0.2463 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9871 - val_prc: 0.9686 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 11/100 16/16 [==============================] - 19s 1s/step - loss: 0.0190 - tp: 141.1765 - fp: 0.7059 - tn: 424.9412 - fn: 0.7059 - accuracy: 0.9968 - auc: 1.0000 - prc: 0.9999 - precision: 0.9968 - recall: 0.9968 - val_loss: 0.2205 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9876 - val_prc: 0.9695 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 12/100 16/16 [==============================] - 20s 1s/step - loss: 0.0290 - tp: 138.7059 - fp: 1.2941 - tn: 418.7059 - fn: 1.2941 - accuracy: 0.9859 - auc: 0.9999 - prc: 0.9996 - precision: 0.9859 - recall: 0.9859 - val_loss: 0.1949 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9880 - val_prc: 0.9712 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 13/100 16/16 [==============================] - 20s 1s/step - loss: 0.0271 - tp: 138.2941 - fp: 1.2353 - tn: 417.3529 - fn: 1.2353 - accuracy: 0.9925 - auc: 0.9998 - prc: 0.9995 - precision: 0.9925 - recall: 0.9925 - val_loss: 0.1976 - val_tp: 62.0000 - val_fp: 1.0000 - val_tn: 188.0000 - val_fn: 1.0000 - val_accuracy: 0.9841 - val_auc: 0.9880 - val_prc: 0.9708 - val_precision: 0.9841 - val_recall: 0.9841 Epoch 14/100 16/16 [==============================] - 22s 1s/step - loss: 0.0841 - tp: 138.2353 - fp: 2.7059 - tn: 420.1176 - fn: 2.7059 - accuracy: 0.9812 - auc: 0.9954 - prc: 0.9882 - precision: 0.9812 - recall: 0.9812 - val_loss: 0.3480 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 5.0000 - val_accuracy: 0.9365 - val_auc: 0.9868 - val_prc: 0.9678 - val_precision: 0.9355 - val_recall: 0.9206 Epoch 15/100 16/16 [==============================] - 21s 1s/step - loss: 0.0194 - tp: 141.2353 - fp: 0.8824 - tn: 425.4706 - fn: 0.8824 - accuracy: 0.9957 - auc: 0.9999 - prc: 0.9997 - precision: 0.9957 - recall: 0.9957 - val_loss: 0.1823 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9881 - val_prc: 0.9711 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 16/100 16/16 [==============================] - 20s 1s/step - loss: 0.0052 - tp: 139.4118 - fp: 0.1176 - tn: 418.4706 - fn: 0.1176 - accuracy: 0.9995 - auc: 1.0000 - prc: 1.0000 - precision: 0.9995 - recall: 0.9995 - val_loss: 0.2354 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9872 - val_prc: 0.9689 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 17/100 16/16 [==============================] - 20s 1s/step - loss: 0.0179 - tp: 139.5294 - fp: 1.6471 - tn: 421.8824 - fn: 1.6471 - accuracy: 0.9878 - auc: 0.9999 - prc: 0.9998 - precision: 0.9878 - recall: 0.9878 - val_loss: 0.2424 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9885 - val_prc: 0.9730 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 18/100 16/16 [==============================] - 20s 1s/step - loss: 0.0339 - tp: 138.5882 - fp: 0.7059 - tn: 417.1765 - fn: 0.7059 - accuracy: 0.9956 - auc: 0.9969 - prc: 0.9925 - precision: 0.9956 - recall: 0.9956 - val_loss: 0.2728 - val_tp: 61.0000 - val_fp: 1.0000 - val_tn: 188.0000 - val_fn: 2.0000 - val_accuracy: 0.9841 - val_auc: 0.9886 - val_prc: 0.9731 - val_precision: 0.9839 - val_recall: 0.9683 Epoch 19/100 16/16 [==============================] - 20s 1s/step - loss: 0.0210 - tp: 139.5882 - fp: 0.8824 - tn: 420.5294 - fn: 0.8824 - accuracy: 0.9957 - auc: 0.9999 - prc: 0.9997 - precision: 0.9957 - recall: 0.9957 - val_loss: 0.2520 - val_tp: 62.0000 - val_fp: 1.0000 - val_tn: 188.0000 - val_fn: 1.0000 - val_accuracy: 0.9841 - val_auc: 0.9888 - val_prc: 0.9739 - val_precision: 0.9841 - val_recall: 0.9841 Epoch 20/100 16/16 [==============================] - 20s 1s/step - loss: 0.0099 - tp: 141.8824 - fp: 0.7059 - tn: 427.0588 - fn: 0.7059 - accuracy: 0.9957 - auc: 1.0000 - prc: 0.9999 - precision: 0.9957 - recall: 0.9957 - val_loss: 0.2793 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9881 - val_prc: 0.9715 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 21/100 16/16 [==============================] - 20s 1s/step - loss: 0.0076 - tp: 138.5882 - fp: 0.4706 - tn: 416.7059 - fn: 0.4706 - accuracy: 0.9980 - auc: 1.0000 - prc: 1.0000 - precision: 0.9980 - recall: 0.9980 - val_loss: 0.1892 - val_tp: 62.0000 - val_fp: 1.0000 - val_tn: 188.0000 - val_fn: 1.0000 - val_accuracy: 0.9841 - val_auc: 0.9887 - val_prc: 0.9735 - val_precision: 0.9841 - val_recall: 0.9841 Epoch 22/100 16/16 [==============================] - 20s 1s/step - loss: 0.0120 - tp: 140.0588 - fp: 0.6471 - tn: 421.4706 - fn: 0.6471 - accuracy: 0.9963 - auc: 1.0000 - prc: 0.9999 - precision: 0.9963 - recall: 0.9963 - val_loss: 0.4701 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9775 - val_prc: 0.9475 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 23/100 16/16 [==============================] - 20s 1s/step - loss: 0.0021 - tp: 142.1176 - fp: 0.0000e+00 - tn: 426.3529 - fn: 0.0000e+00 - accuracy: 1.0000 - auc: 1.0000 - prc: 1.0000 - precision: 1.0000 - recall: 1.0000 - val_loss: 0.3293 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9880 - val_prc: 0.9716 - val_precision: 0.9524 - val_recall: 0.9524 CPU times: user 14min 36s, sys: 28.5 s, total: 15min 4s Wall time: 8min 29s
plot_history(history.history, soft=True)
This is exciting and above expectations. It shows how powerful transfer learning can be! It is not merely predicting the largest class all the time and is achieving both precision and recall at around 95% on the validation set. A relative improvement of 12% compared with the baseline.
# Make predictions on the validation set (which is not ideal!) with the best model
best_model = keras.models.load_model(model_dir+'Xception_1.hdf5') # Load the best performing model
predictions = best_model.predict(validation_generator)
# Plot confusion matrix
plot_cm(validation_generator.classes,
np.argmax(predictions, axis=1))
Move on to the last two steps:
Unfreeze the last six layers belonging to "block14" (this can later be re-adjusted).
# Start by setting all layers as trainable
conv_base.trainable = True
# Freeze the layers belonging to "block14"
set_trainable = False
for layer in conv_base.layers:
if 'block14_sep' in layer.name:
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
# Check how the model looks like now
model.summary()
Model: "Xception_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) [(None, 800, 380, 3)] 0 _________________________________________________________________ xception (Functional) (None, None, None, 2048) 20861480 _________________________________________________________________ flatten_1 (Flatten) (None, 614400) 0 _________________________________________________________________ dense_8 (Dense) (None, 128) 78643328 _________________________________________________________________ output (Dense) (None, 4) 516 ================================================================= Total params: 99,505,324 Trainable params: 83,392,644 Non-trainable params: 16,112,680 _________________________________________________________________
%%time
# Compile the model
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(lr=2e-6), # Lower learning rate
metrics=METRICS)
# Add callbacks for storing the best model
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=model_dir+'Xception_2_fine_tuned.hdf5',
monitor='val_accuracy',
save_best_only=True,
mode='max'
)
]
# Train the model using the same train and validation generators as before
history = model.fit(train_generator,
epochs=30,
callbacks=callbacks,
validation_data=validation_generator,
workers=2)
Epoch 1/30 16/16 [==============================] - 27s 1s/step - loss: 0.9652 - tp: 181.9412 - fp: 22.9412 - tn: 591.7059 - fn: 22.9412 - accuracy: 0.8972 - auc: 0.9560 - prc: 0.8999 - precision: 0.8972 - recall: 0.8972 - val_loss: 0.2728 - val_tp: 56.0000 - val_fp: 7.0000 - val_tn: 182.0000 - val_fn: 7.0000 - val_accuracy: 0.8889 - val_auc: 0.9898 - val_prc: 0.9722 - val_precision: 0.8889 - val_recall: 0.8889 Epoch 2/30 16/16 [==============================] - 22s 1s/step - loss: 0.1363 - tp: 129.4706 - fp: 8.9412 - tn: 408.9412 - fn: 9.8235 - accuracy: 0.9386 - auc: 0.9970 - prc: 0.9915 - precision: 0.9457 - recall: 0.9386 - val_loss: 0.2629 - val_tp: 56.0000 - val_fp: 7.0000 - val_tn: 182.0000 - val_fn: 7.0000 - val_accuracy: 0.8889 - val_auc: 0.9837 - val_prc: 0.9592 - val_precision: 0.8889 - val_recall: 0.8889 Epoch 3/30 16/16 [==============================] - 22s 1s/step - loss: 0.0296 - tp: 137.0588 - fp: 1.3529 - tn: 415.8235 - fn: 2.0000 - accuracy: 0.9934 - auc: 0.9999 - prc: 0.9997 - precision: 0.9934 - recall: 0.9896 - val_loss: 0.2572 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9878 - val_prc: 0.9707 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 4/30 16/16 [==============================] - 22s 1s/step - loss: 0.0702 - tp: 136.1176 - fp: 5.0588 - tn: 418.4706 - fn: 5.0588 - accuracy: 0.9557 - auc: 0.9990 - prc: 0.9971 - precision: 0.9557 - recall: 0.9557 - val_loss: 0.3951 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9872 - val_prc: 0.9691 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 5/30 16/16 [==============================] - 21s 1s/step - loss: 0.0535 - tp: 138.1765 - fp: 3.0000 - tn: 420.5294 - fn: 3.0000 - accuracy: 0.9832 - auc: 0.9971 - prc: 0.9959 - precision: 0.9832 - recall: 0.9832 - val_loss: 0.5477 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9664 - val_prc: 0.9225 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 6/30 16/16 [==============================] - 22s 1s/step - loss: 0.0785 - tp: 138.2941 - fp: 2.8824 - tn: 420.6471 - fn: 2.8824 - accuracy: 0.9795 - auc: 0.9981 - prc: 0.9946 - precision: 0.9795 - recall: 0.9795 - val_loss: 0.7641 - val_tp: 57.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 6.0000 - val_accuracy: 0.9048 - val_auc: 0.9558 - val_prc: 0.8995 - val_precision: 0.9194 - val_recall: 0.9048 Epoch 7/30 16/16 [==============================] - 21s 1s/step - loss: 0.0868 - tp: 137.5882 - fp: 1.9412 - tn: 416.6471 - fn: 1.9412 - accuracy: 0.9865 - auc: 0.9951 - prc: 0.9876 - precision: 0.9865 - recall: 0.9865 - val_loss: 0.7331 - val_tp: 57.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 6.0000 - val_accuracy: 0.9048 - val_auc: 0.9560 - val_prc: 0.9003 - val_precision: 0.9194 - val_recall: 0.9048 Epoch 8/30 16/16 [==============================] - 22s 1s/step - loss: 0.0367 - tp: 138.0000 - fp: 2.9412 - tn: 419.8824 - fn: 2.9412 - accuracy: 0.9835 - auc: 0.9998 - prc: 0.9993 - precision: 0.9835 - recall: 0.9835 - val_loss: 0.6032 - val_tp: 57.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 6.0000 - val_accuracy: 0.9048 - val_auc: 0.9660 - val_prc: 0.9218 - val_precision: 0.9194 - val_recall: 0.9048 Epoch 9/30 16/16 [==============================] - 21s 1s/step - loss: 0.0606 - tp: 139.1765 - fp: 2.4706 - tn: 422.4706 - fn: 2.4706 - accuracy: 0.9825 - auc: 0.9994 - prc: 0.9982 - precision: 0.9825 - recall: 0.9825 - val_loss: 0.6552 - val_tp: 57.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 6.0000 - val_accuracy: 0.9048 - val_auc: 0.9561 - val_prc: 0.9005 - val_precision: 0.9194 - val_recall: 0.9048 Epoch 10/30 16/16 [==============================] - 20s 1s/step - loss: 0.0390 - tp: 139.3529 - fp: 2.0588 - tn: 422.1765 - fn: 2.0588 - accuracy: 0.9856 - auc: 0.9997 - prc: 0.9992 - precision: 0.9856 - recall: 0.9856 - val_loss: 0.3850 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 5.0000 - val_accuracy: 0.9206 - val_auc: 0.9771 - val_prc: 0.9466 - val_precision: 0.9206 - val_recall: 0.9206 Epoch 11/30 16/16 [==============================] - 22s 1s/step - loss: 0.0346 - tp: 137.4118 - fp: 1.1765 - tn: 416.7059 - fn: 1.8824 - accuracy: 0.9933 - auc: 0.9998 - prc: 0.9995 - precision: 0.9933 - recall: 0.9889 - val_loss: 0.4916 - val_tp: 58.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 5.0000 - val_accuracy: 0.9365 - val_auc: 0.9666 - val_prc: 0.9234 - val_precision: 0.9355 - val_recall: 0.9206 Epoch 12/30 16/16 [==============================] - 21s 1s/step - loss: 0.0392 - tp: 139.4706 - fp: 1.2353 - tn: 420.8824 - fn: 1.2353 - accuracy: 0.9921 - auc: 0.9997 - prc: 0.9992 - precision: 0.9921 - recall: 0.9921 - val_loss: 0.4651 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9666 - val_prc: 0.9234 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 13/30 16/16 [==============================] - 21s 1s/step - loss: 0.0235 - tp: 139.0000 - fp: 0.2941 - tn: 417.5882 - fn: 0.2941 - accuracy: 0.9987 - auc: 0.9999 - prc: 0.9998 - precision: 0.9987 - recall: 0.9987 - val_loss: 0.3223 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 184.0000 - val_fn: 5.0000 - val_accuracy: 0.9206 - val_auc: 0.9766 - val_prc: 0.9450 - val_precision: 0.9206 - val_recall: 0.9206 Epoch 14/30 16/16 [==============================] - 21s 1s/step - loss: 0.0352 - tp: 139.8824 - fp: 1.1176 - tn: 424.5294 - fn: 2.0000 - accuracy: 0.9862 - auc: 0.9997 - prc: 0.9992 - precision: 0.9934 - recall: 0.9862 - val_loss: 0.2791 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9866 - val_prc: 0.9674 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 15/30 16/16 [==============================] - 20s 1s/step - loss: 0.0215 - tp: 140.0000 - fp: 0.9412 - tn: 421.8824 - fn: 0.9412 - accuracy: 0.9954 - auc: 0.9998 - prc: 0.9994 - precision: 0.9954 - recall: 0.9954 - val_loss: 0.2560 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9864 - val_prc: 0.9670 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 16/30 16/16 [==============================] - 21s 1s/step - loss: 0.0412 - tp: 138.7647 - fp: 2.1765 - tn: 420.6471 - fn: 2.1765 - accuracy: 0.9773 - auc: 0.9996 - prc: 0.9990 - precision: 0.9773 - recall: 0.9773 - val_loss: 0.2965 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9864 - val_prc: 0.9668 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 17/30 16/16 [==============================] - 21s 1s/step - loss: 0.0501 - tp: 139.8235 - fp: 2.2941 - tn: 424.0588 - fn: 2.2941 - accuracy: 0.9843 - auc: 0.9995 - prc: 0.9984 - precision: 0.9843 - recall: 0.9843 - val_loss: 0.1670 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9966 - val_prc: 0.9904 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 18/30 16/16 [==============================] - 21s 1s/step - loss: 0.0419 - tp: 139.7059 - fp: 1.2353 - tn: 421.5882 - fn: 1.2353 - accuracy: 0.9930 - auc: 0.9994 - prc: 0.9983 - precision: 0.9930 - recall: 0.9930 - val_loss: 0.1637 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9967 - val_prc: 0.9907 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 19/30 16/16 [==============================] - 21s 1s/step - loss: 0.0437 - tp: 139.4118 - fp: 1.2941 - tn: 420.8235 - fn: 1.2941 - accuracy: 0.9922 - auc: 0.9991 - prc: 0.9975 - precision: 0.9922 - recall: 0.9922 - val_loss: 0.2007 - val_tp: 59.0000 - val_fp: 4.0000 - val_tn: 185.0000 - val_fn: 4.0000 - val_accuracy: 0.9365 - val_auc: 0.9961 - val_prc: 0.9891 - val_precision: 0.9365 - val_recall: 0.9365 Epoch 20/30 16/16 [==============================] - 20s 1s/step - loss: 0.0122 - tp: 140.4118 - fp: 0.7647 - tn: 422.7647 - fn: 0.7647 - accuracy: 0.9958 - auc: 1.0000 - prc: 1.0000 - precision: 0.9958 - recall: 0.9958 - val_loss: 0.1792 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9962 - val_prc: 0.9896 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 21/30 16/16 [==============================] - 21s 1s/step - loss: 0.0228 - tp: 141.2353 - fp: 1.1176 - tn: 425.9412 - fn: 1.1176 - accuracy: 0.9934 - auc: 0.9999 - prc: 0.9997 - precision: 0.9934 - recall: 0.9934 - val_loss: 0.1731 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9962 - val_prc: 0.9894 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 22/30 16/16 [==============================] - 20s 1s/step - loss: 0.0069 - tp: 141.2941 - fp: 0.1176 - tn: 424.1176 - fn: 0.1176 - accuracy: 0.9995 - auc: 1.0000 - prc: 1.0000 - precision: 0.9995 - recall: 0.9995 - val_loss: 0.1180 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9974 - val_prc: 0.9925 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 23/30 16/16 [==============================] - 20s 1s/step - loss: 0.0560 - tp: 136.2353 - fp: 2.8824 - tn: 417.1176 - fn: 3.7647 - accuracy: 0.9720 - auc: 0.9992 - prc: 0.9977 - precision: 0.9770 - recall: 0.9720 - val_loss: 0.1450 - val_tp: 60.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 3.0000 - val_accuracy: 0.9683 - val_auc: 0.9970 - val_prc: 0.9915 - val_precision: 0.9677 - val_recall: 0.9524 Epoch 24/30 16/16 [==============================] - 21s 1s/step - loss: 0.0390 - tp: 138.2941 - fp: 1.0000 - tn: 416.8824 - fn: 1.0000 - accuracy: 0.9867 - auc: 0.9998 - prc: 0.9993 - precision: 0.9867 - recall: 0.9867 - val_loss: 0.1374 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9971 - val_prc: 0.9917 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 25/30 16/16 [==============================] - 21s 1s/step - loss: 0.0096 - tp: 140.0588 - fp: 0.2941 - tn: 421.1176 - fn: 0.4118 - accuracy: 0.9983 - auc: 1.0000 - prc: 1.0000 - precision: 0.9987 - recall: 0.9983 - val_loss: 0.1528 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9968 - val_prc: 0.9909 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 26/30 16/16 [==============================] - 22s 1s/step - loss: 0.0330 - tp: 141.5882 - fp: 1.0000 - tn: 426.7647 - fn: 1.0000 - accuracy: 0.9873 - auc: 0.9998 - prc: 0.9994 - precision: 0.9873 - recall: 0.9873 - val_loss: 0.1513 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9968 - val_prc: 0.9909 - val_precision: 0.9683 - val_recall: 0.9683 Epoch 27/30 16/16 [==============================] - 22s 1s/step - loss: 0.0073 - tp: 138.9412 - fp: 0.3529 - tn: 417.5294 - fn: 0.3529 - accuracy: 0.9984 - auc: 1.0000 - prc: 1.0000 - precision: 0.9984 - recall: 0.9984 - val_loss: 0.1180 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9979 - val_prc: 0.9938 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 28/30 16/16 [==============================] - 22s 1s/step - loss: 0.0091 - tp: 141.1765 - fp: 0.0000e+00 - tn: 423.5294 - fn: 0.0000e+00 - accuracy: 1.0000 - auc: 1.0000 - prc: 1.0000 - precision: 1.0000 - recall: 1.0000 - val_loss: 0.1224 - val_tp: 60.0000 - val_fp: 3.0000 - val_tn: 186.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9976 - val_prc: 0.9929 - val_precision: 0.9524 - val_recall: 0.9524 Epoch 29/30 16/16 [==============================] - 20s 1s/step - loss: 0.0216 - tp: 140.4118 - fp: 1.4706 - tn: 424.1765 - fn: 1.4706 - accuracy: 0.9904 - auc: 1.0000 - prc: 0.9999 - precision: 0.9904 - recall: 0.9904 - val_loss: 0.1104 - val_tp: 60.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 3.0000 - val_accuracy: 0.9524 - val_auc: 0.9974 - val_prc: 0.9924 - val_precision: 0.9677 - val_recall: 0.9524 Epoch 30/30 16/16 [==============================] - 21s 1s/step - loss: 0.0161 - tp: 140.5294 - fp: 1.3529 - tn: 424.2941 - fn: 1.3529 - accuracy: 0.9918 - auc: 1.0000 - prc: 0.9999 - precision: 0.9918 - recall: 0.9918 - val_loss: 0.1398 - val_tp: 61.0000 - val_fp: 2.0000 - val_tn: 187.0000 - val_fn: 2.0000 - val_accuracy: 0.9683 - val_auc: 0.9971 - val_prc: 0.9917 - val_precision: 0.9683 - val_recall: 0.9683 CPU times: user 19min 21s, sys: 36.2 s, total: 19min 58s Wall time: 10min 54s
plot_history(history.history, soft=True)
# Make predictions with the best model
best_model = keras.models.load_model(model_dir+'Xception_2_fine_tuned.hdf5') # Load the best performing model
predictions = best_model.predict(validation_generator)
# Plot confusion matrix
plot_cm(validation_generator.classes,
np.argmax(predictions, axis=1))
Fine-tuning the pre-trained Xception model boosts both precision and recall to around 97%. Great!
This initial experiment indicate that it is possible to learn to classify the thread depth of tyres by training a deep learning model on single camera input images. Although the more basic architectures trained from scratch failed to yield useful results, making use of transfer learning and an Xception model pre-trained on ImageNet indicates that this can possibly be done. This despite using a very small dataset of a total 315 images with significant class imbalance. It's also very important to note that the performance has been calculated on the validation set as the data was too small to be split into three pieces; train, validation and test sets.
The first thing to address moving forward would be to collect more data. That would allow us to more confidently assess how the model performs. Another consideration is also whether it's easier to take an accurate photo of a tyre rather than just using current methods available on the market.