Transfer learning in TensorFlow 2 tutorial

In this post, I’m going to cover the very important deep learning concept called transfer learning. Transfer learning is the process whereby one uses neural network models trained in a related domain to accelerate the development of accurate models in your more specific domain of interest. For instance, a deep learning practitioner can use one of the state-of-the-art image classification models, already trained, as a starting point for their own, more specialized, image classification task. In this tutorial, I’ll be showing you how to perform transfer learning using an advanced, pre-trained image classification model – ResNet50 – to improve a more specific image classification task – the cats vs dogs classification problem. In particular, I’ll be showing you how to do this using TensorFlow 2. The code for this tutorial, in a Google Colaboratory notebook format, can be found on this site’s Github repository here. This code borrows some components from the official TensorFlow tutorial.

Eager to build deep learning systems? Get the book here

What are the benefits of transfer learning?

Transfer learning has many benefits, these are:

  1. It speeds up learning: For state of the art results in deep learning, one often needs to build very deep networks with many layers. In order to train such networks, one needs lots of data, computational power and time. These three things are often not readily available.
  2. It needs less data: As will be shown, transfer learning usually only adds a few extra layers to the pre-trained model, and the weights in the pre-trained model are generally fixed. Therefore, during the fine tuning of the model, only those few extra layers, or a small subset of the total number of layers, is subjected to training. This requires much less data to get good results.
  3. You can leverage the expert tuning of state-of-the-art models: As anyone who has been involved in building deep learning systems can tell you, it requires a lot of patience and tuning of the models to get the best results. By utilizing pre-trained, state-of-the-art models, you can skip a lot of this arduous work and rely on the efforts of experts in the field.

For these reasons, if you are performing some image recognition task, it may be worth using some of the pre-trained, state-of-the-art image classification models, like ResNet, DenseNet, InceptionNet and so on. How does one use these pre-trained models?

How to create a transfer learning model

To create a transfer learning model, all that is required is to take the pre-trained layers and “bolt on” your own network. This could be either at the beginning or end of the pre-trained model. Usually, one disables the pre-trained layer weights and only trains the “bolted on” layers which have been added. For image classification transfer learning, one usually takes the convolutional neural network (CNN) layers from the pre-trained model and adds one or more densely connected “classification” layers at the end (for more on convolutional neural networks, see this tutorial).  The pre-trained CNN layers act as feature extractors / maps, and the classification layer/s at the end can be “taught” to “interpret” these image features. The transfer learning model architecture that will be used in this example is shown below:  


Transfer learning TensorFlow 2 architecture

ResNet50 transfer learning architecture

The full ResNet50 model shown in the image above, in addition to a Global Average Pooling (GAP) layer, contains a 1000 node dense / fully connected layer which acts as a “classifier” of the 2048 (4 x 4) feature maps output from the ResNet CNN layers. For more on Global Average Pooling, see my tutorial. In this transfer learning task, we’ll be removing these last two layers (GAP and Dense layer) and replacing these with our own GAP and dense layer (in this example, we have a binary classification task – hence the output size is only 1). The GAP layer has no trainable parameters, but the dense layer obviously does – these will be the only parameters trained in this example. All of this is performed quite easily in TensorFlow 2, as will be shown in the next section.

Transfer learning in TensorFlow 2

In this example, we’ll be using the pre-trained ResNet50 model and transfer learning to perform the cats vs dogs image classification task. I’ll also train a smaller CNN from scratch to show the benefits of transfer learning. To access the image dataset, we’ll be using the tensorflow_datasets package which contains a number of common machine learning datasets. To load the data, the following commands can be run:

import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_datasets as tfds

split = (80, 10, 10)
splits = tfds.Split.TRAIN.subsplit(weighted=split)

(cat_train, cat_valid, cat_test), info = tfds.load('cats_vs_dogs', split=list(splits), with_info=True, as_supervised=True)

A few things to note about the code snippet above. First, the split tuple (80, 10, 10) signifies the (training, validation, test) split as percentages of the dataset. This is then passed to the tensorflow_datasets split object which tells the dataset loader how to break up the data. Finally, the tfds.load() function is invoked. The first argument is a string specifying the dataset name to load. Following arguments relate to whether a split should be used, whether to return an argument with information about the dataset (info) and whether the dataset is intended to be used in a supervised learning problem, with labels being included. The variables cat_train, cat_valid and cat_test are TensorFlow Dataset objects – to learn more about these, check out my previous post. In order to examine the images in the data set, the following code can be run:

import matplotlib.pylab as plt

for image, label in cat_train.take(2):

This produces the following images: As can be observed, the images are of varying sizes. This will need to be rectified so that the images have a consistent size to feed into our model. As usual, the image pixel values (which range from 0 to 255) need to be normalized – in this case, to between 0 and 1. The function below performs these tasks:

