Paper reviews: how can we understand reinforcement learning agents

@gbonaert Gregory Bonaert bonaert

Update: added Unsupervised Video Object Segmentation for Deep Reinforcement Learning

RL has made some incredible progress these last 5 years. We’re now able train superhuman agents that can play Atari games from raw pixels, beat Go and Starcraft world champions; robots learned how to handle a huge amount of objects in warehouses (including many they have never seen before) and how to walk from scratch.

On the left, we can see AlphaStar’s game play. The raw pixel observations are fed into the neural network leading to the visualized activations. From this, the agent predicts the expected outcome (win, draw, lose) and the next action it will take. On the right we can see MaNa playing (one of the best players in the world). Of course, the AlphaStar agent can’t see what MaNa is doing.
Source: Deep Mind Blog – AlphaStar: Mastering the Real-Time Strategy Game StarCraft II

What lead to all of these improvements? There are many advances, but the key one is neural networks. Their performance in normal supervised learning came from their ability to efficiently learn good representations, and that ability is just as useful in reinforcement learning. Andrej Karpathy has a fantastic blog post “Deep Reinforcement Learning: Pong from Pixels” which explains very clearly how neural networks are used.

Problems of deep learning

Unfortunately, we can’t reap the advantages of neural networks without also suffering from their disadvantages. Chief among them, it’s very hard to interpret deep learning models and understand what they learned, which makes it hard to trust them. A modern network basically does a series of matrix multiplications and applies simple functions (such as ReLU, which replaces negative numbers by 0). If we serialize it to disk, they’re a huge collection of numbers. If you receive 150 million numbers, it’s not exactly trivial to deduce what the network has learned and if it’s robust.

Adversarial attacks. Recently, a lot of research has been done on adversarial attacks, which are tweaked inputs that can fool a network into making the wrong prediction even though a human wouldn’t notice the difference. The most famous example comes from the “Explaining and Harnessing Adversarial Examples” paper, where they intelligently added some very small noise to the image of a panda and it got classified as a gibbon. Humans wouldn’t see the difference between the 2 images, yet the classifier was completely fooled.

The right image is created by adding a very small amount of noise (the image in the middle is greatly amplified to make the changes visible to the human eye) to the image on the left. While the left classification is correct, the right image is incorrectly classified (and worse, with very high confidence, higher than the confidence for panda!)
Source: “Explaining and Harnessing Adversarial Examples” by Goodfellow et al (2015)

A cold-war game of finding new attacks and countering it with new defenses erupted between researchers, with dozens of new attacks being created, with nice properties such as still being effective after image compression or being effective on many different classifiers. On the defense side, many new techniques were developed, such as adversarial training (use adversarial examples during training), training certifiably robust networks (train the network so that it’s easier to prove that it will be correct for some class of perturbations) or deflecting adversarial attacks (cause the attacker to create inputs that actually look like the changed class, thus defeating the point).

Interpretability. Because it’s not obvious how a neural network goes from input to output, it may be hard to rely on it. For example, if a medical AI system said “The patient has cancer”, that would not be good enough. We need to understand why it predicted cancer. Were there some anomalies, some strange spots, some known disease patterns? How did it reach that conclusion? Which evidence lead to that output?

For some domains, not knowing these answers is unacceptable. This prompted researchers to better understand the predictions of deep learning systems.

Interpretability and robustness in deep RL

In reinforcement learning, these problems are even more complex, since the input is often non-visual or a partial view of the world. Additionally, explaining a single decision is almost never good enough (e.g. in chess, advance the queen) to understand the strategy of the agent. We need to understand the whole sequence of decisions it took and how it reacted to new information as it came in. This requires creating higher-level explanations, which are coherent over a long time span. In other words, this is pretty hard!

A lot of research was done in the supervised learning case (image classification for example) but the field is still very new and unexplored in the reinforcement learning setting. I’ll show the main methods that were created and how they were applied to reinforcement learning. If they’re relevant for supervised learning, I may also include them.

Saliency maps. One way to gain insight is to figure out which part of the input (X-ray scan for example) was the most important to get the prediction. Several methods to create these saliency maps were created, using gradient information (which pixels have a large effect on the predictions), occlusions (see which pixels can’t be perturbed) or patterns of neuron activations (a Distill paper, check it out)

The Grad-Cam method attributes the dog classification to pixels near the dog’s head and the cat classification of pixels around the cat’s body.

Source: “Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization” by Selvaraju et al. (2019)

Which minimal set of pixels do we have to perturb to reduce the confidence in the ‘flute’ classification?

Source: “Interpretable Explanations of Black Boxes by Meaningful Perturbation” by Fong et al. (2018)

This perturbation approach was applied to RL (Greydanus et al. 2018), to see which part of the input was most important for the action distribution (policy) and the value function (Q values). In the Breakout Atari game, the method shows the agent learned to pay attention to the corners so that it could send the ball there and create a tunnel. Additionally, for the games where the agent was very bad, the saliency maps show it focused on wrong or irrelevant parts of the screen. For example, in Enduro, it focused on the distant mountains and failed to focus on the other car and the road (see below).

Blue pixels indicate parts that affect the action distribution (policy) the most and red ones affect the value function (critic) the most

Source: “Visualizing and Understanding Atari Agents” by Graydanus et al. (2018)

In Enduro, the car has to race and avoid hitting the other cars (in blue, circled). The policy saliency is shown in green. The agent only focuses on itself and the distant mountains, and pays no attention to the blue car, indicating a problem with its strategy.

Source: “Visualizing and Understanding Atari Agents” by Graydanus et al. (2018)

Counterfactuals. To understand why the agent took its decision, it would be useful to know which parts of the input, if not seen by the agent, would affect the most its output. Perturbation-based saliency maps attempt to answer this question, but they have one big weakness: by adding noise or blurring parts of the input, they create inputs that are not realistic and will not be encountered by the agent. Counterfactuals, on the other hand, are inputs that lead to a different action and that are as close as possible to the original input, but which are still realistic (likely given the data distribution). They answer the question “What’s the closest realistic input that leads to a different classification / action?”

To test the importance of an image region, Chang et al. (2019) use a generative model that realistically fills in the perturbed regions so that the changed image remains realistic.

These are the saliency maps obtained when answering the question “What is the smallest input region that could be substituted into a fixed reference input in order to maximize the classification score?” The generative approach produces much better results and tends to focus the saliency maps on more relevant parts of the image than the simple heuristic in-filling methods.
Source: “Explaining Image Classifiers by Counterfactual Generation” by Chang et al. (2019)

On the RL side, Olson et al. (2019) use an encoder-generator approach to generate counterfactuals. The generator receives an encoded version of the image and the desired policy distribution, and creates an image that leads to action to follow the desired action. A discriminator is used to train the whole system.

A known problem is that the latent space in the auto-encoder system might have holes, meaning that images generated from those holes would look unrealistic. To avoid this, a Wasserstein auto-encoder is used to get the original policy distribution and results in better outputs (a more detailed explanation is present in the paper).

Left: The ablated version uses a simple generator approach, where the generator receives the raw image (instead of an encoding of it) and the new action and tries to modify the raw image as little as possible to ensure the agent’s policy is to perform the new action. The created image is very unrealistic.

Middle and Right: the authors attempt to use the Contrastive Explanation Method, but he results are either filled with artifacts or almost identical to the input image (note: this seems to indicate the agent is vulnerable to an adversarial attack).
Source: “Counterfactual States for Atari Agents via Generative Deep Learning” by Olson et al. (2019)
The author’s method creates more realistic inputs than the 2 previous methods. The left part of each pair shows the original input and the action the agent would take, and on the right the action we want the agent to take and the counterfactual input that would lead the agent to do that action. The quality is better, but it’s not always perfect: in the bottom right we see 2 submarines, which is impossible in the real game. Additionally, some other examples in the paper also show the presence of artifacts.
Source: “Counterfactual States for Atari Agents via Generative Deep Learning” by Olson et al. (2019)

Policy distillation. The idea behind policy distillation is to take a trained expert agent and use it to train a student agent. Why would we care about this? Because the student agent can have a different architecture than the teacher: it can be smaller, simpler or even a completely different model with good properties, such as a decision tree (good for interpretability).

Additionally, the student agent may learn from multiple teachers who are experts in different environments, thus learning from their combined experience. If we tried to train the agent to solve 55 Atari games at the same time, it wouldn’t work – the agent would be very bad in all games; however, training one agent per Atari game and then making the student learn from the 55 trained agents will allow the student to play most of the games.

The paper “Policy Distillation” by Rusu et al. (2016) first applied this idea in reinforcement learning and proved its effectiveness. The student network tries to improve its Q-value predictions and at the same time output an action distribution as similar as possible to the teacher. Remarkably, the student network could even perform better the teacher, even in the single-game scenario and even if the student network was smaller!

Source: “Policy Distillation” by Rusu et al. (2016)

Online distillation is an interesting variant – instead of training the student in the end, the student learns from the teacher as the teacher is being trained. The original “Policy Distillation” paper did this, but this was considerably improved by Sun and Fazli (2019) using two simple tricks:

  1. Use Reverse KL divergence instead of normal KL divergence to evaluate the difference between the policies of the student and the teacher
  2. In the student’s Q-value update rule, pick the action taken in the next state using the teacher’s policy, instead of the student’s policy. This increases even further how much the agent learns from the teacher.

Using online distillation allows faster distillation and can allow even very small networks to have good performance. They managed to train a student whose size is 1.7% of the teacher network and still achieve good performance in most games. Additionally, the real-time policy distillation with both tricks performs better than the one in the original “Policy Distillation” paper.

Interpretability from distillation. The previous points explored how distillation could be used to compress the teacher into a much smaller student. This can be useful, but it doesn’t necessarily improve interpretability since the result is still a fairly large neural network. Another approach is to distill the teacher into a more interpretable model. A balance needs to attained here:

  • If the model is too simple, the performance will drop a lot and the student might be too different from the teacher, so we won’t actually be explaining the teacher’s actions
  • If the model is too complex, the performance can be good but we don’t gain much interpretability

Coppens et al. (2019) distill the teacher into a Soft Decision Tree-based agent. Soft Decision Trees resemble normal decision trees, except that the nodes learn a probability distribution for going left or right, instead of a simple binary rule that means going right or left (no probabilities involved).

To pick the action, a series of relatively-simple decisions is taken. If we examine how those decisions are taken, we may gain insight into the policy of the original agent. The decisions are taken by a simple perceptron, which allows us to see which parts of the input are important (by seeing if their associated weight in the perceptron is very high in absolute value).

Visualization of the importance of each part of the input in the Mario AI benchmark. We can see that different leaves place emphasis on different parts of the input. In the top row, the soft decision tree outputs an action distribution close to the original, but in the second row the output is very different.
Source: “Distilling Deep Reinforcement Learning Policies in Soft Decision Trees” by Coppens et al. (2019)

One weakness of this approach is that the soft-decision tree can sometimes take very different decisions than the teacher. This raises the question if the student is similar enough to the teacher to explain the teacher’s behavior. However, picking a different model or improving training could solve this problem.

Layer-wise relevance propagation.

In contrast to those gradient-based saliency maps, Bach et al. proposed
a method that directly uses the activations of the neurons during the forward pass to calculate the relevance of the input pixels. This is computationally efficient compared to gradient-based methods because they can reuse the values of the forward pass. Instead of calculating how much a change in an input pixel would impact the prediction, Bach et al. investigate the contribution of the input pixels to prediction. For this purpose, they do not only describe a single specific algorithm but introduce a general concept which they call layer-wise relevance propagation (LRP).

This concept has two advantageous properties which gradient-based saliency maps lack. The first is the conservation property which says that the sum of all relevance values, generated by LRP, is equal to the value of the prediction. This ascertains that the relevance values reflect the certainty of the prediction. The second property is positivity which states that all relevance values are non-negative. This ascertains that the generated saliency maps do not contain contradictory evidence

Enhancing Explainability of Deep Reinforcement Learning Through Selective Layer-Wise Relevance Propagation” by Huber et al. (2019)

However, using the raw relevance propagation leads to saliency maps that highlighted a large amount of the Atari image. Empirical studies show that people don’t like explanations that include every single possible cause and influence, and prefer instead explanations that focus on the most important evidence. Therefore, relevance propagation was changed so that only the most important parts of the input would be salient. To do so, each neuron A attributed all of its relevance to the input node which contributed the most to its activation. This introduced a lot of sparsity, which lead to sparser saliency maps and thus made them easier to understand and interpret.

2 saliency maps are created for the input on the left, using two methods. In the middle, a simple method with no sparsity constraints leads to a large and diffuse saliency map, which doesn’t help to understand the policy very much. On the right side, through sparsity constraints, the saliency map focuses on the most important elements and is much more sparse.
Source: “Enhancing Explainability of Deep Reinforcement Learning Through Selective Layer-Wise Relevance Propagation” by Huber et al. (2019)
The original approach distributed relevance to the previous nodes according to their contribution to the activation. To increase sparsity, the authors create the argmax approach. When attributing relevance from each node of layer L to nodes of layer L – 1, all of the relevance of a node A in layer L is assigned to the single input activation in layer L – 1 that contributed the most to the activation of node A.
Source: “Enhancing Explainability of Deep Reinforcement Learning Through Selective Layer-Wise Relevance Propagation” by Huber et al. (2019)

