토니의 연습장

StableDiffusionPipeline 본문

비전 AI (VISION)/Stable Diffusion

StableDiffusionPipeline

bellmake 2024. 8. 23. 21:59

[ 참고 ]

sd_txt2img.py

from diffusers import StableDiffusionPipeline
import torch

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

image.save("astronaut_rides_horse.png")

 

 

StableDiffusionPipeline

class StableDiffusionPipeline(
 
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):

 

*위 unet 부분이 Diffusion model 입니다.

 

• feature_extractor : (Safety-check) 생성된 이미지를 safety_checker를 통해 확인하기 위해 featu 
re를 얻어내는 모델입니다. 
• safety_checker : (Safety-check) Classification 모델을 활용해 생성된 이미지가 offensive하거나 
harmful한지를 찾아내는 SD의 모듈입니다. 
• scheduler : (Diffusion) encoded된 image latents를 denoise할 때 UNet과 함께 사용되는 Sched 
uler입니다. DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler 등을 활용할 수 있습니다. 
• unet : (Diffusion) encoded된 image latents를 denoise할 때 사용하는 UNet입니다. 
• text_encoder : (Control) Frozen된 text-encoder로, SD에서는 CLIP을 활용해 합니다. 
• tokenizer : (Control) 위 text_encoder에서 활용하는 tokenizer입니다. 
• vae : (Encoder) image를 latent representation으로 encode/decode하기 위한 Variational Aut 
o-Encoder(VAE) 모델입니다. 
• model_index.json : pipeline클래스를 통해 불러올 때, 어떤 클래스를 활용하며 어떤 모듈들을 불러올 
것인지를 대한 정보를 담은 config파일입니다.

 

 

__call__

image = pipe(prompt).images[0]

위의 pipe 부분 실행시에 StableDiffusionPipeline의 아래 __call__ 부분이 실행됩니다.

def __call__(
self,
prompt: Union[str, List[str]] = None,
...
...

 

StableDiffusionPipeline 클래스에서 __call__ 메서드가 실행되는 이유는 Python의 특별한 메서드인 __call__이 클래스 인스턴스를 함수처럼 호출할 수 있게 해주기 때문입니다. 즉, 클래스의 인스턴스가 함수처럼 동작하도록 만들고 싶을 때 __call__ 메서드를 정의합니다.

구체적으로 설명하면, 다음과 같은 코드가 있을 때:

python
코드 복사
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) image = pipe(prompt).images[0]

pipe(prompt) 부분이 실행될 때 실제로는 StableDiffusionPipeline 클래스의 __call__ 메서드가 호출됩니다.

왜 __call__이 호출되는가?

  1. 클래스 인스턴스 호출: pipe는 StableDiffusionPipeline 클래스의 인스턴스입니다. pipe(prompt)처럼 인스턴스에 대해 함수 호출 구문을 사용할 때, Python은 자동으로 해당 인스턴스의 __call__ 메서드를 찾고 실행합니다.
  2. 클래스 설계: StableDiffusionPipeline 클래스는 모델의 입력으로 주어진 프롬프트를 받아서, 그에 따른 출력을 생성하기 위해 __call__ 메서드를 오버라이드(재정의)한 것입니다. 이 메서드 내부에서 실제로 텍스트에서 이미지를 생성하는 전체 파이프라인이 실행됩니다.
  3. 유연한 인터페이스 제공: __call__ 메서드를 사용하면, 클래스 인스턴스를 마치 함수처럼 사용할 수 있게 되며, 이는 코드의 가독성과 사용성을 높이는 데 도움을 줍니다. 이렇게 하면 클래스를 사용하는 사람은 파이프라인을 간단히 호출하여 바로 결과를 얻을 수 있습니다.

따라서 pipe(prompt)를 실행할 때, Python은 pipe 객체의 __call__ 메서드를 찾아 실행하고, 이 과정에서 모델이 프롬프트를 기반으로 이미지를 생성하게 되는 것입니다.

 

__call__ 메서드 안에서 아래와 같이 encode_prompt에 넣게 됩니다.

이때, 결국 아래 encode_prompt 부분에서 CLIP이 사용되었다고 볼 수 있습니다.

 

__call__ : encode_prompt


prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
)

 

이 때, negative_prompt_embeds는 이러이러한 것은 생성하지 말라고 알려주는 것과 관련된 부분입니다.

encode_prompt 내부에서는 아래와 같은 과정이 이루어집니다.

def encode_prompt(
 
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]

 

__call__ : # 5. Prepare latent variables

# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
print(latents.shape)
from torchvision.utils import save_image
[save_image(latents[0,i], '/home/joseph/study/multimodal/ai_editor/latent_samples/latents_{}.png'.format(i)) for i in range(4)]

 

*위 save_image에서 range(4)로 설정한 이유는 unet의 input으로 들어가는 채널이 4이기 때문입니다.

 

 

[ 생성물 비교 ]

1. latent의 중간 생성물

2. vae의 decoder를 통한 최종 생성물(normalize 이전)

    : 본 모델의 vae는 embedding space 상에서 quantize가 잘 되도록 학습된 것으로서, 여기서는 흑백 구분이 매우 뚜렷합니다.

3. normalize된 image의 최종 생성물

 

'비전 AI (VISION) > Stable Diffusion' 카테고리의 다른 글

FLUX - LoRA  (2) 2025.03.15
LoRA (Low Rank Adaptation)  (1) 2024.08.28
Inpainting  (0) 2024.08.24
Stable Diffusion 이론  (0) 2024.08.23