토니의 연습장

출력 정밀도 변환 클래스 정의 (CastOutputToFloat) 본문

AI 일반/모델, 아키텍처, 구현

출력 정밀도 변환 클래스 정의 (CastOutputToFloat)

bellmake 2024. 11. 14. 15:33
class CastOutputToFloat(nn.Sequential):
    def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

 

모델의 출력이 32비트 정밀도로 반환되도록 하기 위해 CastOutputToFloat 클래스를 정의하고, lm_head의 출력을 32비트로 변환하여 안정적인 출력이 가능하도록 합니다.

이 코드를 통해 모델은 메모리 사용량을 최소화하면서 8비트 정밀도로 로드되고, 파라미터를 동결하여 일부 파라미터만 훈련이 가능하도록 설정되었습니다.

위 코드에서 super().forward(x) 호출의 이유는 CastOutputToFloat 클래스가 nn.Sequential 클래스를 상속받고 있기 때문입니다. super().forward(x)는 nn.Sequential 클래스의 forward 메서드를 호출하여, model.lm_head에 정의된 연산이 먼저 수행되도록 합니다. 이렇게 함으로써 원래의 lm_head 연산 결과를 얻고, 그 결과를 torch.float32로 캐스팅하여 반환하게 됩니다.

구체적으로 설명하자면:

  • nn.Sequential은 여러 레이어를 순차적으로 실행하는 모듈이므로, model.lm_head의 기존 연산을 그대로 수행하면서도 그 결과를 추가로 변환할 수 있도록 합니다.
  • super().forward(x)는 nn.Sequential에서 정의된 원래의 forward 메서드를 호출하여, model.lm_head의 모든 연산을 거친 출력값을 반환합니다.
  • 그 다음 .to(torch.float32)를 적용하여, model.lm_head의 출력값을 32비트 부동소수점(float32) 형식으로 변환합니다.

즉, super().forward(x) 호출을 통해 CastOutputToFloat 클래스에서 model.lm_head의 기존 연산을 유지하면서도 최종 결과를 원하는 데이터 형식으로 변환할 수 있게 됩니다.

super()를 호출하지 않고도 forward 메서드를 직접 호출할 수는 있지만, 이 경우에는 nn.Sequential의 기본 동작이 무시되므로, 원래의 model.lm_head의 모든 레이어가 순차적으로 실행되지 않을 수 있습니다. 대신 forward 메서드를 직접 구현해야 하므로 코드가 복잡해질 수 있습니다.

예를 들어 model.lm_head가 하나의 레이어라면 직접 호출이 가능합니다. 하지만 nn.Sequential이 여러 레이어를 포함하고 있다면, super().forward(x)를 호출하지 않으면 Sequential이 각 레이어를 순차적으로 호출하는 원래 기능을 잃게 됩니다.

만약 super()를 사용하지 않고 구현하고 싶다면 다음과 같이 직접 model.lm_head 내부 레이어에 접근해서 호출할 수는 있습니다.