Set the indices to be the lengths of each sequence in the batch. (Normally you can use pack/unpack in PyTorch, but this does not yet work with transformers.)
batch_size, seq_len, embed_dim = output.size()
selected = output[
torch.arange(batch_size),
indices,
...,
]