Keras shoot-out: TensorFlow vs MXNet

A few months, we took an early look at running Keras with Apache MXNet as its backend. Things were pretty beta at the time, but a lot of progress has since been made. It’s time to reevaluate… and benchmark MXNet against Tensorflow.

In this world, there’s two kinds of people, my friend. Those with GPUs and those who wait for days. You wait.

The story so far

In addition to the Keras and MXNet codebases, here’s what we’re going to use today:

Let’s ride.

Installing MXNet and Keras

Updating Keras is quite simple too.

Let’s check that we have the correct versions.

Ok, looks good. Let’s move on to training.

Keras backends

All it takes is one line in the ~/.keras/keras.json file.

Learning CIFAR-10 with Tensorflow

Time to train.

Here’s what memory usage looks like, as reported by nvidia-smi.

As we can see, TensorFlow is a bit of a memory hog, pretty much eating up 100% of available GPU memory . Not really a problem here, but I’m wondering if a much more complex model would still be able to fit in memory. To be tested in a future post, I suppose :)

After a while, here’s the result (full log here).

All right. Now let’s move on to MXNet.

Learning CIFAR-10 with MXNet

Just replace the call to model.compile() in cifar10_resnet.py with this snippet.

Time to train.

Holy moly! MXNet is 60% faster: 25 seconds per epoch instead of 61. Very nice. In the same time frame, this would definitely allow us to try more things, like different model architectures or different hyper parameters. Definitely an advantage when you’re experimenting.

What about memory usage? As we can see, MXNet uses over 90% less RAM and there is plenty left for other jobs.

Here’s the result after 100 epochs (full log here): 43 minutes, 99.4% training accuracy, 62% test accuracy.

Conclusion

It seems to me every Deep Learning practitioner ought to check MXNet out, especially now that it’s properly integrated with Keras: changing a line of configuration is all it takes :)

If you’d like to dive a bit more into MXNet, may I recommend the following resources?

In part 2, I’m taking a deeper look at memory usage in Tensorflow and how to optimise it.

In part 3, we’ll learn how to fine-tune the models for improved accuracy.

Thank you for reading :)

This post was written while blasting classics by Whitesnake, Rainbow and Dio. Fortunately, no neighbour was injured in the process.

Hacker. Headbanger. Harley rider. Hunter. https://aws.amazon.com/evangelists/julien-simon/

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store