trw.layers.gan

Module Contents

Classes

GanDataPool

Gan

Generic GAN implementation. Support conditional GANs.

Functions

process_outputs_and_extract_loss(outputs, batch, is_training)

trw.layers.gan.process_outputs_and_extract_loss(outputs, batch, is_training)
class trw.layers.gan.GanDataPool(pool_size, replacement_probability=0.5, insertion_probability=0.1)
get_data(self, batch, images_fake)
class trw.layers.gan.Gan(discriminator, generator, latent_size, optimizer_discriminator_fn, optimizer_generator_fn, real_image_from_batch_fn, train_split_name='train', loss_from_outputs_fn=process_outputs_and_extract_loss, image_pool=None)

Bases: torch.nn.Module

Generic GAN implementation. Support conditional GANs.

Examples

  • generator conditioned by concatenating a one-hot attribute to the latent or conditioned

    by another image (e.g., using UNet)

  • discriminator conditioned by concatenating a one-hot image sized to the image

    or one-hot concatenated to intermediate layer

  • simple GAN (i.e., no observation)

Notes

Here the module will have its own optimizer. The trw.train.Trainer should have optimizers_fn set to None.

_generate_latent(self, nb_samples)
static _merge_generator_discriminator_outputs(generator_outputs, discriminator_real_outputs, discriminator_fake_outputs)
forward(self, batch)