def pre_process_image(image, label):
  image = tf.cast(image, tf.float32)
  image = image / 255.0
  image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
  return image, label

In this example, we’ll be resizing the images to 100 x 100 using tf.image.resize. To get state of the art levels of accuracy, you would probably want a larger image size, say 200 x 200, but in this case I’ve chosen speed over accuracy for demonstration purposes. As can be observed, the image values are also cast into the tf.float32 datatype and normalized by dividing by 255. Next we apply this function to the datasets, and also shuffle and batch where appropriate:

cat_train =
cat_valid =

First, we’ll build a smaller CNN image classifier which will be trained from scratch.

A smaller CNN model

In the code below, a 3 x CNN layer head, a GAP layer and a final densely connected output layer is created. The Keras API, which is the encouraged approach for TensorFlow 2, is used in the model definition below. For more on Keras, see this and this tutorial.

head = tf.keras.Sequential()
head.add(layers.Conv2D(32, (3, 3), input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)))
head.add(layers.MaxPooling2D(pool_size=(2, 2)))

head.add(layers.Conv2D(32, (3, 3)))
head.add(layers.MaxPooling2D(pool_size=(2, 2)))

head.add(layers.Conv2D(64, (3, 3)))
head.add(layers.MaxPooling2D(pool_size=(2, 2)))

average_pool = tf.keras.Sequential()
average_pool.add(layers.Dense(1, activation='sigmoid'))

standard_model = tf.keras.Sequential([

To train the model we run:


callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./log/standard_model', update_freq='batch')], steps_per_epoch = 23262//TRAIN_BATCH_SIZE, epochs=7, 
               validation_data=cat_valid, validation_steps=10, callbacks=callbacks)

Note that the loss function is ‘binary cross-entropy’, due to the fact that the cats vs dogs image classification task is a binary classification problem (i.e. 0 = cat, 1 = dog or vice-versa). Running the code above, after 7 epochs, gives a training accuracy of around 89% and a validation accuracy of around 85%. Next we’ll see how this compares to the transfer learning case.

ResNet50 transfer learning example

To download the ResNet50 model, you can utilize the tf.keras.applications object to download the ResNet50 model in Keras format with trained parameters. To do so, run the following code:

res_net = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=IMG_SHAPE)

The weights argument ‘imagenet’ denotes that the weights to be used are those generated by being trained on the ImageNet dataset. The include_top argument states that we only want the CNN-feature maps part of the ResNet50 model – not its final GAP and dense connected layers. Finally, we need to specify what input shape we want the model being setup to receive. Next, we need to disable the training of the parameters within this Keras model. This is performed really easily:

res_net.trainable = False

Next we create a Global Average Pooling layer, along with a final densely connected output layer with sigmoid activation. Then the model is combined using the Keras sequential framework where Keras models can be chained together:

global_average_layer = layers.GlobalAveragePooling2D()
output_layer = layers.Dense(1, activation='sigmoid')
tl_model = tf.keras.Sequential([

That’s all that’s required – TensorFlow 2 and Keras make many deep learning tasks quite easy. Running tl_model.summary() gives the following output:

Layer (type)                 Output Shape              Param #   
resnet50 (Model)             (None, 4, 4, 2048)        23587712  
global_average_pooling2d (Gl (None, 2048)              0         
dense_1 (Dense)              (None, 1)                 2049      
Total params: 23,589,761
Trainable params: 2,049
Non-trainable params: 23,587,712

As can be observed, while the total number of parameters is large (i.e. 23 million) the number of trainable parameters, corresponding to the weights of the final output layer, is only 2,049. 

To train the model we run:


callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./log/transer_learning_model', update_freq='batch')], steps_per_epoch = 23262//TRAIN_BATCH_SIZE, epochs=7, 
             validation_data=cat_valid, validation_steps=10, callbacks=callbacks)

Comparing the models

The graphs below from TensorBoard show the relative performance of the small CNN model trained from scratch and the ResNet50 transfer learning model:

Transfer learning TensorFlow 2 training accuracy comparison

Transfer learning training accuracy comparison (blue – ResNet50, pink – smaller CNN model)

Transfer learning TensorFlow 2 validation accuracy comparison

Transfer learning validation accuracy comparision (red – ResNet50 model, green – smaller CNN model)

The results above show that the ResNet50 model reaches higher levels of both training and validation accuracy much quicker than the smaller CNN model that was trained from scratch. This illustrates the benefit of using these powerful pre-trained models as a starting point for your more domain specific deep learning tasks. I hope this post has been a help and given you a good understanding of the benefits of transfer learning, and also how to implement it easily in TensorFlow 2.  

Eager to build deep learning systems? Get the book here

The post Transfer learning in TensorFlow 2 tutorial appeared first on Adventures in Machine Learning.

Leave a Reply

Your email address will not be published. Required fields are marked *