Categories
Misc

TensorFlow significantly slower than PyTorch when training a small number of batches

Recently, I am learning and playing around with Deep Reinforcement Learning. Basically, for many DRL algorithms, we need to train a single batch with 1 epoch at a time. I observed that TensorFlow 2 performs significantly slower (9 – 22 times slower) than PyTorch.

It is the first time I met this problem. I used to do more supervised computer vision tasks, therefore, I suspect that the performance issue is caused by a small number of batches per epoch/training (since, unlike DRL, common CV tasks have a lot of batches and epochs, I saw only a minor performance difference between the two frameworks).

However, I could not solve the problem, I asked on StackOverflow and even opened an issue, nobody answered yet. I personally prefer TensorFlow, so I don’t want to move to PyTorch unless I have to. I just wonder if anyone can help explain why or help me to improve the performance on a small number of batches.

Github Issue with reproducible code and more detailed explanation:

https://github.com/tensorflow/tensorflow/issues/48844

Any help would be appreciated, thank you so much!

submitted by /u/seermer
[visit reddit] [comments]

Leave a Reply

Your email address will not be published. Required fields are marked *