AI 일반/논문, 구현
ViT (Vision Transformer) 논문 구현하기
bellmake
2024. 8. 28. 14:57
[ 참고 ] helper_functions.py
"""
A series of helper functions used throughout the course.
If a function gets defined once and could be used over and over, it'll go in here.
"""
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import os
import zipfile
from pathlib import Path
import requests
# Walk through an image classification directory and find out how many files (images)
# are in each subdirectory.
import os
def walk_through_dir(dir_path):
"""
Walks through dir_path returning its contents.
Args:
dir_path (str): target directory
Returns:
A print out of:
number of subdiretories in dir_path
number of images (files) in each subdirectory
name of each subdirectory
"""
for dirpath, dirnames, filenames in os.walk(dir_path):
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
"""Plots decision boundaries of model predicting on X in comparison to y.
Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
"""
# Put everything to CPU (works better with NumPy + Matplotlib)
model.to("cpu")
X, y = X.to("cpu"), y.to("cpu")
# Setup prediction boundaries and grid
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))
# Make features
X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
# Make predictions
model.eval()
with torch.inference_mode():
y_logits = model(X_to_pred_on)
# Test for multi-class or binary and adjust logits to prediction labels
if len(torch.unique(y)) > 2:
y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # mutli-class
else:
y_pred = torch.round(torch.sigmoid(y_logits)) # binary
# Reshape preds and plot
y_pred = y_pred.reshape(xx.shape).detach().numpy()
plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
# Plot linear data or training and test and predictions (optional)
def plot_predictions(
train_data, train_labels, test_data, test_labels, predictions=None
):
"""
Plots linear training data and test data and compares predictions.
"""
plt.figure(figsize=(10, 7))
# Plot training data in blue
plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
# Plot test data in green
plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
if predictions is not None:
# Plot the predictions in red (predictions were made on the test data)
plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
# Show the legend
plt.legend(prop={"size": 14})
# Calculate accuracy (a classification metric)
def accuracy_fn(y_true, y_pred):
"""Calculates accuracy between truth labels and predictions.
Args:
y_true (torch.Tensor): Truth labels for predictions.
y_pred (torch.Tensor): Predictions to be compared to predictions.
Returns:
[torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
"""
correct = torch.eq(y_true, y_pred).sum().item()
acc = (correct / len(y_pred)) * 100
return acc
def print_train_time(start, end, device=None):
"""Prints difference between start and end time.
Args:
start (float): Start time of computation (preferred in timeit format).
end (float): End time of computation.
device ([type], optional): Device that compute is running on. Defaults to None.
Returns:
float: time between start and end in seconds (higher is longer).
"""
total_time = end - start
print(f"\nTrain time on {device}: {total_time:.3f} seconds")
return total_time
# Plot loss curves of a model
def plot_loss_curves(results):
"""Plots training curves of a results dictionary.
Args:
results (dict): dictionary containing list of values, e.g.
{"train_loss": [...],
"train_acc": [...],
"test_loss": [...],
"test_acc": [...]}
"""
loss = results["train_loss"]
test_loss = results["test_loss"]
accuracy = results["train_acc"]
test_accuracy = results["test_acc"]
epochs = range(len(results["train_loss"]))
plt.figure(figsize=(15, 7))
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(epochs, loss, label="train_loss")
plt.plot(epochs, test_loss, label="test_loss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.legend()
# Plot accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, accuracy, label="train_accuracy")
plt.plot(epochs, test_accuracy, label="test_accuracy")
plt.title("Accuracy")
plt.xlabel("Epochs")
plt.legend()
# Pred and plot image function from notebook 04
# See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
from typing import List
import torchvision
def pred_and_plot_image(
model: torch.nn.Module,
image_path: str,
class_names: List[str] = None,
transform=None,
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Makes a prediction on a target image with a trained model and plots the image.
Args:
model (torch.nn.Module): trained PyTorch image classification model.
image_path (str): filepath to target image.
class_names (List[str], optional): different class names for target image. Defaults to None.
transform (_type_, optional): transform of target image. Defaults to None.
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
Returns:
Matplotlib plot of target image and model prediction as title.
Example usage:
pred_and_plot_image(model=model,
image="some_image.jpeg",
class_names=["class_1", "class_2", "class_3"],
transform=torchvision.transforms.ToTensor(),
device=device)
"""
# 1. Load in image and convert the tensor values to float32
target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
# 2. Divide the image pixel values by 255 to get them between [0, 1]
target_image = target_image / 255.0
# 3. Transform if necessary
if transform:
target_image = transform(target_image)
# 4. Make sure the model is on the target device
model.to(device)
# 5. Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# Add an extra dimension to the image
target_image = target_image.unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(target_image.to(device))
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 7. Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# 8. Plot the image alongside the prediction and prediction probability
plt.imshow(
target_image.squeeze().permute(1, 2, 0)
) # make sure it's the right size for matplotlib
if class_names:
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
else:
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
plt.title(title)
plt.axis(False)
def set_seeds(seed: int=42):
"""Sets random sets for torch operations.
Args:
seed (int, optional): Random seed to set. Defaults to 42.
"""
# Set the seed for general torch operations
torch.manual_seed(seed)
# Set the seed for CUDA torch operations (ones that happen on the GPU)
torch.cuda.manual_seed(seed)
def download_data(source: str,
destination: str,
remove_source: bool = True) -> Path:
"""Downloads a zipped dataset from source and unzips to destination.
Args:
source (str): A link to a zipped file containing data.
destination (str): A target directory to unzip data to.
remove_source (bool): Whether to remove the source after downloading and extracting.
Returns:
pathlib.Path to downloaded data.
Example usage:
download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
destination="pizza_steak_sushi")
"""
# Setup path to data folder
data_path = Path("data/")
image_path = data_path / destination
# If the image folder doesn't exist, download it and prepare it...
if image_path.is_dir():
print(f"[INFO] {image_path} directory exists, skipping download.")
else:
print(f"[INFO] Did not find {image_path} directory, creating one...")
image_path.mkdir(parents=True, exist_ok=True)
# Download pizza, steak, sushi data
target_file = Path(source).name
with open(data_path / target_file, "wb") as f:
request = requests.get(source)
print(f"[INFO] Downloading {target_file} from {source}...")
f.write(request.content)
# Unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / target_file, "r") as zip_ref:
print(f"[INFO] Unzipping {target_file} data...")
zip_ref.extractall(image_path)
# Remove .zip file
if remove_source:
os.remove(data_path / target_file)
return image_path
[구현 코드 ] vit_replication.ipynb
In [ ]:
try:
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
assert int(torch.__version__.split(".")[0])>1 or (int(torch.__version__.split(".")[0])>2 and int(torch.__version__.split(".")[1]) >= 12), "torch version should be 1.12+"
assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
# print(f"torch version: {torch.__version__}")
# print(f"torchvision version: {torchvision.__version__}")
except:
print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
!pip3 install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
torch version: 2.3.0
torchvision version: 0.18.0
In [ ]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms
# Try to get torchinfo, install it if it doesn't work
try:
from torchinfo import summary
except:
print("[INFO] Couldn't find torchinfo... installing it.")
!pip install -q torchinfo
from torchinfo import summary
# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
from going_modular.going_modular import data_setup, engine
from helper_functions import download_data, set_seeds, plot_loss_curves
except:
# Get the going_modular scripts
print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
!git clone https://github.com/mrdbourke/pytorch-deep-learning
!mv pytorch-deep-learning/going_modular .
!mv pytorch-deep-learning/helper_functions.py . # get the helper_functions.py script
!rm -rf pytorch-deep-learning
from going_modular.going_modular import data_setup, engine
from helper_functions import download_data, set_seeds, plot_loss_curves
/home/joseph/miniconda3/envs/clip/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
In [ ]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device
Out[ ]:
'cuda'
In [ ]:
# Download pizza, steak, sushi images from GitHub
image_path = download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
destination="pizza_steak_sushi")
image_path
[INFO] data/pizza_steak_sushi directory exists, skipping download.
Out[ ]:
PosixPath('data/pizza_steak_sushi')
In [ ]:
# Setup directory paths to train and test images
train_dir = image_path / "train"
test_dir = image_path / "test"
In [ ]:
# Create image size (from Table 3 in the ViT paper)
IMG_SIZE = 224 # comes from Table 3 of the ViT paper
# Create transform pipeline manually
manual_transforms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
print(f"Manually created transforms: {manual_transforms}")
Manually created transforms: Compose(
Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
ToTensor()
)
In [ ]:
# Set the batch size
BATCH_SIZE = 32 # this is lower than the ViT paper (4096) but it's because we're starting small # we can always scale up later
# Create data loaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=manual_transforms,
batch_size=BATCH_SIZE
)
len(train_dataloader), len(test_dataloader), class_names
Out[ ]:
(8, 3, ['pizza', 'steak', 'sushi'])
In [ ]:
# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader))
# Get a single image from the batch
image, label = image_batch[0], label_batch[0]
# View the batch shapes
image.shape, label
Out[ ]:
(torch.Size([3, 224, 224]), tensor(1))
In [ ]:
# Plot image with matplotlib
plt.imshow(image.permute(1, 2, 0)) # rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis(False);
In [ ]:
# Create example values
height = 224 # H ("The training resolution is 224.")
width = 224 # W
color_channels = 3 # C
patch_size = 16 # P
# Calculate N (number of patches)
number_of_patches = int((height*width)/patch_size**2)
number_of_patches
Out[ ]:
196
In [ ]:
# Input shape (this is the size of a single image)
embedding_layer_input_shape = (height, width, color_channels)
# Output shape
embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)
print(f"Input shape (single 2D image): {embedding_layer_input_shape}")
print(f"Output shape (single 1D sequence of patches): {embedding_layer_output_shape} -> (number_of_patches, embedding_dimension)") # 2D image flattened into patches = 1D sequence of patches
Input shape (single 2D image): (224, 224, 3)
Output shape (single 1D sequence of patches): (196, 768) -> (number_of_patches, embedding_dimension)
In [ ]:
# Change image shape to be compatible with matplotlib (color_channels, height, width) -> (height, width, color_channels)
image_permuted = image.permute(1, 2, 0)
# Index to plot the top row of patched pixels
patch_size = 16
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size,:,:])
Out[ ]:
<matplotlib.image.AxesImage at 0x78d724efaf10>
In [ ]:
# Setup hyperparameters and make sure img_size and patch_size are compatible
img_size = 224
patch_size = 16
num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
# Create a series of subplots
fig, axs = plt.subplots(nrows=1,
ncols=img_size//patch_size, # one column for each patch
sharex=True, sharey=True,
figsize=(patch_size, patch_size))
# Iterate through number of patches in the top row
for i, patch in enumerate(range(0, img_size, patch_size)):
axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]);
axs[i].set_xlabel(i+1) # set the patch label
axs[i].set_xticks([])
# axs[i].set_yticks([])
# keep height index constant, alter the width index
In [ ]:
# Setup hyperparameters and make sure img_size and patch_size are compatible
img_size = 224
patch_size = 16
num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\
\nNumber of patches per column: {num_patches}\
\nTotal patches : {num_patches*num_patches}\
\nPatch size: {patch_size} pixels x {patch_size} pixels")
# Create a series of subplots
fig, axs = plt.subplots(nrows=img_size//patch_size,
ncols=img_size//patch_size,
figsize=(num_patches, num_patches),
sharex=True,
sharey=True)
# need int not float
# Loop through height and width of image
for i, patch_height in enumerate(range(0, img_size, patch_size)): # iterate through height
for j, patch_width in enumerate(range(0, img_size, patch_size)): # iterate through width
# Plot the permuted image patch (image_permuted -> (Height, Width, Color Channels))
axs[i, j].imshow(image_permuted[patch_height:patch_height+patch_size, # iterate through height
patch_width:patch_width+patch_size, # iterate through width
:])# get all color channels
# Set up label information, remove the ticks for clarity and set labels to outside
axs[i, j].set_ylabel(i+1,
rotation="horizontal",
horizontalalignment="right",
verticalalignment="center")
axs[i, j].set_xlabel(j+1)
axs[i, j].set_xticks([])
axs[i, j].set_yticks([])
axs[i, j].label_outer()
# Set a super title
fig.suptitle(f"{class_names[label]} -> Patchified", fontsize=14)
plt.show()
Number of patches per row: 14.0
Number of patches per column: 14.0
Total patches : 196.0
Patch size: 16 pixels x 16 pixels
In [ ]:
# Creating image patches with torch.nn.Conv2d()
from torch import nn
patch_size = 16
conv2d = nn.Conv2d(in_channels=3,
out_channels=768,
kernel_size=patch_size,
stride=patch_size,
padding=0)
conv2d
Out[ ]:
Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
In [ ]:
plt.imshow(image.permute(1,2,0))
plt.title(class_names[label])
plt.axis(False);
In [ ]:
print(image.shape)
image_out_of_conv = conv2d(image.unsqueeze(0))
print(image_out_of_conv.shape)
torch.Size([3, 224, 224])
torch.Size([1, 768, 14, 14])
In [ ]:
image_out_of_conv.requires_grad
Out[ ]:
True
In [ ]:
import random
random_indexes = random.sample(range(0,768), k=5)
print(f"Showing random convlutional feature maps from indexes: {random_indexes}") # pick 5 numbers between 0 and the embedding size
fig, axs = plt.subplots(nrows=1,
ncols=5,
figsize=(12,12))
for i, idx in enumerate(random_indexes):
image_conv_feature_map = image_out_of_conv[:,idx, :, :]
axs[i].imshow(image_conv_feature_map.squeeze().detach().numpy())
axs[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
Showing random convlutional feature maps from indexes: [234, 285, 767, 124, 470]
In [ ]:
plt.xticks([]), plt.yticks([])
# single feature map (in tensor form)
single_feature_map = image_out_of_conv[:,0,:,:]
single_feature_map, single_feature_map.requires_grad, plt.imshow(image_conv_feature_map.squeeze().detach().numpy())
Out[ ]:
(tensor([[[ 0.4720, 0.4569, 0.3627, 0.3329, 0.3450, 0.1972, 0.1450,
0.2124, 0.1681, 0.1995, 0.0030, 0.0150, 0.0408, -0.0208],
[ 0.3480, 0.2649, 0.3930, 0.3091, 0.1293, 0.1407, 0.1049,
0.0949, 0.0869, 0.0535, 0.0806, 0.2279, -0.0084, 0.0353],
[ 0.2417, 0.3728, 0.2824, 0.2928, 0.2926, 0.0080, -0.0288,
0.0794, 0.0931, 0.0670, 0.0492, 0.0297, 0.0258, 0.0994],
[ 0.1512, 0.1639, 0.2057, 0.3360, 0.1675, 0.1915, 0.3541,
0.3075, 0.1526, 0.0591, 0.0985, 0.1104, 0.0998, 0.1399],
[ 0.0700, 0.0468, 0.0447, 0.0440, 0.1784, 0.0653, 0.1721,
0.2511, 0.1725, 0.0268, 0.0608, -0.0088, 0.1059, 0.1353],
[ 0.0704, 0.1071, 0.0472, -0.0375, 0.1636, 0.1010, 0.1597,
0.1485, 0.1692, 0.2489, 0.0979, 0.1014, 0.3265, 0.0729],
[ 0.0567, 0.0700, 0.0094, 0.1024, 0.1629, 0.0403, 0.1480,
0.1586, 0.0038, 0.2564, 0.0658, 0.1980, 0.2527, 0.1586],
[ 0.0587, 0.1058, 0.0719, -0.0503, 0.0370, 0.1696, 0.1577,
0.0563, 0.0141, 0.0112, 0.1327, 0.0717, 0.2524, 0.1221],
[ 0.0951, 0.0835, 0.0319, 0.0269, 0.0532, -0.0198, 0.0477,
0.1058, 0.0566, 0.0901, 0.2299, 0.1150, 0.1273, 0.2273],
[ 0.0891, 0.0427, 0.0592, 0.0691, 0.0634, -0.0264, -0.0023,
0.0232, 0.2197, 0.0818, 0.2860, 0.0940, 0.0953, 0.0649],
[ 0.0685, 0.0777, 0.0261, 0.0447, 0.0460, 0.1362, 0.2436,
-0.0026, 0.0223, 0.0761, 0.2044, 0.1443, 0.0535, 0.0100],
[ 0.0753, 0.0885, 0.1101, 0.0365, 0.0654, 0.0582, 0.0585,
0.0168, 0.1001, 0.0454, 0.1100, 0.0717, 0.0976, 0.0365],
[ 0.0231, 0.1516, 0.1113, 0.1454, 0.1218, 0.1493, 0.0615,
0.0303, 0.0728, 0.0806, 0.1288, 0.0632, 0.0545, 0.0579],
[ 0.1339, 0.0176, 0.2339, 0.2434, 0.1646, 0.1614, 0.1995,
0.2260, 0.1002, 0.1222, 0.1036, 0.0652, 0.0572, 0.0574]]],
grad_fn=<SliceBackward0>),
True,
<matplotlib.image.AxesImage at 0x78d714d3bc50>)
In [ ]:
# Create flatten layer
from torch import nn
flatten_layer = nn.Flatten(start_dim=2, # flatten feature_map_height (dimension 2)
end_dim=3) # flatten feature_map_width (dimension 3)
flattened_image_out_of_conv = flatten_layer(image_out_of_conv)
flattened_image_out_of_conv.shape
Out[ ]:
torch.Size([1, 768, 196])
In [ ]:
# 1. View single image
plt.imshow(image.permute(1, 2, 0)) # adjust for matplotlib
plt.title(class_names[label])
plt.axis(False);
print(f"Original image shape: {image.shape}")
# 2. Turn image into feature maps
image_out_of_conv = conv2d(image.unsqueeze(0)) # add batch dimension to avoid shape errors
print(f"Image feature map shape: {image_out_of_conv.shape}")
# 3. Flatten the feature maps
image_out_of_conv_flattened = flatten_layer(image_out_of_conv)
print(f"Flattened image feature map shape: {image_out_of_conv_flattened.shape}")
Original image shape: torch.Size([3, 224, 224])
Image feature map shape: torch.Size([1, 768, 14, 14])
Flattened image feature map shape: torch.Size([1, 768, 196])
In [ ]:
# Get flattened image patch embeddings in right shape
image_out_of_conv_flattened_reshaped = image_out_of_conv_flattened.permute(0, 2, 1) # [batch_size, P^2•C, N] -> [batch_size, N, P^2•C]
print(f"Patch embedding sequence shape: {image_out_of_conv_flattened_reshaped.shape} -> [batch_size, num_patches, embedding_size]")
Patch embedding sequence shape: torch.Size([1, 196, 768]) -> [batch_size, num_patches, embedding_size]
In [ ]:
# Get a single flattened feature map
single_flattened_feature_map = image_out_of_conv_flattened_reshaped[:, :, 0] # index: (batch_size, number_of_patches, embedding_dimension)
# Plot the flattened feature map visually
plt.figure(figsize=(22, 22))
plt.imshow(single_flattened_feature_map.detach().numpy())
plt.title(f"Flattened feature map shape: {single_flattened_feature_map.shape}")
plt.axis(False);
In [ ]:
# See the flattened feature map as a tensor
single_flattened_feature_map, single_flattened_feature_map.requires_grad, single_flattened_feature_map.shape
Out[ ]:
(tensor([[ 0.4720, 0.4569, 0.3627, 0.3329, 0.3450, 0.1972, 0.1450, 0.2124,
0.1681, 0.1995, 0.0030, 0.0150, 0.0408, -0.0208, 0.3480, 0.2649,
0.3930, 0.3091, 0.1293, 0.1407, 0.1049, 0.0949, 0.0869, 0.0535,
0.0806, 0.2279, -0.0084, 0.0353, 0.2417, 0.3728, 0.2824, 0.2928,
0.2926, 0.0080, -0.0288, 0.0794, 0.0931, 0.0670, 0.0492, 0.0297,
0.0258, 0.0994, 0.1512, 0.1639, 0.2057, 0.3360, 0.1675, 0.1915,
0.3541, 0.3075, 0.1526, 0.0591, 0.0985, 0.1104, 0.0998, 0.1399,
0.0700, 0.0468, 0.0447, 0.0440, 0.1784, 0.0653, 0.1721, 0.2511,
0.1725, 0.0268, 0.0608, -0.0088, 0.1059, 0.1353, 0.0704, 0.1071,
0.0472, -0.0375, 0.1636, 0.1010, 0.1597, 0.1485, 0.1692, 0.2489,
0.0979, 0.1014, 0.3265, 0.0729, 0.0567, 0.0700, 0.0094, 0.1024,
0.1629, 0.0403, 0.1480, 0.1586, 0.0038, 0.2564, 0.0658, 0.1980,
0.2527, 0.1586, 0.0587, 0.1058, 0.0719, -0.0503, 0.0370, 0.1696,
0.1577, 0.0563, 0.0141, 0.0112, 0.1327, 0.0717, 0.2524, 0.1221,
0.0951, 0.0835, 0.0319, 0.0269, 0.0532, -0.0198, 0.0477, 0.1058,
0.0566, 0.0901, 0.2299, 0.1150, 0.1273, 0.2273, 0.0891, 0.0427,
0.0592, 0.0691, 0.0634, -0.0264, -0.0023, 0.0232, 0.2197, 0.0818,
0.2860, 0.0940, 0.0953, 0.0649, 0.0685, 0.0777, 0.0261, 0.0447,
0.0460, 0.1362, 0.2436, -0.0026, 0.0223, 0.0761, 0.2044, 0.1443,
0.0535, 0.0100, 0.0753, 0.0885, 0.1101, 0.0365, 0.0654, 0.0582,
0.0585, 0.0168, 0.1001, 0.0454, 0.1100, 0.0717, 0.0976, 0.0365,
0.0231, 0.1516, 0.1113, 0.1454, 0.1218, 0.1493, 0.0615, 0.0303,
0.0728, 0.0806, 0.1288, 0.0632, 0.0545, 0.0579, 0.1339, 0.0176,
0.2339, 0.2434, 0.1646, 0.1614, 0.1995, 0.2260, 0.1002, 0.1222,
0.1036, 0.0652, 0.0572, 0.0574]], grad_fn=<SelectBackward0>),
True,
torch.Size([1, 196]))
In [ ]:
class PatchEmbedding(nn.Module):
def __init__(self,
in_channels=3,
patch_size=16,
embedding_dim=768):
super().__init__()
self.patch_size = patch_size
# Create a layer to turn an image into patches
self.patcher = nn.Conv2d(in_channels=in_channels,
out_channels=embedding_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0)
# Create a layer to flatten the patches
self.flatten = nn.Flatten(2,3) # only flatten the feature map dimensions into a single vector
# Forward Method
def forward(self, x):
image_resolution = x.shape[-1] # image resolution check
assert image_resolution % patch_size == 0, f"Image resolution {image_resolution} must be divisible by patch size {self.patch_size}" # 나뉘어야 함, 아닐 경우 print
# forward pass
x_flattened = self.flatten(self.patcher(x))
return x_flattened.permute(0,2,1) # adjust so the embedding is on the final dimension [batch_size, P^2•C, N] -> [batch_size, N (number of patches), P^2•C (embedding_dimension)]
In [ ]:
set_seeds()
patchify = PatchEmbedding(in_channels=3,
patch_size=16,
embedding_dim=768)
# Pass a single image through PatchEmbedding
print(image.unsqueeze(0).shape)
patch_embedded_image = patchify(image.unsqueeze(0))
print(patch_embedded_image.shape)
torch.Size([1, 3, 224, 224])
torch.Size([1, 196, 768])
In [ ]:
random_input_image = (1,3,224,224)
summary(PatchEmbedding(),
input_size=random_input_image,
col_names=["input_size", "output_size", "num_params", "trainable"],
col_width=20,
row_settings=["var_names"])
Out[ ]:
========================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
========================================================================================================================
PatchEmbedding (PatchEmbedding) [1, 3, 224, 224] [1, 196, 768] -- True
├─Conv2d (patcher) [1, 3, 224, 224] [1, 768, 14, 14] 590,592 True
├─Flatten (flatten) [1, 768, 14, 14] [1, 768, 196] -- --
========================================================================================================================
Total params: 590,592
Trainable params: 590,592
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 115.76
========================================================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 1.20
Params size (MB): 2.36
Estimated Total Size (MB): 4.17
========================================================================================================================
In [ ]:
patch_embedded_image, patch_embedded_image.shape
Out[ ]:
(tensor([[[-0.9774, 0.3165, -0.2863, ..., 0.7592, -0.4357, 0.2859],
[-0.9457, 0.2263, -0.1818, ..., 0.6788, -0.4486, 0.3397],
[-0.9073, 0.2230, -0.0745, ..., 0.6548, -0.5242, 0.2882],
...,
[-0.3089, 0.0881, -0.1674, ..., 0.1896, -0.0919, 0.1633],
[-0.2465, 0.0679, -0.1307, ..., 0.2124, -0.0784, 0.1908],
[-0.2418, 0.0056, -0.0883, ..., 0.1460, -0.1009, 0.1555]]],
grad_fn=<PermuteBackward0>),
torch.Size([1, 196, 768]))
In [ ]:
# Make learnable class token, add it to the number_of_patches dimension
batch_size = patch_embedded_image.shape[0]
embedding_dimension = patch_embedded_image.shape[-1]
batch_size, embedding_dimension
class_token = nn.Parameter(torch.randn(batch_size, 1, embedding_dimension),
requires_grad=True)
class_token, class_token.shape
Out[ ]:
(Parameter containing:
tensor([[[-2.4663e+00, 4.8666e-01, 8.1275e-01, 3.2980e-01, -1.6121e+00,
-5.7590e-01, -2.2805e-01, 2.9429e-01, 1.1967e-01, 1.5855e+00,
-2.9088e-01, -1.1353e+00, 1.4488e+00, -1.8383e+00, 8.0495e-01,
-6.1171e-01, -2.0441e+00, 6.1023e-01, 8.8660e-01, 2.9294e+00,
1.0633e+00, 1.9676e+00, -2.5913e-01, 4.3707e-01, 2.1142e+00,
1.0478e+00, -3.8597e-01, -1.0752e+00, -8.8747e-01, -5.0138e-01,
1.3277e+00, 1.1156e+00, -1.0734e+00, -1.5635e-01, -1.0746e-01,
7.7948e-01, -1.6764e+00, -9.8822e-01, 5.4542e-01, -1.2481e-01,
-2.7308e-01, -1.0960e+00, -2.7923e-01, 3.1009e-01, -7.9536e-03,
7.0091e-01, -9.8159e-02, 1.5224e+00, -3.0794e-01, -8.6819e-01,
-4.4327e-01, 1.8111e+00, 1.5676e+00, -7.6208e-01, 2.2796e-01,
-2.7188e-01, -4.5235e-01, -7.5575e-01, -1.0747e+00, -8.1149e-01,
-1.0482e+00, -3.1524e-01, 3.0454e+00, -4.0238e-01, -2.8119e+00,
6.4365e-01, -1.5131e-01, -2.4506e-02, 2.0009e+00, -1.5997e-01,
-2.3831e+00, 1.5303e-01, 1.4485e+00, -6.1641e-01, -1.9251e-01,
5.7108e-01, 8.5249e-02, -9.1584e-01, 1.3751e+00, -8.0936e-01,
-6.4958e-01, 4.5701e-01, -2.6298e-01, 5.7950e-01, -6.2483e-01,
5.4139e-01, -5.2634e-01, 1.4987e+00, -2.6469e+00, 3.3835e-01,
4.8949e-02, 7.0244e-01, 7.3006e-01, 1.1651e+00, -9.1065e-01,
-9.8773e-01, 1.4321e+00, -2.3764e+00, 3.6966e-01, 7.1336e-01,
5.6285e-01, 1.4557e+00, 2.3149e+00, -1.3751e+00, 3.0908e-02,
-1.3450e+00, -1.2066e+00, -5.8235e-01, 1.7946e+00, 1.2021e+00,
-7.2061e-01, 8.1186e-01, 7.1330e-01, -4.6015e-01, -2.2334e-01,
8.5654e-01, 5.5488e-01, -4.2007e-01, -5.3695e-01, -6.7887e-01,
1.5869e+00, 7.7618e-01, 7.9760e-01, -2.1267e-01, 1.3630e+00,
-1.0513e+00, -3.5729e-01, -5.4563e-01, -1.4671e+00, 2.1267e-01,
-1.0036e+00, -4.5995e-01, 4.6842e-01, -1.6592e+00, 3.2974e-01,
1.5333e+00, -4.0662e-01, -5.0779e-01, 3.5816e-01, 2.5230e+00,
-1.6344e+00, 1.4451e+00, -4.1123e-01, -1.9125e+00, -1.2149e+00,
9.4946e-01, 4.0797e-01, 1.0470e+00, -3.6169e-01, 2.5102e-01,
6.5731e-01, -1.2497e-01, 5.9888e-01, 9.5168e-01, -1.8938e-01,
2.4328e+00, 8.5947e-01, 1.9068e-01, 4.0164e-01, 1.5136e+00,
1.2289e+00, -2.4048e-01, 4.9195e-01, 8.6668e-01, -8.8110e-01,
-1.3052e+00, -1.4150e-01, -4.6848e-01, 8.2842e-01, -7.4216e-01,
6.0980e-01, -4.0975e-01, -2.3762e-01, 7.9649e-01, 2.6647e-01,
1.3864e+00, -8.6840e-02, 1.4389e-01, 7.1538e-01, 3.7460e-01,
1.8901e+00, 5.8136e-01, 7.8046e-01, 2.1103e+00, 1.1667e-01,
-1.7420e+00, -1.2490e+00, 1.0989e+00, 2.5245e+00, -6.4193e-01,
5.9591e-01, -1.8392e+00, 9.5644e-01, -8.1351e-01, -1.3015e+00,
-6.7723e-01, -8.7792e-01, -7.2687e-01, 2.2748e-01, -2.3358e-01,
7.2080e-01, 1.2446e+00, 3.5584e-01, 4.3763e-02, 8.1047e-02,
-6.1666e-01, -7.1315e-01, -8.4045e-01, -6.8671e-01, -7.8864e-01,
3.2227e-01, -1.5264e+00, 8.8559e-01, 2.7902e-01, 1.1361e+00,
3.0836e-01, 6.2971e-01, 9.5257e-01, 1.1608e-01, -1.0034e+00,
-9.2401e-01, 9.3538e-01, 9.0322e-01, -8.5360e-01, -1.0055e+00,
-4.9312e-01, 2.4158e-01, 7.1069e-01, 1.1498e+00, -1.3105e+00,
9.6594e-01, -9.9835e-02, 2.2057e+00, -2.8795e-01, 2.4050e-02,
3.8941e-01, -9.9020e-02, 1.9942e+00, -2.8806e+00, -2.5648e+00,
9.2936e-01, 3.0717e+00, -7.4013e-01, 8.2029e-02, -1.1248e+00,
6.2929e-01, -3.1160e-01, -6.9379e-01, -1.2745e+00, 4.5585e-01,
-1.3939e+00, 6.5389e-01, -1.0455e+00, -3.7537e-01, -1.5477e-01,
4.4222e-01, -4.2000e-01, -4.3622e-01, 2.0585e-01, -4.1429e-03,
4.1578e-01, 2.1204e+00, -2.5788e-01, -2.0871e-01, -1.4869e-01,
7.6406e-01, -3.0639e-01, 5.0144e-01, -6.2882e-01, -3.3618e-01,
-5.1086e-01, 3.1610e-01, -6.7897e-01, -4.2431e-01, 9.6366e-02,
-1.2339e+00, -1.5003e+00, 1.1918e+00, -1.4613e+00, -3.9560e-01,
1.0379e+00, 1.1900e+00, 3.6565e-01, -2.6782e+00, 1.1432e+00,
7.4210e-01, -4.6126e-01, -2.1124e+00, 1.3408e+00, 1.6650e+00,
-4.3767e-01, -6.6343e-01, -2.4454e+00, -4.7083e-01, 1.7589e+00,
-9.6669e-03, -8.5231e-01, 4.4874e-01, 7.6007e-01, -8.1169e-01,
1.8447e+00, -7.6986e-01, -5.2545e-01, -1.4023e+00, 2.1701e-01,
2.3150e+00, -1.4652e+00, 1.9786e-01, 2.7592e-02, -2.9146e-01,
1.3744e+00, 1.4111e+00, 1.6206e-01, -2.8410e-02, -6.5128e-01,
-1.5313e+00, -2.1181e+00, 2.0185e-01, -1.2793e+00, 8.3367e-02,
-1.5131e+00, 1.1757e+00, 3.0661e-01, -4.9523e-01, 2.7490e-01,
1.1675e+00, -1.6608e-01, 5.3262e-01, 1.3098e+00, -1.0378e-01,
2.0932e-01, -4.4942e-01, 7.9521e-01, 1.2609e+00, 7.3724e-01,
-1.1676e+00, 6.5645e-01, 9.1575e-02, 1.1367e+00, -4.1655e-01,
-2.3307e-01, 1.2473e+00, -1.6037e+00, -5.8572e-01, -1.0999e+00,
1.8825e+00, 1.2295e+00, 2.0194e+00, -2.1164e+00, 7.3806e-01,
7.0929e-01, 1.2969e+00, 6.9091e-02, 1.0360e+00, 1.0298e+00,
-1.1492e+00, -4.9803e-01, 1.2814e+00, 2.2602e+00, 3.2201e-01,
7.5473e-01, 7.2348e-01, 2.7511e-01, 1.3005e-01, -2.3962e+00,
4.1528e-01, -6.1748e-01, -3.1966e-02, -3.8636e-01, 2.1751e-01,
-4.1240e-01, -1.2413e+00, 3.2481e-01, -9.1328e-01, -8.8716e-01,
9.6750e-01, -1.9257e+00, 1.6614e+00, 1.8839e+00, 8.0188e-01,
5.8801e-01, -3.7526e-01, 1.2135e+00, -2.0062e+00, -1.4055e+00,
1.6474e-01, -5.9838e-01, -1.1895e+00, -6.5246e-01, 1.8573e+00,
1.5198e+00, 4.1307e-02, -7.3637e-01, -3.0526e-01, 7.3175e-01,
-5.2235e-01, -3.4287e-01, 1.2410e+00, -1.3739e+00, 7.6886e-01,
1.9612e+00, -4.5910e-01, 5.2693e-01, -1.2945e-02, -4.7565e-01,
1.8319e-01, 1.3472e-01, 3.1519e-01, 7.6589e-01, -1.7012e-01,
-1.3584e+00, 8.5766e-01, 3.5076e-01, 5.2551e-01, 1.2906e-01,
-1.6290e+00, -1.3662e-02, 1.2275e+00, -2.7422e-01, 2.3920e-01,
-1.8072e+00, -6.8737e-01, -6.4154e-01, -3.8959e-01, 1.2829e+00,
-1.9543e+00, 2.0442e-02, -7.1100e-01, -2.9450e-01, 1.8402e-01,
5.7565e-02, -1.2126e+00, -6.9801e-01, 3.9269e-01, 1.2207e-01,
1.8093e+00, 6.8223e-01, -1.4010e-01, -1.2075e+00, 5.0409e-01,
-8.7561e-01, 1.1326e+00, -1.0453e+00, 1.1767e+00, -5.8584e-01,
4.4282e-01, 9.0208e-01, -1.1131e+00, 1.1575e+00, -1.3338e+00,
-1.2511e+00, 1.5547e+00, -1.1667e-01, 1.0422e+00, -1.4096e+00,
3.5995e-01, 1.1345e+00, -6.5438e-01, -8.1134e-02, 5.9245e-01,
-2.4482e-01, 8.8250e-01, 5.3979e-01, -1.2274e+00, -7.8686e-01,
-8.0274e-01, 6.6507e-01, -1.2384e+00, 9.6129e-01, -6.9682e-01,
6.6707e-02, 7.6631e-02, -1.1248e+00, 2.9703e-01, 3.0625e-01,
-1.2247e+00, 5.0392e-01, 2.1826e+00, 3.3741e-01, 1.2371e-01,
1.4832e+00, 3.3272e-01, 4.3121e-01, 4.2052e-02, 7.1321e-01,
1.5862e+00, -7.4614e-02, -6.6995e-01, 4.1975e-01, -2.2703e-01,
-1.0765e+00, 2.1119e-01, -1.5079e+00, 2.3137e-01, 1.2206e+00,
-2.4318e-01, -8.9700e-01, 6.7004e-01, -3.9546e-01, -6.5133e-01,
-4.4584e-01, 1.2954e+00, 6.0750e-01, -1.5378e+00, -8.1290e-01,
2.6481e-01, -4.8638e-01, 4.5494e-01, -1.9351e-01, 3.7955e-01,
-2.2265e-01, 5.4630e-01, -1.3879e+00, -1.5831e+00, 6.9914e-01,
-8.2490e-01, 4.8831e-01, 1.1783e+00, 2.3862e+00, -4.6005e-01,
-1.3280e-02, -5.3504e-01, -8.8975e-01, 2.9375e-01, -9.4672e-01,
-1.9645e+00, 8.7722e-01, 1.8751e+00, -6.8002e-01, -1.2197e-01,
-2.4844e-01, 2.6299e-01, 8.1069e-01, -1.3609e+00, -2.5422e-01,
-1.7741e-01, -6.8981e-02, -1.0476e-01, 1.3590e+00, 7.1255e-01,
-9.5029e-01, 6.7370e-01, 1.7393e+00, -8.2074e-01, -7.0101e-01,
6.3490e-01, 1.4178e+00, -1.1659e+00, 3.4812e-01, -2.7621e-01,
-4.3620e-01, 6.8950e-01, 2.1557e-01, 5.8649e-01, -4.6764e-01,
5.8723e-01, -5.5215e-01, 4.1394e-02, 2.1737e-01, 1.4294e+00,
-1.3630e+00, -4.3399e-01, 2.0615e-01, -1.1187e+00, -5.0444e-01,
7.3180e-01, 1.4964e+00, -7.0826e-01, 1.0332e-01, 8.6451e-01,
5.3216e-01, 4.6641e-01, -4.7921e-01, -2.5699e-01, 4.2708e-03,
-1.8098e-01, 1.1033e+00, -4.7951e-01, 4.2702e-01, -2.2927e+00,
-1.1406e+00, 7.0293e-01, 1.1128e+00, 4.8508e-01, -5.6911e-01,
-3.6554e-01, 4.7901e-01, -1.9588e-01, 1.6835e-01, 8.3843e-01,
-2.3870e-01, 1.0694e+00, 6.9603e-02, 1.3219e+00, -1.1589e+00,
6.1499e-01, -1.3969e+00, 6.8176e-01, -9.5809e-01, -9.6423e-02,
-7.8831e-01, 1.9547e+00, 5.6770e-02, -9.7516e-01, 1.7369e-01,
9.6280e-01, -1.0836e-01, 1.5905e-01, -1.5878e-01, -1.9887e+00,
9.7657e-01, -1.8908e-01, 7.4765e-02, 7.5543e-01, 1.3980e-01,
2.9325e-01, -3.9821e-01, 1.4709e+00, 3.2677e+00, 6.3319e-01,
-1.2042e+00, -1.8490e+00, 6.3890e-01, 4.5159e-01, -9.0168e-01,
4.5903e-01, -6.7692e-01, 2.3308e-01, 9.0799e-01, -2.9374e-01,
-1.6671e+00, 1.2336e+00, 1.4648e+00, 5.2688e-01, 1.1205e+00,
-6.3617e-01, 7.3315e-01, 8.1748e-01, -8.5610e-02, -3.0190e-01,
5.3348e-01, -1.6074e+00, 1.2343e+00, 6.2067e-01, 4.7330e-01,
9.8530e-01, 5.0630e-01, -1.8883e-01, 2.9485e+00, 6.9200e-01,
-1.0066e+00, -1.3622e+00, -1.8472e-01, -1.6898e+00, -3.3291e-01,
-9.9397e-02, -4.3042e-01, -3.8661e-01, 9.5914e-02, -1.1725e-01,
-5.2036e-01, 1.6859e-02, 1.5356e+00, -1.7651e-01, -5.3866e-01,
-1.6446e+00, 1.5086e+00, 7.5897e-01, 9.3864e-01, -4.2343e-01,
6.8639e-01, -9.0262e-01, 2.1827e-01, 4.3948e-01, 5.5954e-01,
1.1438e+00, 1.5549e-01, -8.4575e-01, 2.1083e-02, 3.8149e-01,
9.8379e-01, 1.7375e+00, -6.4812e-01, 4.6343e-01, 3.9832e-01,
-4.2474e-01, 4.0585e-01, 6.7247e-02, 3.8344e-01, 9.3451e-01,
-1.0447e+00, 7.9257e-01, 7.8476e-01, 1.2462e+00, -3.1305e-01,
4.7145e-01, -2.2445e-01, 2.2467e-02, 5.9276e-01, 2.2825e+00,
-9.0850e-02, -1.2864e+00, -1.2162e+00, -8.9073e-01, 5.4260e-02,
2.0824e+00, -1.4491e+00, -4.2042e-02, 1.9448e+00, -6.6641e-01,
-9.7422e-03, 4.4986e-01, 4.3885e-01, 1.5645e+00, -1.0725e+00,
-3.5566e-02, -5.8029e-01, -7.8112e-01, 1.0443e+00, 4.3737e-02,
-6.3965e-01, 9.6874e-02, -9.1256e-01, 1.0238e+00, 8.7988e-01,
-9.2681e-01, 9.0425e-02, -8.0759e-01, -5.5904e-01, -1.8672e+00,
-7.6924e-01, -5.3611e-01, 1.7620e-01, -2.6473e-01, 1.7035e+00,
-9.3756e-01, 8.6468e-01, -7.6115e-01, 1.0636e-01, 9.9937e-01,
-4.5102e-01, 5.9307e-01, 4.2880e-01, 1.7275e+00, -1.7822e+00,
-1.9426e+00, -3.8476e-01, -2.3721e+00, 5.8993e-01, -4.3341e-01,
-8.5441e-01, -1.2678e+00, 2.2409e-01, -2.9557e+00, 5.5826e-01,
4.8283e-01, -2.1659e-01, -1.0768e+00, 5.8981e-01, 8.1402e-01,
3.3150e-01, 6.3759e-01, 4.0372e-04, 1.0251e-01, -2.2858e+00,
9.2033e-01, 2.3699e-01, 1.4758e+00]]], requires_grad=True),
torch.Size([1, 1, 768]))
In [ ]:
patch_embedded_image_with_class_embedding = torch.cat((class_token, patch_embedded_image), dim=1)
print(f'patch_embedded_image_with_class_embedding : {patch_embedded_image_with_class_embedding.shape} -> [batch_size, number_of_patches, embedding_dimension]')
patch_embedded_image_with_class_embedding : torch.Size([1, 197, 768]) -> [batch_size, number_of_patches, embedding_dimension]
In [ ]:
number_of_patches = int((height*width)/patch_size**2)
# Get embedding dimension
embedding_dimension = patch_embedded_image_with_class_embedding.shape[-1]
# Create learnable 1D position embedding
position_embedding = nn.Parameter(torch.randn(batch_size,
number_of_patches+1,
embedding_dimension,
requires_grad=True))
# Show first 10 sequences and 10 position embedding values, check shape of position embedding
print(position_embedding[:10])
print(f"Position embedding sequence shape: {position_embedding.shape} -> [batch_size, number_of_patches+1, embedding_dimension]")
tensor([[[-0.7411, 0.8154, -0.3632, ..., 1.1716, -0.6841, -1.8584],
[-0.6329, 0.6017, 0.7737, ..., -0.4305, -1.1761, 0.5040],
[ 0.4181, -0.8031, 0.6838, ..., -0.1055, 1.2080, -1.1675],
...,
[-0.2004, -2.7290, -0.4892, ..., -1.7061, 0.1665, 1.2587],
[ 0.4338, 0.2027, 0.7711, ..., -0.2499, -0.6855, -0.0338],
[-1.7389, 0.7569, -1.0380, ..., -1.5983, -0.1549, 0.4573]]],
grad_fn=<SliceBackward0>)
Position embedding sequence shape: torch.Size([1, 197, 768]) -> [batch_size, number_of_patches+1, embedding_dimension]
In [ ]:
patch_and_position_embedding = patch_embedded_image_with_class_embedding + position_embedding
patch_and_position_embedding[:1], patch_and_position_embedding.shape
Out[ ]:
(tensor([[[-3.2074, 1.3021, 0.4496, ..., 2.0919, -0.4471, -0.3826],
[-1.6103, 0.9182, 0.4874, ..., 0.3287, -1.6118, 0.7899],
[-0.5276, -0.5769, 0.5020, ..., 0.5733, 0.7593, -0.8278],
...,
[-0.5093, -2.6409, -0.6566, ..., -1.5164, 0.0747, 1.4219],
[ 0.1873, 0.2706, 0.6404, ..., -0.0375, -0.7639, 0.1570],
[-1.9807, 0.7626, -1.1263, ..., -1.4523, -0.2558, 0.6128]]],
grad_fn=<SliceBackward0>),
torch.Size([1, 197, 768]))
- Set patch size : 16
- Single image shape, height, width
- Add batch dimension to the single image (PatchEmbedding layer compatibility)
- Create PatchEmbedding layer
- Passing single image through PatchEmbedding layer (Create patch embeddings)
- Create class token embedidng
- Prepend class token embedding to the patch embeddings
- Create position embedding (series of 1D learnable tokens)
- Add position embedding to class token + patch embeddings
In [ ]:
set_seeds()
#1
patch_size = 16
#2
print(f'Image tensor shape:{image.shape}')
height, width = image.shape[1], image.shape[2]
#3
x = image.unsqueeze(0)
print(f'Input image shape: {x.shape}')
#4
patch_embedding_layer = PatchEmbedding(in_channels=3,
patch_size=patch_size,
embedding_dim=768)
#5
patch_embedding = patch_embedding_layer(x)
print(f'Patch embedding shape:{patch_embedding.shape}')
#6
batch_size = patch_embedding.shape[0]
embedding_dimension = patch_embedding.shape[-1]
class_token = nn.Parameter(torch.randn(batch_size, 1, embedding_dimension),
requires_grad=True)
print(f'Class token shape: {class_token.shape}')
#7
patch_embedding_class_token = torch.cat((class_token, patch_embedding), dim=1)
print(f'Patch embedding with class token shape: {patch_embedding_class_token.shape}')
#8
token_shape = patch_embedding_class_token.shape
position_embedding = torch.randn(token_shape, requires_grad=True)
#9
patch_and_position_embedding = patch_embedding_class_token + position_embedding
print(f'Patch and position embedding shape: {patch_and_position_embedding.shape}')
Image tensor shape:torch.Size([3, 224, 224])
Input image shape: torch.Size([1, 3, 224, 224])
Patch embedding shape:torch.Size([1, 196, 768])
Class token shape: torch.Size([1, 1, 768])
Patch embedding with class token shape: torch.Size([1, 197, 768])
Patch and position embedding shape: torch.Size([1, 197, 768])
- Create class - MultiheadSelfAttentionBlock
- Initialize with hyperparameters
- Create LN layer using nn.LayerNorm with normalized_shape as embedding dimension
- Create MSA layer using nn.MultiheadAttention with above parameters
- Create forward() for class passing
In [ ]:
#1
class MultiheadSelfAttentionBlock(nn.Module):
#2
def __init__(self,
embedding_dimension=768,
num_heads=12,
attn_dropout=0):
super().__init__()
#3
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dimension)
#4
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dimension,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True)
#5
def forward(self, x):
#(1)
x = self.layer_norm(x)
#(2)
attn_output, _ = self.multihead_attn(query=x, key=x, value=x, need_weights=False)
return attn_output
In [ ]:
multihead_self_attention_block = MultiheadSelfAttentionBlock(embedding_dimension=embedding_dimension,
num_heads=12,
attn_dropout=0)
patched_image_through_msa_block = multihead_self_attention_block(patch_and_position_embedding)
print(f'Input shape of MSA block:{patch_and_position_embedding.shape}')
print(f'Output shape of MSA block:{patched_image_through_msa_block.shape}')
Input shape of MSA block:torch.Size([1, 197, 768])
Output shape of MSA block:torch.Size([1, 197, 768])
- Create MLPBlock class
- Initialize with hyperparameters from Table 1,3
- Create LN layer with nn.LayerNorm() with normalized_shape as embedding_dimension
- Create sequential series of MLP layers with nn.Linear(), nn.Dropout(), nn.GELU()
- Create forward() for passing in the inputs
zℓ=MLP(LN(z′ℓ))+z′ℓ,ℓ=1…L
In [ ]:
#1
class MLPBlock(nn.Module):
#2
def __init__(self,
embedding_dimension=768,
mlp_size=3072,
dropout=0.1):
super().__init__()
#3
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dimension)
#4
self.mlp = nn.Sequential(
#(1) linearity
nn.Linear(in_features=embedding_dimension,
out_features=mlp_size),
#(2) non-linearity
nn.GELU(),
#(3) dropout
nn.Dropout(p=dropout),
#(4) linearity
nn.Linear(in_features=mlp_size,
out_features=embedding_dimension),
#(5) dropout
nn.Dropout(p=dropout)
)
#5
def forward(self, x):
x = self.layer_norm(x)
x = self.mlp(x)
return x
In [ ]:
#Create instance of MLPBlock
mlp_block = MLPBlock(embedding_dimension=768,
mlp_size=3072,
dropout=0.1)
#Pass output of MSABlock through MLPBlock
patched_image_through_msa_mlp_block = mlp_block(patched_image_through_msa_block)
print(patched_image_through_msa_block.shape)
print(patched_image_through_msa_mlp_block.shape)
torch.Size([1, 197, 768])
torch.Size([1, 197, 768])
In [ ]:
patched_image_through_msa_block
Out[ ]:
tensor([[[ 0.0363, 0.0558, -0.0045, ..., 0.0497, 0.0657, -0.1147],
[ 0.0336, 0.0563, 0.0247, ..., 0.0433, 0.0438, -0.1162],
[ 0.0398, 0.0611, 0.0194, ..., 0.0197, 0.0523, -0.1503],
...,
[ 0.0574, 0.0398, -0.0308, ..., 0.0256, 0.0590, -0.1622],
[ 0.0578, 0.0208, 0.0023, ..., 0.0165, 0.0307, -0.1606],
[ 0.0466, 0.0369, -0.0139, ..., 0.0197, 0.0523, -0.1488]]],
grad_fn=<TransposeBackward0>)
In [ ]:
patched_image_through_msa_mlp_block
Out[ ]:
tensor([[[ 0.3337, 0.0167, -0.2575, ..., 0.0319, 0.0000, 0.0110],
[ 0.0000, 0.1198, -0.1656, ..., 0.0005, 0.0207, -0.0089],
[ 0.4262, 0.1166, 0.0343, ..., 0.0278, 0.1854, 0.1963],
...,
[ 0.4494, 0.1222, -0.2378, ..., -0.1426, 0.2395, 0.1712],
[ 0.2270, 0.0375, -0.1651, ..., 0.1885, 0.0214, 0.0778],
[ 0.3566, 0.1415, -0.3244, ..., -0.1056, 0.0885, -0.0283]]],
grad_fn=<MulBackward0>)
Create the Transformer Encoder¶
x_input -> MSA_block -> [MSA_block_output + x_input] -> MLP_block -> [MLP_block_output + MSA_block_output + x_input] -> ...
- Create TransformerEncoderBlock class
- Initialize with hyperparameters from Table 1,3
- Instantiate MSA block for equation 2 using MultiheadSelfAttentionBlock
- Instantiate MLP block for equation 3 using MLPBlock
- Create forward()
- Create residual connection for MSA block
- Create residual connection for MLP block
In [ ]:
#1
class TransformerEncoderBlock(nn.Module):
#2
def __init__(self,
embedding_dimension=768,
num_heads=12,
mlp_size=3072,
mlp_dropout=0.1,
attn_dropout=0):
super().__init__()
#3
self.msa_block = MultiheadSelfAttentionBlock(embedding_dimension=embedding_dimension,
num_heads=num_heads,
attn_dropout=attn_dropout)
#4
self.mlp_block = MLPBlock(embedding_dimension=embedding_dimension,
mlp_size=mlp_size,
dropout=mlp_dropout)
#5
def forward(self, x):
x = self.msa_block(x)+x
x = self.mlp_block(x)+x
return x
In [ ]:
#Create instance of TransformerEncoderBlock
transformer_encoder_block = TransformerEncoderBlock(embedding_dimension=768,
num_heads=12,
mlp_size=3072,
mlp_dropout=0.1,
attn_dropout=0)
#input/output summary of Transformer Encoder
summary(model=transformer_encoder_block,
input_size=(1,197,768),# (batch_size, num_patches, embedding_dimension)
col_names=["input_size", "output_size", "num_params", "trainable"],
col_width=20,
row_settings=["var_names"])
Out[ ]:
==================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
==================================================================================================================================
TransformerEncoderBlock (TransformerEncoderBlock) [1, 197, 768] [1, 197, 768] -- True
├─MultiheadSelfAttentionBlock (msa_block) [1, 197, 768] [1, 197, 768] -- True
│ └─LayerNorm (layer_norm) [1, 197, 768] [1, 197, 768] 1,536 True
│ └─MultiheadAttention (multihead_attn) -- [1, 197, 768] 2,362,368 True
├─MLPBlock (mlp_block) [1, 197, 768] [1, 197, 768] -- True
│ └─LayerNorm (layer_norm) [1, 197, 768] [1, 197, 768] 1,536 True
│ └─Sequential (mlp) [1, 197, 768] [1, 197, 768] -- True
│ │ └─Linear (0) [1, 197, 768] [1, 197, 3072] 2,362,368 True
│ │ └─GELU (1) [1, 197, 3072] [1, 197, 3072] -- --
│ │ └─Dropout (2) [1, 197, 3072] [1, 197, 3072] -- --
│ │ └─Linear (3) [1, 197, 3072] [1, 197, 768] 2,360,064 True
│ │ └─Dropout (4) [1, 197, 768] [1, 197, 768] -- --
==================================================================================================================================
Total params: 7,087,872
Trainable params: 7,087,872
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 4.73
==================================================================================================================================
Input size (MB): 0.61
Forward/backward pass size (MB): 8.47
Params size (MB): 18.90
Estimated Total Size (MB): 27.98
==================================================================================================================================
Putting it altogether to create ViT¶
- Create ViT class
- Initialize with hyperparameters from Table 1,3
- Check image size divisible by patch size
- Calcuate number of patches
- Create learnable class embedding token
- Create learnable postion embedding vector
- Setup embedding dropout layer
- Create patch embedding layer using PatchEmbedding
- Create series of Transformer Encoder blocks using TranformerEncoderBlock and nn.Sequential()
- Create MLP head (classifier head) using nn.LayerNorm and nn.Linear, nn.Sequential
- Create forward()
- Get the batch size of the input (first dimension of the shape)
- Create patching embedding (Equation 1)
- Create class token embedding and expand it across the number of patches using torch.Tensor.expand()
- Concat class token embedding to the first dimension of the patch embedding
- Add postion embedding to the patch and class token embedding (Equation 1)
- Pass patch and postion embedding through (embedding) dropout layer
- Pass patch and postion embedding through the stack of Tranformer Encoder layers (Equation 2,3)
- Pass index 0 of the output of the stack of Transformer Encoder layers through classifier head (Equation 4)
In [ ]:
#1
class ViT(nn.Module):
#2
def __init__(self,
img_size=224,
in_channels=3,
patch_size=16,
num_transformer_layers=12,
embedding_dim=768,
mlp_size=3072,
num_heads=12,
attn_dropout=0,
mlp_dropout=0.1,
embedding_dropout=0.1,
num_classes=1000):
super().__init__()
#3
assert img_size & patch_size == 0, "Image size must be divisible by patch size"
#4
self.num_patches = patch_size**2 // img_size**2
#5
self.class_embedding = nn.Parameter(torch.randn(1,1,embedding_dim),
requires_grad=True)
#6
self.postion_embedding = nn.Parameter(torch.randn(1,self.num_patches+1,embedding_dim),
requires_grad=True)
#7
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
#8
self.patch_embedding = PatchEmbedding(in_channels=3,
patch_size=16,
embedding_dim=768)
#9
self.transformer_encoder = nn.Sequential( # 12 TransformerEncoderBlocks
*[TransformerEncoderBlock(embedding_dimension=embedding_dim,
num_heads=num_heads,
mlp_size=mlp_size,
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)]
)
#10
self.classifer = nn.Sequential(
nn.LayerNorm(normalized_shape=embedding_dim),
nn.Linear(in_features=embedding_dim,
out_features=num_classes)
)
#11
def forward(self, x):
#12
batch_size = x.shape[0]
#13
x = self.patch_embedding(x)
#14
class_token = self.class_embedding.expand(batch_size,-1,-1)
#15
x = torch.cat((class_token, x), dim=1)
#16
x = self.postion_embedding + x
#17
x = self.embedding_dropout(x)
#18
x = self.transformer_encoder(x)
#19
# print(x.shape)
x = self.classifer(x[:,0])
# print(x.shape)
return x
In [ ]:
vit = ViT()
vit
Out[ ]:
ViT(
(embedding_dropout): Dropout(p=0.1, inplace=False)
(patch_embedding): PatchEmbedding(
(patcher): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(flatten): Flatten(start_dim=2, end_dim=3)
)
(transformer_encoder): Sequential(
(0): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(1): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(2): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(3): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(4): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(5): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(6): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(7): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(8): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(9): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(10): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
(11): TransformerEncoderBlock(
(msa_block): MultiheadSelfAttentionBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(multihead_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
)
(mlp_block): MLPBlock(
(layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.1, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.1, inplace=False)
)
)
)
)
(classifer): Sequential(
(0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=768, out_features=1000, bias=True)
)
)
In [ ]:
set_seeds()
random_image_tensor = torch.randn(1,3,224,224)
vit = ViT(num_classes=len(class_names))
vit(random_image_tensor)
Out[ ]:
tensor([[ 1.0469, 0.9204, -0.4547]], grad_fn=<AddmmBackward0>)
Create optimizer, Loss function, Training¶
In [ ]:
from going_modular.going_modular import engine
set_seeds()
optimizer = torch.optim.Adam(params=vit.parameters(),
lr=1e-3,
betas=(0.9,0.999),
weight_decay=0.1)
loss_fn = torch.nn.CrossEntropyLoss()
results = engine.train(model=vit,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=10,
optimizer=optimizer,
loss_fn=loss_fn,
device=device)
10%|█ | 1/10 [00:02<00:26, 2.96s/it]
Epoch: 1 | train_loss: 1.2664 | train_acc: 0.3750 | test_loss: 2.3642 | test_acc: 0.2604
20%|██ | 2/10 [00:05<00:23, 2.99s/it]
Epoch: 2 | train_loss: 1.3274 | train_acc: 0.4297 | test_loss: 1.5884 | test_acc: 0.2604
30%|███ | 3/10 [00:08<00:20, 2.98s/it]
Epoch: 3 | train_loss: 1.1972 | train_acc: 0.3203 | test_loss: 1.2175 | test_acc: 0.2604
40%|████ | 4/10 [00:11<00:17, 2.96s/it]
Epoch: 4 | train_loss: 1.1743 | train_acc: 0.2695 | test_loss: 1.3095 | test_acc: 0.2604
50%|█████ | 5/10 [00:14<00:14, 2.97s/it]
Epoch: 5 | train_loss: 1.1066 | train_acc: 0.3203 | test_loss: 1.3343 | test_acc: 0.2604
60%|██████ | 6/10 [00:17<00:11, 2.98s/it]
Epoch: 6 | train_loss: 1.1584 | train_acc: 0.2617 | test_loss: 1.2096 | test_acc: 0.1979
70%|███████ | 7/10 [00:20<00:08, 2.98s/it]
Epoch: 7 | train_loss: 1.0910 | train_acc: 0.4336 | test_loss: 1.4791 | test_acc: 0.1979
80%|████████ | 8/10 [00:23<00:05, 2.97s/it]
Epoch: 8 | train_loss: 1.1700 | train_acc: 0.2695 | test_loss: 1.2251 | test_acc: 0.2604
90%|█████████ | 9/10 [00:26<00:02, 2.97s/it]
Epoch: 9 | train_loss: 1.1564 | train_acc: 0.2773 | test_loss: 1.0779 | test_acc: 0.5218
100%|██████████| 10/10 [00:29<00:00, 2.97s/it]
Epoch: 10 | train_loss: 1.1326 | train_acc: 0.2812 | test_loss: 1.1620 | test_acc: 0.2604
In [ ]:
from helper_functions import plot_loss_curves
plot_loss_curves(results)
Using pretrained ViT from torchvision.models on same dataset¶
In [ ]:
#1. Get pretrained weight for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
#2. Setup ViT model instance with weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
#3. Freeze base parameters
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
#4. Change classifier head
set_seeds()
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
pretrained_vit
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/joseph/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:05<00:00, 58.1MB/s]
Out[ ]:
VisionTransformer(
(conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
(encoder): Encoder(
(dropout): Dropout(p=0.0, inplace=False)
(layers): Sequential(
(encoder_layer_0): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_1): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_2): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_3): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_4): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_5): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_6): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_7): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_8): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_9): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_10): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
(encoder_layer_11): EncoderBlock(
(ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(self_attention): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
(mlp): MLPBlock(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=3072, out_features=768, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
(ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
)
(heads): Linear(in_features=768, out_features=3, bias=True)
)
In [ ]:
summary(model=pretrained_vit,
input_size=(32,3,224,224),
col_names=["input_size", "output_size", "num_params", "trainable"],
col_width=20,
row_settings=["var_names"])
Out[ ]:
============================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
============================================================================================================================================
VisionTransformer (VisionTransformer) [32, 3, 224, 224] [32, 3] 768 Partial
├─Conv2d (conv_proj) [32, 3, 224, 224] [32, 768, 14, 14] (590,592) False
├─Encoder (encoder) [32, 197, 768] [32, 197, 768] 151,296 False
│ └─Dropout (dropout) [32, 197, 768] [32, 197, 768] -- --
│ └─Sequential (layers) [32, 197, 768] [32, 197, 768] -- False
│ │ └─EncoderBlock (encoder_layer_0) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_1) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_2) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_3) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_4) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_5) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_6) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_7) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_8) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_9) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_10) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ │ └─EncoderBlock (encoder_layer_11) [32, 197, 768] [32, 197, 768] (7,087,872) False
│ └─LayerNorm (ln) [32, 197, 768] [32, 197, 768] (1,536) False
├─Linear (heads) [32, 768] [32, 3] 2,307 True
============================================================================================================================================
Total params: 85,800,963
Trainable params: 2,307
Non-trainable params: 85,798,656
Total mult-adds (Units.GIGABYTES): 5.52
============================================================================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 3330.74
Params size (MB): 229.20
Estimated Total Size (MB): 3579.21
============================================================================================================================================
In [ ]:
print(image_path)
train_dir = image_path / "train"
test_dir = image_path / "test"
train_dir, test_dir
data/pizza_steak_sushi
Out[ ]:
(PosixPath('data/pizza_steak_sushi/train'),
PosixPath('data/pizza_steak_sushi/test'))
In [ ]:
vit_transforms = pretrained_vit_weights.transforms()
vit_transforms
Out[ ]:
ImageClassification(
crop_size=[224]
resize_size=[256]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)
In [ ]:
#Setup dataloaders
train_dataloader_pretrained, test_dataloader_pretrained, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=vit_transforms,
batch_size=32
)
Train feature extractor of ViT model
In [ ]:
from going_modular.going_modular import engine
set_seeds()
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
#Train classifier head of pretrained ViT feature extractor model
pretrained_vit_results = engine.train(
model=pretrained_vit,
train_dataloader=train_dataloader_pretrained,
test_dataloader=test_dataloader_pretrained,
epochs=10,
optimizer=optimizer,
loss_fn=loss_fn,
device=device
)
10%|█ | 1/10 [00:02<00:22, 2.48s/it]
Epoch: 1 | train_loss: 0.7466 | train_acc: 0.7500 | test_loss: 0.5321 | test_acc: 0.8665
20%|██ | 2/10 [00:04<00:19, 2.47s/it]
Epoch: 2 | train_loss: 0.3356 | train_acc: 0.9453 | test_loss: 0.3219 | test_acc: 0.8977
30%|███ | 3/10 [00:07<00:17, 2.48s/it]
Epoch: 3 | train_loss: 0.2032 | train_acc: 0.9531 | test_loss: 0.2685 | test_acc: 0.9081
40%|████ | 4/10 [00:09<00:14, 2.48s/it]
Epoch: 4 | train_loss: 0.1536 | train_acc: 0.9570 | test_loss: 0.2411 | test_acc: 0.9081
50%|█████ | 5/10 [00:12<00:12, 2.49s/it]
Epoch: 5 | train_loss: 0.1231 | train_acc: 0.9727 | test_loss: 0.2271 | test_acc: 0.8977
60%|██████ | 6/10 [00:14<00:09, 2.50s/it]
Epoch: 6 | train_loss: 0.1201 | train_acc: 0.9766 | test_loss: 0.2123 | test_acc: 0.9280
70%|███████ | 7/10 [00:17<00:07, 2.49s/it]
Epoch: 7 | train_loss: 0.0925 | train_acc: 0.9766 | test_loss: 0.2348 | test_acc: 0.8883
80%|████████ | 8/10 [00:19<00:04, 2.49s/it]
Epoch: 8 | train_loss: 0.0786 | train_acc: 0.9844 | test_loss: 0.2277 | test_acc: 0.8778
90%|█████████ | 9/10 [00:22<00:02, 2.48s/it]
Epoch: 9 | train_loss: 0.1080 | train_acc: 0.9883 | test_loss: 0.2071 | test_acc: 0.9384
100%|██████████| 10/10 [00:24<00:00, 2.49s/it]
Epoch: 10 | train_loss: 0.0640 | train_acc: 0.9922 | test_loss: 0.1804 | test_acc: 0.9176
In [ ]:
from helper_functions import plot_loss_curves
# Plot the loss curves
plot_loss_curves(pretrained_vit_results)
In [ ]:
# Save model
from going_modular.going_modular import utils
utils.save_model(model=pretrained_vit,
target_dir="models",
model_name="pretrained_vit_feature_extractor_pizza_steak_sushi.pth")
[INFO] Saving model to: models/pretrained_vit_feature_extractor_pizza_steak_sushi.pth
In [ ]:
from pathlib import Path
pretrained_vit_model_size = Path("models/pretrained_vit_feature_extractor_pizza_steak_sushi.pth").stat().st_size // (1024**2)
print(f"Model size: {pretrained_vit_model_size} MB")
Model size: 327 MB
Predictions on custom image¶
In [ ]:
import requests
from going_modular.going_modular.predictions import pred_and_plot_image
custom_image_path = image_path / "04-pizza-dad.jepg"
if not custom_image_path.is_file():
with open(custom_image_path, "wb") as f:
# When downloading from GitHub, need to use the "raw" file link
request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/04-pizza-dad.jpeg")
print(f"Downloading {custom_image_path}...")
f.write(request.content)
else:
print(f"{custom_image_path} already exists, skipping download.")
pred_and_plot_image(model=pretrained_vit,
image_path=custom_image_path,
class_names=class_names)
Downloading data/pizza_steak_sushi/04-pizza-dad.jepg...