토니의 연습장
CLIPVisionTransformer / CLIPTextTransformer 본문
class CLIPVisionTransformer(nn.Module):
pixel_values 받아서 embedding 및 layer norm 이후 encoder를 통과시키고 0번째만 추출하여 사용합니다.
def forward(
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
print('last hidden', encoder_outputs[0].shape)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
print('pooled', pooled_output.shape)
class CLIPTextTransformer(nn.Module):
def forward(
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
print('encoder', encoder_outputs)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
print('last', last_hidden_state.shape)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
'비전 AI (VISION) > CLIP' 카테고리의 다른 글
CLIPProcessor (1) | 2024.08.23 |
---|---|
CLIP 이론 (1) | 2024.08.23 |
CLIPModel (3) | 2024.08.23 |