AI 일반/모델, 아키텍처, 구현
causal mask
bellmake
2025. 5. 28. 14:25
* 실제 사용은 하단 전체 코드 참고
미래 토큰 차단용 마스크 (causal mask)
self.register_buffer(
'mask',
torch.triu(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH), diagonal=1)
)
- (L, L) 크기의 상삼각 행렬 생성 (L = CONTEXT_LENGTH).
- 대각선 위쪽(미래 위치)에 1이 채워져 있고, 나머지는 0.
- register_buffer로 저장하면 파라미터가 아니지만 .to(device) 등 이동 시 함께 따라갑니다.
- 이후 forward에서
scores = scores.masked_fill(self.mask == 1, float('-inf'))
처럼 사용해서 “미래 정보”가 보이지 않도록 막음.
Transformer의 causal mask(미래 토큰 차단용 마스크)는, 자기회귀(autoregressive) 모델에서 “현재 위치보다 나중(미래)에 있는 토큰”으로부터 정보를 얻지 못하도록 강제하는 역할을 합니다.
1. 마스크 행렬 생성
self.register_buffer(
'mask',
torch.triu(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH), diagonal=1)
)
- torch.ones(L, L) 은 전부 1로 채워진 L×LL \times L 행렬을 만듭니다.
- torch.triu(..., diagonal=1) 은 그 상삼각(upper triangular) 중 대각선 바로 위(diagonal=1)부터 1로 채웁니다.
- 즉, (i, j) 원소가 j > i 일 때만 1이 됩니다.
- (i == j) 또는 j < i 인 경우에는 0이 되죠.
- register_buffer 로 저장하면 이 텐서는 모델 파라미터가 아니면서도 GPU/CPU 이동(.to(device)) 시 함께 이동하며, 체크포인트에 포함됩니다.
결과적으로 mask 텐서는 다음과 같은 형태를 가집니다 (간단히 L=5L=5 예시):
- 행의 인덱스 i는 쿼리 위치,
- 열의 인덱스 j는 키 위치를 의미합니다.
- (i, j)에 1이 있다는 것은 “쿼리 i가 키 j를 보지 못하게(차단) 하라”는 뜻입니다.
2. 어텐션 점수 계산과 마스킹
Self-attention 계산은 일반적으로 다음 순서로 이루어집니다:
1. 쿼리, 키, 값
2. 어텐션 스코어
여기서 scores의 shape 은 (\text{batch}, \text{num_heads}, L, L) 입니다.
3. 마스킹 적용
# mask: (L, L) → (1, 1, L, L)로 차원 확장 후 브로드캐스트
scores = scores.masked_fill(self.mask[None, None, :, :] == 1, float('-inf'))
- self.mask == 1인 위치에 -inf를 채워 넣습니다.
- -inf가 softmax를 통과하면 확률이 0이 되어, 완전히 차단된 셈이 됩니다.
4. softmax & 드롭아웃
5. 가중합
이 과정을 통해, 쿼리 위치 i는 오직 자신보다 앞(과거) 혹은 같은 위치의 키만 참조하고, j > i인 미래의 키(토큰)는 완전히 무시하게 됩니다.
3. 왜 -inf인가?
- softmax 함수에 -inf가 들어가면 해당 위치의 출력 확률이 0이 됩니다.
- 만약 0 대신 “매우 작은 음수”를 넣으면 완벽하게 차단되지 않아 미세하게나마 정보가 흘러들어갈 수 있기 때문에, 자기회귀 모델에서는 반드시 −∞ 를 사용해 완전 차단합니다.
4. 요약
- 마스크 행렬은 upper triangular 형태로, 미래 위치를 1로 표시.
- register_buffer로 저장해 파라미터가 아니면서도 이동·저장 가능.
- masked_fill에 -inf를 넣어 softmax 결과를 0으로 만들어 미래 토큰을 완전 차단.
- Softmax → dropout → V 가중합 순으로 self-attention 이 계산됨.
이 구조 덕분에, GPT-계열 같은 자기회귀 언어 모델은 “현재까지 생성된 토큰”만 보고 다음 토큰을 예측할 수 있게 되는 것입니다.
전체 코드
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
assert d_out % NUM_HEADS == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.head_dim = d_out // NUM_HEADS
self.W_query = nn.Linear(d_in, d_out, bias=QKV_BIAS)
self.W_key = nn.Linear(d_in, d_out, bias=QKV_BIAS)
self.W_value = nn.Linear(d_in, d_out, bias=QKV_BIAS)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(DROP_RATE)
self.register_buffer('mask', torch.triu(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
keys = keys.view(b, num_tokens, NUM_HEADS, self.head_dim)
values = values.view(b, num_tokens, NUM_HEADS, self.head_dim)
queries = queries.view(b, num_tokens, NUM_HEADS, self.head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
attn_scores = queries @ keys.transpose(2, 3)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = (attn_weights @ values).transpose(1, 2)
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(EMB_DIM, 4 * EMB_DIM),
GELU(),
nn.Linear(4 * EMB_DIM, EMB_DIM),
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.att = MultiHeadAttention(
d_in=EMB_DIM,
d_out=EMB_DIM)
self.ff = FeedForward()
self.norm1 = LayerNorm(EMB_DIM)
self.norm2 = LayerNorm(EMB_DIM)
self.drop_shortcut = nn.Dropout(DROP_RATE)
def forward(self, x):
shortcut = x
x = self.norm1(x)
x = self.att(x)
x = self.drop_shortcut(x)
x = x + shortcut
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut
return x
class GPTModel(nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB_SIZE, EMB_DIM)
self.pos_emb = nn.Embedding(CONTEXT_LENGTH, EMB_DIM)
self.drop_emb = nn.Dropout(DROP_RATE)
self.trf_blocks = nn.Sequential(
*[TransformerBlock() for _ in range(NUM_LAYERS)])
self.final_norm = LayerNorm(EMB_DIM)
self.out_head = nn.Linear(EMB_DIM, VOCAB_SIZE, bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits