diff --git a/README.md b/README.md index 18c9dfe5e..6f0d94a92 100644 --- a/README.md +++ b/README.md @@ -44,9 +44,8 @@ and closing the gap between simulation and the real world. Explore Brax easily and quickly through a series of colab notebooks: * [Brax Basics](https://colab.research.google.com/github/google/brax/blob/main/notebooks/basics.ipynb) introduces the Brax API, and shows how to simulate basic physics primitives. -* [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) -introduces the Brax v2 API, and shows how to train a policy with the -generalized backend. +* [Brax Training](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb) introduces Brax's training algorithms, and lets you train your own policies directly within the colab. It also demonstrates loading and saving policies. +* [Brax Training with PyTorch on GPU](https://colab.research.google.com/github/google/brax/blob/main/notebooks/training_torch.ipynb) demonstrates how Brax can be used in other ML frameworks for fast training, in this case PyTorch. ## Using Brax Locally diff --git a/brax/training/agents/es/train.py b/brax/training/agents/es/train.py index ef7d4be21..3b9fbeb4a 100644 --- a/brax/training/agents/es/train.py +++ b/brax/training/agents/es/train.py @@ -200,7 +200,7 @@ def compute_delta( Returns: """ - # NOTE - -> len(weights) * perturbation_std" is + # NOTE: The trick "len(weights) -> len(weights) * perturbation_std" is # equivalent to tuning the l2_coef. weights = jnp.reshape(weights, ([population_size] + [1] * (noise.ndim - 1))) delta = jnp.sum(noise * weights, axis=0) / population_size diff --git a/brax/v1/experimental/composer/agent_utils.py b/brax/v1/experimental/composer/agent_utils.py index ebefc2713..2d4ba49bb 100644 --- a/brax/v1/experimental/composer/agent_utils.py +++ b/brax/v1/experimental/composer/agent_utils.py @@ -33,7 +33,7 @@ e.g. equivalent to agent1=(..., action_agents=('agent1',), ...) agent_groups currently defines which rewards/actions belong to which agent. -observation is the same among all agents (TODO -. +observation is the same among all agents (TODO: add optionality). """ from collections import OrderedDict as odict diff --git a/notebooks/training_torch.ipynb b/notebooks/training_torch.ipynb new file mode 100644 index 000000000..a7de3c4dd --- /dev/null +++ b/notebooks/training_torch.ipynb @@ -0,0 +1,556 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "A100" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "trVNqxHmGISS" + }, + "source": [ + "# Training in Brax with PyTorch on GPUs\n", + "\n", + "Brax is ready to integrate into other research toolkits by way of the [OpenAI Gym](https://gym.openai.com/) interface. Brax environments convert to Gym environments using either [GymWrapper](https://github.com/google/brax/blob/main/brax/envs/wrappers/gym.py) for single environments, or [VectorGymWrapper](https://github.com/google/brax/blob/main/brax/envs/wrappers/gym.py) for batched (parallelized) environments." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GJhPpM5ZPrpq" + }, + "source": [ + "#@title Import Brax and some helper modules\n", + "from IPython.display import clear_output\n", + "\n", + "import collections\n", + "from datetime import datetime\n", + "import functools\n", + "import math\n", + "import os\n", + "import time\n", + "from typing import Any, Callable, Dict, Optional, Sequence\n", + "\n", + "try:\n", + " import brax\n", + "except ImportError:\n", + " !pip install git+https://github.com/google/brax.git@main\n", + " clear_output()\n", + " import brax\n", + "\n", + "from brax import envs\n", + "from brax.envs.wrappers import gym as gym_wrapper\n", + "from brax.envs.wrappers import torch as torch_wrapper\n", + "from brax.io import metrics\n", + "from brax.training.agents.ppo import train as ppo\n", + "import gym\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from torch import nn\n", + "from torch import optim\n", + "import torch.nn.functional as F" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vQFCkfu8Qwre" + }, + "source": [ + "Here is a PPO Agent written in PyTorch:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fWJE4b5BHeH7" + }, + "source": [ + "class Agent(nn.Module):\n", + " \"\"\"Standard PPO Agent with GAE and observation normalization.\"\"\"\n", + "\n", + " def __init__(self,\n", + " policy_layers: Sequence[int],\n", + " value_layers: Sequence[int],\n", + " entropy_cost: float,\n", + " discounting: float,\n", + " reward_scaling: float,\n", + " device: str):\n", + " super(Agent, self).__init__()\n", + "\n", + " policy = []\n", + " for w1, w2 in zip(policy_layers, policy_layers[1:]):\n", + " policy.append(nn.Linear(w1, w2))\n", + " policy.append(nn.SiLU())\n", + " policy.pop() # drop the final activation\n", + " self.policy = nn.Sequential(*policy)\n", + "\n", + " value = []\n", + " for w1, w2 in zip(value_layers, value_layers[1:]):\n", + " value.append(nn.Linear(w1, w2))\n", + " value.append(nn.SiLU())\n", + " value.pop() # drop the final activation\n", + " self.value = nn.Sequential(*value)\n", + "\n", + " self.num_steps = torch.zeros((), device=device)\n", + " self.running_mean = torch.zeros(policy_layers[0], device=device)\n", + " self.running_variance = torch.zeros(policy_layers[0], device=device)\n", + "\n", + " self.entropy_cost = entropy_cost\n", + " self.discounting = discounting\n", + " self.reward_scaling = reward_scaling\n", + " self.lambda_ = 0.95\n", + " self.epsilon = 0.3\n", + " self.device = device\n", + "\n", + " @torch.jit.export\n", + " def dist_create(self, logits):\n", + " \"\"\"Normal followed by tanh.\n", + "\n", + " torch.distribution doesn't work with torch.jit, so we roll our own.\"\"\"\n", + " loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1)\n", + " scale = F.softplus(scale) + .001\n", + " return loc, scale\n", + "\n", + " @torch.jit.export\n", + " def dist_sample_no_postprocess(self, loc, scale):\n", + " return torch.normal(loc, scale)\n", + "\n", + " @classmethod\n", + " def dist_postprocess(cls, x):\n", + " return torch.tanh(x)\n", + "\n", + " @torch.jit.export\n", + " def dist_entropy(self, loc, scale):\n", + " log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)\n", + " entropy = 0.5 + log_normalized\n", + " entropy = entropy * torch.ones_like(loc)\n", + " dist = torch.normal(loc, scale)\n", + " log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))\n", + " entropy = entropy + log_det_jacobian\n", + " return entropy.sum(dim=-1)\n", + "\n", + " @torch.jit.export\n", + " def dist_log_prob(self, loc, scale, dist):\n", + " log_unnormalized = -0.5 * ((dist - loc) / scale).square()\n", + " log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)\n", + " log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))\n", + " log_prob = log_unnormalized - log_normalized - log_det_jacobian\n", + " return log_prob.sum(dim=-1)\n", + "\n", + " @torch.jit.export\n", + " def update_normalization(self, observation):\n", + " self.num_steps += observation.shape[0] * observation.shape[1]\n", + " input_to_old_mean = observation - self.running_mean\n", + " mean_diff = torch.sum(input_to_old_mean / self.num_steps, dim=(0, 1))\n", + " self.running_mean = self.running_mean + mean_diff\n", + " input_to_new_mean = observation - self.running_mean\n", + " var_diff = torch.sum(input_to_new_mean * input_to_old_mean, dim=(0, 1))\n", + " self.running_variance = self.running_variance + var_diff\n", + "\n", + " @torch.jit.export\n", + " def normalize(self, observation):\n", + " variance = self.running_variance / (self.num_steps + 1.0)\n", + " variance = torch.clip(variance, 1e-6, 1e6)\n", + " return ((observation - self.running_mean) / variance.sqrt()).clip(-5, 5)\n", + "\n", + " @torch.jit.export\n", + " def get_logits_action(self, observation):\n", + " observation = self.normalize(observation)\n", + " logits = self.policy(observation)\n", + " loc, scale = self.dist_create(logits)\n", + " action = self.dist_sample_no_postprocess(loc, scale)\n", + " return logits, action\n", + "\n", + " @torch.jit.export\n", + " def compute_gae(self, truncation, termination, reward, values,\n", + " bootstrap_value):\n", + " truncation_mask = 1 - truncation\n", + " # Append bootstrapped value to get [v1, ..., v_t+1]\n", + " values_t_plus_1 = torch.cat(\n", + " [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)\n", + " deltas = reward + self.discounting * (\n", + " 1 - termination) * values_t_plus_1 - values\n", + " deltas *= truncation_mask\n", + "\n", + " acc = torch.zeros_like(bootstrap_value)\n", + " vs_minus_v_xs = torch.zeros_like(truncation_mask)\n", + "\n", + " for ti in range(truncation_mask.shape[0]):\n", + " ti = truncation_mask.shape[0] - ti - 1\n", + " acc = deltas[ti] + self.discounting * (\n", + " 1 - termination[ti]) * truncation_mask[ti] * self.lambda_ * acc\n", + " vs_minus_v_xs[ti] = acc\n", + "\n", + " # Add V(x_s) to get v_s.\n", + " vs = vs_minus_v_xs + values\n", + " vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], 0)\n", + " advantages = (reward + self.discounting *\n", + " (1 - termination) * vs_t_plus_1 - values) * truncation_mask\n", + " return vs, advantages\n", + "\n", + " @torch.jit.export\n", + " def loss(self, td: Dict[str, torch.Tensor]):\n", + " observation = self.normalize(td['observation'])\n", + " policy_logits = self.policy(observation[:-1])\n", + " baseline = self.value(observation)\n", + " baseline = torch.squeeze(baseline, dim=-1)\n", + "\n", + " # Use last baseline value (from the value function) to bootstrap.\n", + " bootstrap_value = baseline[-1]\n", + " baseline = baseline[:-1]\n", + " reward = td['reward'] * self.reward_scaling\n", + " termination = td['done'] * (1 - td['truncation'])\n", + "\n", + " loc, scale = self.dist_create(td['logits'])\n", + " behaviour_action_log_probs = self.dist_log_prob(loc, scale, td['action'])\n", + " loc, scale = self.dist_create(policy_logits)\n", + " target_action_log_probs = self.dist_log_prob(loc, scale, td['action'])\n", + "\n", + " with torch.no_grad():\n", + " vs, advantages = self.compute_gae(\n", + " truncation=td['truncation'],\n", + " termination=termination,\n", + " reward=reward,\n", + " values=baseline,\n", + " bootstrap_value=bootstrap_value)\n", + "\n", + " rho_s = torch.exp(target_action_log_probs - behaviour_action_log_probs)\n", + " surrogate_loss1 = rho_s * advantages\n", + " surrogate_loss2 = rho_s.clip(1 - self.epsilon,\n", + " 1 + self.epsilon) * advantages\n", + " policy_loss = -torch.mean(torch.minimum(surrogate_loss1, surrogate_loss2))\n", + "\n", + " # Value function loss\n", + " v_error = vs - baseline\n", + " v_loss = torch.mean(v_error * v_error) * 0.5 * 0.5\n", + "\n", + " # Entropy reward\n", + " entropy = torch.mean(self.dist_entropy(loc, scale))\n", + " entropy_loss = self.entropy_cost * -entropy\n", + "\n", + " return policy_loss + v_loss + entropy_loss" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CWbuk7IAR0SU" + }, + "source": [ + "Finally, some code for unrolling and batching environment data:" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "D3y5o7-oSBm-" + }, + "source": [ + "StepData = collections.namedtuple(\n", + " 'StepData',\n", + " ('observation', 'logits', 'action', 'reward', 'done', 'truncation'))\n", + "\n", + "\n", + "def sd_map(f: Callable[..., torch.Tensor], *sds) -> StepData:\n", + " \"\"\"Map a function over each field in StepData.\"\"\"\n", + " items = {}\n", + " keys = sds[0]._asdict().keys()\n", + " for k in keys:\n", + " items[k] = f(*[sd._asdict()[k] for sd in sds])\n", + " return StepData(**items)\n", + "\n", + "\n", + "def eval_unroll(agent, env, length):\n", + " \"\"\"Return number of episodes and average reward for a single unroll.\"\"\"\n", + " observation = env.reset()\n", + " episodes = torch.zeros((), device=agent.device)\n", + " episode_reward = torch.zeros((), device=agent.device)\n", + " for _ in range(length):\n", + " _, action = agent.get_logits_action(observation)\n", + " observation, reward, done, _ = env.step(Agent.dist_postprocess(action))\n", + " episodes += torch.sum(done)\n", + " episode_reward += torch.sum(reward)\n", + " return episodes, episode_reward / episodes\n", + "\n", + "\n", + "def train_unroll(agent, env, observation, num_unrolls, unroll_length):\n", + " \"\"\"Return step data over multple unrolls.\"\"\"\n", + " sd = StepData([], [], [], [], [], [])\n", + " for _ in range(num_unrolls):\n", + " one_unroll = StepData([observation], [], [], [], [], [])\n", + " for _ in range(unroll_length):\n", + " logits, action = agent.get_logits_action(observation)\n", + " observation, reward, done, info = env.step(Agent.dist_postprocess(action))\n", + " one_unroll.observation.append(observation)\n", + " one_unroll.logits.append(logits)\n", + " one_unroll.action.append(action)\n", + " one_unroll.reward.append(reward)\n", + " one_unroll.done.append(done)\n", + " one_unroll.truncation.append(info['truncation'])\n", + " one_unroll = sd_map(torch.stack, one_unroll)\n", + " sd = sd_map(lambda x, y: x + [y], sd, one_unroll)\n", + " td = sd_map(torch.stack, sd)\n", + " return observation, td\n", + "\n", + "\n", + "def train(\n", + " env_name: str = 'ant',\n", + " num_envs: int = 2048,\n", + " episode_length: int = 1000,\n", + " device: str = 'cuda',\n", + " num_timesteps: int = 30_000_000,\n", + " eval_frequency: int = 10,\n", + " unroll_length: int = 5,\n", + " batch_size: int = 1024,\n", + " num_minibatches: int = 32,\n", + " num_update_epochs: int = 4,\n", + " reward_scaling: float = .1,\n", + " entropy_cost: float = 1e-2,\n", + " discounting: float = .97,\n", + " learning_rate: float = 3e-4,\n", + " progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None,\n", + "):\n", + " \"\"\"Trains a policy via PPO.\"\"\"\n", + " env = envs.create(env_name, batch_size=num_envs,\n", + " episode_length=episode_length,\n", + " backend='spring')\n", + " env = gym_wrapper.VectorGymWrapper(env)\n", + " # automatically convert between jax ndarrays and torch tensors:\n", + " env = torch_wrapper.TorchWrapper(env, device=device)\n", + "\n", + " # env warmup\n", + " env.reset()\n", + " action = torch.zeros(env.action_space.shape).to(device)\n", + " env.step(action)\n", + "\n", + " # create the agent\n", + " policy_layers = [\n", + " env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2\n", + " ]\n", + " value_layers = [env.observation_space.shape[-1], 64, 64, 1]\n", + " agent = Agent(policy_layers, value_layers, entropy_cost, discounting,\n", + " reward_scaling, device)\n", + " agent = torch.jit.script(agent.to(device))\n", + " optimizer = optim.Adam(agent.parameters(), lr=learning_rate)\n", + "\n", + " sps = 0\n", + " total_steps = 0\n", + " total_loss = 0\n", + " for eval_i in range(eval_frequency + 1):\n", + " if progress_fn:\n", + " t = time.time()\n", + " with torch.no_grad():\n", + " episode_count, episode_reward = eval_unroll(agent, env, episode_length)\n", + " duration = time.time() - t\n", + " # TODO: only count stats from completed episodes\n", + " episode_avg_length = env.num_envs * episode_length / episode_count\n", + " eval_sps = env.num_envs * episode_length / duration\n", + " progress = {\n", + " 'eval/episode_reward': episode_reward,\n", + " 'eval/completed_episodes': episode_count,\n", + " 'eval/avg_episode_length': episode_avg_length,\n", + " 'speed/sps': sps,\n", + " 'speed/eval_sps': eval_sps,\n", + " 'losses/total_loss': total_loss,\n", + " }\n", + " progress_fn(total_steps, progress)\n", + "\n", + " if eval_i == eval_frequency:\n", + " break\n", + "\n", + " observation = env.reset()\n", + " num_steps = batch_size * num_minibatches * unroll_length\n", + " num_epochs = num_timesteps // (num_steps * eval_frequency)\n", + " num_unrolls = batch_size * num_minibatches // env.num_envs\n", + " total_loss = 0\n", + " t = time.time()\n", + " for _ in range(num_epochs):\n", + " observation, td = train_unroll(agent, env, observation, num_unrolls,\n", + " unroll_length)\n", + "\n", + " # make unroll first\n", + " def unroll_first(data):\n", + " data = data.swapaxes(0, 1)\n", + " return data.reshape([data.shape[0], -1] + list(data.shape[3:]))\n", + " td = sd_map(unroll_first, td)\n", + "\n", + " # update normalization statistics\n", + " agent.update_normalization(td.observation)\n", + "\n", + " for _ in range(num_update_epochs):\n", + " # shuffle and batch the data\n", + " with torch.no_grad():\n", + " permutation = torch.randperm(td.observation.shape[1], device=device)\n", + " def shuffle_batch(data):\n", + " data = data[:, permutation]\n", + " data = data.reshape([data.shape[0], num_minibatches, -1] +\n", + " list(data.shape[2:]))\n", + " return data.swapaxes(0, 1)\n", + " epoch_td = sd_map(shuffle_batch, td)\n", + "\n", + " for minibatch_i in range(num_minibatches):\n", + " td_minibatch = sd_map(lambda d: d[minibatch_i], epoch_td)\n", + " loss = agent.loss(td_minibatch._asdict())\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss\n", + "\n", + " duration = time.time() - t\n", + " total_steps += num_epochs * num_steps\n", + " total_loss = total_loss / (num_epochs * num_update_epochs * num_minibatches)\n", + " sps = num_epochs * num_steps / duration" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "R2A9MMlHUajH" + }, + "source": [ + "Let's go!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "B-lrKHvkUeYM", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 524 + }, + "outputId": "b410554d-2de5-4fa1-9df3-df6f0ba5c980" + }, + "source": [ + "# temporary fix to cuda memory OOM\n", + "os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'\n", + "\n", + "xdata = []\n", + "ydata = []\n", + "eval_sps = []\n", + "train_sps = []\n", + "times = [datetime.now()]\n", + "\n", + "def progress(num_steps, metrics):\n", + " times.append(datetime.now())\n", + " xdata.append(num_steps)\n", + " ydata.append(metrics['eval/episode_reward'].cpu())\n", + " eval_sps.append(metrics['speed/eval_sps'])\n", + " train_sps.append(metrics['speed/sps'])\n", + " clear_output(wait=True)\n", + " plt.xlim([0, 30_000_000])\n", + " plt.ylim([0, 6000])\n", + " plt.xlabel('# environment steps')\n", + " plt.ylabel('reward per episode')\n", + " plt.plot(xdata, ydata)\n", + " plt.show()\n", + "\n", + "train(progress_fn=progress)\n", + "\n", + "print(f'time to jit: {times[1] - times[0]}')\n", + "print(f'time to train: {times[-1] - times[1]}')\n", + "print(f'eval steps/sec: {np.mean(eval_sps)}')\n", + "print(f'train steps/sec: {np.mean(train_sps)}')" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "time to jit: 0:00:47.071182\n", + "time to train: 0:04:52.657978\n", + "eval steps/sec: 251792.92696345953\n", + "train steps/sec: 127184.56138694892\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y2p-20bCi4iI" + }, + "source": [ + "In this arrangement, we can rollout environment steps much faster than we can train: the speed at which PyTorch can backpropagate the loss and step the optimizer is the bottleneck. This PyTorch code can probably be sped up by adding [automatic mixed precision](https://pytorch.org/docs/stable/notes/amp_examples.html), and following other recommendations in the [PyTorch performance tuning guide](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html).\n", + "\n", + "We know we have a fair bit of headroom to improve the PyTorch implementation, as the built-in Brax trainer (which uses [flax.optim](https://flax.readthedocs.io/en/latest/flax.optim.html)) runs at more than double the steps per second:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Xmuz3I21p35H", + "outputId": "020efafe-d940-4943-9ca2-bdd534ca2b4f" + }, + "source": [ + "train_sps = []\n", + "\n", + "def progress(_, metrics):\n", + " if 'training/sps' in metrics:\n", + " train_sps.append(metrics['training/sps'])\n", + "\n", + "ppo.train(\n", + " environment=envs.create(env_name='ant', backend='spring'),\n", + " num_timesteps = 30_000_000, num_evals = 10, reward_scaling = .1,\n", + " episode_length = 1000, normalize_observations = True, action_repeat = 1,\n", + " unroll_length = 5, num_minibatches = 32, num_updates_per_batch = 4,\n", + " discounting = 0.97, learning_rate = 3e-4, entropy_cost = 1e-2,\n", + " num_envs = 2048, batch_size = 1024, progress_fn = progress)\n", + "\n", + "print(f'train steps/sec: {np.mean(train_sps[1:])}')" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "train steps/sec: 437059.02080672665\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eqXKdDwVL6L4" + }, + "source": [ + "tunaalabagana! 👋" + ] + } + ] +} \ No newline at end of file