{
"cells": [
{
"cell_type": "markdown",
"id": "dee6d49a-374b-4631-9069-8b16aac31afc",
"metadata": {
"execution": {}
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"id": "6221f0d3-4a18-416b-b548-3d9a5b82dedf",
"metadata": {
"execution": {}
},
"source": [
"# Tutorial 2: Contrastive learning for object recognition\n",
"\n",
"**Week 1, Day 2: Comparing Tasks**\n",
"\n",
"**By Neuromatch Academy**\n",
"\n",
"__Content creators:__ Andrew F. Luo, Leila Wehbe\n",
"\n",
"__Content reviewers:__ Samuele Bolotta, Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Patrick Mineault\n",
"\n",
"__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault\n"
]
},
{
"cell_type": "markdown",
"id": "2734dc52-6dc9-4669-a7c7-99ef87e47b8b",
"metadata": {
"execution": {}
},
"source": [
"___\n",
"\n",
"\n",
"# Tutorial Objectives\n",
"\n",
"*Estimated timing of tutorial: 40 minutes*\n",
"\n",
"By the end of this tutorial, participants will be able to:\n",
"\n",
"1. Understand why we want to do contrastive learning.\n",
"2. Understand the losses used in contrastive learning.\n",
"3. Train a network using contrastive learning on MNIST."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aceddeef-ea13-4224-8ca8-70563edbcf00",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @markdown\n",
"from IPython.display import IFrame\n",
"from ipywidgets import widgets\n",
"out = widgets.Output()\n",
"with out:\n",
" print(f\"If you want to download the slides: https://osf.io/download/x4y79/\")\n",
" display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/x4y79/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n",
"display(out)"
]
},
{
"cell_type": "markdown",
"id": "273ae64d-bc44-4e0e-b077-487da8391334",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Setup\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "e9a0ea4d",
"metadata": {
"execution": {}
},
"source": [
"## Install and import feedback gadget\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Install and import feedback gadget\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3eeb07c-5de3-428e-a422-3a58f1124b43",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Install and import feedback gadget\n",
"\n",
"!pip install vibecheck numpy matplotlib torch torchvision tqdm ipysankeywidget ipywidgets seaborn --quiet\n",
"\n",
"from vibecheck import DatatopsContentReviewContainer\n",
"def content_review(notebook_section: str):\n",
" return DatatopsContentReviewContainer(\n",
" \"\", # No text prompt\n",
" notebook_section,\n",
" {\n",
" \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n",
" \"name\": \"neuromatch_neuroai\",\n",
" \"user_key\": \"wb2cxze8\",\n",
" },\n",
" ).render()\n",
"\n",
"feedback_prefix = \"W1D2_T2\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import dependencies\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b56685b9-3f17-4a55-acca-73dad5623992",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Import dependencies\n",
"\n",
"import logging\n",
"import gc\n",
"import contextlib\n",
"import io\n",
"\n",
"# PyTorch and related libraries\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader\n",
"import torchvision\n",
"\n",
"# Set up PyTorch backend configurations\n",
"torch.backends.cuda.matmul.allow_tf32 = True\n",
"torch.backends.cudnn.allow_tf32 = True\n",
"\n",
"# Numpy for numerical operations\n",
"import numpy as np\n",
"\n",
"# Matplotlib & Seaborn for plotting\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"# Scikit-learn for machine learning utilities\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.manifold import TSNE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Figure settings\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cebc5798-a8dd-4985-b0ec-d5facf4ee700",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Figure settings\n",
"\n",
"logging.getLogger('matplotlib.font_manager').disabled = True\n",
"\n",
"%matplotlib inline\n",
"%config InlineBackend.figure_format = 'retina' # perform high definition rendering for images and plots\n",
"plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "386b49a0-9409-496a-8656-579ca3e4af5f",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Helper functions\n",
"\n",
"# This is code from the pytorch metric learning package\n",
"\n",
"def neg_inf(dtype):\n",
" # Returns the smallest possible value for the given data type\n",
" return torch.finfo(dtype).min\n",
"\n",
"def small_val(dtype):\n",
" # Returns the smallest positive value greater than zero for the given data type\n",
" return torch.finfo(dtype).tiny\n",
"\n",
"def to_dtype(x, tensor=None, dtype=None):\n",
" # Converts tensor `x` to the specified `dtype`, or to the same dtype as `tensor`\n",
" if not torch.is_autocast_enabled():\n",
" dt = dtype if dtype is not None else tensor.dtype\n",
" if x.dtype != dt:\n",
" x = x.type(dt)\n",
" return x\n",
"\n",
"def get_matches_and_diffs(labels, ref_labels=None):\n",
" # Returns tensors indicating matches and differences between pairs of labels\n",
" if ref_labels is None:\n",
" ref_labels = labels\n",
" labels1 = labels.unsqueeze(1) # Expand dimensions for comparison\n",
" labels2 = ref_labels.unsqueeze(0) # Expand dimensions for comparison\n",
" matches = (labels1 == labels2).byte() # Byte tensor of matches\n",
" diffs = matches ^ 1 # Byte tensor of differences (inverse of matches)\n",
" if ref_labels is labels:\n",
" matches.fill_diagonal_(0) # Remove self-matches\n",
" return matches, diffs\n",
"\n",
"def get_all_pairs_indices(labels, ref_labels=None):\n",
" \"\"\"\n",
" Given a tensor of labels, this will return 4 tensors.\n",
" The first 2 tensors are the indices which form all positive pairs\n",
" The second 2 tensors are the indices which form all negative pairs\n",
" \"\"\"\n",
" matches, diffs = get_matches_and_diffs(labels, ref_labels)\n",
" a1_idx, p_idx = torch.where(matches) # Indices for positive pairs\n",
" a2_idx, n_idx = torch.where(diffs) # Indices for negative pairs\n",
" return a1_idx, p_idx, a2_idx, n_idx\n",
"\n",
"def cos_sim(input_embeddings):\n",
" # Computes cosine similarity matrix for input embeddings\n",
" normed_embeddings = torch.nn.functional.normalize(input_embeddings, dim=-1) # Normalize embeddings\n",
" return normed_embeddings @ normed_embeddings.t() # Cosine similarity matrix"
]
},
{
"cell_type": "markdown",
"id": "pIBjtdIDeeVg",
"metadata": {
"execution": {}
},
"source": [
"# Section 1: Building the model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a47d3240-0a70-4e3b-af13-2c02528867fa",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video\n",
"\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'oGs-90Nzzw8'), ('Bilibili', 'BV1pb421H7Tz')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2cb688bf-38c8-40e1-b40b-c4014cdaffba",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_video\")"
]
},
{
"cell_type": "markdown",
"id": "d55436c6-0750-4dcc-ba2d-526c9a64c44b",
"metadata": {
"execution": {}
},
"source": [
"## What is contrastive learning?\n",
"\n",
"Contrastive learning is a form of self-supervised learning (SSL). Contrastive learning seeks to map inputs to a high-dimensional space, bringing similar examples closer together and pushing dissimilar examples farther apart.\n",
"\n",
"It may not be immediately obvious why you want to engage in contrastive learning. Can't we just use a large 1000-class ImageNet-trained classifier to recognize every image? Contrastive learning proves useful when the number of classes is not known ahead of time. For example, if you wanted a network to recognize human faces, there are approximately 8 billion people on the planet, making it impractical to train a classification network with 8 billion output neurons. Instead, you can train a network to output a high-dimensional embedding for each image. With this approach, given a reference image of a person, the network can determine if a new photo is similar to or different from the reference image.\n",
"\n",
"In this section, we will:\n",
"\n",
"* Construct a model that maps images to a high-dimensional space.\n",
"* Visualize the geometric properties of the embedding prior to model training."
]
},
{
"cell_type": "markdown",
"id": "2f1725ea-5d80-4cb8-8b9b-b39d69f21dcf",
"metadata": {
"execution": {}
},
"source": [
"## Constructing the model\n",
"\n",
"We'll now construct a fully connected artificial neural network for contrastive learning, built from residual blocks in the style of a ResNet. This will look much like a classification network, but without a classification head at the end. Instead, the network maps images to a high-dimensional space:\n",
"\n",
"$$f(\\mathbf{x}) = \\mathbf{z}$$\n",
"\n",
"where $f$ is the network, $\\mathbf{x}$ is an input image and $\\mathbf{z}$ is the embedding. $\\mathbf{z}$ is real a vector with dimension `out_dim`, normalized to have a norm of 1. Later, we will train the network such that similar images have similar embeddings and dissimilar images have dissimilar embeddings.\n",
"\n",
"### Building the model from residual blocks\n",
"\n",
"We first define a residual block. The block contains a prenormalization step and a leaky ReLU activation function to help with vanishing gradients, in addition to linear layers. Residual networks tend to be easier to optimize than corresponding plain networks."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7c432bf-1de5-43f6-b887-087d17e9ec0f",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class ResidualBlock(nn.Module):\n",
" # Follows \"Identity Mappings in Deep Residual Networks\", uses LayerNorm instead of BatchNorm, and LeakyReLU instead of ReLU\n",
" def __init__(self, feat_in=128, feat_out=128, feat_hidden=256, use_norm=True):\n",
" super().__init__()\n",
" # Define the residual block with or without normalization\n",
" if use_norm:\n",
" self.block = nn.Sequential(\n",
" nn.LayerNorm(feat_in), # Layer normalization on input features\n",
" nn.LeakyReLU(negative_slope=0.1), # LeakyReLU activation\n",
" nn.Linear(feat_in, feat_hidden), # Linear layer transforming input to hidden features\n",
" nn.LayerNorm(feat_hidden), # Layer normalization on hidden features\n",
" nn.LeakyReLU(negative_slope=0.1), # LeakyReLU activation\n",
" nn.Linear(feat_hidden, feat_out) # Linear layer transforming hidden to output features\n",
" )\n",
" else:\n",
" self.block = nn.Sequential(\n",
" nn.LeakyReLU(negative_slope=0.1), # LeakyReLU activation\n",
" nn.Linear(feat_in, feat_hidden), # Linear layer transforming input to hidden features\n",
" nn.LeakyReLU(negative_slope=0.1), # LeakyReLU activation\n",
" nn.Linear(feat_hidden, feat_out) # Linear layer transforming hidden to output features\n",
" )\n",
"\n",
" # Define the bypass connection\n",
" if feat_in != feat_out:\n",
" self.bypass = nn.Linear(feat_in, feat_out) # Linear layer to match dimensions if they differ\n",
" else:\n",
" self.bypass = nn.Identity() # Identity layer if input and output dimensions are the same\n",
"\n",
" def forward(self, input_data):\n",
" # Forward pass: apply the block and add the bypass connection\n",
" return self.block(input_data) + self.bypass(input_data)"
]
},
{
"cell_type": "markdown",
"id": "45862ae1",
"metadata": {
"execution": {}
},
"source": [
"With this in hand, we'll build our network from a series of residual blocks. We use `nn.Sequential` to chain the blocks together."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cdc7833-aafd-4ee0-83fd-af9922dd5857",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, in_dim, out_dim, hidden_dim, num_blocks=4):\n",
" super().__init__()\n",
" # Initial linear projection from input dimension to hidden dimension\n",
" self.in_proj = nn.Linear(in_dim, hidden_dim)\n",
" # Sequence of residual blocks\n",
" self.hidden = nn.Sequential(\n",
" *[ResidualBlock(feat_in=hidden_dim, feat_out=hidden_dim, feat_hidden=hidden_dim) for i in range(num_blocks)]\n",
" )\n",
" # Output linear projection from hidden dimension to output dimension\n",
" self.out = nn.Linear(hidden_dim, out_dim)\n",
"\n",
" def forward(self, x):\n",
" # Forward pass: input projection, passing through residual blocks, and final output projection\n",
" in_proj_out = self.in_proj(x)\n",
" hidden_out = self.hidden(in_proj_out)\n",
" embedding = self.out(hidden_out)\n",
" return F.normalize(embedding, p=2, dim=-1)"
]
},
{
"cell_type": "markdown",
"id": "96fdc011",
"metadata": {
"execution": {}
},
"source": [
"Notice that the output of the network is normalized using `F.normalize` so that each embedding has unit norm. Let's now look at the geometry of images processed by an untrained network. We first load data from MNIST."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d72d9231-35a5-405e-b2f6-135e134e0ae9",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Define the transformations for the MNIST dataset\n",
"mnist_transforms = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(), # Convert images to tensor\n",
" torchvision.transforms.Normalize((0.1307,), (0.3081,)) # Normalize the images with mean and standard deviation\n",
"])\n",
"\n",
"with contextlib.redirect_stdout(io.StringIO()):\n",
" # Load the MNIST test dataset with the defined transformations\n",
" test_dset = torchvision.datasets.MNIST(\"./\", train=False, transform=mnist_transforms, download=True)\n",
"\n",
"# Calculate the height and width of the MNIST images (28x28)\n",
"height = int(784**0.5)\n",
"width = height\n",
"\n",
"# Select the first image from the test dataset\n",
"idx = 0\n",
"data_point = test_dset[idx]\n",
"\n",
"# Display the image using matplotlib\n",
"plt.figure(figsize=(3, 3))\n",
"plt.imshow(data_point[0][0].numpy(), cmap='gray') # Display the image in grayscale\n",
"plt.show()\n",
"\n",
"# Print the label of the selected image\n",
"print(data_point[1])"
]
},
{
"cell_type": "markdown",
"id": "CkhXcVrGhUM5",
"metadata": {
"execution": {}
},
"source": [
"Now we will create the model using the definition we wrote previously and move it to the desired device. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffd3ab89-ffd9-408c-8396-25b88e0930fb",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Initialize the model with specified input, output, and hidden dimensions\n",
"mynet = Model(in_dim=784, out_dim=128, hidden_dim=256)\n",
"\n",
"# Automatically select the device (GPU if available, otherwise CPU)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# Output the device that will be used\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Move the model to the selected device\n",
"_ = mynet.to(device)"
]
},