diff --git a/notebooks/Lecture_2_Gym.ipynb b/notebooks/Lecture_2_Gym.ipynb new file mode 100644 index 0000000..3ded9ef --- /dev/null +++ b/notebooks/Lecture_2_Gym.ipynb @@ -0,0 +1,512 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "QTNU1mwGB1ZD" + }, + "source": [ + "**Dependencies and setup** (this can take a minute or so...)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "wNughGSMa9lY" + }, + "outputs": [], + "source": [ + "# !pip install swig\n", + "\n", + "# !pip install rldurham # latest release\n", + "# !pip install git+https://github.com/robert-lieck/rldurham.git@main # latest main version (typically same as release)\n", + "# !pip install git+https://github.com/robert-lieck/rldurham.git@dev # latest dev version" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "import rldurham as rld # Reinforcement Learning Durham package with helper functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Basic environment**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "env = gym.make('CartPole-v1', render_mode=\"human\")\n", + "observation, info = env.reset(seed=42)\n", + "\n", + "for episode in range(10):\n", + " observation, info = env.reset()\n", + " done = False\n", + " while not done:\n", + " action = env.action_space.sample() # random action\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", + "\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Reinforcement Learning Durham:** _rldurham_ Python package" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'recorder': {'idx': 10, 'length': 25, 'r_sum': 25.0, 'r_mean': 1.0, 'r_std': 0.0, 'length_': 34.0, 'r_sum_': 204.0, 'r_mean_': 34.0, 'r_std_': 17.281975195754296}}\n", + "{'recorder': {'idx': [0, 2, 4, 6, 8], 'length': [29, 69, 39, 14, 28], 'r_sum': [29.0, 69.0, 39.0, 14.0, 28.0], 'r_mean': [1.0, 1.0, 1.0, 1.0, 1.0], 'r_std': [0.0, 0.0, 0.0, 0.0, 0.0], 'length_': [29.0, 49.0, 45.666666666666664, 37.75, 35.8], 'r_sum_': [29.0, 98.0, 137.0, 151.0, 179.0], 'r_mean_': [29.0, 49.0, 45.666666666666664, 37.75, 35.8], 'r_std_': [0.0, 20.0, 16.996731711975958, 20.116846174288852, 18.410866356584094]}}\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "env = rld.make('CartPole-v1', render_mode=\"rgb_array\") # drop-in for gym.make\n", + "env = rld.Recorder( # record statistics (returned in info) and videos\n", + " env, \n", + " smoothing=10, # rolling averages\n", + " video=True, # record videos\n", + " video_folder=\"videos\", # folder for videos\n", + " video_prefix=\"xxxx00-agent-video\", # prefix for videos\n", + " logs=True, # keep logs\n", + ")\n", + "seed, observation, info = rld.seed_everything(42, env) # seed everything (python, numpy, pytorch, env)\n", + "tracker = rld.InfoTracker() # track statistics, e.g., for plotting\n", + "\n", + "for episode in range(11):\n", + " env.info = episode % 2 == 0 # track every other episode\n", + " env.video = episode % 4 == 0 # set before reset! (is checked on reset)\n", + " #######################################################################\n", + " observation, info = env.reset()\n", + " done = False\n", + " while not done:\n", + " action = env.action_space.sample() # Random action\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", + " #######################################################################\n", + " if done:\n", + " # track and plot statistics\n", + " print(info)\n", + " print(tracker.info)\n", + " tracker.track(info)\n", + " tracker.plot(r_mean_=True, r_std_=True, \n", + " length=dict(linestyle='--', marker='o'),\n", + " r_sum=dict(linestyle='', marker='x'))\n", + "\n", + "env.close() # important (e.g. triggers last video save)\n", + "env.write_log(folder=\"logs\", file=\"xxxx00-agent-log.txt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QEv4ZjXmyrHo" + }, + "source": [ + "**Different environments**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "1Xrcek4hxDXl" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The environment has 4 observations and the agent can take 2 actions\n", + "The action space is: discrete\n", + "The maximum timesteps is: 500\n" + ] + } + ], + "source": [ + "## render mode\n", + "rm=\"human\" # for visualising\n", + "# rm='rgb_array' # for recording videos\n", + "# rm=None # no rendering\n", + "\n", + "## select environment\n", + "env = rld.make('CartPole-v1', render_mode=rm) # easy discrete\n", + "# env = rld.make('LunarLander-v2', render_mode=rm) # discrete\n", + "# env = rld.make('Breakout-v0', render_mode=rm) # discrete\n", + "# env = rld.make('Pong-ram-v0', render_mode=rm) # discrete\n", + "# env = rld.make('Gravitar-ram-v0', render_mode=rm) # hard discrete\n", + "#\n", + "# env = rld.make('Pendulum-v0', render_mode=rm) # easy continuous\n", + "# env = rld.make('LunarLanderContinuous-v2', render_mode=rm) # continuous\n", + "# env = rld.make('BipedalWalker-v3', render_mode=rm) # continuous\n", + "# env = rld.make('BipedalWalkerHardcore-v3', render_mode=rm) # hard continuous\n", + "\n", + "## wrap for stats and video recording\n", + "# env = rld.Recorder(env, video=True)\n", + "# env.video = False # deactivate\n", + "\n", + "## get some info\n", + "discrete = hasattr(env.action_space, 'n')\n", + "obs_dim = env.observation_space.shape[0]\n", + "act_dim = env.action_space.n if discrete else env.action_space.shape[0]\n", + "print('The environment has {} observations and the agent can take {} actions'.format(obs_dim, act_dim))\n", + "print('The action space is: ' + ('discrete' if discrete else 'continuous'))\n", + "print('The maximum timesteps is: {}'.format(env.spec.max_episode_steps))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "for episode in range(1):\n", + " observation, info = env.reset()\n", + " done = False\n", + " while not done:\n", + " action = env.action_space.sample() # random action\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Training an agent**" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "rA38jtUgtZsG" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The device is: cuda\n", + "It's recommended to train on the cpu for this\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import torch\n", + "\n", + "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", + "print('The device is: {}'.format(device))\n", + "if device.type != 'cpu': print('It\\'s recommended to train on the cpu for this')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "4jXNHP8_U-rn" + }, + "outputs": [], + "source": [ + "class Agent(torch.nn.Module):\n", + " def __init__(self, env):\n", + " super().__init__()\n", + " self.discrete = hasattr(env.action_space, 'n')\n", + " self.obs_dim = env.observation_space.shape[0]\n", + " self.act_dim = env.action_space.n if discrete else env.action_space.shape[0]\n", + "\n", + " def prob_action(self, obs):\n", + " return np.ones(self.act_dim)/self.act_dim\n", + "\n", + " def sample_action(self, prob):\n", + " if self.discrete:\n", + " return np.random.choice(self.act_dim, p=prob)\n", + " else:\n", + " return np.random.uniform(-1.0, 1.0, size=self.act_dim)\n", + "\n", + " def train(self):\n", + " return\n", + "\n", + " def put_data(self, item):\n", + " return" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# init environment\n", + "env = rld.make('CartPole-v1')\n", + "env = rld.Recorder(env, smoothing=10)\n", + "\n", + "# seed\n", + "rld.seed_everything(seed=42, env=env)\n", + "\n", + "# init agent\n", + "agent = Agent(env)\n", + "\n", + "# training procedure\n", + "tracker = rld.InfoTracker()\n", + "for episode in range(201):\n", + " obs, info = env.reset()\n", + " done = False\n", + "\n", + " # get episode\n", + " while not done:\n", + " # select action\n", + " prob = agent.prob_action(obs)\n", + " action = agent.sample_action(prob)\n", + "\n", + " # take action in environment\n", + " next_obs, reward, terminated, truncated, info = env.step(action)\n", + " done = terminated or truncated\n", + " agent.put_data((reward, prob[action] if discrete else None))\n", + " obs = next_obs\n", + "\n", + " # track and plot\n", + " tracker.track(info)\n", + " if episode % 10 == 0:\n", + " tracker.plot(r_mean_=True, r_std_=True)\n", + "\n", + " # update agent's policy\n", + " agent.train()\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_5TzwBwi_T1H" + }, + "source": [ + "**REINFORCE agent example**\n", + "This code is based on: https://github.com/seungeunrho/minimalRL\n", + "\n", + "Note: these implementations are good to study, although most are for discrete action spaces" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "gJ5XdxBUX7N6" + }, + "outputs": [], + "source": [ + "# this is an implementation of REINFORCE (taught in lecture 8) - one of the simplest classical policy gradient methods\n", + "# this will only work for simple discrete control problems like cart pole or (slowly) lunar lander discrete\n", + "plot_interval = 50\n", + "video_every = 500\n", + "max_episodes = 5000\n", + "\n", + "if not discrete:\n", + " print(\"REINFORCE only works for discrete action spaces\")\n", + " raise AssertionError()\n", + "\n", + "class Agent(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.data = []\n", + " self.fc1 = nn.Linear(obs_dim, 128)\n", + " self.fc2 = nn.Linear(128, act_dim)\n", + " self.optimizer = optim.Adam(self.parameters(), lr=0.0002)\n", + "\n", + " def sample_action(self, prob):\n", + " m = Categorical(prob)\n", + " a = m.sample()\n", + " return a.item()\n", + "\n", + " def prob_action(self, s):\n", + " x = F.relu(self.fc1(torch.from_numpy(s).float()))\n", + " return F.softmax(self.fc2(x), dim=0)\n", + "\n", + " def put_data(self, item):\n", + " self.data.append(item)\n", + "\n", + " def train(self):\n", + " R = 0\n", + " self.optimizer.zero_grad()\n", + " for r, prob in self.data[::-1]:\n", + " R = r + 0.98 * R\n", + " loss = -torch.log(prob) * R\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " self.data = []" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wrwrBOetkXEu" + }, + "source": [ + "**Custom environments (here: multi-armed bandits)**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "24ggTPUlUNEz" + }, + "outputs": [], + "source": [ + "# one armed bandits are slot machines where you can win or loose money\n", + "# some casinos tune them, so some machines are less successful (the probability distribution of winning)\n", + "# the machines also dish out varying rewards (the reward distribution)\n", + "# so basically each machine has some probability of dishing out £ reward (p_dist doesn't need to sum to 1) else it gives £0\n", + "\n", + "class BanditEnv(gym.Env):\n", + "\n", + " def __init__(self, p_dist=[0.4,0.2,0.1,0.1,0.1,0.7], r_dist=[1,0.1,2,0.5,6,70]):\n", + "\n", + " self.p_dist = p_dist\n", + " self.r_dist = r_dist\n", + "\n", + " self.n_bandits = len(p_dist)\n", + " self.action_space = gym.spaces.Discrete(self.n_bandits)\n", + " self.observation_space = gym.spaces.Discrete(1)\n", + "\n", + " def step(self, action):\n", + " assert self.action_space.contains(action)\n", + "\n", + " reward = -25\n", + " terminated = True\n", + " truncated = False\n", + "\n", + " if np.random.uniform() < self.p_dist[action]:\n", + " if not isinstance(self.r_dist[action], list):\n", + " reward += self.r_dist[action]\n", + " else:\n", + " reward += np.random.normal(self.r_dist[action][0], self.r_dist[action][1])\n", + "\n", + " return np.zeros(1), reward, terminated, truncated, {}\n", + "\n", + " def reset(self, seed=None, options=None):\n", + " if seed is not None:\n", + " self.np_random, seed = gym.utils.seeding.np_random(seed)\n", + " return np.zeros(1), {}\n", + "\n", + " def render(self):\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "24ggTPUlUNEz" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 -25\n", + "5 45\n", + "0 -24\n", + "0 -25\n", + "0 -25\n", + "5 45\n", + "3 -25\n", + "3 -25\n", + "0 -24\n", + "5 45\n" + ] + } + ], + "source": [ + "# initialise\n", + "discrete = True\n", + "env = BanditEnv()\n", + "obs_dim = 1\n", + "act_dim = len(env.p_dist)\n", + "env.spec = gym.envs.registration.EnvSpec('BanditEnv-v0', max_episode_steps=5)\n", + "max_episodes = 1000\n", + "\n", + "for episode in range(10):\n", + " observation, info = env.reset()\n", + " done = False\n", + " while not done:\n", + " action = env.action_space.sample() # random action\n", + " observation, reward, terminated, truncated, info = env.step(action)\n", + " print(action, reward)\n", + " done = terminated or truncated\n", + "\n", + "env.close()" + ] + } + ], + "metadata": { + "colab": { + "name": "OpenAI gym", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.20" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}