Search

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

References

InfoGAN by haawron
초짜 대학원생 InfoGAN [part 1], [part 2] by Jaejun Yoo
GLoss=GGANLoss+λCLossG_{Loss} = G_{GANLoss} + \lambda C_{Loss}
DLoss=DGANLoss+λCLossD_{Loss} = D_{GANLoss} + \lambda C_{Loss}
input_to_G = torch.cat((noise, continuous_code_input, discrete_code_input), 1) fake = G(input_to_G) input_to_D = torch.cat([images, fake]) outputs = D(input_to_D) # outputs : 1 (real/fake) + cont_code + dist_code
Python
복사

Train D

D.zero_grad() # Mutual Information Loss continuous_code_output = outputs[batch_size:, 1:1+num_of_continuous_code] discrete_code_output = outputs[batch_size:, 1+num_of_continuous_code:1+num_of_continuous_code+num_of_class_of_category] Loss_continuous = MSE(continuous_code_output, Distribution_of_continuous_code_generator) Loss_discrete = CrossEntropy(discrete_code_output, discrete_code_input) Loss_D = BCE(outputs[:,0], real_and_fake_labels) + Lambda1 * Loss_continuous + Lambda2 * Loss_discrete Loss_D.backward(retain_variables=True) D_optimizer.step()
Python
복사

Train G

G.zero_grad() Loss_G = BCE(outputs[batch_size:,0], real_labels) + Lambda1 * Loss_continuous + Lambda2 * Loss_discrete Loss_G.backward() G_optimizer.step()
Python
복사