3-D Indexing with PyTorch

Select vectors from a 3-D tensor by index.

2020

Set the indices to be the lengths of each sequence in the batch. (Normally you can use pack/unpack in PyTorch, but this does not yet work with transformers.)

batch_size, seq_len, embed_dim = output.size()
selected = output[
    torch.arange(batch_size),
    indices,
    ...,
]

View gist

Charles Lovering © 2026