What lead to all of these improvements? There are many advances, but the key one is neural networks. Their performance in normal supervised learning came from their ability to efficiently learn good representations, and that ability is just as useful in reinforcement learning. Andrej Karpathy has a fantastic blog post “Deep Reinforcement Learning: Pong from Pixels” which explains very clearly how neural networks are used.
Problems of deep learning
Unfortunately, we can’t reap the advantages of neural networks without also suffering from their disadvantages. Chief among them, it’s very hard to interpret deep learning models and understand what they learned, which makes it hard to trust them. A modern network basically does a series of matrix multiplications and applies simple functions (such as ReLU, which replaces negative numbers by 0). If we serialize it to disk, they’re a huge collection of numbers. If you receive 150 million numbers, it’s not exactly trivial to deduce what the network has learned and if it’s robust.
Adversarial attacks. Recently, a lot of research has been done on adversarial attacks, which are tweaked inputs that can fool a network into making the wrong prediction even though a human wouldn’t notice the difference. The most famous example comes from the “Explaining and Harnessing Adversarial Examples” paper, where they intelligently added some very small noise to the image of a panda and it got classified as a gibbon. Humans wouldn’t see the difference between the 2 images, yet the classifier was completely fooled.
Interpretability. Because it’s not obvious how a neural network goes from input to output, it may be hard to rely on it. For example, if a medical AI system said “The patient has cancer”, that would not be good enough. We need to understand why it predicted cancer. Were there some anomalies, some strange spots, some known disease patterns? How did it reach that conclusion? Which evidence lead to that output?
For some domains, not knowing these answers is unacceptable. This prompted researchers to better understand the predictions of deep learning systems.
Interpretability and robustness in deep RL
In reinforcement learning, these problems are even more complex, since the input is often non-visual or a partial view of the world. Additionally, explaining a single decision is almost never good enough (e.g. in chess, advance the queen) to understand the strategy of the agent. We need to understand the whole sequence of decisions it took and how it reacted to new information as it came in. This requires creating higher-level explanations, which are coherent over a long time span. In other words, this is pretty hard!
A lot of research was done in the supervised learning case (image classification for example) but the field is still very new and unexplored in the reinforcement learning setting. I’ll show the main methods that were created and how they were applied to reinforcement learning. If they’re relevant for supervised learning, I may also include them.
Saliency maps. One way to gain insight is to figure out which part of the input (X-ray scan for example) was the most important to get the prediction. Several methods to create these saliency maps were created, using gradient information (which pixels have a large effect on the predictions), occlusions (see which pixels can’t be perturbed) or patterns of neuron activations (a Distill paper, check it out)
The Grad-Cam method attributes the dog classification to pixels near the dog’s head and the cat classification of pixels around the cat’s body.
This perturbation approach was applied to RL (Greydanus et al. 2018), to see which part of the input was most important for the action distribution (policy) and the value function (Q values). In the Breakout Atari game, the method shows the agent learned to pay attention to the corners so that it could send the ball there and create a tunnel. Additionally, for the games where the agent was very bad, the saliency maps show it focused on wrong or irrelevant parts of the screen. For example, in Enduro, it focused on the distant mountains and failed to focus on the other car and the road (see below).
Blue pixels indicate parts that affect the action distribution (policy) the most and red ones affect the value function (critic) the most
In Enduro, the car has to race and avoid hitting the other cars (in blue, circled). The policy saliency is shown in green. The agent only focuses on itself and the distant mountains, and pays no attention to the blue car, indicating a problem with its strategy.
Counterfactuals. To understand why the agent took its decision, it would be useful to know which parts of the input, if not seen by the agent, would affect the most its output. Perturbation-based saliency maps attempt to answer this question, but they have one big weakness: by adding noise or blurring parts of the input, they create inputs that are not realistic and will not be encountered by the agent. Counterfactuals, on the other hand, are inputs that lead to a different action and that are as close as possible to the original input, but which are still realistic (likely given the data distribution). They answer the question “What’s the closest realistic input that leads to a different classification / action?”
To test the importance of an image region, Chang et al. (2019) use a generative model that realistically fills in the perturbed regions so that the changed image remains realistic.
On the RL side, Olson et al. (2019) use an encoder-generator approach to generate counterfactuals. The generator receives an encoded version of the image and the desired policy distribution, and creates an image that leads to action to follow the desired action. A discriminator is used to train the whole system.
A known problem is that the latent space in the auto-encoder system might have holes, meaning that images generated from those holes would look unrealistic. To avoid this, a Wasserstein auto-encoder is used to get the original policy distribution and results in better outputs (a more detailed explanation is present in the paper).
Policy distillation. The idea behind policy distillation is to take a trained expert agent and use it to train a student agent. Why would we care about this? Because the student agent can have a different architecture than the teacher: it can be smaller, simpler or even a completely different model with good properties, such as a decision tree (good for interpretability).
Additionally, the student agent may learn from multiple teachers who are experts in different environments, thus learning from their combined experience. If we tried to train the agent to solve 55 Atari games at the same time, it wouldn’t work – the agent would be very bad in all games; however, training one agent per Atari game and then making the student learn from the 55 trained agents will allow the student to play most of the games.
The paper “Policy Distillation” by Rusu et al. (2016) first applied this idea in reinforcement learning and proved its effectiveness. The student network tries to improve its Q-value predictions and at the same time output an action distribution as similar as possible to the teacher. Remarkably, the student network could even perform betterthe teacher, even in the single-game scenario and even if the student network was smaller!
Online distillation is an interesting variant – instead of training the student in the end, the student learns from the teacher as the teacher is being trained. The original “Policy Distillation” paper did this, but this was considerably improved by Sun and Fazli (2019) using two simple tricks:
Use Reverse KL divergence instead of normal KL divergence to evaluate the difference between the policies of the student and the teacher
In the student’s Q-value update rule, pick the action taken in the next state using the teacher’s policy, instead of the student’s policy. This increases even further how much the agent learns from the teacher.
Using online distillation allows faster distillation and can allow even very small networks to have good performance. They managed to train a student whose size is 1.7% of the teacher network and still achieve good performance in most games. Additionally, the real-time policy distillation with both tricks performs better than the one in the original “Policy Distillation” paper.
Interpretability from distillation. The previous points explored how distillation could be used to compress the teacher into a much smaller student. This can be useful, but it doesn’t necessarily improve interpretability since the result is still a fairly large neural network. Another approach is to distill the teacher into a more interpretable model. A balance needs to attained here:
If the model is too simple, the performance will drop a lot and the student might be too different from the teacher, so we won’t actually be explaining the teacher’s actions
If the model is too complex, the performance can be good but we don’t gain much interpretability
Coppens et al. (2019) distill the teacher into a Soft Decision Tree-based agent. Soft Decision Trees resemble normal decision trees, except that the nodes learn a probability distribution for going left or right, instead of a simple binary rule that means going right or left (no probabilities involved).
To pick the action, a series of relatively-simple decisions is taken. If we examine how those decisions are taken, we may gain insight into the policy of the original agent. The decisions are taken by a simple perceptron, which allows us to see which parts of the input are important (by seeing if their associated weight in the perceptron is very high in absolute value).
One weakness of this approach is that the soft-decision tree can sometimes take very different decisions than the teacher. This raises the question if the student is similar enough to the teacher to explain the teacher’s behavior. However, picking a different model or improving training could solve this problem.
Layer-wise relevance propagation.
In contrast to those gradient-based saliency maps, Bach et al. proposed a method that directly uses the activations of the neurons during the forward pass to calculate the relevance of the input pixels. This is computationally efficient compared to gradient-based methods because they can reuse the values of the forward pass. Instead of calculating how much a change in an input pixel would impact the prediction, Bach et al. investigate the contribution of the input pixels to prediction. For this purpose, they do not only describe a single specific algorithm but introduce a general concept which they call layer-wise relevance propagation (LRP).
This concept has two advantageous properties which gradient-based saliency maps lack. The first is the conservation property which says that the sum of all relevance values, generated by LRP, is equal to the value of the prediction. This ascertains that the relevance values reflect the certainty of the prediction. The second property is positivity which states that all relevance values are non-negative. This ascertains that the generated saliency maps do not contain contradictory evidence
However, using the raw relevance propagation leads to saliency maps that highlighted a large amount of the Atari image. Empirical studies show that people don’t like explanations that include every single possible cause and influence, and prefer instead explanations that focus on the most important evidence. Therefore, relevance propagation was changed so that only the most important parts of the input would be salient. To do so, each neuron A attributed all of its relevance to the input node which contributed the most to its activation. This introduced a lot of sparsity, which lead to sparser saliency maps and thus made them easier to understand and interpret.
Understanding the agent’s behavior in critical states. While in most states taking the wrong action is not terrible, in some states, taking the right decision is essential because taking the wrong one will lead to much smaller future rewards. Amir et al. (2018) use a simple approach, called Highlights, to understand the agent’s behavior in those states: they simply record what the agent does before, during and after critical states. Measuring critical states is simple: we can simply check the difference between the maximum Q-value and the minimum Q-value. They also develop a variant of the Highlights that ensures some diversity in the critical states that are recorded (we don’t want too many trajectories around a single critical state).
Learning to interpret. One interesting paper is “Learn to Interpret Atari Agents” by Zhao et al. (2018) where they incorporated interpretation / focus / visualization as part of the training process itself. The idea is that some pixels are more useful than others and humans tend to play a game by looking at selected parts of the screen rather than always considering the whole screen at once.
The agent would get access to only some parts of the input: the most salient ones, which are given higher importance. The mechanism is similar to attention but simpler and leads to clearer visualizations. Because we know what the agent is focusing on, it’s easy to interpret which parts of the environment led to its decisions. Remarkably, their approach also led to better performance in most games.
Importance maps. But how do we learn the importance of different parts of the screen at different moments / states of the game? Hardcoding or creating rules would be limiting and ineffective, so we can instead learn to do it. The authors added a module to the Rainbow architecture that takes the activations of filters in the middle of the network (e.g. the learned representation) and then created N difference importance maps with the following 3 steps:
Using 1×1 convolutions, it produces N score maps from the M filter activations (N << M, in the example at the end, N = 2)
These N scores maps are then normalized to generate a probability distribution, thus obtaining N normalized score maps
The new state is obtained by multiplying each of the normalized score maps by the input and then summing the results
The mechanism is similar to attention because it puts emphasis on some parts of the input, but it doesn’t have the query-key component that’s part of attention, making it a bit simpler.
Visualizing the important regions. We obtain importance maps for the learned representation and not for the raw image input. People can’t directly understand the learned representation, so the importance maps aren’t directly useful. Instead, to figure out which parts of the Atari image are important, we need to relate them to the score maps.
To do so, they use gradient-based saliency between the input and the most important element of each score map (again, introducing the idea of sparsity, like in layer-wise relevance propagation). Using a threshold, they can then easily keep the important parts and black out the rest. Thus, if there are N score maps, they get N important image regions (one per score map), like in the figure below
Unsupervised learning to detect moving objects to learn better policies. Humans don’t think in terms of pixels but in terms of entities, and so learning to recognize those might improve what the agent does. However, we don’t want to manually specify what objects exist because then we would need to do that for every new scenario. Unsurprisingly, the author’s approach is to learn to detect these moving objects using unsupervised learning. The agent then learns which of those objects matter and slowly improves its policy.
To learn to recognize objects, the network tries to predict the optical flow for K objects. In other words, it tries to predict the movement of K objects, and the summed optical flow must be as close as possible to the actual change in the frame.
There are many possible solutions, so to lead to more realistic ones, they add L1 regularization in the terms that compute the optical flow, which means that in each of the K masks / optical flows we try to minimize the number of pixels that actually change. Why does this make sense? Because if a monster moves, only the pixels for the monster should have a non-zero optical flow, and everything else should be 0. The L1 regularization tries to ensure the “everything else should have 0 optical flow” part. To avoid predicting 0 optical flow for every pixel, the L1 regularization coefficient in the loss starts at 0 and increases linearly during the first 100k steps.
This technique leads to better performance in many games and is more interpretable because we can see which objects it detected and how important they were. This allows us to know to which objects it paid attention to.