How to port pre-trained Models from Tensorflow to Pytorch
The field of deep learning is one of the fastest changing areas in computer science. When I started getting into the subject about 5 years ago, tensorflow was considered the dominating framework. Today, most reasearchers have shifted towards pytorch.
While this fast-pacedness is exicting, it also brings a lot of challenges. Recently I have been confronted with the task to continue on a project that was conducted in 2018. A colleague trained a segmentation model on a large dataset of clinical data and reported great performance.
Today, our goal is to make use of that trained model in a similar objective in a process called transfer learning. The intuition here is, that instead of starting from scratch, instantiating the weights of a new model at least partly with pre-trained weights will offer a far better starting point.
Collecting Tensorflow 1.x weights
This sounds easier than it is. In tensorflow 1.x, the model is saved in four seperate files — non of which can be directly translated to pytorch’s state_dict. To circumvent this, we have to manually create a dictionary and retrieve the weights from the tensorflow backend.
In order for this to work, you need to know the naming scheme of the tensorflow implementation. Each operation can be assigned a name at creation time. This name will be important later when converting to pytorch.
Recreating the Model
Unfortunately, there is no direct way to convert a tensorflow model to pytorch. However, most layers exist in both frameworks albeit with slightly different syntax. In tf1 for example, the convolutional layer can include an activation function, whereas in pytorch the function needs to be added sequentially.
This example shows a upconv block of the popular UNet architecture in both tf1 and pytorch implementation.
NHWC vs. NCHW
One final important difference between tensorflow and pytorch is the conventions about axes. In legacy tensorflow the data_format attribute could be specified as either channels_last
or channels_first
, whereas the former is the default option. In pytorch however, only channels first is possible. Commonly, these formats are denoted NHWC
and NCHW
for batchsize (N), height (H), width (W) and channels (C).
If a pretrained model in tensorflow was trained with the default channels_last
option, the kernel axis need to be permuted to be used with torch. To compensate for this, a 2d-conv layer weight needs to be adapted like this
np.transpose(kernel, (3, 2, 0, 1))
Initializing the pytorch model
With the weights transposed to the right format, we can load them into the pytorch model. To do so, we instantiate a model randomly and iterate over the named parameter list. We then modify the parameters in-place with the weights from tensorflow.
Conclusion
Following these steps, a model trained in tensorflow 1.x can be extracted and translated into a pytorch model. I hope this helps someone in a similiar situation as me.