Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

where is adversarial loss for the generator? #3

Open
srikanthmalla opened this issue Jan 21, 2021 · 1 comment
Open

where is adversarial loss for the generator? #3

srikanthmalla opened this issue Jan 21, 2021 · 1 comment

Comments

@srikanthmalla
Copy link

Hi @Anjaney1999 ,
I was looking at your code and trying to find adversarial loss in the generator training scheme:

def gen_train(imgs, caps, cap_lens, encoder, generator, discriminator, rollout, gen_optimizer, gen_pg_criterion,

Can you let me know if it is used in your code? If not, it is need for GAN right? Please let me know.

Thank you,
Srikanth

@Anjaney1999
Copy link
Owner

Anjaney1999 commented Jan 22, 2021

Hey Srikanth, this GAN outputs discrete tokens, unlike regular GANs, so a regular adversarial loss will not work. To train the model, Policy Gradient is used, where feedback is given by the discriminator:

rewards = rollout.get_reward(samples=fake_caps, sample_cap_lens=fake_cap_lens, hidden_states=hidden_states,
discriminator=discriminator, img_feats=imgs, word_index=word_index,
col_shape=caps.shape[1], args=args)

For more information, you can refer to https://arxiv.org/pdf/1609.05473.pdf

Also, if you have any other questions, I will try my best to explain:)

I keep procrastinating and end up not writing a proper readme for this repo, but I aim to do that soon. Overall, the model takes ages to train and improvements in performance are not huge; however, it was an excellent learning opportunity for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants