torch.tril?
import torch
import torch.nn.functional as F
import math
B, H, T, D = 1, 1, 5, 4
q = torch.randn(B, H, T, D)
k = torch.randn(B, H, T, D)
v = torch.randn(B, H, T, D)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)
print("raw scores:\n", scores[0,0])
mask = torch.tril(torch.ones(T, T, dtype=torch.bool))
scores = scores.masked_fill(~mask, float("-inf"))
probs = F.softmax(scores, dim=-1)
print("\nmasked scores:\n", scores[0,0])
print("\nprobs:\n", probs[0,0])
This sets top right corner to -inf (so that softmax gives zero to masked positions), and bottom right corner is normal numbers.
# decoding token t=4
t = 4 # 0-based
q = torch.randn(B, H, 1, D) # query for token 4
k_all = torch.randn(B, H, t, D) # stored keys 0..3
v_all = torch.randn(B, H, t, D)
scores = torch.matmul(q, k_all.transpose(-2, -1)) / math.sqrt(D)
probs = F.softmax(scores, dim=-1)
print("\ndecode scores:\n", scores[0,0])
print("\ndecode probs:\n", probs[0,0])
In decode stage, causal mask not required.