Multi-GPU model training in Keras
Defining the discriminator
An example discriminator could look like:
def disc(params):
inp = Input(shape = (256, 256, 1))
model = Model(inputs = inp, outputs = out) #out is the output of final layer
return model
params
denotes any parameters you’d like to pass to the disc
function. shape = (256, 256, 1)
is just used as an example here. You can set it as per your model.
Defining the generator
An example generator could look like:
def gen(params):
inp = Input(shape = ())
model = Model(inpouts = inp, outputs = out)
return model
params
denotes any parameters you’d like to pass to the disc
function. The shape could be set as shape = (256, 256, 1)
depending on your specific model, out is the output of final layer.
Defining the GAN model
def gan(gen, disc, params):
disc.trainable = False
inp = Input(shape = ())
out_gen = gen(inp)
out_disc = disc(out_gen)
model = Model(inputs = inp, outputs = out_disc)
return model
The gan
is used for training the generator (when training the discriminator and generator alternately) which is why the discriminator is frozen using disc.trainable = False
. Here we are assuming that the generator is being trained on a single loss function, outputs
can be modified in case it is trained on more than 1 loss function. We’ll give an example for the same too.
Making instances and defining parallel models
d_model = disc(params)
d_par = multi_gpu_model(d_model, gpus = <no. of GPUs>, cpu_relocation = True)
d_model
provides us an instance of the discriminator , d_par
is the multi GPU model for d_model
.
Similarly for the generator model:
g_model = gen(params)
g_par = multi_gpu_model(g_model, gpus = <no. of GPUs>, cpu_relocation = True)
To define the gan model, g_par
and d_par
are used:
gan_model = gan(g_par, d_par, params)