-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
68 lines (59 loc) · 2.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import csv
import torch
import os
def save_checkpoint(policy_old, episode, save_directory):
"""
Save model checkpoint.
Parameters:
- policy_old: Policy model to be saved.
- episode: Episode number for naming the checkpoint file.
- save_directory: Directory to save the checkpoint file.
"""
filename = os.path.join(save_directory + "/checkpoints", f'checkpoint_{episode}.pth')
torch.save(policy_old.state_dict(), filename)
print(f'Checkpoint saved to \'{filename}\'')
def load_checkpoint(saves_directory, agent, start):
"""
Load a model checkpoint if available.
Parameters:
- saves_directory: Directory containing checkpoints.
- agent: Agent object with models to load checkpoints into.
- start: Episode number to start from.
Returns:
- agent: Updated agent object.
"""
checkpoint_dir = os.path.join(saves_directory, "checkpoints")
if os.path.exists(checkpoint_dir):
saved_files = os.listdir(checkpoint_dir)
checkpoint_files = [filename for filename in saved_files if filename.startswith("checkpoint_") and filename.endswith(".pth")]
if checkpoint_files:
if start == 0:
latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('_')[1].split('.')[0]))
episode_number = int(latest_checkpoint.split('_')[1].split('.')[0])
agent.episode = episode_number
agent.model.load_state_dict(torch.load(os.path.join(checkpoint_dir, latest_checkpoint)))
agent.model_old.load_state_dict(torch.load(os.path.join(checkpoint_dir, latest_checkpoint)))
else:
agent.episode = start
checkpoint = f"checkpoint_{start}.pth"
if checkpoint in checkpoint_files:
agent.model.load_state_dict(torch.load(os.path.join(checkpoint_dir, checkpoint)))
agent.model_old.load_state_dict(torch.load(os.path.join(checkpoint_dir, checkpoint)))
else:
print(f"Checkpoint {checkpoint} not found. Starting from episode {start}.")
print(f'Resuming training from checkpoint \'{agent.episode}\'.')
else:
print("No checkpoint files found.")
return agent
def write_to_csv(save_dir, filename, episode, reward):
"""
Write episode and reward to a CSV file.
Parameters:
- save_dir: Directory to save the CSV file.
- filename: Name of the CSV file.
- episode: Episode number.
- reward: Total reward for the episode.
"""
with open(os.path.join(save_dir, filename), mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([episode, reward])