토니의 연습장

SAM 코드 분석 본문

비전 AI (VISION)/Segmentation

SAM 코드 분석

bellmake 2025. 4. 7. 10:57
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn
from torch.nn import functional as F

from typing import Any, Dict, List, Tuple

from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder


class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
        self,
        image_encoder: ImageEncoderViT,
        prompt_encoder: PromptEncoder,
        mask_decoder: MaskDecoder,
        pixel_mean: List[float] = [123.675, 116.28, 103.53],
        pixel_std: List[float] = [58.395, 57.12, 57.375],
    ) -> None:
        """
        SAM predicts object masks from an image and input prompts.

        Arguments:
          image_encoder (ImageEncoderViT): The backbone used to encode the
            image into image embeddings that allow for efficient mask prediction.
          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
            and encoded prompts.
          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
          pixel_std (list(float)): Std values for normalizing pixels in the input image.
        """
        super().__init__()
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

    @property
    def device(self) -> Any:
        return self.pixel_mean.device

    @torch.no_grad()
    def forward(
        self,
        batched_input: List[Dict[str, Any]],
        multimask_output: bool,
    ) -> List[Dict[str, torch.Tensor]]:
        """
        Predicts masks end-to-end from provided images and prompts.
        If prompts are not known in advance, using SamPredictor is
        recommended over calling the model directly.

        Arguments:
          batched_input (list(dict)): A list over input images, each a
            dictionary with the following keys. A prompt key can be
            excluded if it is not present.
              'image': The image as a torch tensor in 3xHxW format,
                already transformed for input to the model.
              'original_size': (tuple(int, int)) The original size of
                the image before transformation, as (H, W).
              'point_coords': (torch.Tensor) Batched point prompts for
                this image, with shape BxNx2. Already transformed to the
                input frame of the model.
              'point_labels': (torch.Tensor) Batched labels for point prompts,
                with shape BxN.
              'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
                Already transformed to the input frame of the model.
              'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
                in the form Bx1xHxW.
          multimask_output (bool): Whether the model should predict multiple
            disambiguating masks, or return a single mask.

        Returns:
          (list(dict)): A list over input images, where each element is
            as dictionary with the following keys.
              'masks': (torch.Tensor) Batched binary mask predictions,
                with shape BxCxHxW, where B is the number of input prompts,
                C is determined by multimask_output, and (H, W) is the
                original size of the image.
              'iou_predictions': (torch.Tensor) The model's predictions
                of mask quality, in shape BxC.
              'low_res_logits': (torch.Tensor) Low resolution logits with
                shape BxCxHxW, where H=W=256. Can be passed as mask input
                to subsequent iterations of prediction.
        """
        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
        image_embeddings = self.image_encoder(input_images)

        outputs = []
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
            if "point_coords" in image_record:
                points = (image_record["point_coords"], image_record["point_labels"])
            else:
                points = None
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=image_record.get("boxes", None),
                masks=image_record.get("mask_inputs", None),
            )
            low_res_masks, iou_predictions = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            masks = self.postprocess_masks(
                low_res_masks,
                input_size=image_record["image"].shape[-2:],
                original_size=image_record["original_size"],
            )
            masks = masks > self.mask_threshold
            outputs.append(
                {
                    "masks": masks,
                    "iou_predictions": iou_predictions,
                    "low_res_logits": low_res_masks,
                }
            )
        return outputs

    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        """
        Remove padding and upscale masks to the original image size.

        Arguments:
          masks (torch.Tensor): Batched masks from the mask_decoder,
            in BxCxHxW format.
          input_size (tuple(int, int)): The size of the image input to the
            model, in (H, W) format. Used to remove padding.
          original_size (tuple(int, int)): The original size of the image
            before resizing for input to the model, in (H, W) format.

        Returns:
          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
            is given by original_size.
        """
        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize pixel values and pad to a square input."""
        # Normalize colors
        x = (x - self.pixel_mean) / self.pixel_std

        # Pad
        h, w = x.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(x, (0, padw, 0, padh))
        return x
📌 클래스의 정의부 설명
python
복사
편집
class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"
이 클래스는 PyTorch의 기본 클래스인 nn.Module을 상속받아 구현되었습니다.

mask_threshold는 최종적으로 예측된 마스크를 이진화할 때 기준이 되는 임곗값이며, 기본값은 0.0입니다.

image_format은 입력 이미지의 색상 형식을 나타내며, 기본적으로 "RGB" 포맷을 사용합니다.

📌 클래스의 초기화 (__init__) 메서드 설명
python
복사
편집
def __init__(
    self,
    image_encoder: ImageEncoderViT,
    prompt_encoder: PromptEncoder,
    mask_decoder: MaskDecoder,
    pixel_mean: List[float] = [123.675, 116.28, 103.53],
    pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
클래스가 초기화될 때는 세 가지 주요 구성 요소를 입력받습니다:

image_encoder: 입력된 이미지를 임베딩(특성 벡터) 형태로 변환하는 역할을 합니다.

prompt_encoder: 사용자가 제공한 프롬프트(점, 박스, 마스크 등)를 모델이 이해할 수 있는 임베딩 형태로 변환합니다.

mask_decoder: 이미지 임베딩과 프롬프트 임베딩을 결합하여 최종 마스크를 예측합니다.

pixel_mean과 pixel_std는 입력 이미지의 픽셀값을 정규화할 때 사용하는 평균과 표준편차 값으로, 이미지 전처리 과정에서 사용됩니다.

python
복사
편집
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
부모 클래스(nn.Module)의 초기화 메서드를 호출합니다.

각 구성 요소를 클래스 내부 속성으로 저장하여 이후 모델 연산에서 사용합니다.

정규화에 사용될 pixel_mean과 pixel_std를 버퍼로 등록하여, 모델과 함께 GPU 등 연산 디바이스로 자동으로 이동될 수 있도록 설정합니다.

📌 디바이스 속성 정의
python
복사
편집
@property
def device(self) -> Any:
    return self.pixel_mean.device
이 속성은 현재 모델이 위치한 연산 장치(GPU 또는 CPU)를 간편하게 확인할 수 있도록 돕습니다.

📌 forward 메서드 상세 설명 (마스크 예측)
python
복사
편집
@torch.no_grad()
def forward(
    self,
    batched_input: List[Dict[str, Any]],
    multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
이 메서드는 모델을 이용해 실제로 입력 이미지와 프롬프트에서 최종 마스크를 예측하는 역할을 합니다.

추론 과정에서 사용되며, 학습 과정이 아니므로 가중치 업데이트가 없도록 @torch.no_grad()가 붙어 있습니다.

batched_input에는 여러 이미지와 프롬프트들이 리스트 형태로 제공됩니다.

multimask_output은 여러 마스크 후보를 출력할 것인지, 단일 마스크만 출력할 것인지 결정하는 변수입니다.

🔹 이미지 전처리와 임베딩 계산
python
복사
편집
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
입력된 이미지 각각을 모델이 요구하는 형태로 전처리(self.preprocess)하여 정규화합니다.

이 전처리된 이미지들을 하나의 배치(batch)로 묶어줍니다.

python
복사
편집
image_embeddings = self.image_encoder(input_images)
전처리된 이미지들을 이미지 인코더(image_encoder)에 전달하여 이미지 임베딩(특성)을 얻습니다.

이 임베딩은 마스크 예측을 위한 중요한 정보를 담고 있습니다.

🔹 개별 이미지별 마스크 예측 루프
python
복사
편집
outputs = []
각 이미지에 대한 결과를 저장하기 위한 빈 리스트를 준비합니다.

python
복사
편집
for image_record, curr_embedding in zip(batched_input, image_embeddings):
각 이미지와 해당 이미지의 임베딩을 함께 반복하여 처리합니다.

python
복사
편집
    if "point_coords" in image_record:
        points = (image_record["point_coords"], image_record["point_labels"])
    else:
        points = None
이미지에 점 프롬프트가 존재하면, 점 좌표와 라벨을 튜플로 묶어 준비합니다.

없다면 점 프롬프트는 None으로 설정합니다.

python
복사
편집
    sparse_embeddings, dense_embeddings = self.prompt_encoder(
        points=points,
        boxes=image_record.get("boxes", None),
        masks=image_record.get("mask_inputs", None),
    )
점, 박스, 마스크 형태의 프롬프트를 프롬프트 인코더를 통해 각각 희소한 형태(sparse_embeddings)와 밀집된 형태(dense_embeddings)로 변환합니다.

🔹 마스크 디코딩 단계
python
복사
편집
    low_res_masks, iou_predictions = self.mask_decoder(
        image_embeddings=curr_embedding.unsqueeze(0),
        image_pe=self.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
    )
개별 이미지 임베딩과 프롬프트 임베딩을 마스크 디코더에 입력하여, 저해상도 마스크와 해당 마스크의 정확성 점수(IoU)를 예측합니다.

🔹 마스크 후처리 및 이진화 과정
python
복사
편집
    masks = self.postprocess_masks(
        low_res_masks,
        input_size=image_record["image"].shape[-2:],
        original_size=image_record["original_size"],
    )
저해상도 마스크를 원본 이미지 크기에 맞게 확대하여 후처리합니다.

python
복사
편집
    masks = masks > self.mask_threshold
마스크의 픽셀값을 설정한 임곗값(mask_threshold)과 비교하여 객체 영역(True)과 배경 영역(False)으로 이진화합니다.

🔹 결과 저장 및 반환
python
복사
편집
    outputs.append(
        {
            "masks": masks,
            "iou_predictions": iou_predictions,
            "low_res_logits": low_res_masks,
        }
    )
예측된 최종 마스크, 마스크의 정확성 점수, 저해상도의 마스크 예측값을 딕셔너리로 묶어 결과 리스트에 추가합니다.

python
복사
편집
return outputs
모든 이미지에 대한 예측 결과를 담은 리스트를 최종적으로 반환합니다.