
{
"cell_type": "markdown",
"id": "2c29c9b2",
"metadata": {
"execution": {}
},
"source": [
"## The geometry of the untrained network\n",
"\n",
"Let's visualize how close different examples are in the embedding space before training. We'll first gather the embeddings from multiple images by passing them through the network."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2f9122b-ce83-464f-b6c9-5af2feceac62",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# First try with untrained network, find the cosine similarities within a class and across classes\n",
"\n",
"# Create a DataLoader for the test dataset with a batch size of 50\n",
"test_loader = DataLoader(test_dset, batch_size=50, shuffle=False) # enable persistent_workers=True if more than 1 worker to save CPU\n",
"\n",
"# Set the model to evaluation mode\n",
"mynet.eval()\n",
"\n",
"# Initialize lists to store test embeddings and labels\n",
"test_embeddings = []\n",
"test_labels = []\n",
"\n",
"# Initialize a similarity matrix of size 10x10 for 10 classes\n",
"sim_matrix = np.zeros((10, 10))\n",
"\n",
"# Disable gradient computation for inference\n",
"with torch.inference_mode():\n",
" for data_batch in test_loader:\n",
" test_img, test_label = data_batch # Get images and labels from the batch\n",
" batch_size = test_img.shape[0] # Get the batch size\n",
" flat = test_img.reshape(batch_size, -1).to(device, non_blocking=True) # Flatten the images and move to device\n",
" pred_embeddings = mynet(flat).cpu().numpy().tolist() # Get embeddings from the model and move to CPU\n",
" test_embeddings.extend(pred_embeddings) # Store the embeddings\n",
" test_labels.extend(test_label.numpy().tolist()) # Store the labels\n",
"\n",
"# Convert embeddings and labels to numpy arrays\n",
"test_embeddings_untrained = np.array(test_embeddings)\n",
"\n",
"# Convert test labels to numpy array\n",
"test_labels_untrained = np.array(test_labels)"
]
},
{
"cell_type": "markdown",
"id": "GtbiyBRmi-e2",
"metadata": {
"execution": {}
},
"source": [
"### Code exercise 1: Visualizing the cosine similarity of embeddings within and across classes before training\n",
"\n",
"In this exercise, we'll measure the cosine similarity between embeddings of images from the same class and across different classes. We'll visualize the cosine similarity matrix to see if the network has learned to distinguish between different classes. The cosine similarity between two embedding vectors $\\mathbf{z}_1$ and $\\mathbf{z}_2$ with norm 1 is defined as:\n",
"\n",
"$$\\text{sim}(\\mathbf{z}_1, \\mathbf{z}_2) = \\mathbf{z}_1 \\cdot \\mathbf{z}_2$$\n",
"\n",
"where $\\cdot$ denotes the dot product. The cosine similarity ranges from -1 to 1, where:\n",
"\n",
"* 1 indicates the vectors are identical\n",
"* 0 indicates that the vectors are orthogonal\n",
"* -1 indicates that the vectors are diametrically opposed. "
]
},
{
"cell_type": "markdown",
"id": "0JcR1VwliL1f",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"```python\n",
"# Dictionary to store normalized embeddings for each class\n",
"embeddings = {}\n",
"for i in range(10):\n",
" embeddings[i] = test_embeddings_untrained[test_labels_untrained == i]\n",
"\n",
"############################################################\n",
"# Fill in this code to compute cosine similarity matrix within the class.\n",
"raise NotImplementedError(\"Student exercise: calculate cosine similarity.\")\n",
"############################################################\n",
"\n",
"# Within class cosine similarity:\n",
"for i in range(10):\n",
" sims = ... # Compute cosine similarity matrix within the class\n",
" np.fill_diagonal(sims, np.nan) # Ignore diagonal values (self-similarity)\n",
" cur_sim = np.nanmean(sims) # Calculate the mean similarity excluding diagonal\n",
" sim_matrix[i, i] = cur_sim # Store the within-class similarity in the matrix\n",
"\n",
"# Between class cosine similarity:\n",
"for i in range(10):\n",
" for j in range(10):\n",
" if i == j:\n",
" continue # Skip if same class (already computed)\n",
" elif i > j:\n",
" continue # Skip if already computed (matrix symmetry)\n",
" else:\n",
" sims = embeddings[i] @ embeddings[j].T # Compute cosine similarity between different classes\n",
" cur_sim = np.mean(sims) # Calculate the mean similarity\n",
" sim_matrix[i, j] = cur_sim # Store the similarity in the matrix\n",
" sim_matrix[j, i] = cur_sim # Ensure symmetry in the matrix\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"sns.heatmap(sim_matrix, vmin=0.0, vmax=1.0, annot=True, fmt=\".2f\", cmap=\"YlGnBu\", linewidths=0.5)\n",
"plt.title(\"Untrained Network Cosine Similarity Matrix\")\n",
"plt.show()\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9594e7b0-e401-4792-8798-7f23109bcf94",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# to_remove solution\n",
"\n",
"# Dictionary to store normalized embeddings for each class\n",
"embeddings = {}\n",
"for i in range(10):\n",
" embeddings[i] = test_embeddings_untrained[test_labels_untrained == i]\n",
"\n",
"# Within class cosine similarity:\n",
"for i in range(10):\n",
" sims = embeddings[i] @ embeddings[i].T # Compute cosine similarity matrix within the class\n",
" np.fill_diagonal(sims, np.nan) # Ignore diagonal values (self-similarity)\n",
" cur_sim = np.nanmean(sims) # Calculate the mean similarity excluding diagonal\n",
" sim_matrix[i, i] = cur_sim # Store the within-class similarity in the matrix\n",
"\n",
"# Between class cosine similarity:\n",
"for i in range(10):\n",
" for j in range(10):\n",
" if i == j:\n",
" continue # Skip if same class (already computed)\n",
" elif i > j:\n",
" continue # Skip if already computed (matrix symmetry)\n",
" else:\n",
" sims = embeddings[i] @ embeddings[j].T # Compute cosine similarity between different classes\n",
" cur_sim = np.mean(sims) # Calculate the mean similarity\n",
" sim_matrix[i, j] = cur_sim # Store the similarity in the matrix\n",
" sim_matrix[j, i] = cur_sim # Ensure symmetry in the matrix\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"sns.heatmap(sim_matrix, vmin=0.0, vmax=1.0, annot=True, fmt=\".2f\", cmap=\"YlGnBu\", linewidths=0.5)\n",
"plt.title(\"Untrained Network Cosine Similarity Matrix\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26dd2a45-38ed-45ea-ad27-dfddfeb1613f",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Code_Exercise_1\")"
]
},
{
"cell_type": "markdown",
"id": "adce807f",
"metadata": {
"execution": {}
},
"source": [
"### Reflection\n",
"\n",
"What do you observe in the cosine similarity matrix? Are the embeddings from the same digit class more similar than embeddings of different classes?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "728682ab",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# to_remove explanation\n",
"\"\"\"\n",
"Since our network is untrained, there isn't much difference in the cosine similarities\n",
"within and across image classes. This lack of clear structure in the similarity matrix\n",
"is expected at this stage because the network has not yet learned to distinguish between\n",
"different classes.\n",
"\n",
"Ideally, we should observe a very high cosine similarity for images within the same\n",
"class (along the diagonal) and very low cosine similarity for images from different\n",
"classes (off-diagonal).\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d13d5761-e2ec-4ae1-adb4-edbe57b77ba9",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Discussion_Point_1\")"
]
},
{
"cell_type": "markdown",
"id": "bebbe1da-d23c-4367-94d3-5f5d68fd0356",
"metadata": {
"execution": {}
},
"source": [
"# Section 2: Training the model and visualizing feature similarity \n",
"\n",
"Let's train the network to pull elements of the same class together and push elements of different classes apart. We'll use a contrastive loss function to do this. "
]
},
{
"cell_type": "markdown",
"id": "C29eZ1BkfuIr",
"metadata": {
"execution": {}
},
"source": [
"## The contrastive loss function\n",
"\n",
"Our goal is to train the model to put similar examples close together and dissimilar examples far away from each other. We can achieve this by minimizing a contrastive loss function.\n",
"\n",
"\n",
"\n",
"Let's first consider a single anchor image whose embedding is $\\mathbf{z}_a$. We want to compare this against a set of embeddings for different images $\\mathbf{z}_k$, where $k \\in [0 \\ldots K]$. We have a single positive image $\\mathbf{z}_0$ of the same class as the anchor image, while the negative images $\\mathbf{z}_{[1 \\ldots K]}$ are images of other classes. We want to minimize the distance between the anchor image and the positive image while maximizing the distance between the anchor image and the negative images.\n",
"\n",
"A classic way to do this is via the InfoNCE loss function, which is widely used in contrastive learning, for example in OpenAI's CLIP. This loss is defined as:\n",
"\n",
"$$ \\mathcal{L}_a = -\\log \\left( \\frac{\\exp(\\mathbf{z}_a \\cdot \\mathbf{z}_{0} / \\tau)}{\\sum_{k=0}^{K} \\exp(\\mathbf{z}_a \\cdot \\mathbf{z}_k / \\tau)} \\right) $$\n",
"\n",
"Here $\\tau$ is a temperature parameter that controls the sharpness of the distribution. You can think of it as a cross-entropy loss with a single pseudo-class corresponding to similar labels and the negative pairs corresponding to different labels. \n",
"\n",
"### Decoupled constrastive learning\n",
"\n",
"InfoNCE typically requires substantial batch sizes—commonly 128 or larger—to perform optimally. The need for large batch sizes stems from the necessity for diverse negative samples in the batch to effectively learn the contrasts. However, large batch sizes can be impractical in resource-constrained settings or when data availability is limited.\n",
"\n",
"To address this, we will implement a modified version of InfoNCE as described in the [\"Decoupled Contrastive Learning\"](https://link.springer.com/chapter/10.1007/978-3-031-19809-0_38) paper. This variant adapts the loss to be more suitable for smaller batch sizes by modifying the denominator of the InfoNCE formula. Specifically, it removes the positive example from the denominator, which reduces the computational demand and stabilizes training when fewer examples are available. This adjustment not only makes the loss function more flexible, but also maintains robustness in learning discriminative features even with smaller batch sizes.\n",
"\n",
"$$ \\mathcal{L}_a = -\\log \\left( \\frac{\\exp(\\mathbf{z}_a \\cdot \\mathbf{z}_{0} / \\tau)}{\\sum_{k=1}^{K} \\exp(\\mathbf{z}_a \\cdot \\mathbf{z}_k / \\tau)} \\right) $$\n",
"\n",
"### Batching and numerical stability\n",
"\n",
"In practice, we will sample an entire batch of images at a time. We will then compute the average decoupled contrastive loss for every positive pair of images that have the same label. Thus:\n",
"\n",
"$$ \\mathcal{L}_{\\text{batch}} = \\frac{1}{N_{\\text{positive pairs}}} \\sum_{\\text{positive pairs (i, j)}} -\\log \\left( \\frac{\\exp(\\mathbf{z}_i \\cdot \\mathbf{z}_{j} / \\tau)}{\\sum_{k \\in \\text{Negative(i)}} \\exp(\\mathbf{z}_i \\cdot \\mathbf{z}_k / \\tau)} \\right) $$\n",
"\n",
"Here $i$ corresponds to the index of a single image, and $\\text{Negative(i)}$ is the set of indices of all the negative images corresponding to anchor image $i$ (all the images with labels that differ from the anchor image). \n",
"\n",
"To prevent the exponential from overflowing, we'll subtract the maximum value from the dot products before exponentiating. This is a common trick to improve numerical stability."
]
},
{
"cell_type": "markdown",
"id": "60bb1eba-68f6-436b-b019-302c3f27279e",
"metadata": {
"execution": {}
},
"source": [
"## Code exercise 2: The decoupled contrastive learning loss function\n",
"\n",
"Let's first complete the implementation of the Decoupled Contrastive Learning (DCL) loss function to better grasp how it separates positive and negative pairs. The total training time might take around 5 minutes."
]
},
{
"cell_type": "markdown",
"id": "1bfe95da-cd87-4fea-bb22-c1419450f9cb",
"metadata": {
"colab_type": "text",
"execution": {}
},
"source": [
"```python\n",
"def dcl_loss(pos_pairs, neg_pairs, indices_tuple, temperature=0.07):\n",
" ############################################################\n",
" # First question: Think about how you can ensure that non-matching pairs do\n",
" # not contribute to the denominator in the loss calculation. You need to set their\n",
" # values to a large negative number to effectively exclude them from the\n",
" # exponential computation. Second question: Look at the numerator in the formula!\n",
" raise NotImplementedError(\"Student exercise: complete DCL loss function.\")\n",
" ############################################################\n",
" \"\"\"\n",
" Computes the Decoupled Contrastive Learning loss.\n",
"\n",
" Returns:\n",
" torch.Tensor: The computed loss value.\n",
" \"\"\"\n",
" a1, _, a2, _ = indices_tuple # Unpack indices\n",
"\n",
" if len(a1) == 0 or len(a2) == 0:\n",
" return 0\n",
"\n",
" dtype = neg_pairs.dtype\n",
" pos_pairs = pos_pairs.unsqueeze(1) / temperature # Scale positive pairs by temperature\n",
" neg_pairs = neg_pairs / temperature # Scale negative pairs by temperature\n",
" n_per_p = to_dtype(a2.unsqueeze(0) == a1.unsqueeze(1), dtype=dtype) # Indicator matrix for matching pairs\n",
" neg_pairs = neg_pairs * n_per_p # Zero out non-matching pairs\n",
" neg_pairs[n_per_p == 0] = ... # Replace non-matching pairs with negative infinity\n",
"\n",
" # Compute the maximum value for numerical stability\n",
" max_val = torch.max(\n",
" pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]\n",
" ).detach()\n",
" # Compute numerator and denominator for the loss\n",
" numerator = ...\n",
" denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1)\n",
" log_exp = torch.log((numerator / denominator) + small_val(dtype))\n",
" return -log_exp # Return the negative log of the exponential\n",
"\n",
"def pair_based_loss(similarities, indices_tuple, lossfunc):\n",
" \"\"\"\n",
" Computes pair-based loss using the provided loss function.\n",
"\n",
" Args:\n",
" similarities : torch.Tensor\n",
" A tensor of pairwise similarities. For n_examples, the shape should be\n",
" (n_examples, n_examples).\n",
" indices_tuple : tuple\n",
" A tuple of indices for positive and negative pairs. The tuple should\n",
" contain 4 tensors: a1, p, a2, n. The tensors a1 and p contain indices\n",
" for positive pairs, while a2 and n contain indices for negative pairs.\n",
" a1 and p should have the same length, and a2 and n should have the same\n",
" length. a1[i] and p[i] should form a positive pair, such that they have the\n",
" same label. Similarly, a2[i] and n[i] should form a negative pair, such that\n",
" they have different labels.\n",
" lossfunc : function\n",
" The loss function to be applied for computing the loss.\n",
" \"\"\"\n",
" # Computes pair-based loss using the provided loss function\n",
" a1, p, a2, n = indices_tuple # Unpack indices\n",
" pos_pair, neg_pair = [], []\n",
" if len(a1) > 0:\n",
" pos_pair = similarities[a1, p] # Extract positive pairs\n",
" if len(a2) > 0:\n",
" neg_pair = similarities[a2, n] # Extract negative pairs\n",
" return lossfunc(pos_pair, neg_pair, indices_tuple) # Apply loss function\n",
"\n",
"# Number of epochs for training\n",
"epochs = 10\n",
"\n",
"# Automatically select the device (GPU if available, otherwise CPU)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# Output the device that will be used\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Load the MNIST training dataset with the defined transformations\n",
"train_dset = torchvision.datasets.MNIST(\"./\", train=True, transform=mnist_transforms)\n",
"train_loader = DataLoader(train_dset, batch_size=50, shuffle=True) # Enable persistent_workers=True if more than 1 worker to save CPU\n",
"\n",
"# Cleanup: delete the network and free up memory if this block is re-run\n",
"try:\n",
" del mynet\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
"except:\n",
" pass\n",
"\n",
"# Initialize the model with specified input, output, and hidden dimensions\n",
"mynet = Model(in_dim=784, out_dim=128, hidden_dim=256)\n",
"_ = mynet.to(device) # Move the model to the selected device\n",
"\n",
"# Enable training mode, which may affect dropout and other layers\n",
"mynet.train(mode=True)\n",
"print(\"Is the network in training mode?\", mynet.training)\n",
"\n",
"# Initial learning rate and decay factor for the optimizer\n",
"init_lr = 3e-4\n",
"lr_decay_factor = 0.5\n",
"\n",
"# Initialize the optimizer with model parameters and learning rate\n",
"optimizer = torch.optim.AdamW(mynet.parameters(), lr=init_lr, weight_decay=1e-2)\n",
"\n",
"# Tracker to keep track of loss values during training\n",
"loss_tracker = []\n",
"\n",
"# Training loop over the specified number of epochs\n",
"for epoch_id in range(1, epochs+1):\n",
" loss_epoch_tracker = 0\n",
" batch_counter = 0\n",
"\n",
" # Adjust learning rate for the current epoch\n",
" new_lrate = init_lr * (lr_decay_factor ** (epoch_id / epochs))\n",
" for param_group in optimizer.param_groups:\n",
" param_group['lr'] = new_lrate\n",
"\n",
" batches_in_epoch = len(train_loader)\n",
" for data_batch in train_loader:\n",
" optimizer.zero_grad() # Zero out gradients\n",
"\n",
" # Get images and labels from the batch\n",
" train_img, train_label = data_batch\n",
" batch_size = train_img.shape[0]\n",
"\n",
" # Flatten images and move data to the selected device\n",
" flat = train_img.reshape(batch_size, -1).to(device, non_blocking=True)\n",
" train_label = train_label.to(device, non_blocking=True)\n",
"\n",
" # Forward pass through the network\n",
" predicted_results = mynet(flat)\n",
"\n",
" # Compute cosine similarity matrix for the batch\n",
" similarities = cos_sim(predicted_results)\n",
"\n",
" # Get pairs of indices for positive and negative pairs\n",
" label_pos_neg = get_all_pairs_indices(train_label)\n",
"\n",
" # Compute the loss using the decoupled contrastive learning loss function\n",
" final_loss = torch.mean(pair_based_loss(similarities, label_pos_neg, dcl_loss))\n",
"\n",
" # Compute gradients from the loss\n",
" final_loss.backward()\n",
"\n",
" # Update the model parameters using the optimizer\n",
" optimizer.step()\n",
"\n",
" # Convert the loss to a single CPU scalar\n",
" loss_cpu_number = final_loss.item()\n",
"\n",
" # Keep track of the losses for visualization\n",
" loss_epoch_tracker += loss_cpu_number\n",
" batch_counter += 1\n",
"\n",
" # Print the current epoch, batch number, and loss every 500 batches\n",
" if batch_counter % 500 == 0:\n",
" print(\"Epoch {}, Batch {}/{}, loss: {}\".format(epoch_id, batch_counter, batches_in_epoch, loss_cpu_number))\n",
"\n",
" # Print the average loss for the epoch\n",
" print(\"Epoch average loss {}\".format(loss_epoch_tracker / batch_counter))\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67db2b80",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# to_remove solution\n",
"\n",
"def dcl_loss(pos_pairs, neg_pairs, indices_tuple, temperature=0.07):\n",
" \"\"\"\n",
" Computes the Decoupled Contrastive Learning loss.\n",
"\n",
" Returns:\n",
" torch.Tensor: The computed loss value.\n",
" \"\"\"\n",
" a1, _, a2, _ = indices_tuple # Unpack indices\n",
"\n",
" if len(a1) == 0 or len(a2) == 0:\n",
" return 0\n",
"\n",
" dtype = neg_pairs.dtype\n",
" pos_pairs = pos_pairs.unsqueeze(1) / temperature # Scale positive pairs by temperature\n",
" neg_pairs = neg_pairs / temperature # Scale negative pairs by temperature\n",
" n_per_p = to_dtype(a2.unsqueeze(0) == a1.unsqueeze(1), dtype=dtype) # Indicator matrix for matching pairs\n",
" neg_pairs = neg_pairs * n_per_p # Zero out non-matching pairs\n",
" neg_pairs[n_per_p == 0] = neg_inf(dtype) # Replace non-matching pairs with negative infinity\n",
"\n",
" # Compute the maximum value for numerical stability\n",
" max_val = torch.max(\n",
" pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]\n",
" ).detach()\n",
" # Compute numerator and denominator for the loss\n",
" numerator = torch.exp(pos_pairs - max_val).squeeze(1)\n",
" denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1)\n",
" log_exp = torch.log((numerator / denominator) + small_val(dtype))\n",
" return -log_exp # Return the negative log of the exponential\n",
"\n",
"def pair_based_loss(similarities, indices_tuple, lossfunc):\n",
" \"\"\"\n",
" Computes pair-based loss using the provided loss function.\n",
"\n",
" Args:\n",
" similarities : torch.Tensor\n",
" A tensor of pairwise similarities. For n_examples, the shape should be\n",
" (n_examples, n_examples).\n",
" indices_tuple : tuple\n",
" A tuple of indices for positive and negative pairs. The tuple should\n",
" contain 4 tensors: a1, p, a2, n. The tensors a1 and p contain indices\n",
" for positive pairs, while a2 and n contain indices for negative pairs.\n",
" a1 and p should have the same length, and a2 and n should have the same\n",
" length. a1[i] and p[i] should form a positive pair, such that they have the\n",
" same label. Similarly, a2[i] and n[i] should form a negative pair, such that\n",
" they have different labels.\n",
" lossfunc : function\n",
" The loss function to be applied for computing the loss.\n",
" \"\"\"\n",
" # Computes pair-based loss using the provided loss function\n",
" a1, p, a2, n = indices_tuple # Unpack indices\n",
" pos_pair, neg_pair = [], []\n",
" if len(a1) > 0:\n",
" pos_pair = similarities[a1, p] # Extract positive pairs\n",
" if len(a2) > 0:\n",
" neg_pair = similarities[a2, n] # Extract negative pairs\n",
" return lossfunc(pos_pair, neg_pair, indices_tuple) # Apply loss function\n",
"\n",
"# Number of epochs for training\n",
"epochs = 10\n",
"\n",
"# Automatically select the device (GPU if available, otherwise CPU)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# Output the device that will be used\n",
"print(f\"Using device: {device}\")\n",
"\n",
"# Load the MNIST training dataset with the defined transformations\n",
"train_dset = torchvision.datasets.MNIST(\"./\", train=True, transform=mnist_transforms)\n",
"train_loader = DataLoader(train_dset, batch_size=50, shuffle=True) # Enable persistent_workers=True if more than 1 worker to save CPU\n",
"\n",
"# Cleanup: delete the network and free up memory if this block is re-run\n",
"try:\n",
" del mynet\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
"except:\n",
" pass\n",
"\n",
"# Initialize the model with specified input, output, and hidden dimensions\n",
"mynet = Model(in_dim=784, out_dim=128, hidden_dim=256)\n",
"_ = mynet.to(device) # Move the model to the selected device\n",
"\n",
"# Enable training mode, which may affect dropout and other layers\n",
"mynet.train(mode=True)\n",
"print(\"Is the network in training mode?\", mynet.training)\n",
"\n",
"# Initial learning rate and decay factor for the optimizer\n",
"init_lr = 3e-4\n",
"lr_decay_factor = 0.5\n",
"\n",
"# Initialize the optimizer with model parameters and learning rate\n",
"optimizer = torch.optim.AdamW(mynet.parameters(), lr=init_lr, weight_decay=1e-2)\n",
"\n",
"# Tracker to keep track of loss values during training\n",
"loss_tracker = []\n",
"\n",
"# Training loop over the specified number of epochs\n",
"for epoch_id in range(1, epochs+1):\n",
" loss_epoch_tracker = 0\n",
" batch_counter = 0\n",
"\n",
" # Adjust learning rate for the current epoch\n",
" new_lrate = init_lr * (lr_decay_factor ** (epoch_id / epochs))\n",
" for param_group in optimizer.param_groups:\n",
" param_group['lr'] = new_lrate\n",
"\n",
" batches_in_epoch = len(train_loader)\n",
" for data_batch in train_loader:\n",
" optimizer.zero_grad() # Zero out gradients\n",
"\n",
" # Get images and labels from the batch\n",
" train_img, train_label = data_batch\n",
" batch_size = train_img.shape[0]\n",
"\n",
" # Flatten images and move data to the selected device\n",
" flat = train_img.reshape(batch_size, -1).to(device, non_blocking=True)\n",
" train_label = train_label.to(device, non_blocking=True)\n",
"\n",
" # Forward pass through the network\n",
" predicted_results = mynet(flat)\n",
"\n",
" # Compute cosine similarity matrix for the batch\n",
" similarities = cos_sim(predicted_results)\n",
"\n",
" # Get pairs of indices for positive and negative pairs\n",
" label_pos_neg = get_all_pairs_indices(train_label)\n",
"\n",
" # Compute the loss using the decoupled contrastive learning loss function\n",
" final_loss = torch.mean(pair_based_loss(similarities, label_pos_neg, dcl_loss))\n",
"\n",
" # Compute gradients from the loss\n",
" final_loss.backward()\n",
"\n",
" # Update the model parameters using the optimizer\n",
" optimizer.step()\n",
"\n",
" # Convert the loss to a single CPU scalar\n",
" loss_cpu_number = final_loss.item()\n",
"\n",
" # Keep track of the losses for visualization\n",
" loss_epoch_tracker += loss_cpu_number\n",
" batch_counter += 1\n",
"\n",
" # Print the current epoch, batch number, and loss every 500 batches\n",
" if batch_counter % 500 == 0:\n",
" print(\"Epoch {}, Batch {}/{}, loss: {}\".format(epoch_id, batch_counter, batches_in_epoch, loss_cpu_number))\n",
"\n",
" # Print the average loss for the epoch\n",
" print(\"Epoch average loss {}\".format(loss_epoch_tracker / batch_counter))"
]
},
{
"cell_type": "markdown",
"id": "uvQ4MdHKjY1t",
"metadata": {
"execution": {}
},
"source": [
"Great, we have a trained network! Let's collect the embeddings from the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d4f11e3-07e8-4104-9ae4-6a9fd6f037eb",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"def get_embeddings_labels(loader, model):\n",
" \"\"\"\n",
" Function to extract embeddings and labels from a given data loader and model\n",
" Args:\n",
" loader (DataLoader): DataLoader object containing the dataset\n",
" model (nn.Module): Model object to extract embeddings\n",
"\n",
" Returns:\n",
" embeddings (np.array): NumPy array of embeddings\n",
" labels (np.array): NumPy array of labels\n",
" \"\"\"\n",
" # Initialize lists to store embeddings and labels\n",
" embeddings = []\n",
" labels = []\n",
"\n",
" # Set the model to evaluation\n",
" model.eval()\n",
"\n",
" # Disable gradient computation for inference\n",
" with torch.inference_mode():\n",
" for data_batch in loader:\n",
" # Get images and labels from the batch\n",
" img, label = data_batch\n",
" batch_size = img.shape[0]\n",
"\n",
" # Flatten images and move data to the selected device\n",
" flat = img.reshape(batch_size, -1).to(device, non_blocking=True)\n",
"\n",
" # Forward pass through the network\n",
" pred_results = model(flat).cpu().numpy().tolist()\n",
"\n",
" # Store the embeddings and labels\n",
" embeddings.extend(pred_results)\n",
" labels.extend(label.numpy().tolist())\n",
" return np.array(embeddings), np.array(labels)\n",
"\n",
"# DataLoader for the test dataset with a batch size of 50\n",
"test_loader = DataLoader(test_dset, batch_size=50, shuffle=False) # Enable persistent_workers=True if more than 1 worker to save CPU\n",
"test_embeddings, test_labels = get_embeddings_labels(test_loader, mynet)\n",
"\n",
"train_loader = DataLoader(train_dset, batch_size=50, shuffle=False) # Enable persistent_workers=True if more than 1 worker to save CPU\n",
"train_embeddings, train_labels = get_embeddings_labels(train_loader, mynet)\n",
"\n",
"# Indicate that feature extraction is complete\n",
"print(\"Feature extraction done!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de3b7b2b-4351-4df8-8877-8f81b1d3878b",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Code_Exercise_2\")"
]
},