토니의 연습장

SAM 코드 분석 본문

비전 AI (VISION)/Segmentation

SAM 코드 분석

bellmake 2025. 4. 7. 10:57
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn
from torch.nn import functional as F

from typing import Any, Dict, List, Tuple

from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder


class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
        self,
        image_encoder: ImageEncoderViT,
        prompt_encoder: PromptEncoder,
        mask_decoder: MaskDecoder,
        pixel_mean: List[float] = [123.675, 116.28, 103.53],
        pixel_std: List[float] = [58.395, 57.12, 57.375],
    ) -> None:
        """
        SAM predicts object masks from an image and input prompts.

        Arguments:
          image_encoder (ImageEncoderViT): The backbone used to encode the
            image into image embeddings that allow for efficient mask prediction.
          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
            and encoded prompts.
          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
          pixel_std (list(float)): Std values for normalizing pixels in the input image.
        """
        super().__init__()
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

    @property
    def device(self) -> Any:
        return self.pixel_mean.device

    @torch.no_grad()
    def forward(
        self,
        batched_input: List[Dict[str, Any]],
        multimask_output: bool,
    ) -> List[Dict[str, torch.Tensor]]:
        """
        Predicts masks end-to-end from provided images and prompts.
        If prompts are not known in advance, using SamPredictor is
        recommended over calling the model directly.

        Arguments:
          batched_input (list(dict)): A list over input images, each a
            dictionary with the following keys. A prompt key can be
            excluded if it is not present.
              'image': The image as a torch tensor in 3xHxW format,
                already transformed for input to the model.
              'original_size': (tuple(int, int)) The original size of
                the image before transformation, as (H, W).
              'point_coords': (torch.Tensor) Batched point prompts for
                this image, with shape BxNx2. Already transformed to the
                input frame of the model.
              'point_labels': (torch.Tensor) Batched labels for point prompts,
                with shape BxN.
              'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
                Already transformed to the input frame of the model.
              'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
                in the form Bx1xHxW.
          multimask_output (bool): Whether the model should predict multiple
            disambiguating masks, or return a single mask.

        Returns:
          (list(dict)): A list over input images, where each element is
            as dictionary with the following keys.
              'masks': (torch.Tensor) Batched binary mask predictions,
                with shape BxCxHxW, where B is the number of input prompts,
                C is determined by multimask_output, and (H, W) is the
                original size of the image.
              'iou_predictions': (torch.Tensor) The model's predictions
                of mask quality, in shape BxC.
              'low_res_logits': (torch.Tensor) Low resolution logits with
                shape BxCxHxW, where H=W=256. Can be passed as mask input
                to subsequent iterations of prediction.
        """
        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
        image_embeddings = self.image_encoder(input_images)

        outputs = []
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
            if "point_coords" in image_record:
                points = (image_record["point_coords"], image_record["point_labels"])
            else:
                points = None
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=image_record.get("boxes", None),
                masks=image_record.get("mask_inputs", None),
            )
            low_res_masks, iou_predictions = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            masks = self.postprocess_masks(
                low_res_masks,
                input_size=image_record["image"].shape[-2:],
                original_size=image_record["original_size"],
            )
            masks = masks > self.mask_threshold
            outputs.append(
                {
                    "masks": masks,
                    "iou_predictions": iou_predictions,
                    "low_res_logits": low_res_masks,
                }
            )
        return outputs

    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        """
        Remove padding and upscale masks to the original image size.

        Arguments:
          masks (torch.Tensor): Batched masks from the mask_decoder,
            in BxCxHxW format.
          input_size (tuple(int, int)): The size of the image input to the
            model, in (H, W) format. Used to remove padding.
          original_size (tuple(int, int)): The original size of the image
            before resizing for input to the model, in (H, W) format.

        Returns:
          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
            is given by original_size.
        """
        masks = F.interpolate(
            masks,
            (self.image_encoder.img_size, self.image_encoder.img_size),
            mode="bilinear",
            align_corners=False,
        )
        masks = masks[..., : input_size[0], : input_size[1]]
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
        """Normalize pixel values and pad to a square input."""
        # Normalize colors
        x = (x - self.pixel_mean) / self.pixel_std

        # Pad
        h, w = x.shape[-2:]
        padh = self.image_encoder.img_size - h
        padw = self.image_encoder.img_size - w
        x = F.pad(x, (0, padw, 0, padh))
        return x
📌 클래스의 정의부 설명
python
복사
편집
class Sam(nn.Module):
    mask_threshold: float = 0.0
    image_format: str = "RGB"
이 클래스는 PyTorch의 기본 클래스인 nn.Module을 상속받아 구현되었습니다.

mask_threshold는 최종적으로 예측된 마스크를 이진화할 때 기준이 되는 임곗값이며, 기본값은 0.0입니다.

image_format은 입력 이미지의 색상 형식을 나타내며, 기본적으로 "RGB" 포맷을 사용합니다.

📌 클래스의 초기화 (__init__) 메서드 설명
python
복사
편집
def __init__(
    self,
    image_encoder: ImageEncoderViT,
    prompt_encoder: PromptEncoder,
    mask_decoder: MaskDecoder,
    pixel_mean: List[float] = [123.675, 116.28, 103.53],
    pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
클래스가 초기화될 때는 세 가지 주요 구성 요소를 입력받습니다:

image_encoder: 입력된 이미지를 임베딩(특성 벡터) 형태로 변환하는 역할을 합니다.

prompt_encoder: 사용자가 제공한 프롬프트(점, 박스, 마스크 등)를 모델이 이해할 수 있는 임베딩 형태로 변환합니다.

mask_decoder: 이미지 임베딩과 프롬프트 임베딩을 결합하여 최종 마스크를 예측합니다.

pixel_mean과 pixel_std는 입력 이미지의 픽셀값을 정규화할 때 사용하는 평균과 표준편차 값으로, 이미지 전처리 과정에서 사용됩니다.

python
복사
편집
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
부모 클래스(nn.Module)의 초기화 메서드를 호출합니다.

각 구성 요소를 클래스 내부 속성으로 저장하여 이후 모델 연산에서 사용합니다.

정규화에 사용될 pixel_mean과 pixel_std를 버퍼로 등록하여, 모델과 함께 GPU 등 연산 디바이스로 자동으로 이동될 수 있도록 설정합니다.

📌 디바이스 속성 정의
python
복사
편집
@property
def device(self) -> Any:
    return self.pixel_mean.device
이 속성은 현재 모델이 위치한 연산 장치(GPU 또는 CPU)를 간편하게 확인할 수 있도록 돕습니다.

📌 forward 메서드 상세 설명 (마스크 예측)
python
복사
편집
@torch.no_grad()
def forward(
    self,
    batched_input: List[Dict[str, Any]],
    multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
이 메서드는 모델을 이용해 실제로 입력 이미지와 프롬프트에서 최종 마스크를 예측하는 역할을 합니다.

추론 과정에서 사용되며, 학습 과정이 아니므로 가중치 업데이트가 없도록 @torch.no_grad()가 붙어 있습니다.

batched_input에는 여러 이미지와 프롬프트들이 리스트 형태로 제공됩니다.

multimask_output은 여러 마스크 후보를 출력할 것인지, 단일 마스크만 출력할 것인지 결정하는 변수입니다.

🔹 이미지 전처리와 임베딩 계산
python
복사
편집
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
입력된 이미지 각각을 모델이 요구하는 형태로 전처리(self.preprocess)하여 정규화합니다.

이 전처리된 이미지들을 하나의 배치(batch)로 묶어줍니다.

python
복사
편집
image_embeddings = self.image_encoder(input_images)
전처리된 이미지들을 이미지 인코더(image_encoder)에 전달하여 이미지 임베딩(특성)을 얻습니다.

이 임베딩은 마스크 예측을 위한 중요한 정보를 담고 있습니다.

🔹 개별 이미지별 마스크 예측 루프
python
복사
편집
outputs = []
각 이미지에 대한 결과를 저장하기 위한 빈 리스트를 준비합니다.

python
복사
편집
for image_record, curr_embedding in zip(batched_input, image_embeddings):
각 이미지와 해당 이미지의 임베딩을 함께 반복하여 처리합니다.

python
복사
편집
    if "point_coords" in image_record:
        points = (image_record["point_coords"], image_record["point_labels"])
    else:
        points = None
이미지에 점 프롬프트가 존재하면, 점 좌표와 라벨을 튜플로 묶어 준비합니다.

없다면 점 프롬프트는 None으로 설정합니다.

python
복사
편집
    sparse_embeddings, dense_embeddings = self.prompt_encoder(
        points=points,
        boxes=image_record.get("boxes", None),
        masks=image_record.get("mask_inputs", None),
    )
점, 박스, 마스크 형태의 프롬프트를 프롬프트 인코더를 통해 각각 희소한 형태(sparse_embeddings)와 밀집된 형태(dense_embeddings)로 변환합니다.

🔹 마스크 디코딩 단계
python
복사
편집
    low_res_masks, iou_predictions = self.mask_decoder(
        image_embeddings=curr_embedding.unsqueeze(0),
        image_pe=self.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=multimask_output,
    )
개별 이미지 임베딩과 프롬프트 임베딩을 마스크 디코더에 입력하여, 저해상도 마스크와 해당 마스크의 정확성 점수(IoU)를 예측합니다.

🔹 마스크 후처리 및 이진화 과정
python
복사
편집
    masks = self.postprocess_masks(
        low_res_masks,
        input_size=image_record["image"].shape[-2:],
        original_size=image_record["original_size"],
    )
저해상도 마스크를 원본 이미지 크기에 맞게 확대하여 후처리합니다.

python
복사
편집
    masks = masks > self.mask_threshold
마스크의 픽셀값을 설정한 임곗값(mask_threshold)과 비교하여 객체 영역(True)과 배경 영역(False)으로 이진화합니다.

🔹 결과 저장 및 반환
python
복사
편집
    outputs.append(
        {
            "masks": masks,
            "iou_predictions": iou_predictions,
            "low_res_logits": low_res_masks,
        }
    )
예측된 최종 마스크, 마스크의 정확성 점수, 저해상도의 마스크 예측값을 딕셔너리로 묶어 결과 리스트에 추가합니다.

python
복사
편집
return outputs
모든 이미지에 대한 예측 결과를 담은 리스트를 최종적으로 반환합니다.

'비전 AI (VISION) > Segmentation' 카테고리의 다른 글

SAM3 video segmentation  (0) 2025.12.01