토니의 연습장
출력 정밀도 변환 클래스 정의 (CastOutputToFloat) 본문
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
모델의 출력이 32비트 정밀도로 반환되도록 하기 위해 CastOutputToFloat 클래스를 정의하고, lm_head의 출력을 32비트로 변환하여 안정적인 출력이 가능하도록 합니다.
이 코드를 통해 모델은 메모리 사용량을 최소화하면서 8비트 정밀도로 로드되고, 파라미터를 동결하여 일부 파라미터만 훈련이 가능하도록 설정되었습니다.
위 코드에서 super().forward(x) 호출의 이유는 CastOutputToFloat 클래스가 nn.Sequential 클래스를 상속받고 있기 때문입니다. super().forward(x)는 nn.Sequential 클래스의 forward 메서드를 호출하여, model.lm_head에 정의된 연산이 먼저 수행되도록 합니다. 이렇게 함으로써 원래의 lm_head 연산 결과를 얻고, 그 결과를 torch.float32로 캐스팅하여 반환하게 됩니다.
구체적으로 설명하자면:
- nn.Sequential은 여러 레이어를 순차적으로 실행하는 모듈이므로, model.lm_head의 기존 연산을 그대로 수행하면서도 그 결과를 추가로 변환할 수 있도록 합니다.
- super().forward(x)는 nn.Sequential에서 정의된 원래의 forward 메서드를 호출하여, model.lm_head의 모든 연산을 거친 출력값을 반환합니다.
- 그 다음 .to(torch.float32)를 적용하여, model.lm_head의 출력값을 32비트 부동소수점(float32) 형식으로 변환합니다.
즉, super().forward(x) 호출을 통해 CastOutputToFloat 클래스에서 model.lm_head의 기존 연산을 유지하면서도 최종 결과를 원하는 데이터 형식으로 변환할 수 있게 됩니다.
super()를 호출하지 않고도 forward 메서드를 직접 호출할 수는 있지만, 이 경우에는 nn.Sequential의 기본 동작이 무시되므로, 원래의 model.lm_head의 모든 레이어가 순차적으로 실행되지 않을 수 있습니다. 대신 forward 메서드를 직접 구현해야 하므로 코드가 복잡해질 수 있습니다.
예를 들어 model.lm_head가 하나의 레이어라면 직접 호출이 가능합니다. 하지만 nn.Sequential이 여러 레이어를 포함하고 있다면, super().forward(x)를 호출하지 않으면 Sequential이 각 레이어를 순차적으로 호출하는 원래 기능을 잃게 됩니다.
만약 super()를 사용하지 않고 구현하고 싶다면 다음과 같이 직접 model.lm_head 내부 레이어에 접근해서 호출할 수는 있습니다.
'AI 일반 > 모델, 아키텍처, 구현' 카테고리의 다른 글
Model workflow (1) | 2025.01.14 |
---|---|
loss 값 저장 및 출력 (5) | 2025.01.14 |
Conv2d / ConvTranspose2d (0) | 2024.09.19 |
Fine Tuning (Transfer Learning의 한 종류) (1) | 2024.08.28 |
Feature Extraction (Transfer Learning의 한 종류) (0) | 2024.08.28 |