Blog

Multi-GPU model training in Keras

Defining the discriminator

An example discriminator could look like:

def disc(params):
  inp = Input(shape = (256, 256, 1))
  model = Model(inputs = inp, outputs = out) #out is the output of final layer
  return model

params denotes any parameters you’d like to pass to the disc function. shape = (256, 256, 1) is just used as an example here. You can set it as per your model.

Defining the generator

An example generator could look like:

def gen(params):
  inp = Input(shape = ())
  model = Model(inpouts = inp, outputs = out)
  return model

params denotes any parameters you’d like to pass to the disc function. The shape could be set as shape = (256, 256, 1) depending on your specific model, out is the output of final layer.

Defining the GAN model

def gan(gen, disc, params):
  disc.trainable = False
  inp = Input(shape = ())
  out_gen = gen(inp)
  out_disc = disc(out_gen)
  model = Model(inputs = inp, outputs = out_disc)
  return model

The gan is used for training the generator (when training the discriminator and generator alternately) which is why the discriminator is frozen using disc.trainable = False. Here we are assuming that the generator is being trained on a single loss function, outputs can be modified in case it is trained on more than 1 loss function. We’ll give an example for the same too.

Making instances and defining parallel models

d_model = disc(params)
d_par = multi_gpu_model(d_model, gpus = <no. of GPUs>, cpu_relocation = True)
 

d_model provides us an instance of the discriminator , d_par is the multi GPU model for d_model.

Similarly for the generator model:

g_model = gen(params)
g_par = multi_gpu_model(g_model, gpus = <no. of GPUs>, cpu_relocation = True)

To define the gan model, g_par and d_par are used:

gan_model = gan(g_par, d_par, params)