Mixed-Precision Neural Network Training with APEX

TLDR: Just make these changes:

from apex import amp
# add this after net and optimizer are defined:
net, optimizer = amp.initialize(net, optimizer, opt_level='O1')
# replace 'loss.backward()' with this:
with amp.scale_loss(loss, optimizer) as scaled_loss:


I have a Turing GPU, which contains hardware optimized for efficient FP16 (half-precision floating point) processing. This is useful because gpu memory is often a bottleneck in deep learning - doubling the size of a network or doubling batch size can have a sizable impact. It’s been shown that reducing the precision of neural network operations often has minimal impact on performance, so switching to half-precision can in theory be a free upgrade. As an example, in a small test training session, at the default FP32, I have ~5 GB gpu memory being used. Training for 1 epoch takes 160 seconds, and results in a training loss of 0.02.

In PyTorch, switching to half-precision is as simple as

half_tensor = tensor.half() # cast to half_tensors as needed before inputting to network

And indeed, with these changes, memory usage is now ~3 GB. But…

Epoch [1/1], Step[10/255], Loss: nan
Epoch [1/1], Step[20/255], Loss: nan

As it turns out, while the network itself may not need much precision, the training process does. In this case, some computation within our loss function or our optimizer is becoming numerically unstable, leading to divide-by-zeros. Some stack overflow searching suggests that modifying the epsilon values used by optimizers and batch norm layers could help, but I had no luck there. Instead, let’s consider mixed-precision, using higher precision for computations that need it, and lower precision elsewhere.


Enter APEX - this library from Nvidia does all the work under-the-hood needed to train a network using mixed-precision operations. In other words, it knows which operations can get away with switching to FP16, and which ones should be done in FP32, and handles the data management accordingly. It’s able to do this quite seamlessly by just monkey-patching over PyTorch functions as needed.

APEX advertises itself as only needing 3 lines of code to set up. I found there was a slight additional step, in that building it requires a version of CUDA installed that matches the exact version of CUDA used by PyTorch, and my local CUDA was a little out of date. Once I remedied that though, I did indeed just make the changes above.

By the way, O1 is the recommended/default amount of mixed precision. O0 reverts back to normal FP32, O2 is another mixed precision setting, and O3 is basically FP16.

After making the above changes and kicking off a new training run, I find memory usage equivalent to FP16. Training for 1 epoch takes a little longer at 170 seconds, and still reaches 0.02 loss. Perhaps the runtime might wash out given a larger/longer training session. Either way though, the 50% extra memory overhead is quite nice, and opens up more possibilities for local training on my own hardware.