The Blog

 

Deep learning algorithms help solve cognitive problems in an effective way. And, two popular frameworks for deep learning are TensorFlow (from Google) and PyTorch (from Facebook), which are now part of the IBM PowerAI package.

One decision that data scientists or developers of artificial intelligence (AI) apps must know is which framework is best for their use case. TensorFlow and PyTorch both excel in their own way, and in this blog, I’ll explain how TensorFlow and PyTorch compare against each other using a convolutional neural network as an example for image training using a Resnet-50 model.

Google started a proprietary machine learning language called DistBelief that later transformed into TensorFlow. At the beginning, TensorFlow had a Python front end and a large C++ runtime at the back end. Over time, they moved most of their runtime into Python. With TensorFlow, you must define the graph statically and then run the model through a session. These graphs in TensorFlow are difficult to debug. Even though they provide a debugging tool called tfdbg, it helps to analyze the tensors and their operations. From a Python standpoint, you need a separate debugger to debug that code. However, there is a very good visualizing tool called TensorBoard that gives a great visualization of the model, hyper parameters, runtime, and so on.

Torch is an open source machine learning library based on the Lua programming language. Over time, it has been converted into a Python-based library with some changes and called PyTorch. This is heavily used by Facebook. PyTorch lets you define, change, and run the model dynamically. You can use any Python debugger like pdb to debug the PyTorch-based code. It does not have a visualizer like TensorBoard. However, as the framework becomes more mature, there should be more visualizers developed for it as well.

Now, let’s compare these two deep learning frameworks using a standard image recognition model. I use a Resnet-50 model with an ImageNet data set and a batch size of 32 images. I evaluated it on both TensorFlow and PyTorch.

I found that PyTorch performed much better compared to TensorFlow. It’s not always the norm that PyTorch will outperform TensorFlow, but at least for the nature of deep learning applications like Resnet-50, it should. I dissected the application to see where they spent most of their time and for what purpose.

The per iteration time to process 32 images was computed at 160 ms on PyTorch compared to 197 ms on TensorFlow. The major benefit for PyTorch comes from the type of kernels that it uses for the forward and backward propagation, which is evident from the time spent on the propagation. PyTorch spent 31 ms and 33 ms on forward and backward computation, respectively, whereas TensorFlow spent 55 ms and 120 ms on similar operations. The gradient reduction operation in PyTorch is an exclusive operation with no other computations happening in parallel. With TensorFlow, the reduction is a parallel operation that gets computed alongside the backward propagation kernels.

diagram

I also note that PyTorch acts on raw input images and eventually spends a lot of time doing the preprocessing of the data. TensorFlow does the processing of the images to a certain extent and stores them as TFRecords even before the start of the training phase. This results in TensorFlow spending only 22 ms compared to PyTorch spending 48 ms to preprocess the data. This benefit of preprocessing in TensorFlow does not get converted to the full training benefit because the kernels used in PyTorch are much superior compared to TensorFlow. If TensorFlow can somehow use similar kernels, that should result in TensorFlow performing better than PyTorch for models like Resnet-50.

In this blog post, I showed that even though two different deep learning frameworks work on the same model, the runtime characteristics can be drastically different, which results in a difference in performance. In another sense, it also shows the possible optimization opportunities for improving some of these frameworks.