Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume training #35

Open
Tomeu7 opened this issue Apr 29, 2021 · 5 comments
Open

Resume training #35

Tomeu7 opened this issue Apr 29, 2021 · 5 comments
Labels
help wanted Extra attention is needed

Comments

@Tomeu7
Copy link

Tomeu7 commented Apr 29, 2021

Hello I am trying to use the SAC agent and resume training, to do that I do:

def load_model(actor_path, critic_path, optimizer_actor_path, optimizer_critic_path, optimizer_alpha_path):

  policy = torch.load(actor_path)
  self.alpha = policy['alpha'].detach().item()
  self.log_alpha = torch.tensor([policy['log_alpha'].detach().item()], requires_grad=True, device=self.device)
  self.alpha_optim = Adam([self.log_alpha], lr=self.lr) # I had to recreate alpha optim with the new log_alpha loaded

  self.policy.load_state_dict(policy['model_state_dict'])
  self.policy.train()
  self.critic.load_state_dict(torch.load(critic_path))
  self.critic.train()

  self.policy_optim.load_state_dict(torch.load(optimizer_actor_path))
  self.critic_optim.load_state_dict(torch.load(optimizer_critic_path))
  self.alpha_optim.load_state_dict(torch.load(optimizer_alpha_path))

Is this correct? The loss explodes after resuming which is very strange.

@pranz24
Copy link
Owner

pranz24 commented May 25, 2021

That shouldn't happen. Will look into it.
I might, require more detail on how you resume training.
(Sorry for the late reply.)

@BennetLeff
Copy link

BennetLeff commented Dec 11, 2021

Unfortunately I think this is still the case. When I reload saved parameters I get NaN in my loss function. It looks like it's coming from the QNetworks. After just the first layer, the network outputs NaNs. Also, it's not immediate. It takes anywhere from 30 steps to 500 steps depending on what I do. It is always deterministically failing though, if I don't change anything, it will fail every time with the same number of steps.

I've printed out the state_dict for each loaded param. The problem seems to be that the QNetworks (and critic optimizer) have NaN values in their serialized versions. What's strange is that I'm deserializing a checkpoint from a model that is still running. So the problem might not be the QNetworks, rather the serialization of them.

Edit: I just ran another test by saving/loading models and seeing if they were corrupted but couldn't find any such thing. That points the finger back at the QNetworks having some exploding gradient problem or something similar.

@BennetLeff
Copy link

Sorry I'm back without an answer but is it possible one of the issues that the alpha optimizer is not saved/loaded via checkpoints?

@pranz24 pranz24 added the help wanted Extra attention is needed label Dec 24, 2021
@BennetLeff
Copy link

I have a bit more time to look into this. Interestingly, the critic/Q Networks are what are filling up with NaN not the policy so it's probably not related to the temperature/alpha parameter.

@typoverflow
Copy link

Hello I am trying to use the SAC agent and resume training, to do that I do:

def load_model(actor_path, critic_path, optimizer_actor_path, optimizer_critic_path, optimizer_alpha_path):

  policy = torch.load(actor_path)
  self.alpha = policy['alpha'].detach().item()
  self.log_alpha = torch.tensor([policy['log_alpha'].detach().item()], requires_grad=True, device=self.device)
  self.alpha_optim = Adam([self.log_alpha], lr=self.lr) # I had to recreate alpha optim with the new log_alpha loaded

  self.policy.load_state_dict(policy['model_state_dict'])
  self.policy.train()
  self.critic.load_state_dict(torch.load(critic_path))
  self.critic.train()

  self.policy_optim.load_state_dict(torch.load(optimizer_actor_path))
  self.critic_optim.load_state_dict(torch.load(optimizer_critic_path))
  self.alpha_optim.load_state_dict(torch.load(optimizer_alpha_path))

Is this correct? The loss explodes after resuming which is very strange.

Did you reload the state-dicts of target networks? If not, this might be the reason for exploding loss.
What's more, maybe the replay buffer needs storing/reloading as well : )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants