Blog

Saving multi-GPU models and resuming training in Keras

Unlike Pytorch, the use of multiple GPUs for training models is not straightforward in Keras, especially for GANs. This part was covered in our previous post. In this post, we demonstrate how to save such models, and how to load them to resume training. It was quite a task to figure this part out, as we couldn’t find any documentation related to this. Without further ado, we’ll start with the important steps.

The first step is to define the generator and the discriminator as functional models. Then, we create the multi-GPU instances of these models, using which the GAN model is created. This part was covered in our previous post. Here, we list some important notations:

Saving models

Saving g_par weights

As mentioned earlier, the multi-GPU model would contain the parallel version of the corresponding model in one of the layers. This can be confirmed from the summary of the parallel model, e.g. g_par.summary() for the generator. In order to save the weights of the model, we need to retrieve the model from that layer, and then save it in a .h5 file, as usual. For inference, we need the weights of the generator only. The following steps are followed to save the weights:

g_save = g_par.get_layer(<layer_name>)
g_save.save_weights(<file_name>)

Loading gen weights

For inference, the weights can be loaded as follows:

gen.load_weights(<file_name>)

where gen is an instance of the generator model.

Saving d_par weights

In order to save the model with the purpose of resuming training, we also need to save the weights of the discriminator, as well as the optimizer state of the model. The weights of the discriminator can be saved in a similar manner as the generator:

d_save = d_par.get_layer(<layer_name>)
d_save.save_weights(<file_name>)

Saving optimizer states

The optimizer states are saved in pickle format. According to the definition of the model, we need to save these for the discriminator and the GAN model. The following steps are followed:

symbolic_weights = getattr(d_par.optimizer, 'weights')
weight_values = K.batch_get_value(symbolic_weights)
with open(<file_name>, 'wb') as f:
  pickle.dump(weight_values, f)
  f.close()
           
symbolic_weights = getattr(gan_model.optimizer, 'weights')
weight_values = K.batch_get_value(symbolic_weights)
with open(<file_name>, 'wb') as f:
  pickle.dump(weight_values, f)
  f.close()
          
del symbolic_weights
del weight_values

where K is keras.backend, and del is used to clear space.

Resuming Training

Loading d_par weights

The most crucial part is resuming training. The first step is loading weights in the parallel models, which is different than how weights are loaded usually. For discriminator, this is done as follows:

disc.load_weights(<file_name>)
d_par = multi_gpu_model(disc, gpus = <no_of_gpus>, cpu_relocation = True)
d_par.layers[-2].set_weights(disc.get_weights())

where disc is an instance of the discriminator model. The weights are first loaded in this model, and later copied into the parallel model. This is because the multi_gpu_model function initializes the weights of the parallel model randomly, and they need to be set again. After this step, we can proceed to compile d_par.

Loading g_par weights

The steps followed to load the weights of the generator are the same:

gen.load_weights(<file_name>)
g_par = multi_gpu_model(gen, gpus = <no_of_gpus>, cpu_relocation = True) 
g_par.layers[-2].set_weights(gen.get_weights())

After this, we define the GAN model, and compile it.

Loading optimizer states

The final step is to load the optimizer states of d_par and gan. This is done as follows:

d_par._make_train_function()
with open(<file_name>, 'rb') as f:
   weight_values = pickle.load(f)
d_par.optimizer.set_weights(weight_values)

gan._make_train_function()
with open(<file_name>, 'rb') as f:
   weight_values = pickle.load(f)
gan.optimizer.set_weights(weight_values)

After this, we can train the model as usual. The training would resume, and the models would show similar losses as the ones before it was stopped and saved.