
{
"cell_type": "markdown",
"id": "d0KXB1wpmMp1",
"metadata": {
"execution": {}
},
"source": [
"## Visualizing the cosine similarity after training\n",
"\n",
"Let's start by double-checking that the network has learned to distinguish between different classes. We'll measure the cosine similarity between embeddings of images within and across classes. We'll visualize the cosine similarity matrix to see if the network has learned to distinguish between different classes."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e88a8b40-faa7-4071-b070-0ca192e06d55",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Create DataLoader for the test dataset with a batch size of 50\n",
"test_loader = DataLoader(test_dset, batch_size=50, shuffle=False) # Enable persistent_workers=True if more than 1 worker to save CPU\n",
"\n",
"# Set the model to evaluation mode\n",
"mynet.eval()\n",
"\n",
"# Initialize lists to store test embeddings and labels\n",
"test_embeddings = []\n",
"test_labels = []\n",
"\n",
"# Initialize a similarity matrix of size 10x10 for 10 classes\n",
"sim_matrix = np.zeros((10, 10))\n",
"\n",
"# Disable gradient computation for inference\n",
"with torch.inference_mode():\n",
" for data_batch in test_loader:\n",
" # Get images and labels from the batch\n",
" test_img, test_label = data_batch\n",
" batch_size = test_img.shape[0] # Get the batch size\n",
"\n",
" # Flatten images and move data to the selected device\n",
" flat = test_img.reshape(batch_size, -1).to(device, non_blocking=True)\n",
"\n",
" # Get embeddings from the model and move to CPU\n",
" pred_embeddings = mynet(flat).cpu().numpy().tolist()\n",
"\n",
" # Store the embeddings and labels\n",
" test_embeddings.extend(pred_embeddings)\n",
" test_labels.extend(test_label.numpy().tolist())\n",
"\n",
"# Convert embeddings and labels to numpy arrays for further processing\n",
"test_embeddings = np.array(test_embeddings)\n",
"\n",
"# Normalize the embeddings to unit length by dividing each embedding by its L2 norm\n",
"test_embeddings_normed = test_embeddings / np.linalg.norm(test_embeddings, axis=1, keepdims=True)\n",
"\n",
"# Convert test labels to a numpy array\n",
"test_labels = np.array(test_labels)\n",
"\n",
"# Dictionary to store normalized embeddings for each class\n",
"embeddings = {}\n",
"for i in range(10):\n",
" embeddings[i] = test_embeddings_normed[test_labels == i]\n",
"\n",
"# Calculate within-class cosine similarity\n",
"for i in range(10):\n",
" # Compute cosine similarity matrix within the class\n",
" sims = embeddings[i] @ embeddings[i].T\n",
"\n",
" # Ignore diagonal values (self-similarity)\n",
" np.fill_diagonal(sims, np.nan)\n",
"\n",
" # Calculate the mean similarity excluding diagonal\n",
" cur_sim = np.nanmean(sims)\n",
"\n",
" # Store the within-class similarity in the matrix\n",
" sim_matrix[i, i] = cur_sim\n",
"\n",
"# Calculate between-class cosine similarity\n",
"for i in range(10):\n",
" for j in range(10):\n",
" if i == j:\n",
" pass # Skip if same class (already computed)\n",
" elif i > j:\n",
" pass # Skip if already computed (matrix symmetry)\n",
" else:\n",
" # Compute cosine similarity between different classes\n",
" sims = embeddings[i] @ embeddings[j].T\n",
"\n",
" # Calculate the mean similarity\n",
" cur_sim = np.mean(sims)\n",
"\n",
" # Store the similarity in the matrix\n",
" sim_matrix[i, j] = cur_sim\n",
" sim_matrix[j, i] = cur_sim # Ensure symmetry in the matrix\n",
"\n",
"\n",
"# Plot the similarity matrix using matplotlib\n",
"plt.figure(figsize=(8, 6))\n",
"sns.heatmap(sim_matrix, vmin=0.0, vmax=1.0, annot=True, fmt=\".2f\", cmap=\"YlGnBu\", linewidths=0.5)\n",
"plt.title(\"Trained Network Cosine Similarity Matrix\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "d75c3ef3",
"metadata": {
"execution": {}
},
"source": [
"We see that the network has rapidly learned to distinguish between different classes and cluster similar examples together, despite not being trained to classify images directly."
]
},
{
"cell_type": "markdown",
"id": "56affa17",
"metadata": {
"execution": {}
},
"source": [
"## Visualizing the geometry of the embeddings before and after training\n",
"\n",
"Let's use t-SNE to visualize the geometry of the embeddings before and after training. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "767e0f93-3f68-4db0-9edf-4891a3858f63",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Convert list of embeddings to a numpy array\n",
"test_embeddings_untrained = np.array(test_embeddings_untrained)\n",
"\n",
"# Initialize t-SNE with 2 components for dimensionality reduction\n",
"tsne = TSNE(n_components=2)\n",
"\n",
"# Notify that the t-SNE transformation may take some time\n",
"print(\"t-SNE transformation in progress... This may take a minute\")\n",
"\n",
"# Fit t-SNE on the normalized embeddings and transform them to 2D\n",
"tsne_embeddings_untrained = tsne.fit_transform(test_embeddings_untrained)\n",
"\n",
"# Optional: Print the shape of the resulting t-SNE embeddings to verify\n",
"print(\"t-SNE embeddings shape:\", test_embeddings_untrained.shape)"
]
},
{
"cell_type": "markdown",
"id": "90670307",
"metadata": {
"execution": {}
},
"source": [
"Similarly, for the trained network."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4cdd26a-15c9-4c5f-82f7-4a65ca0bd5f7",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Convert list of embeddings to a numpy array\n",
"test_embeddings = np.array(test_embeddings)\n",
"\n",
"# Initialize t-SNE with 2 components for dimensionality reduction\n",
"tsne = TSNE(n_components=2)\n",
"\n",
"# Notify that the t-SNE transformation may take some time\n",
"print(\"t-SNE transformation in progress... This may take a minute\")\n",
"\n",
"# Fit t-SNE on the normalized embeddings and transform them to 2D\n",
"tsne_embeddings = tsne.fit_transform(test_embeddings)\n",
"\n",
"# Optional: Print the shape of the resulting t-SNE embeddings to verify\n",
"print(\"t-SNE embeddings shape:\", tsne_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"id": "eqJJWmx7j1kF",
"metadata": {
"execution": {}
},
"source": [
"Now plot the distribution of features before and after training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2afe3c0d-7eed-4b2b-bdaa-bb258fc92d10",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Use t-SNE embeddings for visualization\n",
"plt.figure(figsize=(8, 4.5))\n",
"plt.subplot(121)\n",
"for num in range(10):\n",
" plt.scatter(tsne_embeddings_untrained[test_labels_untrained==num, 0],\n",
" tsne_embeddings_untrained[test_labels_untrained==num, 1])\n",
"\n",
"plt.xlabel('t-sne dim 1')\n",
"plt.ylabel('t-sne dim 2')\n",
"plt.legend([f\"Digit {i}\" for i in range(10)])\n",
"plt.title('Before training')\n",
"\n",
"plt.subplot(122)\n",
"for num in range(10):\n",
" plt.scatter(tsne_embeddings[test_labels==num, 0], tsne_embeddings[test_labels==num, 1])\n",
"\n",
"plt.xlabel('t-sne dim 1')\n",
"plt.ylabel('t-sne dim 2')\n",
"plt.legend([f\"Digit {i}\" for i in range(10)])\n",
"plt.title('After training')"
]
},
{
"cell_type": "markdown",
"id": "cb3ac6b2",
"metadata": {
"execution": {}
},
"source": [
"Notice how training has pulled examples with similar labels together and pushed examples with different labels apart."
]
},