{ "cells": [ { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 3: Attention\n", "\n", "**Week 1, Day 5: Microcircuits**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Saaketh Medepalli, Aditya Singh, Saeed Salehi, Xaq Pitkow\n", "\n", "__Content reviewers:__ Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk\n", "\n", "__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 1 hour*\n", "\n", "\n", "By the end of this tutorial, we aim to:\n", "\n", "1. Learn how the brain and AI systems implemention attention\n", "\n", "2. Understand how multiplicative interactions allow flexible gating of information\n", "\n", "3. Demonstrate the inductive bias of the self-attention mechanism towards learning functions of sparse subsets of variables\n", "\n", "A key microarchitectural operation in brains and machines is **attention**. The essence of this operation is multiplication. Unlike the vanilla neural network operation -- a nonlinear function of a weighted sum of inputs $f(\\sum_j w_{j}x_j)$ -- the attention operation allows responses to *multiply* inputs. This enable computations like modulating weights by other inputs.\n", "\n", " \n", "\n", "In brains, the theory of the Spotlight of Attention (Posner et al 1980) posited that gain modulation allowed brain computations to select information for later computation. In machines, Rumelhart, Hinton, and McClelland (1986) described Sigma-Pi networks that included both sums ($\\Sigma$) and products ($\\Pi$) as fundamental operations. The Transformer network (Vaswani et al 2017) used a specific architecture featuring layers of multiplication, sparsification, and normalization. Many machine learning systems since then have fruitfully applied this architecture to language, vision, and many more modalities. In this tutorial we will isolate the central properties and generalization benefits of multiplicative attention shared by all of these applications.\n", "\n", " \n", "\n", "Exercises include simple attentional modulation of inputs, coding the self-attention mechanism, demonstrating its inductive bias, and interpreting the consequences of attention.\n", "\n", " \n", "\n", "References:\n", "- Posner MI, Snyder CR, Davidson BJ (1980). Attention and the detection of signals. *Journal of experimental psychology: General*. 109(2):160.\n", "\n", "- Rumelhart DE, Hinton G, McClelland JL, PDP Research Group (1986). *Parallel Distributed Processing*. MIT press.\n", "\n", "- Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I (2017). [Attention is all you need.](https://papers.nips.cc/paper_files/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) *Advances in Neural Information Processing Systems*.\n", "\n", "- Thomas Viehmann, [toroidal - a lightweight transformer library for PyTorch](https://github.com/MathInf/toroidal)\n", "\n", "- Edelman B, Goel S, Kakade S, Zhang C (2022). [Inductive biases and variable creation in self-attention mechanisms](https://proceedings.mlr.press/v162/edelman22a/edelman22a.pdf). *PMLR*.\n", "\n", "- Deep Graph Library Tutorials, [Transformer as a Graph Neural Network](https://docs.dgl.ai/en/0.8.x/tutorials/models/4_old_wines/7_transformer.html)\n", "\n", "- Lilian Weng, [Attention? Attention!](https://lilianweng.github.io/posts/2018-06-24-attention/)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/eckvr/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/eckvr/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip install vibecheck datatops --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_neuroai\",\n", " \"user_key\": \"wb2cxze8\",\n", " },\n", " ).render()\n", "\n", "feedback_prefix = \"W1D5_T3\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 9705, "status": "ok", "timestamp": 1718046716707, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Imports\n", "import os\n", "import sys\n", "import math\n", "import torch\n", "import random\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 452, "status": "ok", "timestamp": 1718046717157, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure settings\n", "import logging\n", "import matplotlib.cm as cm\n", "import ipywidgets as widgets # interactive display\n", "from ipywidgets import interactive, FloatSlider, Layout\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "%config InlineBackend.figure_format = 'retina'\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle\")\n", "fig_w, fig_h = plt.rcParams['figure.figsize']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1718046717157, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plotting functions\n", "def plot_loss_accuracy(t_loss, t_acc, v_loss = None, v_acc = None):\n", " with plt.xkcd():\n", " plt.figure(figsize=(15, 4))\n", " plt.suptitle(\"Training and Validation for the Transformer Model\")\n", " plt.subplot(1, 2, 1)\n", " plt.plot(t_loss, label=\"Training loss\", color=\"red\")\n", " if v_loss is not None:\n", " # plt.plot(v_loss, label=\"Valididation loss\", color=\"blue\")\n", " plt.scatter(len(t_loss)-1, v_loss, label=\"Validation loss\", color=\"blue\", marker=\"*\")\n", " # plt.text(len(t_loss)-1, v_loss, f\"{v_loss:.3f}\", va=\"bottom\", ha=\"right\")\n", " plt.yscale(\"log\")\n", " plt.xlabel(\"Epoch\")\n", " plt.ylabel(\"Loss\")\n", " plt.xticks([])\n", " plt.legend(loc=\"lower right\")\n", " plt.subplot(1, 2, 2)\n", " plt.plot(t_acc, label=\"Training accuracy\", color=\"red\", linestyle=\"dotted\")\n", " if v_acc is not None:\n", " # plt.plot(v_acc, label=\"Validation accuracy\", color=\"blue\", linestyle=\"--\")\n", " plt.scatter(len(t_acc)-1, v_acc, label=\"Validation accuracy\", color=\"blue\", marker=\"*\")\n", " # plt.text(len(t_acc)-1, v_acc, f\"{v_acc:.3f}\", va=\"bottom\", ha=\"right\")\n", " plt.xticks([])\n", " plt.ylim(0, 1)\n", " plt.xlabel(\"Epoch\")\n", " plt.ylabel(\"Accuracy\")\n", " plt.legend(loc=\"lower right\")\n", " plt.show()\n", "\n", "\n", "def plot_samples(X_plot, y_plot, correct_ids, title=None):\n", " with plt.xkcd():\n", " n_samples, seq_length = X_plot.shape\n", " fig, axs = plt.subplots(1, 2, figsize=(16, 2.5), sharey=True)\n", " rects = []\n", " for ri in correct_ids:\n", " rects.append(plt.Rectangle((ri-0.5, -0.5), 1, n_samples, edgecolor=\"red\", alpha=1.0, fill=False, linewidth=2))\n", " axs[0].imshow(X_plot, cmap=\"binary\")\n", " for rect in rects:\n", " axs[0].add_patch(rect)\n", " # axs[0].axis(\"off\")\n", " axs[0].set_yticks([])\n", " axs[0].set_xticks([])\n", " axs[0].set_ylabel(\"Samples\")\n", " axs[0].set_xlabel(\"Context\")\n", " if title is not None:\n", " axs[0].set_title(title)\n", " axs[1].imshow(y_plot, cmap=\"binary\")\n", " axs[1].add_patch(plt.Rectangle((-0.5, -0.5), 1, n_samples, edgecolor=\"black\", alpha=1.0, fill=False, linewidth=2))\n", " axs[1].yaxis.set_label_position(\"right\")\n", " axs[1].set_ylabel(\"Labels\")\n", " axs[1].set_yticks([])\n", " axs[1].set_xticks([])\n", " plt.subplots_adjust(wspace=-1)\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "\n", "def plot_attention_weights(att_weights, correct_ids, context_length, labels):\n", " with plt.xkcd():\n", " from matplotlib.lines import Line2D\n", " B = att_weights.size(0)\n", " aw_flatten = att_weights[:, -1, :-1].view(-1, context_length)\n", " x_axis = torch.arange(context_length).repeat(B, 1)\n", " y_axis = labels.view(-1, 1).repeat(1, context_length)\n", " aw_ravel = aw_flatten.ravel()\n", " x_ravel = x_axis.ravel()\n", " y_ravel = y_axis.ravel()\n", "\n", " fig, ax = plt.subplots(figsize=(9, 5))\n", " colors = [\"#1E88E5\", \"#FFC107\"]\n", " labels_legened = [\"True\", \"False\"]\n", " ax.scatter(x_ravel, aw_ravel, alpha=0.5, c=[colors[int(y)] for y in y_ravel])\n", " rects = []\n", " for ri in correct_ids:\n", " # rects.append(plt.Rectangle((ri-0.5, 1e-6), 1.0, 2.0, edgecolor=\"blue\", alpha=1.0, fill=False, linewidth=2))\n", " rects.append(plt.Rectangle((ri-0.5, -0.1), 1.0, 0.8, edgecolor=\"red\", alpha=1.0, fill=False, linewidth=2))\n", " for rect in rects:\n", " ax.add_patch(rect)\n", " # plt.yscale(\"log\")\n", " # plt.ylim(1e-6, 2)\n", " plt.ylim(-0.1, 0.7)\n", " plt.title(\"Attention weights for the whole batch\")\n", " plt.xlabel(\"Boolean input index t\")\n", " plt.ylabel(\"Attention weight\")\n", " legend_elements = [Line2D([0], [0], linestyle='None', marker='o', color='#1E88E5', label='True', markerfacecolor='#1E88E5', markersize=7),\n", " Line2D([0], [0], linestyle='None', marker='o', color='#FFC107', label='False', markerfacecolor='#FFC107', markersize=7)]\n", " ax.legend(handles=legend_elements, loc='upper right')\n", " plt.show()\n", "\n", "\n", "def plot_attention_weights_stats(att_weights, correct_ids, context_length, labels):\n", " with plt.xkcd():\n", " aw_flatten = att_weights[:, -1, :-1].view(-1, context_length)\n", " aw_flatten_mean = aw_flatten.mean(dim=0)\n", " aw_flatten_std = aw_flatten.std(dim=0)\n", " fig, ax = plt.subplots(figsize=(9, 5))\n", " ax.errorbar(torch.arange(context_length), aw_flatten_mean, yerr=aw_flatten_std, fmt='o', color='blue')\n", " rects = []\n", " for ri in correct_ids:\n", " rects.append(plt.Rectangle((ri-0.5, -0.1), 1.0, 0.8, edgecolor=\"red\", alpha=1.0, fill=False, linewidth=2))\n", " for rect in rects:\n", " ax.add_patch(rect)\n", " plt.title(\"Attention weights statistics\")\n", " plt.xlabel(\"Boolean input index t\")\n", " plt.ylabel(\"Mean attention weight\")\n", " plt.ylim(-0.1, 0.7)\n", " plt.show()\n", "\n", "def plot_compare(results_sat_d, results_sat_s, results_mlp_d, results_mlp_s,\n", " B_t_sat_s, B_t_sat_d, B_t_mlp_s, B_t_mlp_d):\n", " with plt.xkcd():\n", " from matplotlib.colors import LinearSegmentedColormap\n", " cmap=LinearSegmentedColormap.from_list('rg',[\"w\", \"w\", \"r\", \"y\", \"g\"], N=256)\n", "\n", " t_loss_sat_d, t_acc_sat_d, v_loss_sat_d, v_acc_sat_d, model_np_sat_d = results_sat_d\n", " t_loss_sat_s, t_acc_sat_s, v_loss_sat_s, v_acc_sat_s, model_np_sat_s = results_sat_s\n", " t_loss_mlp_d, t_acc_mlp_d, v_loss_mlp_d, v_acc_mlp_d, model_np_mlp_d = results_mlp_d\n", " t_loss_mlp_s, t_acc_mlp_s, v_loss_mlp_s, v_acc_mlp_s, model_np_mlp_s = results_mlp_s\n", "\n", " plt.figure(figsize=(8, 6))\n", " plt.subplot(2, 2, 1)\n", " plt.title(\"sparse\", fontsize=16)\n", " plt.ylabel(\"Attention\", fontsize=16)\n", " plt.barh(3, 0, color=\"blue\")\n", " plt.barh(4, 0, color=\"green\")\n", " plt.text(0.05, 3.5, f\"# Parameters: {model_np_sat_s}\")\n", " plt.text(0.05, 3.0, f\"# samples: {B_t_sat_s}\")\n", " plt.barh(2, t_acc_sat_s, color=cmap(t_acc_sat_s))\n", " plt.barh(1, v_acc_sat_s, color=cmap(v_acc_sat_s))\n", " plt.text(0.05, 2, f\"# training acc: {t_acc_sat_s:.0%}\")\n", " plt.text(0.05, 1, f\"# validation acc: {v_acc_sat_s:.0%}\")\n", " plt.xticks([])\n", " plt.yticks([])\n", "\n", " plt.subplot(2, 2, 2)\n", " plt.title(\"dense\", fontsize=16)\n", " plt.barh(3, 0, color=\"blue\")\n", " plt.barh(4, 0, color=\"green\")\n", " plt.text(0.05, 3.5, f\"# Parameters: {model_np_sat_d}\")\n", " plt.text(0.05, 3.0, f\"# samples: {B_t_sat_d}\")\n", " plt.barh(2, t_acc_sat_d, color=cmap(t_acc_sat_d))\n", " plt.barh(1, v_acc_sat_d, color=cmap(v_acc_sat_d))\n", " plt.text(0.05, 2, f\"# training acc: {t_acc_sat_d:.0%}\")\n", " plt.text(0.05, 1, f\"# validation acc: {v_acc_sat_d:.0%}\")\n", " plt.xticks([])\n", " plt.yticks([])\n", "\n", " plt.subplot(2, 2, 3)\n", " plt.barh(3, 0, color=\"blue\")\n", " plt.barh(4, 0, color=\"green\")\n", " plt.text(0.05, 3.5, f\"# Parameters: {model_np_mlp_s}\")\n", " plt.text(0.05, 3.0, f\"# samples: {B_t_mlp_s}\")\n", " plt.barh(2, t_acc_mlp_s, color=cmap(t_acc_mlp_s))\n", " plt.barh(1, v_acc_mlp_s, color=cmap(v_acc_mlp_s))\n", " plt.text(0.05, 2, f\"# training acc: {t_acc_mlp_s:.0%}\")\n", " plt.text(0.05, 1, f\"# validation acc: {v_acc_mlp_s:.0%}\")\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.ylabel(\"MLP\", fontsize=16)\n", "\n", " plt.subplot(2, 2, 4)\n", " plt.barh(3, 0, color=\"blue\")\n", " plt.barh(4, 0, color=\"green\")\n", " plt.text(0.05, 3.5, f\"# Parameters: {model_np_mlp_d}\")\n", " plt.text(0.05, 3.0, f\"# samples: {B_t_mlp_d}\")\n", " plt.barh(2, t_acc_mlp_d, color=cmap(t_acc_mlp_d))\n", " plt.barh(1, v_acc_mlp_d, color=cmap(v_acc_mlp_d))\n", " plt.text(0.05, 2, f\"# training acc: {t_acc_mlp_d:.0%}\")\n", " plt.text(0.05, 1, f\"# validation acc: {v_acc_mlp_d:.0%}\")\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.axis(\"off\")\n", " plt.show()\n", "\n", "def gained_dot_product_attention_implemented(x: torch.Tensor, # input vector\n", " q_1: torch.Tensor, # query vector 1\n", " q_2: torch.Tensor, # query vector 2\n", " z_1: float, # query gain 1\n", " z_2: float, # query gain 2\n", " ):\n", " \"\"\"This function computes the gained dot product attention\n", " Args:\n", " x (Tensor): input vector\n", " q_1 (Tensor): query vector 1\n", " q_2 (Tensor): query vector 2\n", " z_1 (float): query gain 1\n", " z_2 (float): query gain 2\n", " Returns:\n", " w (Tensor): attention weights\n", " y (float): gained dot product attention\n", " \"\"\"\n", " w = torch.softmax(z_1 * q_1 + z_2 * q_2, dim=0)\n", " y = torch.dot(w, x)\n", " return w, y\n", "\n", "def plot_weights_and_output(gain_1, gain_2):\n", " T = 20 # context length\n", " x = torch.sin(torch.linspace(0, 2*np.pi, T)) + 0.1 * torch.randn(T)\n", " q_1 = 1.0 - torch.sigmoid(torch.linspace(-3, 7, T))\n", " q_1 = q_1 / q_1.sum()\n", " q_2 = torch.sigmoid(torch.linspace(-7, 3, T))\n", " q_2 = q_2 / q_2.sum()\n", " w, y = gained_dot_product_attention_implemented(x, q_1, q_2, gain_1, gain_2)\n", " print(f\"output y: {y}\")\n", " with plt.xkcd():\n", " plt.figure(figsize=(8, 6))\n", " plt.subplot(2, 1, 1)\n", " plt.plot(q_1, label=\"$\\mathbf{q_1}$\", c=\"m\")\n", " plt.plot(q_2, label=\"$\\mathbf{q_2}$\", c=\"y\")\n", " plt.plot(w, label=\"$\\mathbf{w}$\", c=\"r\")\n", " plt.ylim(-0.1, max(q_1.max(), q_2.max()))\n", " plt.ylabel(\"Attention weights\")\n", " plt.xlabel(\"Context dimension\")\n", " plt.legend()\n", " plt.subplot(2, 1, 2)\n", " plt.plot(x, label=\"$\\mathbf{x}$\", c=\"blue\")\n", " plt.plot(5 * x * w, label=\"$5\\mathbf{w}*\\mathbf{x}$\", c=\"red\")\n", " plt.ylim(-x.abs().max(), x.abs().max())\n", " plt.ylabel(\"Amplitude\")\n", " plt.xlabel(\"Context\")\n", " plt.legend()\n", " plt.show()\n", "\n", "def plot_attention_graph(model, data_generator, m):\n", " with plt.xkcd():\n", " X_valid, y_valid = data_generator.generate(m, verbose=False)\n", " logits, scores, attention, output = model(X_valid, elaborate=True)\n", " with torch.no_grad():\n", " scores_x_linear = torch.einsum(\"bij,jk->bik\", scores, model.linear.weight.T).detach().cpu()\n", " w_att_trues_0 = scores_x_linear[(y_valid[:, 0] == 1) & (X_valid[:, -1] == 0)]\n", " w_att_trues_1 = scores_x_linear[(y_valid[:, 0] == 1) & (X_valid[:, -1] == 1)]\n", " w_att_trues_0 = w_att_trues_0.mean(dim=0).squeeze(-1)\n", " w_att_trues_1 = w_att_trues_1.mean(dim=0).squeeze(-1)\n", "\n", " fig, ax = plt.subplots(2, 1, figsize=(10, 8))\n", " rules = [\"Rule = 0\", \"Rule = 1\"]\n", " for r, x in enumerate((w_att_trues_0, w_att_trues_1)):\n", " #xc = (x - x.min()) / (x.max() - x.min() + 1e-9)\n", " xc = (x + x.abs().max()) / (2 * x.abs().max()) # color range spans -max|x| to +max|x|\n", " xw = x.abs() / x.abs().max()\n", " xa = (xw - xw.min()) / (xw.max() - xw.min() + 1e-9)\n", " xabsmax = x.abs().max()\n", " x, xc, xw, xa = (some_x.numpy() for some_x in (x, xc, xw, xa))\n", " T = x.shape[0]\n", " m = (T/2) - 0.5\n", "\n", " ax[r].text(m-1, 1.0, \" \", c='k', fontsize=16)\n", " ax[r].text(m-1, 0.9, f\"Graph for {rules[r]}\", c='k', fontsize=16)\n", " for i in range(T):\n", " c = 'r' if i in data_generator.f_i_1 else 'b' if i in data_generator.f_i_2 else 'm' if i == T-1 else'grey'\n", " ax[r].plot([i, m], [0, 0.75], c=cm.plasma_r(xc[i]), linewidth=5, alpha=1)\n", " ax[r].scatter([i], [0], c=c, zorder=100, s=100)\n", " ax[r].scatter([m], [0.75], c='k', zorder=100, s=200)\n", " for i, j in zip(data_generator.f_i_1, data_generator.f_i_2):\n", " ax[r].text(i-0.3, -0.1, \"tar_0\", c='r', fontsize=12)\n", " ax[r].text(j-0.3, -0.1, \"tar_1\", c='b', fontsize=12)\n", " ax[r].text(m-0.3, 0.8, \"Output\", c='k', fontsize=12)\n", " ax[r].text(T-1-0.1, -0.1, \"Rule\", c='m', fontsize=12)\n", " cbar = fig.colorbar(cm.ScalarMappable(cmap=cm.plasma_r), ax=ax[r], ticks=[0, 1], label=\"Attention\")\n", " #cbar.ax.set_yticklabels([f\"{x.min():.1f}\", f\"{x.max():.1f}\"])\n", " cbar.ax.set_yticklabels([f\"{-xabsmax:.1f}\", f\"{xabsmax:.1f}\"])\n", " ax[r].axis('off')\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data retrieval\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1718046717157, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "#@title Data retrieval\n", "class s_Sparse_AND: # 1-Dimensional AND\n", " def __init__(self, T: int, s: int):\n", " self.T = T # context length\n", " self.s = s # sparsity\n", " self.p = 0.5**(1.0/3.0) # probability chosen for balanced data\n", " self.f_i = None\n", "\n", " def pick_an_f(self):\n", " self.f_i = sorted(random.sample(range(self.T), self.s))\n", " self.others = list(i for i in range(self.T) if i not in self.f_i)\n", "\n", " def generate(self, m: int, verbose: bool = False):\n", " if self.f_i is None:\n", " self.pick_an_f()\n", " max_try = 100\n", " i_try = 0\n", " while i_try < max_try:\n", " i_try += 1\n", " X, y = torch.zeros(m, self.T), torch.zeros(m, 1)\n", " X[torch.rand(m, self.T) < self.p] = 1\n", " y[X[:, self.f_i].sum(dim=1) == self.s] = 1\n", " if y.sum()/m < 0.4 or y.sum()/m > 0.6:\n", " verbose and print(f\"Large imbalance in the training set {y.sum()/m}, retrying...\")\n", " continue\n", " else:\n", " verbose and print(f\"Data-label balance: {y.sum()/m}\")\n", " bad_batch = False\n", " for i in self.f_i:\n", " for o in self.others:\n", " if (X[:, i] == X[:, o]).all():\n", " verbose and print(f\"Found at least another compatible hypothesis {i} and {o}\")\n", " bad_batch = True\n", " break\n", " if bad_batch:\n", " continue\n", " else:\n", " break\n", " else:\n", " verbose and print(\"Could not find a compatible hypothesis\")\n", " return X.long(), y.float()\n", "\n", "\n", "class s_Sparse_AND_Query: # 1-Dimensional AND\n", " def __init__(self, T, s):\n", " self.T = T - 1\n", " self.k = 1 # currently only for k = 1\n", " self.s = s\n", " self.p = 0.5**(1.0/2.0) # probability chosen for balanced data\n", " self.f_i_1, self.f_i_2, self.f_i = None, None, None\n", "\n", " def pick_an_f(self):\n", " self.f_i_1 = sorted(random.sample(range(self.T), self.s))\n", " self.others_1 = list(i for i in range(self.T) if i not in self.f_i_1)\n", " self.f_i_2 = sorted(random.sample(self.others_1, self.s))\n", " self.others_2 = list(i for i in self.others_1 if i not in self.f_i_2)\n", " self.f_i = self.f_i_1 + self.f_i_2\n", "\n", " def __call__(self):\n", " self.pick_an_f()\n", "\n", " def generate(self, m: int, verbose: bool = False):\n", " z = torch.randint(0, 2, (m, self.k))\n", " if self.f_i_1 is None or self.f_i_2 is None:\n", " self.pick_an_f()\n", " max_try = 100\n", " i_try = 0\n", " while i_try < max_try:\n", " i_try += 1\n", " X, y = torch.zeros(m, self.T), torch.zeros(m, 1)\n", " X[torch.rand(m, self.T) < self.p] = 1\n", " # Rule 0\n", " X_z_0 = (X[:, self.f_i_1].sum(dim=1) == self.s) & (z[:, 0] == 0)\n", " y[X_z_0] = 1\n", " # Rule 1\n", " X_z_1 = (X[:, self.f_i_2].sum(dim=1) == self.s) & (z[:, 0] == 1)\n", " y[X_z_1] = 1\n", " if y.sum()/m < 0.45 or y.sum()/m > 0.55:\n", " verbose and print(f\"Large imbalance in the training set {y.sum()/m}, retrying...\")\n", " continue\n", " else:\n", " verbose and print(f\"Data-label balance: {y.sum()/m}\")\n", " break\n", " else:\n", " verbose and print(\"Could not find a compatible hypothesis\")\n", " return torch.cat([X, z], dim=1).long(), y.float()\n", "\n", "\n", "# From notebook\n", "# Load an example image from the imagenet-sample-images repository\n", "def load_image(path):\n", " \"\"\"\n", " Load an image from a given path\n", "\n", " Args:\n", " path: String\n", " Path to the image\n", "\n", " Returns:\n", " img: PIL Image\n", " Image loaded from the path\n", " \"\"\"\n", " img = Image.open(path)\n", " return img" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 347, "status": "ok", "timestamp": 1718046717502, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper functions\n", "\n", "class BinaryMLP(torch.nn.Module):\n", " def __init__(self, in_dims, h_dims, out_dims, dropout=0.1):\n", " super().__init__()\n", " self.in_dims = in_dims\n", " self.h_dims = h_dims\n", " self.out_dims = out_dims\n", "\n", " self.layers = torch.nn.ModuleList()\n", " self.layers.append(torch.nn.Linear(in_dims, h_dims[0]))\n", " torch.nn.init.normal_(self.layers[-1].weight, std=0.02)\n", " torch.nn.init.zeros_(self.layers[-1].bias)\n", " self.layers.append(torch.nn.GELU())\n", " self.layers.append(torch.nn.Dropout(dropout))\n", " for i in range(len(h_dims) - 1):\n", " self.layers.append(torch.nn.Linear(h_dims[i], h_dims[i+1]))\n", " torch.nn.init.normal_(self.layers[-1].weight, std=0.02)\n", " torch.nn.init.zeros_(self.layers[-1].bias)\n", " self.layers.append(torch.nn.GELU())\n", " self.layers.append(torch.nn.Dropout(dropout))\n", " self.layers.append(torch.nn.Linear(h_dims[-1], out_dims))\n", " self.layers.append(torch.nn.Sigmoid())\n", "\n", " def forward(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", " return x\n", "\n", "\n", "\n", "class SelfAttention(torch.nn.Module):\n", " \"\"\"Simple Binary Self-Attention\n", " \"\"\"\n", " def __init__(self, T: int, d: int):\n", " \"\"\"\n", " Args:\n", " T (int): context length\n", " d (int): embedding size for K, Q but not V\n", " Note:\n", " The embedding size for V is the same as the input size i.e. 1\n", " \"\"\"\n", " super().__init__()\n", " self.T = T # context length\n", " self.d = d # embedding size\n", " self.scale = self.d ** -0.5 # scaling factor (1 / sqrt(d_k))\n", " self.v = 2 # vocabulary size (binary input, 0 or 1)\n", " init_std = 0.001 # standard deviation for weight initialization\n", "\n", " # embedding layers\n", " self.tokenizer = torch.nn.Embedding(self.v, self.d) # token embedding\n", " self.positioner = torch.nn.Parameter(torch.rand(1, self.T, self.d)) # positional embedding\n", "\n", " # self-attention layers\n", " self.Wq = torch.nn.Linear(self.d, self.d, bias=False) # query layer\n", " self.Wk = torch.nn.Linear(self.d, self.d, bias=False) # key layer\n", " self.Wv = torch.nn.Linear(self.d, 1, bias=False) # value layer\n", "\n", " # projection layer\n", " self.linear = torch.nn.Linear(self.T, 1, bias=False) # projection layer\n", " torch.nn.init.normal_(self.linear.weight, std=init_std)\n", " # torch.nn.init.zeros_(self.linear.bias)\n", "\n", " # initialize weights and biases (per description in the paper)\n", " torch.nn.init.normal_(self.tokenizer.weight, std=init_std)\n", " torch.nn.init.normal_(self.positioner, std=init_std)\n", " torch.nn.init.normal_(self.Wq.weight, std=init_std)\n", " # torch.nn.init.zeros_(self.Wq.bias)\n", " torch.nn.init.normal_(self.Wk.weight, std=init_std)\n", " # torch.nn.init.zeros_(self.Wk.bias)\n", " torch.nn.init.normal_(self.Wv.weight, std=init_std)\n", " # torch.nn.init.zeros_(self.Wv.bias)\n", "\n", " def forward(self, x: torch.Tensor, elaborate: bool = False):\n", " \"\"\"Forward pass\n", " Args:\n", " x (torch.Tensor): input tensor of shape (B, T, d)\n", " \"\"\"\n", " # Embedding\n", " x = self.tokenizer(x)\n", " x = x + self.positioner\n", "\n", " # (Scaled Dot-Product Attention)\n", " Q = self.Wq(x)\n", " K = self.Wk(x)\n", " V = self.Wv(x)\n", " logits = torch.einsum(\"btc,bsc->bts\", Q, K) # query key product\n", " logits *= self.scale # normalize against staturation\n", " scores = torch.softmax(logits, dim=-1) # attention scores\n", " attention = torch.einsum(\"bts,bsc->btc\", scores, V) # scaled dot-product attention\n", " attention = attention.squeeze(-1) # remove last dimension\n", " output = self.linear(attention) # linear layer\n", " output = torch.sigmoid(output) # sigmoid activation\n", " return (logits, scores, attention, output) if elaborate else output\n", "\n", "\n", "class PositionalEncoding(torch.nn.Module):\n", " def __init__(self, T: int, d_model: int):\n", " super().__init__()\n", " Te = T + T%2\n", " de = d_model + d_model%2\n", " position = torch.arange(Te).unsqueeze(1)\n", " div_term = torch.exp(torch.arange(0, de, 2) * (-math.log(10000.0) / de))\n", " pe = torch.zeros(Te, de)\n", " pe[:, 0::2] = torch.sin(position * div_term)\n", " pe[:, 1::2] = torch.cos(position * div_term)\n", " pe = pe[:T, :d_model]\n", " self.register_buffer('pe', pe)\n", "\n", " def forward(self):\n", " return self.pe\n", "\n", "\n", "class SDPA_Solution(torch.nn.Module):\n", " def __init__(self, T: int, dm: int, dk: int):\n", " \"\"\"\n", " Scaled Dot Product Attention\n", " Args:\n", " T (int): context length\n", " dm (int): model dimension\n", " dk (int): key dimension\n", " Note:\n", " we assume dm == dv\n", " \"\"\"\n", " super().__init__()\n", " self.T = T # context length\n", " self.dm = dm # model dimension\n", " self.dk = dk # key dimension\n", " self.scale = 1.0 / math.sqrt(dk)\n", "\n", " # positional Encoding\n", " self.position = PositionalEncoding(T, dm)\n", "\n", " # self-attention layers\n", " self.Wq = torch.nn.Linear(dm, dk, bias=False) # query layer\n", " self.Wk = torch.nn.Linear(dm, dk, bias=False) # key layer\n", " self.Wv = torch.nn.Linear(dm, dm, bias=False) # value layer\n", "\n", " def forward(self, x: torch.Tensor):\n", " \"\"\"\n", " Args:\n", " x (torch.Tensor): input tensor of shape (T, d)\n", " \"\"\"\n", " # Positional Encoding\n", " x = x + self.position()\n", "\n", " # (Scaled Dot-Product Attention)\n", " Q = self.Wq(x) # Query\n", " K = self.Wk(x) # Key\n", " V = self.Wv(x) # Value\n", " QK = Q @ K.T # Query Key product\n", " S = QK * self.scale # Scores (scaled against staturation)\n", " S_softmax = torch.softmax(S, dim=-1) # softmax attention scores (row dimensions)\n", " A = S_softmax @ V # scaled dot-product attention\n", " return A\n", "\n", "\n", "def test_sdpa(your_implementation):\n", " # self-attention layers\n", " context_len = 7\n", " model_dim = 5\n", " key_dim = 3\n", " toy_x = torch.rand(context_len, model_dim)\n", " your_model = your_implementation(context_len, model_dim, key_dim)\n", " our_model = SDPA_Solution(context_len, model_dim, key_dim)\n", " your_model.Wq = our_model.Wq\n", " your_model.Wk = our_model.Wk\n", " your_model.Wv = our_model.Wv\n", " your_results = your_model(toy_x.clone())\n", " our_results = our_model(toy_x.clone())\n", " if (your_results == our_results).all():\n", " print(\"Very well Done!\")\n", " else:\n", " print(\"The two implementations should produce the same result!\")\n", "\n", "\n", "def bin_acc(y_hat, y):\n", " \"\"\"\n", " Compute the binary accuracy\n", " \"\"\"\n", " y_ = y_hat.round()\n", " TP_TN = (y_ == y).float().sum().item()\n", " FP_FN = (y_ != y).float().sum().item()\n", " assert TP_TN + FP_FN == y.numel(), f\"{TP_TN + FP_FN} != {y.numel()}\"\n", " return TP_TN / y.numel()\n", "\n", "\n", "def get_n_parameters(model: torch.nn.Module):\n", " \"\"\"\n", " Get the number of learnable parameters in a model\n", " \"\"\"\n", " i = 0\n", " for par in model.parameters():\n", " i += par.numel()\n", " return i\n", "\n", "\n", "def save_model(model):\n", " torch.save(model.state_dict(), 'model_states.pt')\n", "\n", "\n", "def load_model(model):\n", " model_states = torch.load('model_states.pt')\n", " model.load_state_dict(model_states)\n", "\n", "\n", "class Trainer:\n", " def __init__(self, model, optimizer, criterion):\n", " self.model = model\n", " self.optimizer = optimizer\n", " self.criterion = criterion\n", "\n", " def do_eval(self, X_v, y_v, device=\"cpu\"):\n", " self.model.to(device)\n", " self.model.eval()\n", " X_v, y_v = X_v.to(device), y_v.to(device)\n", " with torch.no_grad():\n", " y_hat = self.model(X_v)\n", " loss = self.criterion(y_hat.squeeze(), y_v.squeeze())\n", " acc = bin_acc(y_hat, y_v)\n", " self.model.to(\"cpu\")\n", " return loss.item(), acc\n", "\n", " def do_train(self, n_epochs, X_t, y_t, device=\"cpu\"):\n", " train_loss, train_acc = [], []\n", " self.model.to(device)\n", " self.model.train()\n", " X_t, y_t = X_t.to(device), y_t.to(device)\n", " for i in range(n_epochs):\n", " self.optimizer.zero_grad(set_to_none=True)\n", " y_hat = self.model(X_t)\n", " loss_t = self.criterion(y_hat.squeeze(), y_t.squeeze())\n", " loss_t.backward()\n", " self.optimizer.step()\n", " self.optimizer.zero_grad(set_to_none=True)\n", " if (i + 1) % 10 == 0 or i == 0:\n", " train_loss.append(loss_t.item())\n", " train_acc.append(bin_acc(y_hat, y_t))\n", " self.model.eval()\n", " self.model.to(\"cpu\")\n", " return train_loss, train_acc\n", "\n", "\n", "def make_train(model, data_gen, B_t, B_v, n_epochs, device=\"cpu\", kind=\"MLP\", verbose=False, etta=1e-2):\n", " X_train, y_train = data_gen.generate(B_t, verbose=False)\n", " X_valid, y_valid = data_gen.generate(B_v, verbose=False)\n", " optimizer = torch.optim.Adam(model.parameters(), lr=etta)\n", " criterion = torch.nn.BCELoss()\n", " train_eval = Trainer(model, optimizer, criterion)\n", " model_np = get_n_parameters(model) # number of learnable parameters\n", " verbose and print(f\"Number of model's learnable parameters: {model_np}\")\n", " if kind == \"MLP\":\n", " t_loss, t_acc = train_eval.do_train(n_epochs, X_train.float(), y_train, device=device)\n", " v_loss, v_acc = train_eval.do_eval(X_valid.float(), y_valid, device=device)\n", " else:\n", " t_loss, t_acc = train_eval.do_train(n_epochs, X_train, y_train, device=device)\n", " v_loss, v_acc = train_eval.do_eval(X_valid, y_valid, device=device)\n", " verbose and print(f\"Training loss: {t_loss[-1]:.3f}, accuracy: {t_acc[-1]:.3f}\")\n", " verbose and print(f\"Validation loss: {v_loss:.3f}, accuracy: {v_acc:.3f}\")\n", " return t_loss[-1], t_acc[-1], v_loss, v_acc, model_np\n", "\n", "\n", "def scaled_dot_product_attention_solution(Q, K, V):\n", " \"\"\" Scaled dot product attention\n", " Args:\n", " Q: queries (B, H, d, n)\n", " K: keys (B, H, d, n)\n", " V: values (B, H, d, n)\n", " Returns:\n", " Attention tensor (B, H, d, n), Scores (B, H, d, d)\n", " Notes:\n", " (B, H, d, n): batch size, H: number of heads, d: key-query dim, n: embedding dim\n", " \"\"\"\n", "\n", " assert K.shape == Q.shape and K.shape == V.shape, \"Queries, Keys and Values must have the same shape\"\n", " B, H, d, n = K.shape # batch_size, num_heads, key-query dim, embedding dim\n", " scale = math.sqrt(d)\n", " Q_mm_K = torch.einsum(\"bhdn,bhen->bhde\", Q, K) # dot-product reducing the n dimension\n", " S = Q_mm_K / scale # score or scaled dot product\n", " S_sm = torch.softmax(S, dim=-1) # softmax\n", " A = torch.einsum(\"bhde,bhen->bhdn\", S_sm, V) # Attention\n", " return A, S\n", "\n", "\n", "def weighted_attention(self, x):\n", " \"\"\"This function computes the weighted attention as a method for the BinarySAT class\n", "\n", " Args:\n", " x (Tensor): An array of shape (B:batch_size, T:context length) containing the input data\n", "\n", " Returns:\n", " Tensor: weighted attention of shape (B, E, E) where E = T + 1\n", " \"\"\"\n", " assert self.n_heads == 1, \"This function is only implemented for a single head!\"\n", " # Embedding\n", " B = x.size(0) # batch size\n", " x = self.toke(x) # token embedding\n", " x = torch.cat([x, self.cls.expand(B, -1, -1)], dim=1) # concatenate cls token\n", " x = x + self.pose # positional embedding\n", " norm_x = self.norm1(x) # normalization\n", "\n", " # Scaled Dot-Product Attention (partially implemented)\n", " q, k, v = self.qkv(norm_x).view(B, self.E, 3, self.d).unbind(dim=2)\n", " # q, k, v all have shape (B, E, h, d) where h is the number of heads (1 in this case)\n", " W_qk = q @ k.transpose(-2, -1)\n", " W_qk = W_qk * self.scale\n", " W_qk = torch.softmax(W_qk, dim=-1)\n", " return W_qk\n", "\n", "\n", "def results_dict(results_sat_d, results_sat_s, results_mlp_d, results_mlp_s):\n", " from collections import OrderedDict\n", " t_loss_sat_d, t_acc_sat_d, v_loss_sat_d, v_acc_sat_d, model_np_sat_d = results_sat_d\n", " t_loss_sat_s, t_acc_sat_s, v_loss_sat_s, v_acc_sat_s, model_np_sat_s = results_sat_s\n", " t_loss_mlp_d, t_acc_mlp_d, v_loss_mlp_d, v_acc_mlp_d, model_np_mlp_d = results_mlp_d\n", " t_loss_mlp_s, t_acc_mlp_s, v_loss_mlp_s, v_acc_mlp_s, model_np_mlp_s = results_mlp_s\n", " output = OrderedDict()\n", " # output[\"Training Loss SAT dense\"] = t_loss_sat_d\n", " # output[\"Training Accuracy SAT dense\"] = t_acc_sat_d\n", " output[\"Validation Loss SAT dense\"] = round(v_loss_sat_d, 3)\n", " output[\"Validation Accuracy SAT dense\"] = v_acc_sat_d\n", " output[\"Number of Parameters SAT dense\"] = model_np_sat_d\n", " # output[\"Training Loss SAT sparse\"] = t_loss_sat_s\n", " # output[\"Training Accuracy SAT sparse\"] = t_acc_sat_s\n", " output[\"Validation Loss SAT sparse\"] = round(v_loss_sat_s, 3)\n", " output[\"Validation Accuracy SAT sparse\"] = v_acc_sat_s\n", " output[\"Number of Parameters SAT sparse\"] = model_np_sat_s\n", " # output[\"Training Loss MLP dense\"] = t_loss_mlp_d\n", " # output[\"Training Accuracy MLP dense\"] = t_acc_mlp_d\n", " output[\"Validation Loss MLP dense\"] = round(v_loss_mlp_d, 3)\n", " output[\"Validation Accuracy MLP dense\"] = v_acc_mlp_d\n", " output[\"Number of Parameters MLP dense\"] = model_np_mlp_d\n", " # output[\"Training Loss MLP sparse\"] = t_loss_mlp_s\n", " # output[\"Training Accuracy MLP sparse\"] = t_acc_mlp_s\n", " output[\"Validation Loss MLP sparse\"] = round(v_loss_mlp_s, 3)\n", " output[\"Validation Accuracy MLP sparse\"] = v_acc_mlp_s\n", " output[\"Number of Parameters MLP sparse\"] = model_np_mlp_s\n", " return output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set random seed\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 6, "status": "ok", "timestamp": 1718046717502, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set random seed\n", "\n", "def set_seed(seed=None, seed_torch=True):\n", " \"\"\"\n", " Handles variability by controlling sources of randomness\n", " through set seed values\n", "\n", " Args:\n", " seed: Integer\n", " Set the seed value to given integer.\n", " If no seed, set seed value to random integer in the range 2^32\n", " seed_torch: Bool\n", " Seeds the random number generator for all devices to\n", " offer some guarantees on reproducibility\n", "\n", " Returns:\n", " Nothing\n", " \"\"\"\n", " if seed is None:\n", " seed = np.random.choice(2 ** 32)\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " if seed_torch:\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.cuda.manual_seed(seed)\n", " torch.backends.cudnn.benchmark = False\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set device (GPU or CPU)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Set device (GPU or CPU)\n", "\n", "def set_device():\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " if device != \"cuda\":\n", " print(\"GPU is not enabled in this notebook. \\n\"\n", " \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n", " \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n", " else:\n", " print(\"GPU is enabled in this notebook. \\n\"\n", " \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n", " \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n", "\n", " return device\n", "\n", "DEVICE = set_device()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Introduction to attention\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Introduction to attention\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'Ryr04zXhdIw'), ('Bilibili', 'BV1N4421D7NM')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_introduction_to_attention\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 1: Intro to multiplication\n", "\n", "In this section, we will show how signals can be used to gate inputs selectively. This is the essence of attention.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 2: Cross-attention\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: Cross-attention\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'Jkt4w98sP1U'), ('Bilibili', 'BV1Rw4m1e7RZ')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_cross_attention\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Exercise 1: Dot product attention\n", "\n", "We'll implement simple dot product attention for input $\\mathbf{x}$ and scalar output $y$ given by a weighted combination of the inputs,\n", "\n", "$$y = \\mathbf{w}\\cdot\\mathbf{x}$$\n", "\n", "for weights $\\mathbf{w}$ (also called gains). Unlike the fixed weights in an MLP, attention can adjust the weights depending on some other signal. Here we will use a two dimensional attentional gain, $z_1$ and $z_2$, each with a corresponding vector $\\mathbf{q}$ that determines the spatial pattern to be attended:\n", "\n", "$$\\mathbf{w}(z_1,z_2) = \\text{softmax}(z_1 \\mathbf{q}_1 + z_2 \\mathbf{q}_2)$$\n", "\n", "where $\\text{softmax}(\\mathbf{x})={e^\\mathbf{x}}/{\\sum_i e^{x_i}}$ is a combination of a sparsifying element-wise nonlinearity ($\\exp$) and a normalization.\n", "\n", "You should code up this function to return the attentional weights $\\mathbf{w}$ and the weighted output $y$.\n", "\n", "
\n", "\n", "
Figure after Edelman et al. 2022.
\n", "
" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1718046717502, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "source": [ "```python\n", "def gained_dot_product_attention(x: torch.Tensor, # input vector\n", " q_1: torch.Tensor, # query vector 1\n", " q_2: torch.Tensor, # query vector 2\n", " z_1: float, # query gain 1\n", " z_2: float, # query gain 2\n", " ):\n", " \"\"\"This function computes the gained dot product attention\n", " Args:\n", " x (Tensor): input vector\n", " q_1 (Tensor): query vector 1\n", " q_2 (Tensor): query vector 2\n", " z_1 (float): query gain 1\n", " z_2 (float): query gain 2\n", " Returns:\n", " w (Tensor): attention weights\n", " y (float): gained dot product attention\n", " \"\"\"\n", " #################################################\n", " ## TODO Implement the `gained_dot_product_attention`\n", " # Fill remove the following line of code one you have completed the exercise:\n", " raise NotImplementedError(\"Student exercise: complete calculation of gained dot product attention.\")\n", " #################################################\n", " w = ...\n", " y = ...\n", " return w, y\n", "\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1718046717502, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# to_remove solution\n", "def gained_dot_product_attention(x: torch.Tensor, # input vector\n", " q_1: torch.Tensor, # query vector 1\n", " q_2: torch.Tensor, # query vector 2\n", " z_1: float, # query gain 1\n", " z_2: float, # query gain 2\n", " ):\n", " \"\"\"This function computes the gained dot product attention\n", " Args:\n", " x (Tensor): input vector\n", " q_1 (Tensor): query vector 1\n", " q_2 (Tensor): query vector 2\n", " z_1 (float): query gain 1\n", " z_2 (float): query gain 2\n", " Returns:\n", " w (Tensor): attention weights\n", " y (float): gained dot product attention\n", " \"\"\"\n", " w = torch.softmax(z_1 * q_1 + z_2 * q_2, dim=0)\n", " y = torch.dot(w, x)\n", " return w, y" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Now we plot the weights as a function of $\\mathbf{z}$. Manipulate the sliders to control the two attention gains $z_k$, which we've pre-assigned to two specific sigmoidal pattern vectors $\\mathbf{q}_k$. (Feel free to change those functions and see how the sliders change what they attend.) Observe how these sliders affect which part of the input is amplified or attenuated. The x-axis represents the input variable to the functions that define $\\mathbf{q}_1$ and $\\mathbf{q}_2$ (observe that these are non-linear of this input)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Execute this cell to enable the widget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "executionInfo": { "elapsed": 3181, "status": "ok", "timestamp": 1718046720680, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Execute this cell to enable the widget\n", "# style = {'description_width': 'initial'}\n", "gain_1_widget = widgets.FloatSlider(1.0, description='gain_1', min=0.0, max=10.0, step=0.1, continuous_update=False, layout=Layout(width='50%'))\n", "gain_2_widget = widgets.FloatSlider(1.0, description='gain_2', min=0.0, max=10.0, step=0.1, continuous_update=False, layout=Layout(width='50%'))\n", "interactive_plot = interactive(plot_weights_and_output, gain_1=gain_1_widget, gain_2=gain_2_widget)\n", "interactive_plot" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "This plot illustrates the impact of the gain factor on the attention weights for different input dimensions. Observe that:\n", "\n", " - the exponential and normalization from the softmax create sparse weights that can select relevant features.\n", " \n", " - For small $z$, the attention weights are evenly distributed across the input dimensions.\n", "\n", " - As $z$ increases, the attention weights become sparser.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_dot_product_attention\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 3: Exercise conclusions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 3: Exercise conclusions\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'YMK9niQNC1Q'), ('Bilibili', 'BV1Li421v7sh')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_first_exercise_conclusions\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Self-attention\n", "\n", "*Estimated timing to here from start of tutorial: 20 minutes.*\n", "\n", "Where do these gain factors $\\mathbf{z}$ come from?\n", "\n", "In dot product attention, the weights were $\\mathbf{w}=\\text{softmax}(\\mathbf{z}\\cdot Q)$, a function of an *external* $\\mathbf{z}$ and fixed matrix $Q$.\n", "\n", "In the *self-attention* mechanism used in Transformers, we again see a multiplicative weighting on inputs, but now with weights that are functions of the input $\\mathbf{x}$. Transformers also add a few other components we will describe, too, but the multiplicative weighting is the essential operation we are highlighting in this tutorial.\n", "\n", "The rest of this Section goes through the math you need to understand the self-attention mechanism in Transformers and asks you to implement it in code. If you're short on time or already very familiar with transformers, you may want to jump to Section 3. Even though we had previously used transformers and thought we understood them, there were some important structures that we didn't appreciate until we created this tutorial! So you may still benefit from following along." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 4: Self-attention\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 4: Self-attention\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'V9OmjZhoed4'), ('Bilibili', 'BV1oS411w7Ls')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_self_attention\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Section 2.1: Tokens\n", "\n", "*Hypothetically*, a very simple version of input-dependent weights would simply set $\\mathbf{z}=\\mathbf{x}$, so\n", "\n", "$$\\mathbf{w}=\\text{softmax}(\\mathbf{x}\\cdot Q)$$\n", "\n", "Practically, it's important to recognize that in Transformers, the input $\\mathbf{x}$ is 'tokenized': the input is divided into separate elements, which could be words or image patches, and each element is assigned its own vector $\\mathbf{x}_i$ through some embedding which could be an arbitrary function (like in the mushroom body of the fly olfactory system, or like a hash function), or could be learned. The dot products in $\\mathbf{x}\\cdot Q$ could apply to the concatenated vector of all tokens or separately to each token. Transformers take that second approach. The details here mean that it can be tricky to specify the sizes and dimensions of the different matrices, so be careful.\n", "\n", "Ultimately, most of the novel machinery of attention in transformers ends up modifying tokens according to their context, leaving the subsequent remixing of tokens to other, more conventional layers. That specific way of applying attention is probably not central to the value of attention." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_tokens\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Section 2.2: More multiplications\n", "\n", "To provide more flexibility than that very simple version above, we allow $Q$ *also* to depend on the input: $Q=W_q\\mathbf{x}$. For even more flexibility, we select only specific aspects of the input tokens to serve as gain modulation,\n", "\n", "$$\\mathbf{z}=W_k\\mathbf{x}$$\n", "\n", "so\n", "\n", "$$\\mathbf{w}=\\text{softmax}(W_k \\mathbf{x}\\,\\mathbf{x}^\\top W_q)$$\n", "\n", "Note that each attention weight is applied to a different token, with $\\mathbf{x}\\mathbf{x}^\\top$ essentially computing representational similarity between tokens.\n", "\n", "The weights themselves are applied to linearly transformed tokens, $V=W_v\\mathbf{x}$, with the final attention output being a weighted sum over different tokens, which is added back to each original token. The result is that every original token is modified by the attended context. Finally, these modified tokens are processed by later layers.\n", "\n", "Let's take an interpretable example. You might have a sentence using the word \"cookies\" and its own embedding. But what *kind* of cookies? It means two very different things in the sentences \"I ate two cookies\" or \"This website uses cookies.\" The self-attention mechanism will use the sentence context to change the representation of the token \"cookies.\"\n", "\n", "This is commonly described as a query-key-value framework. These various learnable projections of the input are given names:\n", "\n", "$$\n", "Q=W_q\\mathbf{x}\\ \\ \\ \\ \\text{query}\\\\\n", "K=W_k\\mathbf{x}\\ \\ \\ \\ \\ \\ \\text{key}\\\\\n", "V=W_v\\mathbf{x}\\ \\ \\ \\ \\text{value}\n", "$$\n", "\n", "The self-attention mechanism is then usually written as:\n", "\n", "$$\\mathbf{y}=\\text{softmax}\\,(\\frac{QK^\\top}{\\sqrt{d}})V$$\n", "\n", "We get a big weight on inputs with a strong match between Query and Key, and that weight is applied to a particular set of Values. (For different dimensions $d$ of attended features, there's also a scaling by $\\sqrt{d}$ in the softmax for convenience.)\n", "\n", "Putting all this together, the self-attention is then:\n", "\n", "$$\n", "\\mathbf{y}=\\text{softmax}(\\tfrac{1}{\\sqrt{d}}W_q\\mathbf{x}\\mathbf{x}^\\top W_k^\\top)W_v\\mathbf{x}\n", "$$\n", "\n", "This equation shows that **the attention mechanism in transformers is a composition of multiplication, sparsification, normalization, and another multiplication** --- all special microarchitecture operations from this tutorial!\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_more_multiplications\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Bonus Section 2.3: Dimensions of Self-Attention\n", "\n", "This section's guidance should help you understand and implement self-attention in Bonus Exercise 2.\n", "\n", "Self-attention in transformers can seem intimidating because there are so many elements (i.e., tensors and matrices) to consider. Here, we try to explain the dimensions and provide you with a smoother entry into the architecture. The following figure should help you navigate this text. Note that here, we remove the batch dimension (and multiple \"heads\") for simplicity.\n", "\n", "Starting from the left, the input $\\mathbf{x}'$ is a sequence of length $T$ (also called context length). Each row of $\\mathbf{x}'$ is called a token, and it has dimension $d_m$, also known as the model dimension or embedding dimension.\n", "\n", "Tokens are largely treated separately in Transformers, so crucial information about their spatial or sequential relationships is missing. To fix this, Transformers change the raw token embedding by adding a position-dependent term $\\mathbf{p}$ to the tokens: $\\mathbf{x} = \\mathbf{x}' + \\mathbf{p}$. The positional matrix could be a hard-coded matrix (called \"Positional *Encoding*\") or can be a trainable (learned) matrix (called \"Positional *Embedding*\").\n", "\n", "Given these tokens, the transformer architecture selects some dimensions for further processing. It computes Keys ($\\mathbf{K}$), Queries ($\\mathbf{Q}$), and Values ($\\mathbf{V}$) by linear projections of the tokens:\n", "\n", "$$\\mathbf{K} = \\mathbf{W}_k \\mathbf{x}$$\n", "$$\\mathbf{Q} = \\mathbf{W}_q \\mathbf{x}$$\n", "$$\\mathbf{V} = \\mathbf{W}_v \\mathbf{x}$$\n", "\n", "where $\\mathbf{W}_k$, $ \\mathbf{W}_q$, and $ \\mathbf{W}_v$ are learnable weights. As you can see from the figure, the dimensions of these weights do not depend on $T$: these transformations are applied to each token separately.\n", "\n", "We then calculate which Keys are similar to which Queries by the dot-product $\\mathbf{QK^\\top}$ (and scale by $\\frac{1}{\\sqrt{d_k}}$).\n", "\n", "It can be helpful to consider just one row of this similarity matrix $\\mathbf{QK^\\top}$, i.e. consider the keys for just one token's query (token $i$). We take a softmax operation over this row by applying an element-wise exponential, which sparsifies the entries and then normalizes them to give a probability distribution. This normalized vector is called the attentional weights $\\mathbf{pi}$, and it determines which other tokens should be attended for token $i$. This is the self-attention version of the gains we used in the first exercise.\n", "\n", "How are these weights used? The Value $\\mathbf{V}$ selected certain dimensions for each token, and these dimensions are then combined by the weighted sum $\\Delta\\mathbf{x}=\\sum_j \\pi_j\\mathbf{V}_j$ and, typically, added to the original token representation, $\\mathbf{x}^*=\\mathbf{x}+\\Delta\\mathbf{x}$. This leads to a new version of the original token, which is refined by the other tokens that attention deems relevant.\n", "\n", "This refinement process is repeated for all tokens, and these new tokens can then be processed by other layers in the neural network.\n", "\n", "\n", "
\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_dimensions_of_self_attention\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Bonus Exercise 2: Implement self-attention\n", "\n", "In this exercise, you will implement the Attention operation from the transformer architecture and test it against pre-written code.\n", "\n", "(In this exercise, we use Positional Encoding, but in future models and exercises, we will use Positional Embedding. The difference is that an Encoding is fixed, like a hash code, whereas an Embedding is learned.)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "executionInfo": { "elapsed": 9, "status": "ok", "timestamp": 1718046720680, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "source": [ "```python\n", "class ScaledDotProductAttention(torch.nn.Module):\n", " def __init__(self, T: int, dm: int, dk: int):\n", " \"\"\"\n", " Scaled Dot Product Attention\n", " Args:\n", " T (int): context length\n", " dm (int): model dimension\n", " dk (int): key dimension\n", " Note:\n", " we assume dm == dv\n", " \"\"\"\n", " super().__init__()\n", " self.T = T # context length\n", " self.dm = dm # model dimension\n", " self.dk = dk # key dimension\n", " self.scale = 1.0 / math.sqrt(dk)\n", "\n", " # positional Encoding\n", " self.position = PositionalEncoding(T, dm)\n", "\n", " # self-attention layers\n", " self.Wq = torch.nn.Linear(dm, dk, bias=False) # query layer\n", " self.Wk = torch.nn.Linear(dm, dk, bias=False) # key layer\n", " self.Wv = torch.nn.Linear(dm, dm, bias=False) # value layer\n", "\n", " def forward(self, x: torch.Tensor):\n", " \"\"\"\n", " Args:\n", " x (torch.Tensor): input tensor of shape (T, d)\n", " \"\"\"\n", " # Positional Encoding\n", " x = x + self.position()\n", "\n", " #################################################\n", " # # TODO Scaled dot product attention\n", " # # Remove or comment out the lines you don't need\n", " # Fill remove the following line of code once you have completed the exercise:\n", " raise NotImplementedError(\"Student exercise: implement full version of scaled dot product attention.\")\n", " #################################################\n", " # (Scaled Dot-Product Attention)\n", " Q = self.Wq(x) # Query\n", " K = ... # Key\n", " V = ... # Value\n", " QK = ... # Query Key product\n", " S = ... # Scores (scaled against saturation)\n", " S_softmax = ... # softmax attention scores (row dimensions)\n", " A = ... # scaled dot-product attention\n", " return A\n", "\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 8, "status": "ok", "timestamp": 1718046720680, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# to_remove solution\n", "class ScaledDotProductAttention(torch.nn.Module):\n", " def __init__(self, T: int, dm: int, dk: int):\n", " \"\"\"\n", " Scaled Dot Product Attention\n", " Args:\n", " T (int): context length\n", " dm (int): model dimension\n", " dk (int): key dimension\n", " Note:\n", " we assume dm == dv\n", " \"\"\"\n", " super().__init__()\n", " self.T = T # context length\n", " self.dm = dm # model dimension\n", " self.dk = dk # key dimension\n", " self.scale = 1.0 / math.sqrt(dk)\n", "\n", " # positional Encoding\n", " self.position = PositionalEncoding(T, dm)\n", "\n", " # self-attention layers\n", " self.Wq = torch.nn.Linear(dm, dk, bias=False) # query layer\n", " self.Wk = torch.nn.Linear(dm, dk, bias=False) # key layer\n", " self.Wv = torch.nn.Linear(dm, dm, bias=False) # value layer\n", "\n", " def forward(self, x: torch.Tensor):\n", " \"\"\"\n", " Args:\n", " x (torch.Tensor): input tensor of shape (T, d)\n", " \"\"\"\n", " # Positional Encoding\n", " x = x + self.position()\n", "\n", " # (Scaled Dot-Product Attention)\n", " Q = self.Wq(x) # Query\n", " K = self.Wk(x) # Key\n", " V = self.Wv(x) # Value\n", " QK = Q @ K.T # Query Key product\n", " S = QK * self.scale # Scores (scaled against saturation)\n", " S_softmax = torch.softmax(S, dim=-1) # softmax attention scores (row dimensions)\n", " A = S_softmax @ V # scaled dot-product attention\n", " return A" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Here is a function to test whether your function matches the correct output for self-attention." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 7, "status": "ok", "timestamp": 1718046720680, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Testing the function scaled_dot_product\n", "test_sdpa(ScaledDotProductAttention)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_self_attention_implementation\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 3: Inductive bias of self-attention: Sparse variable creation\n", "\n", "*Estimated timing to here from start of tutorial: 40 minutes.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 5: Sparse variable\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 5: Sparse variable\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', 'aefbaKUN_cg'), ('Bilibili', 'BV1WE421N7g5')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_sparse_variable\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "What is the self-attention mechanism especially good at learning?\n", "\n", "In the next exercise, we will follow the results from the paper [Inductive Biases and Variable Creation in Self-Attention Mechanisms](https://arxiv.org/abs/2110.10090). They show that a simple transformer can successfully learn to represent a sparse function with low sample complexity.\n", "\n", "The sparse functions we will use in this tutorial are **s-sparse AND** functions: For $T$ binary inputs ('context length') and *s* pre-selected unique indices, the sequence is labeled as *True* if the value of $T$ at *all* of the chosen indices is $1$, and is otherwise *False*. This means that the sequence label only depends on *s* elements out of the whole input.\n", "\n", "As an intuitive example, imagine an attribute vector that describes a \"thing\" on earth. Such vectors can be very, very long. If we are interested in classifying a \"thing\" as living or not based on a given vector, only 6 binary features could be enough: can grow, reproduce, metabolize, excrete, respire, and react to the change in the environment. It must do all of these things to be alive. The function that takes the vector and returns whether the \"thing\" is alive or not is a 6-Sparse AND function.\n", "\n", "
\n", "\n", "
\n", "\n", "Here, we will compare two architectures, Multi-Layer Perceptron (MLP) *vs* Self-Attention, in learning this sparse boolean function $f$. We will test how well they perform with few training samples and good generalization error.\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Dataset\n", "\n", "For this tutorial, we have already defined a s-sparse AND dataset generator class `s_Sparse_AND` for you. It generates $m$ sequences with context length $T$ and sparsity $s$. (Note that Each input element is more likely to be 1 than 0 so that the output labels have equal probability.) Here, we visualize a few input samples and their corresponding labels. The red rectangles show the relevant indices for this dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 569, "status": "ok", "timestamp": 1718046721243, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "context_length = 13 # T: context length\n", "s_sparse = 2 # s: sparsity (number of function-relevant indices)\n", "n_sequences = 10 # m: number of samples (sequences)\n", "data_gen = s_Sparse_AND(context_length, s_sparse)\n", "X_, y_ = data_gen.generate(n_sequences, verbose=False)\n", "correct_ids = data_gen.f_i\n", "print(f\"Target (function-relevant indices) indices: {correct_ids}\")\n", "\n", "plot_samples(X_, y_, correct_ids, f\"{s_sparse}_Sparse_AND samples\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Exercise 3: MLP vs. Self-Attention\n", "\n", "Let's see which of the two architectures generalizes better in this problem. We will test both on sparse functions (*s* = 3) and denser functions (*s* = 15) for the context length of *T* = 30. We use a helper-function `make_train` that takes the model and hyper-parameters and returns the trained model and some results.\n", "\n", "This exercise has 5 parts:\n", "\n", "1. Create training and validation datasets.\n", "2. Train an MLP and self-attention model on task.\n", "3. Plot to compare the results for the two models and two datasets.\n", "4. Change the sample complexity of the MLP dataset, to get to 100% accuracy.\n", "5. Plot the attention score (attention weights) for the transformer!\n", "\n", "**ABBREVIATIONS:**\n", "- suffix `_s` = sparse\n", "- suffix `_d` = dense\n", "- prefix `t_` = training\n", "- prefix `v_` = validation\n", "- `model_np` = number of parameters\n", "- `sat` = Self-Attention Transformer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1718046721243, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Problem hyperparameters\n", "context_length = 17 # T: context length\n", "sparse_dense = [2, 4] # s: sparsity (number of function-relevant indices)\n", "B_valid = 500 # batch size for validation" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Exercise 3.1: MLP\n", "Training an MLP on the \"s-Sparse AND\" task. How does the model do?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 9576, "status": "ok", "timestamp": 1718046730816, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Hyperparameters for sparse MLP\n", "B_t_mlp = 200 # batch size for training (number of training samples)\n", "n_epochs = 500 # number of epochs\n", "s_sparse = sparse_dense[0] # sparse\n", "hidden_layers = [512] # the number of hidden units in each layer [H1, H2, ...]\n", "kind = \"MLP\"\n", "\n", "mlp_model = BinaryMLP(context_length, hidden_layers, 1) # MLP model\n", "data_gen = s_Sparse_AND_Query(context_length, s_sparse)\n", "results_mlp_s = make_train(mlp_model, data_gen, B_t_mlp, B_valid, n_epochs, DEVICE, kind, verbose=True, etta=1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 2691, "status": "ok", "timestamp": 1718046733506, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Hyperparameters for dense MLP\n", "B_t_mlp = 200 # batch size for training (number of training samples)\n", "n_epochs = 500 # number of epochs\n", "s_sparse = sparse_dense[1] # dense\n", "hidden_layers = [512] # the number of hidden units in each layer [H1, H2, ...]\n", "kind = \"MLP\"\n", "\n", "mlp_model = BinaryMLP(context_length, hidden_layers, 1) # MLP model\n", "data_gen = s_Sparse_AND_Query(context_length, s_sparse)\n", "results_mlp_d = make_train(mlp_model, data_gen, B_t_mlp, B_valid, n_epochs, DEVICE, kind, verbose=True, etta=1e-3)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Exercise 3.2: Self-Attention\n", "\n", "Build a Transformer model and train it on the given dataset. How do the training results compare to MLP?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 10625, "status": "ok", "timestamp": 1718046744127, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Hyperparameters for sparse SAT\n", "B_t_sat = 200 # batch size for training (number of training samples)\n", "n_epochs = 2000 # number of epochs\n", "s_sparse = sparse_dense[0] # sparse\n", "embed_dim = 2 # embedding dimension\n", "kind = \"SAT\"\n", "\n", "sat_model_s = SelfAttention(context_length, embed_dim) # selt-attention transformer\n", "data_gen = s_Sparse_AND_Query(context_length, s_sparse)\n", "results_sat_s = make_train(sat_model_s, data_gen, B_t_sat, B_valid, n_epochs, DEVICE, kind, verbose=True, etta=1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 14546, "status": "ok", "timestamp": 1718046758668, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Hyperparameters for dense SAT\n", "B_t_sat = 200 # batch size for training (number of training samples)\n", "n_epochs = 2000 # number of epochs\n", "s_sparse = sparse_dense[1] # dense\n", "embed_dim = 2 # embedding dimension\n", "kind = \"SAT\"\n", "\n", "sat_model_d = SelfAttention(context_length, embed_dim) # selt-attention transformer\n", "data_gen = s_Sparse_AND_Query(context_length, s_sparse)\n", "results_sat_d = make_train(sat_model_d, data_gen, B_t_sat, B_valid, n_epochs, DEVICE, kind, verbose=True, etta=1e-2)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Exercise 3.3: Coding Exercise Comparing results\n", "\n", "Here, we ask you to plot the results of the trainings above and the hyper-parameters you think are important. The goal is to show in one plot how Self-Attention Transformers and Multi-Layer Perceptons compare for both sparse and dense boolean tasks. You can use any or all of the following information in your plot:\n", "- number of parameters\n", "- training samples\n", "- validation accuracy\n", "- validation loss\n", "\n", "We have provided the results in an ordered dictionary, `ordered_results`, for your convenience.\n", "\n", "Discuss: what components of self-attention may be responsible for this improved performance for sparse functions? What could you add to an MLP that might make it more competitive?\n", "\n", "**Hint:** Use your creativity and information-communication skills :)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1718046758669, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Validation loss, accuracy, number of parameters and number of training samples\n", "ordered_results = results_dict(results_sat_d, results_sat_s, results_mlp_d, results_mlp_s)\n", "ordered_results" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Exercise 4: Context-dependent Attention Weights\n", "\n", "The *S*-Sparse AND task is not the best task to demonstrate the power of self-attention transformers in learning representations because even a linear network can learn a single projection of the inputs. Here we design a new task that would be a bit more challenging but also expose more core structure of attention.\n", "\n", "For a given binary sequence, we include a special \"Rule\" token that specifies which tokens should be included for the AND function. For example: if $x_{\\rm rule} = 0$, then the label of sequence should be $x_3~\\text{AND}~ x_7$, and if $x_{\\rm rule} = 1$, then the label should be $x_2~\\text{AND}~ x_5$. Therefore the self-attention should learn to \"attend\" to $(x_3,x_7)$ OR $(x_2,x_5)$, depending on $x_{\\rm rule}$.\n", "\n", "Here, we will train our model on such a task and then plot how each token affects the output label. Since this is affected by both the attentional weights and the next layer readout, we plot the product $W_\\text{readout}\\cdot W_\\text{attention}$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {}, "executionInfo": { "elapsed": 6947, "status": "ok", "timestamp": 1718046765614, "user": { "displayName": "Xaq Pitkow", "userId": "09050806329892245378" }, "user_tz": -180 } }, "outputs": [], "source": [ "# Hyperparameters for sparse SAT\n", "context_length = 11 # T: context length\n", "B_t_sat = 300 # batch size for training (number of training samples)\n", "n_epochs = 1000 # number of epochs\n", "s_sparse = 2 # sparse\n", "embed_dim = 2 # embedding dimension\n", "n_sequences = 200 # number of samples for plotting\n", "kind = \"SAT\"\n", "\n", "sat_model_s = SelfAttention(context_length, embed_dim) # selt-attention transformer\n", "data_gen = s_Sparse_AND_Query(context_length, s_sparse)\n", "_ = make_train(sat_model_s, data_gen, B_t_sat, B_valid, n_epochs, DEVICE, kind, verbose=False, etta=1e-2)\n", "\n", "plot_attention_graph(sat_model_s, data_gen, n_sequences)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Under the two different rules, the attention mechanism weighs the input tokens differently for the final output. Plotting the results this way emphasizes how attention can be interpreted as a changing computation graph." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_mlp_vs_self_attention\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 6: Results discussion & GNN\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 6: Results discussion & GNN\n", "\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i][0] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i][1], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i][0] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i][0] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}')\n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "video_ids = [('Youtube', '8T87j97JO28'), ('Bilibili', 'BV1Ux4y1b7vu')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i][0])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_results_discussion_gnn\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "*Estimated timing of tutorial: 1 hour*\n", "\n", "This tutorial introduced multiplicative attention as a canonical operation in brains and machines and showed how it has an inductive bias that can help with context-dependent tasks.\n", "\n", "The self-attention mechanism of Transformers actually includes all of the canonical NeuroAI operations from this tutorial: sparsifying, normalization, and multiplicative attention. These operations make it easier to generalize in natural tasks." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "gpuType": "T4", "include_colab_link": true, "name": "W1D5_Tutorial3", "provenance": [ { "file_id": "https://github.com/ssnio/nma_neuroai_d4_t4/blob/main/W1D5_Tutorial3_query_full.ipynb", "timestamp": 1717881344256 } ], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 4 }