Towards Interpretable Reinforcement Learning

Reimplementation for NeurIPS Reproducibility Challenge 2019.

Charles Lovering Brown University

2019

Our aim here was to reimplement Mott et al. (2019) for the NeurIPS reproducibility challenge. The authors used attention to demonstrate (and constrict) the agents' focus as they learn and play a range of games. This effort was mostly an exercise in torch indexing! It seems that others found the repository useful but I certainly would recommend looking toward other sources for a modern implementation (and option).

Architecture

The core idea is to split CNN+ConvLSTM vision features into keys K and values V at each spatial location, then compute dot-product attention using learned queries generated from the policy LSTM's hidden state. This produces per-query attention maps over the image grid — you can visualize where the agent is looking at each timestep, which is the interpretability payoff.

# Queries derived from previous policy hidden state
Q_t = self.query(prev_output)  # [B, num_queries, num_keys]

# Dot-product attention: keys at each spatial location scored against each query
A = torch.matmul(K_t, Q_t.transpose(2, 1).unsqueeze(1))  # [B, h, w, num_queries]
A = spatial_softmax(A)  # normalize over h×w per query

# Weighted readout of values
answers = apply_attention(A, V_t)  # [B, num_queries, num_values]

The spatial softmax normalizes over the h × w grid independently per query (not over channels), so each query's attention map sums to 1 across spatial positions — a proper "where to look" distribution.

Rather than learned positional embeddings, a fixed cosine basis is concatenated to both keys and values. It is constructed as an outer product of cosine functions over height and width coordinates, giving the attention mechanism access to absolute position without any extra learned parameters.

# Cosine basis: outer product of cos(position * frequency) over h and w
a = torch.mul(p_h.unsqueeze(2), u_basis)   # height × frequency
b = torch.mul(p_w.unsqueeze(2), v_basis)   # width × frequency
out = torch.einsum("hwu,hwv->hwuv", torch.cos(a), torch.cos(b)).reshape(h, w, d)

This spatial basis is concatenated to both K and V so that queries can attend to specific spatial regions and readouts carry positional signal.

The vision network uses a ConvLSTM (not a standard LSTM) so temporal state retains spatial structure — the hidden and cell states are feature maps, not flat vectors. The policy side is a separate standard LSTM that ingests the attention readouts, previous reward, and previous action. Both recurrent states are reset (zeroed via notdone masking) at episode boundaries.

View the code here.

Charles Lovering © 2026