Skip to content

Commit

Permalink
Merge pull request #1 from opimwue/cuda_support
Browse files Browse the repository at this point in the history
Cuda support
  • Loading branch information
majoma7 authored Aug 15, 2024
2 parents 45ec4b4 + 49e85f0 commit 6b38918
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 357 deletions.
2 changes: 2 additions & 0 deletions ddopnew/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@
'ddopnew/agents/newsvendor/erm.py'),
'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_dataloader': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_dataloader',
'ddopnew/agents/newsvendor/erm.py'),
'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_device': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_device',
'ddopnew/agents/newsvendor/erm.py'),
'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_learning_rate_scheduler': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_learning_rate_scheduler',
'ddopnew/agents/newsvendor/erm.py'),
'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_loss_function': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_loss_function',
Expand Down
20 changes: 19 additions & 1 deletion ddopnew/agents/newsvendor/erm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self,
dataloader_params = dataloader_params or {"batch_size": 32, "shuffle": True}
self.torch_obsprocessors = torch_obsprocessors or []

self.device = device
self.device = self.set_device(device)

self.set_dataloader(dataloader, dataloader_params)
self.set_model(input_shape, output_shape)
Expand All @@ -62,6 +62,24 @@ def __init__(self,

super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)

self.to(self.device)

def set_device(self, device: str):

""" Set the device for the model """

if device == "cuda":
if torch.cuda.is_available():
return "cuda"
else:
logging.warning("CUDA is not available. Using CPU instead.")
return "cpu"
elif device == "cpu":
return "cpu"
else:
raise ValueError(f"Device {device} not currently not supported, use 'cuda' or 'cpu'")


def set_dataloader(self,
dataloader: BaseDataLoader,
dataloader_params: dict, # dict with keys: batch_size, shuffle
Expand Down
8 changes: 6 additions & 2 deletions ddopnew/agents/rl/mushroom_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: #
""" Do one forward pass of the model directly and return the prediction
Overwrite for agents that have additional steps such as SAC"""

observation = torch.tensor(observation, dtype=torch.float32).to(self.device)
device = next(self.actor.parameters()).device
observation = torch.tensor(observation, dtype=torch.float32).to(device)
action = self.actor.forward(observation)
action = action.cpu().detach().numpy()

Expand All @@ -137,12 +138,15 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: #
def train(self):
"""set the internal state of the agent and its model to train"""
self.mode = "train"
for network in self.network_list:
network.train()

def eval(self):
"""set the internal state of the agent and its model to eval"""
self.mode = "eval"
for network in self.network_list:
network.eval()


def to(self, device: str): #
"""Move the model to the specified device"""

Expand Down
10 changes: 4 additions & 6 deletions ddopnew/agents/rl/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ def __init__(self,
logging.info("Actor network (mu network):")
if logging.getLogger().isEnabledFor(logging.INFO):
input_size = self.add_batch_dimension_for_shape(actor_mu_params["input_shape"])
print(summary(self.actor, input_size=input_size))
print(summary(self.actor, input_size=input_size, device=self.device))
time.sleep(0.2)

logging.info("################################################################################")
logging.info("Critic network:")
if logging.getLogger().isEnabledFor(logging.INFO):
input_size = self.add_batch_dimension_for_shape(critic_params["input_shape"])
print(summary(self.critic, input_size=input_size))
print(summary(self.critic, input_size=input_size, device=self.device))

def get_network_list(self, set_actor_critic_attributes: bool = True):
""" Get the list of networks in the agent for the save and load functions
Expand All @@ -189,14 +189,12 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: #
Apply tanh as implemented for the SAC actor in mushroom_rl"""

# make observation torch tensor
device = next(self.actor.parameters()).device
observation = torch.tensor(observation, dtype=torch.float32).to(device)

observation = torch.tensor(observation, dtype=torch.float32).to(self.device)
action = self.actor.forward(observation)
# print("a before tanh: ", action)
action = torch.tanh(action)
# print("a after tanh: ", action)
action = action * self.agent.policy._delta_a + self.agent.policy._central_a
# print("a after scaling: ", action)
action = action.cpu().detach().numpy()

return action
Expand Down
15 changes: 9 additions & 6 deletions ddopnew/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import zipfile


# %% ../nbs/80_datasets/datasets.ipynb 7
# %% ../nbs/80_datasets/datasets.ipynb 6
def get_all_release_tags(token=None):
url = "https://api.github.com/repos/opimwue/ddopnew/releases"
headers = {'Authorization': f'Bearer {token}'} if token else {}
Expand Down Expand Up @@ -89,21 +89,24 @@ def unzip_file(zip_file_path, output_dir, delete_zip_file=True):
os.remove(zip_file_path)

def load_data_from_directory(dir):
data = list()
data = dict()
for file in os.listdir(dir):
if file.endswith(".csv"):
data.append(pd.read_csv(os.path.join(dir, file)))
key = os.path.splitext(file)[0]
data[key] = pd.read_csv(os.path.join(dir, file))
elif file.endswith(".pkl"):
data.append(pd.read_pickle(os.path.join(dir, file)))
key = os.path.splitext(file)[0]
data[key] = pd.read_pickle(os.path.join(dir, file))
elif file.endswith(".npy"):
data.append(np.load(os.path.join(dir, file)))
key = os.path.splitext(file)[0]
data[key] = np.load(os.path.join(dir, file))
else:
raise ValueError(f"File {file} is not a valid file type (csv, pkl, or npy)")

return data


# %% ../nbs/80_datasets/datasets.ipynb 9
# %% ../nbs/80_datasets/datasets.ipynb 8
class DatasetLoader():

"""
Expand Down
1 change: 1 addition & 0 deletions ddopnew/experiment_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def run_experiment( agent: BaseAgent,
for epoch in trange(n_epochs):

env.set_return_truncation(False) # For mushroom Core to work, the step function should not return the truncation flag
agent.train()
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=True)
env.set_return_truncation(True) # Set back to standard gynmasium behavior

Expand Down
1 change: 1 addition & 0 deletions nbs/30_experiment_functions/10_experiment_functions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@
" for epoch in trange(n_epochs):\n",
"\n",
" env.set_return_truncation(False) # For mushroom Core to work, the step function should not return the truncation flag\n",
" agent.train()\n",
" core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=True)\n",
" env.set_return_truncation(True) # Set back to standard gynmasium behavior\n",
"\n",
Expand Down
Loading

0 comments on commit 6b38918

Please sign in to comment.