trw.layers.gan
¶
Module Contents¶
Classes¶
Generic GAN implementation. Support conditional GANs. |
Functions¶
|
- trw.layers.gan.process_outputs_and_extract_loss(outputs, batch, is_training)¶
- class trw.layers.gan.GanDataPool(pool_size, replacement_probability=0.5, insertion_probability=0.1)¶
- get_data(self, batch, images_fake)¶
- class trw.layers.gan.Gan(discriminator, generator, latent_size, optimizer_discriminator_fn, optimizer_generator_fn, real_image_from_batch_fn, train_split_name='train', loss_from_outputs_fn=process_outputs_and_extract_loss, image_pool=None)¶
Bases:
torch.nn.Module
Generic GAN implementation. Support conditional GANs.
Examples
- generator conditioned by concatenating a one-hot attribute to the latent or conditioned
by another image (e.g., using UNet)
- discriminator conditioned by concatenating a one-hot image sized to the image
or one-hot concatenated to intermediate layer
simple GAN (i.e., no observation)
Notes
Here the module will have its own optimizer. The
trw.train.Trainer
should haveoptimizers_fn
set toNone
.- _generate_latent(self, nb_samples)¶
- static _merge_generator_discriminator_outputs(generator_outputs, discriminator_real_outputs, discriminator_fake_outputs)¶
- forward(self, batch)¶