Keras shoot-out, part 3: fine-tuning
Using Keras and the CIFAR-10 dataset, we previously compared the training performance of two Deep Learning libraries, Apache MXNet and Tensorflow.
In this article, we’ll continue to explore this theme. I’ll show you how to:
- save and load a trained model,
- build a subset of CIFAR-10 using samples from two classes,
- retrain the model on this subset to optimise prediction (aka “fine-tuning”),
- improve training time by freezing layers.
Preparing our pre-trained models
Using the same script (~/keras/examples/cifar10_resnet50.py), I trained a Resnet-50 model on CIFAR-10 using first MXNet 0.11, then Tensorflow 1.2. Here is the setup I used:
- All 8 GPUs on a p2.8xlarge AWS instance,
- batch size: 256,
- Data augmentation: enabled,
- Number of epochs: 200.
To save the trained model, I only added one line of code: model.save(modelname). That’s all there is to it!
Even with 8 GPUs, training takes time: about 2h for MXNet and about 3h30 for Tensorflow. Pretty heavy lifting! I uploaded the models to my own personal model zoo: feel free to grab them and run your own tests :)
- MXNet : model (HDF5, 91MB), training log (text, 6MB), test accuracy: 82.12%
- Tensorflow: model (HDF5, 181MB), training log (text, 6MB), test accuracy: 75.48%
Now let’s see how we can load our models.
Loading a model
Depending on the backend configured in ~/.keras/keras.json, we have to load one model or the other. Using keras.backend.backend(), it’s easy to figure which one to pick. Then, we simply call keras.models.load_model().
What about setting the number of GPUs using for training? For Tensorflow, we’ll use session parameters. For MXNet, we’ll add an extra parameter to model.compile().
Ok, now let’s take care of data.
Extracting the subset
Keras provides convenient functions to load commonly used data sets, including CIFAR-10. Nice!
First, we need to extract a given number of samples from two classes. This is the purpose of the get_samples() function:
- find the relevant number of sample indexes matching a given class,
- extract samples (‘x’) and labels (‘y’) for these indexes,
- normalise pixel values for samples, since we also did it when we trained the initial model.
Building the new data set
The next step is to build the new training and test sets with the prepare_dataset() function:
- Concatenate samples for both categories,
- Concatenate labels for both categories,
- One-hot encode labels, e.g. convert ‘6’ to [0, 0, 0, 0, 0, 0, 1, 0, 0, 0].
The last step before actually retraining is to define the optimizer — for now, we’ll use the same one as for the original training — as well as the number of GPUs — we’ll stick to one, as this is the most likely setup for developers fine-tuning a model on their own machine.
Fine-tuning the model
OK, we’re now ready to train the model. First, let’s predict our test set to see what the baseline is. Then, we train the model. Finally, we predict the test set again to see how much the model has improved.
Full code is available on Github. Now, let’s run it!
Is it a horse or a car?
Let’s fine-tune the model on cars and horses (classes 1 and 7) for 10 epochs. Here are the results.
- MXNet: 31 seconds per epoch, test accuracy: 98.79% (log).
- Tensorflow: 44 seconds per epoch, test accuracy: 98.55% (log).
This is a MASSIVE accuracy improvement after just a few minutes of training. Just wait, we can do even better :)
The many layers in our model have already learned the car and horse classes. So, it’s probably a waste of time to potentially retrain all of them. Maybe it would just be enough to retrain the last layer, i.e. the one that actually outputs the probability for the 10 classes.
Keras includes a very nice feature that lets us decide which layers of a model are trainable and which aren’t. Let’s use it to freeze all layers but the last one and try again.
Here are the results.
- MXNet: 12 seconds per epoch, test accuracy: 97.29% (log).
- Tensorflow: 13 seconds per epoch, test accuracy: 98.35% (log).
As expected, freezing layers significantly reduces training time.
Training time is almost identical for both libraries. I guess the work boils down to running backpropagation on a single layer, which both libraries can do equally well on a single GPU. Scaling doesn’t seem to come into play here.
Accuracy is hardly impacted for Tensorflow, but there is a slight hit for MXNet. Hmmm. Maybe our optimisation parameters are not optimal. Let’s try one last thing :)
Tweaking the optimiser
Picking the right hyper parameters for SGD(learning rate, etc.) is tricky. Here, we’re only training for 10 epochs, which probably makes it even more difficult. One way out of this problem may be to use the AdaGrad optimizer, which automatically adapts the learning rate
For an excellent overview of SGD and its variants, please read this post by Sebastian Ruder. It’s by far the best I’ve seen.
Here are the results.
- MXNet: 12 seconds per epoch, test accuracy: 98.90% after 9 epochs (log).
- Tensorflow: 13 seconds per epoch, test accuracy: 98.75% after 8 epochs (log).
AdaGrad works its magic indeed. It helps MXNet improve its accuracy and deliver the best score in this test. Tensorflow improves as well.
Here, we fine-tuned the model for two classes, but we still output 10 probabilities (one for each class). The next step would be to add an extra layer with only two outputs, forcing the model to decide on two categories only (and not ten).
We would need to modify our subset labels, i.e. one-hot encode them with only two values. Let’s keep this for a future article :)
Fine-tuning is a very powerful way to improve the accuracy of a model in a very short period of time compared to training. It’s a great technique when you can’t or don’t want to spend time and money (re)training complex models on large data sets. Just make sure you understand how the model was trained (data set, data format, etc.) in order to retrain it appropriately.
That’s it for today. Thank you for reading.
No animals were harmed during the writing of this article, but there sure was a whole lot of creative swearing and keyboard smashing during the coding phase.