Understanding the agent’s behavior in critical states. While in most states taking the wrong action is not terrible, in some states, taking the right decision is essential because taking the wrong one will lead to much smaller future rewards. Amir et al. (2018) use a simple approach, called Highlights, to understand the agent’s behavior in those states: they simply record what the agent does before, during and after critical states. Measuring critical states is simple: we can simply check the difference between the maximum Q-value and the minimum Q-value. They also develop a variant of the Highlights that ensures some diversity in the critical states that are recorded (we don’t want too many trajectories around a single critical state).

Learning to interpret. One interesting paper is “Learn to Interpret Atari Agents” by Zhao et al. (2018) where they incorporated interpretation / focus / visualization as part of the training process itself. The idea is that some pixels are more useful than others and humans tend to play a game by looking at selected parts of the screen rather than always considering the whole screen at once.

The agent would get access to only some parts of the input: the most salient ones, which are given higher importance. The mechanism is similar to attention but simpler and leads to clearer visualizations. Because we know what the agent is focusing on, it’s easy to interpret which parts of the environment led to its decisions. Remarkably, their approach also led to better performance in most games.

Source: “Learn to Interpret Atari Agents” by Zhao et al. (2018)

Importance maps. But how do we learn the importance of different parts of the screen at different moments / states of the game? Hardcoding or creating rules would be limiting and ineffective, so we can instead learn to do it. The authors added a module to the Rainbow architecture that takes the activations of filters in the middle of the network (e.g. the learned representation) and then created N difference importance maps with the following 3 steps:

  1. Using 1×1 convolutions, it produces N score maps from the M filter activations (N << M, in the example at the end, N = 2)
  2. These N scores maps are then normalized to generate a probability distribution, thus obtaining N normalized score maps
  3. The new state is obtained by multiplying each of the normalized score maps by the input and then summing the results
Source: “Learn to Interpret Atari Agents” by Zhao et al. (2018)

The mechanism is similar to attention because it puts emphasis on some parts of the input, but it doesn’t have the query-key component that’s part of attention, making it a bit simpler.

Visualizing the important regions. We obtain importance maps for the learned representation and not for the raw image input. People can’t directly understand the learned representation, so the importance maps aren’t directly useful. Instead, to figure out which parts of the Atari image are important, we need to relate them to the score maps.

To do so, they use gradient-based saliency between the input and the most important element of each score map (again, introducing the idea of sparsity, like in layer-wise relevance propagation). Using a threshold, they can then easily keep the important parts and black out the rest. Thus, if there are N score maps, they get N important image regions (one per score map), like in the figure below

In this Enduro example, there are two score maps from which we obtain 2 diagrams of important regions. The right one follows the agent, while the left one focuses on the other cars and thus different parts of the road. In the paper, they show that more advanced behavior can be discovered by using these important region maps. For example, in between two races, the agent focus on the sky instead of the road. Why? Because when it detects the sunrise it knows that a new race is about to start accelerating again.
Source: “Learn to Interpret Atari Agents” by Zhao et al. (2018)

Unsupervised learning to detect moving objects to learn better policies. Humans don’t think in terms of pixels but in terms of entities, and so learning to recognize those might improve what the agent does. However, we don’t want to manually specify what objects exist because then we would need to do that for every new scenario. Unsurprisingly, the author’s approach is to learn to detect these moving objects using unsupervised learning. The agent then learns which of those objects matter and slowly improves its policy.

To learn to recognize objects, the network tries to predict the optical flow for K objects. In other words, it tries to predict the movement of K objects, and the summed optical flow must be as close as possible to the actual change in the frame.

There are many possible solutions, so to lead to more realistic ones, they add L1 regularization in the terms that compute the optical flow, which means that in each of the K masks / optical flows we try to minimize the number of pixels that actually change. Why does this make sense? Because if a monster moves, only the pixels for the monster should have a non-zero optical flow, and everything else should be 0. The L1 regularization tries to ensure the “everything else should have 0 optical flow” part. To avoid predicting 0 optical flow for every pixel, the L1 regularization coefficient in the loss starts at 0 and increases linearly during the first 100k steps.

This technique leads to better performance in many games and is more interpretable because we can see which objects it detected and how important they were. This allows us to know to which objects it paid attention to.

Leave a Reply