
{
"cell_type": "markdown",
"id": "a8wRjpsrmva7",
"metadata": {
"execution": {}
},
"source": [
"## Using the network to identify nearest neighbours from the train set\n",
"\n",
"How can we actually use a contrastive learning network for a downstream task like classification? Some options include:\n",
"\n",
"* **Fine-tuning**: Add a linear classification layer on top of the embedding. Fine-tune the model end-to-end for the downstream supervised task.\n",
"* **Nearest neighbour search**: Given a test anchor image, find the most similar image in the train set. Apply the same label to the test image as the most similar image in the train set.\n",
"\n",
"Here, we will use a nearest neighbour search to find the most similar image in the train set."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "013D0NbXowjj",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# Calculate the cosine similarity matrix between all the test images and the train images\n",
"sims_all = test_embeddings @ train_embeddings.T\n",
"\n",
"# Index of the embedding to check for the most similar embedding\n",
"idx_to_check = 4\n",
"\n",
"# Find the index of the most similar embedding (excluding itself)\n",
"best_idx = np.argmax(sims_all[idx_to_check])\n",
"\n",
"# Plot the image corresponding to the index to check\n",
"plt.figure(figsize=(8, 6))\n",
"plt.subplot(121)\n",
"plt.imshow(test_dset[idx_to_check][0][0].cpu().numpy())\n",
"plt.title('Query image')\n",
"\n",
"# Plot the image corresponding to the most similar embedding\n",
"plt.subplot(122)\n",
"plt.imshow(train_dset[best_idx][0][0].cpu().numpy())\n",
"plt.title('Nearest neighbor from train set')"
]
},
{
"cell_type": "markdown",
"id": "e65067ac",
"metadata": {
"execution": {}
},
"source": [
"In this one case, this nearest neighbor scheme works well. Let's measure the overall accuracy over the entire test set."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ae6f966",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"best_idxs = np.argmax(sims_all, axis=1)\n",
"corresponding_labels = train_labels[best_idxs]\n",
"accuracy = np.mean(corresponding_labels == test_labels)\n",
"print(f\"Mean accuracy: {accuracy}\")"
]
},
{
"cell_type": "markdown",
"id": "3fdbc5c9",
"metadata": {
"execution": {}
},
"source": [
"Not too bad for a simple nearest neighbour search! Contrastive learning has allowed us to learn a useful embedding space for recognizing digits. \n",
"\n",
"Keep in mind, however, that a nearest neighbour search can be impractical in real applications, as we'd need to keep the embeddings of all the train examples in memory. It's often more effective to fine-tune the network on the downstream task. This allows the network to learn task-specific features that may not be captured by the contrastive learning objective, and it can be more computationally efficient."
]
},
{
"cell_type": "markdown",
"id": "8444cc4a-fd26-471c-ac34-7e515dbe946a",
"metadata": {
"execution": {}
},
"source": [
"## How is contrastive learning used in practice?\n",
"\n",
"Nearly all vision foundation models, such as DINO, DINOv2, CLIP, and their derivatives (including OpenCLIP and EVA-CLIP), are trained using contrastive losses. DINO and DINOv2 are trained solely on images, while CLIP is trained on a combination of images and text.\n",
"\n",
"When only images are used, the contrastive learning loss is applied to augmentations of the same image. These augmentations can include crops, flips, and rotations, and this approach is referred to as a \"pretext task.\" Typically, augmentations of the same image are treated as instances where the embeddings should be the same. For example, a network should recognize a photo of you and a photo of you flipped, with altered brightness, noise added, or converted to black and white, as representing the same person.\n",
"\n",
"\n",
"\n",
"*Photos by JimboMack66, CC-BY 2.0*\n",
"\n",
"When images and text are used together, as in CLIP, the training data consists of images and their corresponding captions. For example, the caption \"A photo of a dog\" might be paired with a picture of a blue heeler puppy. These captions are typically scraped from online sources and collected into datasets like LAION-2B, COYO-700M, and CommonCrawl. Although these captions are often of varying quality, the sheer volume of data helps to mitigate this issue.\n",
"\n",
"Multimodal contrastive learning typically employs a dual encoder system: one for text and one for images. The network is trained using a loss function that minimizes the distance between the correct text-image pairs while maximizing the distance between incorrect pairs. For example, the caption \"A photo of a dog\" should have embeddings close to the image of the blue heeler puppy and far from the image of a cat. To compute the \"distance\" of the embeddings, methods such as normalized dot-product (cosine similarity), angular distance (Universal Sentence Encoder), Euclidean distance, or squared Euclidean distance are often used."
]
},
{
"cell_type": "markdown",
"id": "f5c89f34-efb0-434f-915f-807886dcfc6e",
"metadata": {
"execution": {}
},
"source": [
"## Discussion point\n",
"\n",
"We have argued that judging similarity is an inherent capability of both human brains and artificial intelligence systems. We've covered many ways to implement contrastive learning in AI; can you speculate if and how contrastive learning might be implemented in the human brain? "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa73993e-bbaa-4d32-a7c1-3903c784883f",
"metadata": {
"execution": {}
},
"outputs": [],
"source": [
"# to_remove explanation\n",
"\"\"\"\n",
"Our brain's ability to perform contrastive learning is linked to the function of the ventral visual stream.\n",
"This area of the brain processes visual information and has been shown to develop hierarchical features that\n",
"capture the structure of visual input through self-supervised learning mechanisms. Evidence suggests that anterior\n",
"regions of the ventral visual stream, particularly the anterior occipito-temporal cortex (aOTC), encode substantial\n",
"information about object categories without requiring explicit category-level supervision (Konkle and Alvarez, 2022).\n",
"Instead, these representations emerge through domain-general learning from natural image structures, where the brain\n",
"differentiates between individual views and categories based on the inherent statistical properties of visual input\n",
"(Livingstone et al., 2019; Arcaro and Livingstone, 2021). This capability supports the notion that the brain's visual\n",
"system can form complex object representations and categorization using self-supervised learning frameworks similar to\n",
"those in artificial neural networks.\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af3f6a4f-2185-41db-a451-09ade489790a",
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Discussion_Point_2\")"
]
},
{
"cell_type": "markdown",
"id": "87ae5fa0-dc3c-4cd3-961d-b0e3ae7ffd9d",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Summary"
]
},
{
"cell_type": "markdown",
"id": "8bf4a13b-43f6-42a2-bbf7-4f78d323ea70",
"metadata": {
"execution": {}
},
"source": [
"In this tutorial, we've covered contrastive learning, a self-supervised technique that works well in situations where the number of classes is large or undefined. This method teaches a model to recognize similarity, not by traditional classification, but by learning to distinguish between 'similar' and 'dissimilar' directly through embeddings. We discussed the significance of generating embeddings that bring data points of the same class closer together while pushing different classes apart, which is particularly valuable in complex recognition tasks like identifying faces among billions of possibilities.\n",
"\n",
"Through practical exercises with the MNIST dataset, we've seen how contrastive learning can be implemented. The session highlighted the intuitive appeal of contrastive learning: learning by comparison, which is a natural way for both humans and machines to understand the world."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"gpuType": "T4",
"include_colab_link": true,
"name": "W1D2_Tutorial2",
"provenance": [],
"toc_visible": true
},
"kernel": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}