You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi there, thanks so much for sharing the codes - really amazing work!
When I was trying to train the model on my own, I ran into an error at line 383 in train.py ((w_rec_loss * args.lambda_w_rec_loss).backward()) that
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
I read the model definition and the implementation looked correct to me - so I don't understand why such error was thrown. Do you maybe have any idea what could have gone wrong?
In the meanwhile, to unblock myself, I modified the codes a bit to run backward() for g_loss and w_rec_loss in one go (g_w_rec_loss in the following example). Does this modification make sense to you? Why did you separate the backward operation in the first place?
adv_loss, w_rec_loss, stylecode = model(None, "G")
adv_loss = adv_loss.mean()
w_rec_loss = w_rec_loss.mean()
g_loss = adv_loss * args.lambda_adv_loss
g_optim.zero_grad()
e_optim.zero_grad()
g_w_rec_loss = g_loss + w_rec_loss * args.lambda_w_rec_loss
g_w_rec_loss.backward()
gather_grad(
g_module.parameters(), world_size
) # Explicitly synchronize Generator parameters. There is a gradient sync bug in G.
g_optim.step()
e_optim.step()
Thanks in advance for your help!
The text was updated successfully, but these errors were encountered:
Hi mingo-x, you can combine two losses (g_loss, w_rec_loss) together.
I think there is no huge difference.
However, you should be aware that w_rec_loss only affects the encoder, not the generator in the original version.
Your modification makes w_rec_loss also affect the update of the generator.
Lastly, if you didn't modify any training code, I don't know why RuntimeError occurs. Please check your torch version.
Hi there, thanks so much for sharing the codes - really amazing work!
When I was trying to train the model on my own, I ran into an error at line 383 in
train.py
((w_rec_loss * args.lambda_w_rec_loss).backward()
) thatI read the model definition and the implementation looked correct to me - so I don't understand why such error was thrown. Do you maybe have any idea what could have gone wrong?
In the meanwhile, to unblock myself, I modified the codes a bit to run
backward()
forg_loss
andw_rec_loss
in one go (g_w_rec_loss
in the following example). Does this modification make sense to you? Why did you separate thebackward
operation in the first place?Thanks in advance for your help!
The text was updated successfully, but these errors were encountered: