bellmake 2025. 5. 28. 14:25

* 실제 사용은 하단 전체 코드 참고

미래 토큰 차단용 마스크 (causal mask)

self.register_buffer(
    'mask',
    torch.triu(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH), diagonal=1)
)
 
  • (L, L) 크기의 상삼각 행렬 생성 (L = CONTEXT_LENGTH).
  • 대각선 위쪽(미래 위치)에 1이 채워져 있고, 나머지는 0.
  • register_buffer로 저장하면 파라미터가 아니지만 .to(device) 등 이동 시 함께 따라갑니다.
  • 이후 forward에서
scores = scores.masked_fill(self.mask == 1, float('-inf'))

처럼 사용해서 “미래 정보”가 보이지 않도록 막음.

 

Transformer의 causal mask(미래 토큰 차단용 마스크)는, 자기회귀(autoregressive) 모델에서 “현재 위치보다 나중(미래)에 있는 토큰”으로부터 정보를 얻지 못하도록 강제하는 역할을 합니다.


1. 마스크 행렬 생성

self.register_buffer(
    'mask',
    torch.triu(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH), diagonal=1)
)
  1. torch.ones(L, L) 은 전부 1로 채워진 L×LL \times L 행렬을 만듭니다.
  2. torch.triu(..., diagonal=1) 은 그 상삼각(upper triangular) 중 대각선 바로 위(diagonal=1)부터 1로 채웁니다.
    • 즉, (i, j) 원소가 j > i 일 때만 1이 됩니다.
    • (i == j) 또는 j < i 인 경우에는 0이 되죠.
  3. register_buffer 로 저장하면 이 텐서는 모델 파라미터가 아니면서도 GPU/CPU 이동(.to(device)) 시 함께 이동하며, 체크포인트에 포함됩니다.

결과적으로 mask 텐서는 다음과 같은 형태를 가집니다 (간단히 L=5L=5 예시):

  • 행의 인덱스 i는 쿼리 위치,
  • 열의 인덱스 j는 키 위치를 의미합니다.
  • (i, j)에 1이 있다는 것은 “쿼리 i가 키 j를 보지 못하게(차단) 하라”는 뜻입니다.

2. 어텐션 점수 계산과 마스킹

Self-attention 계산은 일반적으로 다음 순서로 이루어집니다:

 

1. 쿼리, 키, 값

2. 어텐션 스코어

여기서 scores의 shape 은 (\text{batch}, \text{num_heads}, L, L) 입니다.

 

3. 마스킹 적용

# mask: (L, L) → (1, 1, L, L)로 차원 확장 후 브로드캐스트
scores = scores.masked_fill(self.mask[None, None, :, :] == 1, float('-inf'))
  • self.mask == 1인 위치에 -inf를 채워 넣습니다.
  • -inf가 softmax를 통과하면 확률이 0이 되어, 완전히 차단된 셈이 됩니다.

4. softmax & 드롭아웃

 

5. 가중합

이 과정을 통해, 쿼리 위치 i는 오직 자신보다 앞(과거) 혹은 같은 위치의 키만 참조하고, j > i인 미래의 키(토큰)는 완전히 무시하게 됩니다.


3. 왜 -inf인가?

  • softmax 함수에 -inf가 들어가면 해당 위치의 출력 확률이 0이 됩니다.
  • 만약 0 대신 “매우 작은 음수”를 넣으면 완벽하게 차단되지 않아 미세하게나마 정보가 흘러들어갈 수 있기 때문에, 자기회귀 모델에서는 반드시 −∞ 를 사용해 완전 차단합니다.

4. 요약

  • 마스크 행렬은 upper triangular 형태로, 미래 위치를 1로 표시.
  • register_buffer로 저장해 파라미터가 아니면서도 이동·저장 가능.
  • masked_fill에 -inf를 넣어 softmax 결과를 0으로 만들어 미래 토큰을 완전 차단.
  • Softmax → dropout → V 가중합 순으로 self-attention 이 계산됨.

이 구조 덕분에, GPT-계열 같은 자기회귀 언어 모델은 “현재까지 생성된 토큰”만 보고 다음 토큰을 예측할 수 있게 되는 것입니다.

 

 

전체 코드  

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        
        assert d_out % NUM_HEADS == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.head_dim = d_out // NUM_HEADS

        self.W_query = nn.Linear(d_in, d_out, bias=QKV_BIAS)
        self.W_key = nn.Linear(d_in, d_out, bias=QKV_BIAS)
        self.W_value = nn.Linear(d_in, d_out, bias=QKV_BIAS)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(DROP_RATE)
        self.register_buffer('mask', torch.triu(torch.ones(CONTEXT_LENGTH, CONTEXT_LENGTH), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, NUM_HEADS, self.head_dim)
        values = values.view(b, num_tokens, NUM_HEADS, self.head_dim)
        queries = queries.view(b, num_tokens, NUM_HEADS, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(EMB_DIM, 4 * EMB_DIM),
            GELU(),
            nn.Linear(4 * EMB_DIM, EMB_DIM),
        )

    def forward(self, x):
        return self.layers(x)

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=EMB_DIM,
            d_out=EMB_DIM)
    
        self.ff = FeedForward()
        self.norm1 = LayerNorm(EMB_DIM)
        self.norm2 = LayerNorm(EMB_DIM)
        self.drop_shortcut = nn.Dropout(DROP_RATE)

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut

        return x


class GPTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb = nn.Embedding(VOCAB_SIZE, EMB_DIM)
        self.pos_emb = nn.Embedding(CONTEXT_LENGTH, EMB_DIM)
        self.drop_emb = nn.Dropout(DROP_RATE)

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock() for _ in range(NUM_LAYERS)])

        self.final_norm = LayerNorm(EMB_DIM)
        self.out_head = nn.Linear(EMB_DIM, VOCAB_SIZE, bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits