{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"# Tutorial 3: Attention\n",
"\n",
"**Week 1, Day 5: Microcircuits**\n",
"\n",
"**By Neuromatch Academy**\n",
"\n",
"__Content creators:__ Saaketh Medepalli, Aditya Singh, Saeed Salehi, Xaq Pitkow\n",
"\n",
"__Content reviewers:__ Yizhou Chen, RyeongKyung Yoon, Ruiyi Zhang, Lily Chamakura, Hlib Solodzhuk\n",
"\n",
"__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Tutorial Objectives\n",
"\n",
"*Estimated timing of tutorial: 1 hour*\n",
"\n",
"\n",
"By the end of this tutorial, we aim to:\n",
"\n",
"1. Learn how the brain and AI systems implemention attention\n",
"\n",
"2. Understand how multiplicative interactions allow flexible gating of information\n",
"\n",
"3. Demonstrate the inductive bias of the self-attention mechanism towards learning functions of sparse subsets of variables\n",
"\n",
"A key microarchitectural operation in brains and machines is **attention**. The essence of this operation is multiplication. Unlike the vanilla neural network operation -- a nonlinear function of a weighted sum of inputs $f(\\sum_j w_{j}x_j)$ -- the attention operation allows responses to *multiply* inputs. This enable computations like modulating weights by other inputs.\n",
"\n",
" \n",
"\n",
"In brains, the theory of the Spotlight of Attention (Posner et al 1980) posited that gain modulation allowed brain computations to select information for later computation. In machines, Rumelhart, Hinton, and McClelland (1986) described Sigma-Pi networks that included both sums ($\\Sigma$) and products ($\\Pi$) as fundamental operations. The Transformer network (Vaswani et al 2017) used a specific architecture featuring layers of multiplication, sparsification, and normalization. Many machine learning systems since then have fruitfully applied this architecture to language, vision, and many more modalities. In this tutorial we will isolate the central properties and generalization benefits of multiplicative attention shared by all of these applications.\n",
"\n",
" \n",
"\n",
"Exercises include simple attentional modulation of inputs, coding the self-attention mechanism, demonstrating its inductive bias, and interpreting the consequences of attention.\n",
"\n",
" \n",
"\n",
"References:\n",
"- Posner MI, Snyder CR, Davidson BJ (1980). Attention and the detection of signals. *Journal of experimental psychology: General*. 109(2):160.\n",
"\n",
"- Rumelhart DE, Hinton G, McClelland JL, PDP Research Group (1986). *Parallel Distributed Processing*. MIT press.\n",
"\n",
"- Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I (2017). [Attention is all you need.](https://papers.nips.cc/paper_files/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html) *Advances in Neural Information Processing Systems*.\n",
"\n",
"- Thomas Viehmann, [toroidal - a lightweight transformer library for PyTorch](https://github.com/MathInf/toroidal)\n",
"\n",
"- Edelman B, Goel S, Kakade S, Zhang C (2022). [Inductive biases and variable creation in self-attention mechanisms](https://proceedings.mlr.press/v162/edelman22a/edelman22a.pdf). *PMLR*.\n",
"\n",
"- Deep Graph Library Tutorials, [Transformer as a Graph Neural Network](https://docs.dgl.ai/en/0.8.x/tutorials/models/4_old_wines/7_transformer.html)\n",
"\n",
"- Lilian Weng, [Attention? Attention!](https://lilianweng.github.io/posts/2018-06-24-attention/)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @markdown\n",
"from IPython.display import IFrame\n",
"from ipywidgets import widgets\n",
"out = widgets.Output()\n",
"with out:\n",
" print(f\"If you want to download the slides: https://osf.io/download/eckvr/\")\n",
" display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/eckvr/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n",
"display(out)"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Setup\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install and import feedback gadget\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Install and import feedback gadget\n",
"\n",
"!pip install vibecheck datatops --quiet\n",
"\n",
"from vibecheck import DatatopsContentReviewContainer\n",
"def content_review(notebook_section: str):\n",
" return DatatopsContentReviewContainer(\n",
" \"\", # No text prompt\n",
" notebook_section,\n",
" {\n",
" \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n",
" \"name\": \"neuromatch_neuroai\",\n",
" \"user_key\": \"wb2cxze8\",\n",
" },\n",
" ).render()\n",
"\n",
"feedback_prefix = \"W1D5_T3\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"executionInfo": {
"elapsed": 9705,
"status": "ok",
"timestamp": 1718046716707,
"user": {
"displayName": "Xaq Pitkow",
"userId": "09050806329892245378"
},
"user_tz": -180
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Imports\n",
"import os\n",
"import sys\n",
"import math\n",
"import torch\n",
"import random\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Figure settings\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"executionInfo": {
"elapsed": 452,
"status": "ok",
"timestamp": 1718046717157,
"user": {
"displayName": "Xaq Pitkow",
"userId": "09050806329892245378"
},
"user_tz": -180
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Figure settings\n",
"import logging\n",
"import matplotlib.cm as cm\n",
"import ipywidgets as widgets # interactive display\n",
"from ipywidgets import interactive, FloatSlider, Layout\n",
"logging.getLogger('matplotlib.font_manager').disabled = True\n",
"%config InlineBackend.figure_format = 'retina'\n",
"plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle\")\n",
"fig_w, fig_h = plt.rcParams['figure.figsize']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"executionInfo": {
"elapsed": 3,
"status": "ok",
"timestamp": 1718046717157,
"user": {
"displayName": "Xaq Pitkow",
"userId": "09050806329892245378"
},
"user_tz": -180
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Plotting functions\n",
"def plot_loss_accuracy(t_loss, t_acc, v_loss = None, v_acc = None):\n",
" with plt.xkcd():\n",
" plt.figure(figsize=(15, 4))\n",
" plt.suptitle(\"Training and Validation for the Transformer Model\")\n",
" plt.subplot(1, 2, 1)\n",
" plt.plot(t_loss, label=\"Training loss\", color=\"red\")\n",
" if v_loss is not None:\n",
" # plt.plot(v_loss, label=\"Valididation loss\", color=\"blue\")\n",
" plt.scatter(len(t_loss)-1, v_loss, label=\"Validation loss\", color=\"blue\", marker=\"*\")\n",
" # plt.text(len(t_loss)-1, v_loss, f\"{v_loss:.3f}\", va=\"bottom\", ha=\"right\")\n",
" plt.yscale(\"log\")\n",
" plt.xlabel(\"Epoch\")\n",
" plt.ylabel(\"Loss\")\n",
" plt.xticks([])\n",
" plt.legend(loc=\"lower right\")\n",
" plt.subplot(1, 2, 2)\n",
" plt.plot(t_acc, label=\"Training accuracy\", color=\"red\", linestyle=\"dotted\")\n",
" if v_acc is not None:\n",
" # plt.plot(v_acc, label=\"Validation accuracy\", color=\"blue\", linestyle=\"--\")\n",
" plt.scatter(len(t_acc)-1, v_acc, label=\"Validation accuracy\", color=\"blue\", marker=\"*\")\n",
" # plt.text(len(t_acc)-1, v_acc, f\"{v_acc:.3f}\", va=\"bottom\", ha=\"right\")\n",
" plt.xticks([])\n",
" plt.ylim(0, 1)\n",
" plt.xlabel(\"Epoch\")\n",
" plt.ylabel(\"Accuracy\")\n",
" plt.legend(loc=\"lower right\")\n",
" plt.show()\n",
"\n",
"\n",
"def plot_samples(X_plot, y_plot, correct_ids, title=None):\n",
" with plt.xkcd():\n",
" n_samples, seq_length = X_plot.shape\n",
" fig, axs = plt.subplots(1, 2, figsize=(16, 2.5), sharey=True)\n",
" rects = []\n",
" for ri in correct_ids:\n",
" rects.append(plt.Rectangle((ri-0.5, -0.5), 1, n_samples, edgecolor=\"red\", alpha=1.0, fill=False, linewidth=2))\n",
" axs[0].imshow(X_plot, cmap=\"binary\")\n",
" for rect in rects:\n",
" axs[0].add_patch(rect)\n",
" # axs[0].axis(\"off\")\n",
" axs[0].set_yticks([])\n",
" axs[0].set_xticks([])\n",
" axs[0].set_ylabel(\"Samples\")\n",
" axs[0].set_xlabel(\"Context\")\n",
" if title is not None:\n",
" axs[0].set_title(title)\n",
" axs[1].imshow(y_plot, cmap=\"binary\")\n",
" axs[1].add_patch(plt.Rectangle((-0.5, -0.5), 1, n_samples, edgecolor=\"black\", alpha=1.0, fill=False, linewidth=2))\n",
" axs[1].yaxis.set_label_position(\"right\")\n",
" axs[1].set_ylabel(\"Labels\")\n",
" axs[1].set_yticks([])\n",
" axs[1].set_xticks([])\n",
" plt.subplots_adjust(wspace=-1)\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"\n",
"def plot_attention_weights(att_weights, correct_ids, context_length, labels):\n",
" with plt.xkcd():\n",
" from matplotlib.lines import Line2D\n",
" B = att_weights.size(0)\n",
" aw_flatten = att_weights[:, -1, :-1].view(-1, context_length)\n",
" x_axis = torch.arange(context_length).repeat(B, 1)\n",
" y_axis = labels.view(-1, 1).repeat(1, context_length)\n",
" aw_ravel = aw_flatten.ravel()\n",
" x_ravel = x_axis.ravel()\n",
" y_ravel = y_axis.ravel()\n",
"\n",
" fig, ax = plt.subplots(figsize=(9, 5))\n",
" colors = [\"#1E88E5\", \"#FFC107\"]\n",
" labels_legened = [\"True\", \"False\"]\n",
" ax.scatter(x_ravel, aw_ravel, alpha=0.5, c=[colors[int(y)] for y in y_ravel])\n",
" rects = []\n",
" for ri in correct_ids:\n",
" # rects.append(plt.Rectangle((ri-0.5, 1e-6), 1.0, 2.0, edgecolor=\"blue\", alpha=1.0, fill=False, linewidth=2))\n",
" rects.append(plt.Rectangle((ri-0.5, -0.1), 1.0, 0.8, edgecolor=\"red\", alpha=1.0, fill=False, linewidth=2))\n",
" for rect in rects:\n",
" ax.add_patch(rect)\n",
" # plt.yscale(\"log\")\n",
" # plt.ylim(1e-6, 2)\n",
" plt.ylim(-0.1, 0.7)\n",
" plt.title(\"Attention weights for the whole batch\")\n",
" plt.xlabel(\"Boolean input index t\")\n",
" plt.ylabel(\"Attention weight\")\n",
" legend_elements = [Line2D([0], [0], linestyle='None', marker='o', color='#1E88E5', label='True', markerfacecolor='#1E88E5', markersize=7),\n",
" Line2D([0], [0], linestyle='None', marker='o', color='#FFC107', label='False', markerfacecolor='#FFC107', markersize=7)]\n",
" ax.legend(handles=legend_elements, loc='upper right')\n",
" plt.show()\n",
"\n",
"\n",
"def plot_attention_weights_stats(att_weights, correct_ids, context_length, labels):\n",
" with plt.xkcd():\n",
" aw_flatten = att_weights[:, -1, :-1].view(-1, context_length)\n",
" aw_flatten_mean = aw_flatten.mean(dim=0)\n",
" aw_flatten_std = aw_flatten.std(dim=0)\n",
" fig, ax = plt.subplots(figsize=(9, 5))\n",
" ax.errorbar(torch.arange(context_length), aw_flatten_mean, yerr=aw_flatten_std, fmt='o', color='blue')\n",
" rects = []\n",
" for ri in correct_ids:\n",
" rects.append(plt.Rectangle((ri-0.5, -0.1), 1.0, 0.8, edgecolor=\"red\", alpha=1.0, fill=False, linewidth=2))\n",
" for rect in rects:\n",
" ax.add_patch(rect)\n",
" plt.title(\"Attention weights statistics\")\n",
" plt.xlabel(\"Boolean input index t\")\n",
" plt.ylabel(\"Mean attention weight\")\n",
" plt.ylim(-0.1, 0.7)\n",
" plt.show()\n",
"\n",
"def plot_compare(results_sat_d, results_sat_s, results_mlp_d, results_mlp_s,\n",
" B_t_sat_s, B_t_sat_d, B_t_mlp_s, B_t_mlp_d):\n",
" with plt.xkcd():\n",
" from matplotlib.colors import LinearSegmentedColormap\n",
" cmap=LinearSegmentedColormap.from_list('rg',[\"w\", \"w\", \"r\", \"y\", \"g\"], N=256)\n",
"\n",
" t_loss_sat_d, t_acc_sat_d, v_loss_sat_d, v_acc_sat_d, model_np_sat_d = results_sat_d\n",
" t_loss_sat_s, t_acc_sat_s, v_loss_sat_s, v_acc_sat_s, model_np_sat_s = results_sat_s\n",
" t_loss_mlp_d, t_acc_mlp_d, v_loss_mlp_d, v_acc_mlp_d, model_np_mlp_d = results_mlp_d\n",
" t_loss_mlp_s, t_acc_mlp_s, v_loss_mlp_s, v_acc_mlp_s, model_np_mlp_s = results_mlp_s\n",
"\n",
" plt.figure(figsize=(8, 6))\n",
" plt.subplot(2, 2, 1)\n",
" plt.title(\"sparse\", fontsize=16)\n",
" plt.ylabel(\"Attention\", fontsize=16)\n",
" plt.barh(3, 0, color=\"blue\")\n",
" plt.barh(4, 0, color=\"green\")\n",
" plt.text(0.05, 3.5, f\"# Parameters: {model_np_sat_s}\")\n",
" plt.text(0.05, 3.0, f\"# samples: {B_t_sat_s}\")\n",
" plt.barh(2, t_acc_sat_s, color=cmap(t_acc_sat_s))\n",
" plt.barh(1, v_acc_sat_s, color=cmap(v_acc_sat_s))\n",
" plt.text(0.05, 2, f\"# training acc: {t_acc_sat_s:.0%}\")\n",
" plt.text(0.05, 1, f\"# validation acc: {v_acc_sat_s:.0%}\")\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
" plt.subplot(2, 2, 2)\n",
" plt.title(\"dense\", fontsize=16)\n",
" plt.barh(3, 0, color=\"blue\")\n",
" plt.barh(4, 0, color=\"green\")\n",
" plt.text(0.05, 3.5, f\"# Parameters: {model_np_sat_d}\")\n",
" plt.text(0.05, 3.0, f\"# samples: {B_t_sat_d}\")\n",
" plt.barh(2, t_acc_sat_d, color=cmap(t_acc_sat_d))\n",
" plt.barh(1, v_acc_sat_d, color=cmap(v_acc_sat_d))\n",
" plt.text(0.05, 2, f\"# training acc: {t_acc_sat_d:.0%}\")\n",
" plt.text(0.05, 1, f\"# validation acc: {v_acc_sat_d:.0%}\")\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
"\n",
" plt.subplot(2, 2, 3)\n",
" plt.barh(3, 0, color=\"blue\")\n",
" plt.barh(4, 0, color=\"green\")\n",
" plt.text(0.05, 3.5, f\"# Parameters: {model_np_mlp_s}\")\n",
" plt.text(0.05, 3.0, f\"# samples: {B_t_mlp_s}\")\n",
" plt.barh(2, t_acc_mlp_s, color=cmap(t_acc_mlp_s))\n",
" plt.barh(1, v_acc_mlp_s, color=cmap(v_acc_mlp_s))\n",
" plt.text(0.05, 2, f\"# training acc: {t_acc_mlp_s:.0%}\")\n",
" plt.text(0.05, 1, f\"# validation acc: {v_acc_mlp_s:.0%}\")\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.ylabel(\"MLP\", fontsize=16)\n",
"\n",
" plt.subplot(2, 2, 4)\n",
" plt.barh(3, 0, color=\"blue\")\n",
" plt.barh(4, 0, color=\"green\")\n",
" plt.text(0.05, 3.5, f\"# Parameters: {model_np_mlp_d}\")\n",
" plt.text(0.05, 3.0, f\"# samples: {B_t_mlp_d}\")\n",
" plt.barh(2, t_acc_mlp_d, color=cmap(t_acc_mlp_d))\n",
" plt.barh(1, v_acc_mlp_d, color=cmap(v_acc_mlp_d))\n",
" plt.text(0.05, 2, f\"# training acc: {t_acc_mlp_d:.0%}\")\n",
" plt.text(0.05, 1, f\"# validation acc: {v_acc_mlp_d:.0%}\")\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
"\n",
"def gained_dot_product_attention_implemented(x: torch.Tensor, # input vector\n",
" q_1: torch.Tensor, # query vector 1\n",
" q_2: torch.Tensor, # query vector 2\n",
" z_1: float, # query gain 1\n",
" z_2: float, # query gain 2\n",
" ):\n",
" \"\"\"This function computes the gained dot product attention\n",
" Args:\n",
" x (Tensor): input vector\n",
" q_1 (Tensor): query vector 1\n",
" q_2 (Tensor): query vector 2\n",
" z_1 (float): query gain 1\n",
" z_2 (float): query gain 2\n",
" Returns:\n",
" w (Tensor): attention weights\n",
" y (float): gained dot product attention\n",
" \"\"\"\n",
" w = torch.softmax(z_1 * q_1 + z_2 * q_2, dim=0)\n",
" y = torch.dot(w, x)\n",
" return w, y\n",
"\n",
"def plot_weights_and_output(gain_1, gain_2):\n",
" T = 20 # context length\n",
" x = torch.sin(torch.linspace(0, 2*np.pi, T)) + 0.1 * torch.randn(T)\n",
" q_1 = 1.0 - torch.sigmoid(torch.linspace(-3, 7, T))\n",
" q_1 = q_1 / q_1.sum()\n",
" q_2 = torch.sigmoid(torch.linspace(-7, 3, T))\n",
" q_2 = q_2 / q_2.sum()\n",
" w, y = gained_dot_product_attention_implemented(x, q_1, q_2, gain_1, gain_2)\n",
" print(f\"output y: {y}\")\n",
" with plt.xkcd():\n",
" plt.figure(figsize=(8, 6))\n",
" plt.subplot(2, 1, 1)\n",
" plt.plot(q_1, label=\"$\\mathbf{q_1}$\", c=\"m\")\n",
" plt.plot(q_2, label=\"$\\mathbf{q_2}$\", c=\"y\")\n",
" plt.plot(w, label=\"$\\mathbf{w}$\", c=\"r\")\n",
" plt.ylim(-0.1, max(q_1.max(), q_2.max()))\n",
" plt.ylabel(\"Attention weights\")\n",
" plt.xlabel(\"Context dimension\")\n",
" plt.legend()\n",
" plt.subplot(2, 1, 2)\n",
" plt.plot(x, label=\"$\\mathbf{x}$\", c=\"blue\")\n",
" plt.plot(5 * x * w, label=\"$5\\mathbf{w}*\\mathbf{x}$\", c=\"red\")\n",
" plt.ylim(-x.abs().max(), x.abs().max())\n",
" plt.ylabel(\"Amplitude\")\n",
" plt.xlabel(\"Context\")\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
"def plot_attention_graph(model, data_generator, m):\n",
" with plt.xkcd():\n",
" X_valid, y_valid = data_generator.generate(m, verbose=False)\n",
" logits, scores, attention, output = model(X_valid, elaborate=True)\n",
" with torch.no_grad():\n",
" scores_x_linear = torch.einsum(\"bij,jk->bik\", scores, model.linear.weight.T).detach().cpu()\n",
" w_att_trues_0 = scores_x_linear[(y_valid[:, 0] == 1) & (X_valid[:, -1] == 0)]\n",
" w_att_trues_1 = scores_x_linear[(y_valid[:, 0] == 1) & (X_valid[:, -1] == 1)]\n",
" w_att_trues_0 = w_att_trues_0.mean(dim=0).squeeze(-1)\n",
" w_att_trues_1 = w_att_trues_1.mean(dim=0).squeeze(-1)\n",
"\n",
" fig, ax = plt.subplots(2, 1, figsize=(10, 8))\n",
" rules = [\"Rule = 0\", \"Rule = 1\"]\n",
" for r, x in enumerate((w_att_trues_0, w_att_trues_1)):\n",
" #xc = (x - x.min()) / (x.max() - x.min() + 1e-9)\n",
" xc = (x + x.abs().max()) / (2 * x.abs().max()) # color range spans -max|x| to +max|x|\n",
" xw = x.abs() / x.abs().max()\n",
" xa = (xw - xw.min()) / (xw.max() - xw.min() + 1e-9)\n",
" xabsmax = x.abs().max()\n",
" x, xc, xw, xa = (some_x.numpy() for some_x in (x, xc, xw, xa))\n",
" T = x.shape[0]\n",
" m = (T/2) - 0.5\n",
"\n",
" ax[r].text(m-1, 1.0, \" \", c='k', fontsize=16)\n",
" ax[r].text(m-1, 0.9, f\"Graph for {rules[r]}\", c='k', fontsize=16)\n",
" for i in range(T):\n",
" c = 'r' if i in data_generator.f_i_1 else 'b' if i in data_generator.f_i_2 else 'm' if i == T-1 else'grey'\n",
" ax[r].plot([i, m], [0, 0.75], c=cm.plasma_r(xc[i]), linewidth=5, alpha=1)\n",
" ax[r].scatter([i], [0], c=c, zorder=100, s=100)\n",
" ax[r].scatter([m], [0.75], c='k', zorder=100, s=200)\n",
" for i, j in zip(data_generator.f_i_1, data_generator.f_i_2):\n",
" ax[r].text(i-0.3, -0.1, \"tar_0\", c='r', fontsize=12)\n",
" ax[r].text(j-0.3, -0.1, \"tar_1\", c='b', fontsize=12)\n",
" ax[r].text(m-0.3, 0.8, \"Output\", c='k', fontsize=12)\n",
" ax[r].text(T-1-0.1, -0.1, \"Rule\", c='m', fontsize=12)\n",
" cbar = fig.colorbar(cm.ScalarMappable(cmap=cm.plasma_r), ax=ax[r], ticks=[0, 1], label=\"Attention\")\n",
" #cbar.ax.set_yticklabels([f\"{x.min():.1f}\", f\"{x.max():.1f}\"])\n",
" cbar.ax.set_yticklabels([f\"{-xabsmax:.1f}\", f\"{xabsmax:.1f}\"])\n",
" ax[r].axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data retrieval\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"executionInfo": {
"elapsed": 2,
"status": "ok",
"timestamp": 1718046717157,
"user": {
"displayName": "Xaq Pitkow",
"userId": "09050806329892245378"
},
"user_tz": -180
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"#@title Data retrieval\n",
"class s_Sparse_AND: # 1-Dimensional AND\n",
" def __init__(self, T: int, s: int):\n",
" self.T = T # context length\n",
" self.s = s # sparsity\n",
" self.p = 0.5**(1.0/3.0) # probability chosen for balanced data\n",
" self.f_i = None\n",
"\n",
" def pick_an_f(self):\n",
" self.f_i = sorted(random.sample(range(self.T), self.s))\n",
" self.others = list(i for i in range(self.T) if i not in self.f_i)\n",
"\n",
" def generate(self, m: int, verbose: bool = False):\n",
" if self.f_i is None:\n",
" self.pick_an_f()\n",
" max_try = 100\n",
" i_try = 0\n",
" while i_try < max_try:\n",
" i_try += 1\n",
" X, y = torch.zeros(m, self.T), torch.zeros(m, 1)\n",
" X[torch.rand(m, self.T) < self.p] = 1\n",
" y[X[:, self.f_i].sum(dim=1) == self.s] = 1\n",
" if y.sum()/m < 0.4 or y.sum()/m > 0.6:\n",
" verbose and print(f\"Large imbalance in the training set {y.sum()/m}, retrying...\")\n",
" continue\n",
" else:\n",
" verbose and print(f\"Data-label balance: {y.sum()/m}\")\n",
" bad_batch = False\n",
" for i in self.f_i:\n",
" for o in self.others:\n",
" if (X[:, i] == X[:, o]).all():\n",
" verbose and print(f\"Found at least another compatible hypothesis {i} and {o}\")\n",
" bad_batch = True\n",
" break\n",
" if bad_batch:\n",
" continue\n",
" else:\n",
" break\n",
" else:\n",
" verbose and print(\"Could not find a compatible hypothesis\")\n",
" return X.long(), y.float()\n",
"\n",
"\n",
"class s_Sparse_AND_Query: # 1-Dimensional AND\n",
" def __init__(self, T, s):\n",
" self.T = T - 1\n",
" self.k = 1 # currently only for k = 1\n",
" self.s = s\n",
" self.p = 0.5**(1.0/2.0) # probability chosen for balanced data\n",
" self.f_i_1, self.f_i_2, self.f_i = None, None, None\n",
"\n",
" def pick_an_f(self):\n",
" self.f_i_1 = sorted(random.sample(range(self.T), self.s))\n",
" self.others_1 = list(i for i in range(self.T) if i not in self.f_i_1)\n",
" self.f_i_2 = sorted(random.sample(self.others_1, self.s))\n",
" self.others_2 = list(i for i in self.others_1 if i not in self.f_i_2)\n",
" self.f_i = self.f_i_1 + self.f_i_2\n",
"\n",
" def __call__(self):\n",
" self.pick_an_f()\n",
"\n",
" def generate(self, m: int, verbose: bool = False):\n",
" z = torch.randint(0, 2, (m, self.k))\n",
" if self.f_i_1 is None or self.f_i_2 is None:\n",
" self.pick_an_f()\n",
" max_try = 100\n",
" i_try = 0\n",
" while i_try < max_try:\n",
" i_try += 1\n",
" X, y = torch.zeros(m, self.T), torch.zeros(m, 1)\n",
" X[torch.rand(m, self.T) < self.p] = 1\n",
" # Rule 0\n",
" X_z_0 = (X[:, self.f_i_1].sum(dim=1) == self.s) & (z[:, 0] == 0)\n",
" y[X_z_0] = 1\n",
" # Rule 1\n",
" X_z_1 = (X[:, self.f_i_2].sum(dim=1) == self.s) & (z[:, 0] == 1)\n",
" y[X_z_1] = 1\n",
" if y.sum()/m < 0.45 or y.sum()/m > 0.55:\n",
" verbose and print(f\"Large imbalance in the training set {y.sum()/m}, retrying...\")\n",
" continue\n",
" else:\n",
" verbose and print(f\"Data-label balance: {y.sum()/m}\")\n",
" break\n",
" else:\n",
" verbose and print(\"Could not find a compatible hypothesis\")\n",
" return torch.cat([X, z], dim=1).long(), y.float()\n",
"\n",
"\n",
"# From notebook\n",
"# Load an example image from the imagenet-sample-images repository\n",
"def load_image(path):\n",
" \"\"\"\n",
" Load an image from a given path\n",
"\n",
" Args:\n",
" path: String\n",
" Path to the image\n",
"\n",
" Returns:\n",
" img: PIL Image\n",
" Image loaded from the path\n",
" \"\"\"\n",
" img = Image.open(path)\n",
" return img"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"executionInfo": {
"elapsed": 347,
"status": "ok",
"timestamp": 1718046717502,
"user": {
"displayName": "Xaq Pitkow",
"userId": "09050806329892245378"
},
"user_tz": -180
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Helper functions\n",
"\n",
"class BinaryMLP(torch.nn.Module):\n",
" def __init__(self, in_dims, h_dims, out_dims, dropout=0.1):\n",
" super().__init__()\n",
" self.in_dims = in_dims\n",
" self.h_dims = h_dims\n",
" self.out_dims = out_dims\n",
"\n",
" self.layers = torch.nn.ModuleList()\n",
" self.layers.append(torch.nn.Linear(in_dims, h_dims[0]))\n",
" torch.nn.init.normal_(self.layers[-1].weight, std=0.02)\n",
" torch.nn.init.zeros_(self.layers[-1].bias)\n",
" self.layers.append(torch.nn.GELU())\n",
" self.layers.append(torch.nn.Dropout(dropout))\n",
" for i in range(len(h_dims) - 1):\n",
" self.layers.append(torch.nn.Linear(h_dims[i], h_dims[i+1]))\n",
" torch.nn.init.normal_(self.layers[-1].weight, std=0.02)\n",
" torch.nn.init.zeros_(self.layers[-1].bias)\n",
" self.layers.append(torch.nn.GELU())\n",
" self.layers.append(torch.nn.Dropout(dropout))\n",
" self.layers.append(torch.nn.Linear(h_dims[-1], out_dims))\n",
" self.layers.append(torch.nn.Sigmoid())\n",
"\n",
" def forward(self, x):\n",
" for layer in self.layers:\n",
" x = layer(x)\n",
" return x\n",
"\n",
"\n",
"\n",
"class SelfAttention(torch.nn.Module):\n",
" \"\"\"Simple Binary Self-Attention\n",
" \"\"\"\n",
" def __init__(self, T: int, d: int):\n",
" \"\"\"\n",
" Args:\n",
" T (int): context length\n",
" d (int): embedding size for K, Q but not V\n",
" Note:\n",
" The embedding size for V is the same as the input size i.e. 1\n",
" \"\"\"\n",
" super().__init__()\n",
" self.T = T # context length\n",
" self.d = d # embedding size\n",
" self.scale = self.d ** -0.5 # scaling factor (1 / sqrt(d_k))\n",
" self.v = 2 # vocabulary size (binary input, 0 or 1)\n",
" init_std = 0.001 # standard deviation for weight initialization\n",
"\n",
" # embedding layers\n",
" self.tokenizer = torch.nn.Embedding(self.v, self.d) # token embedding\n",
" self.positioner = torch.nn.Parameter(torch.rand(1, self.T, self.d)) # positional embedding\n",
"\n",
" # self-attention layers\n",
" self.Wq = torch.nn.Linear(self.d, self.d, bias=False) # query layer\n",
" self.Wk = torch.nn.Linear(self.d, self.d, bias=False) # key layer\n",
" self.Wv = torch.nn.Linear(self.d, 1, bias=False) # value layer\n",
"\n",
" # projection layer\n",
" self.linear = torch.nn.Linear(self.T, 1, bias=False) # projection layer\n",
" torch.nn.init.normal_(self.linear.weight, std=init_std)\n",
" # torch.nn.init.zeros_(self.linear.bias)\n",
"\n",
" # initialize weights and biases (per description in the paper)\n",
" torch.nn.init.normal_(self.tokenizer.weight, std=init_std)\n",
" torch.nn.init.normal_(self.positioner, std=init_std)\n",
" torch.nn.init.normal_(self.Wq.weight, std=init_std)\n",
" # torch.nn.init.zeros_(self.Wq.bias)\n",
" torch.nn.init.normal_(self.Wk.weight, std=init_std)\n",
" # torch.nn.init.zeros_(self.Wk.bias)\n",
" torch.nn.init.normal_(self.Wv.weight, std=init_std)\n",
" # torch.nn.init.zeros_(self.Wv.bias)\n",
"\n",
" def forward(self, x: torch.Tensor, elaborate: bool = False):\n",
" \"\"\"Forward pass\n",
" Args:\n",
" x (torch.Tensor): input tensor of shape (B, T, d)\n",
" \"\"\"\n",
" # Embedding\n",
" x = self.tokenizer(x)\n",
" x = x + self.positioner\n",
"\n",
" # (Scaled Dot-Product Attention)\n",
" Q = self.Wq(x)\n",
" K = self.Wk(x)\n",
" V = self.Wv(x)\n",
" logits = torch.einsum(\"btc,bsc->bts\", Q, K) # query key product\n",
" logits *= self.scale # normalize against staturation\n",
" scores = torch.softmax(logits, dim=-1) # attention scores\n",
" attention = torch.einsum(\"bts,bsc->btc\", scores, V) # scaled dot-product attention\n",
" attention = attention.squeeze(-1) # remove last dimension\n",
" output = self.linear(attention) # linear layer\n",
" output = torch.sigmoid(output) # sigmoid activation\n",
" return (logits, scores, attention, output) if elaborate else output\n",
"\n",
"\n",
"class PositionalEncoding(torch.nn.Module):\n",
" def __init__(self, T: int, d_model: int):\n",
" super().__init__()\n",
" Te = T + T%2\n",
" de = d_model + d_model%2\n",
" position = torch.arange(Te).unsqueeze(1)\n",
" div_term = torch.exp(torch.arange(0, de, 2) * (-math.log(10000.0) / de))\n",
" pe = torch.zeros(Te, de)\n",
" pe[:, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 1::2] = torch.cos(position * div_term)\n",
" pe = pe[:T, :d_model]\n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self):\n",
" return self.pe\n",
"\n",
"\n",
"class SDPA_Solution(torch.nn.Module):\n",
" def __init__(self, T: int, dm: int, dk: int):\n",
" \"\"\"\n",
" Scaled Dot Product Attention\n",
" Args:\n",
" T (int): context length\n",
" dm (int): model dimension\n",
" dk (int): key dimension\n",
" Note:\n",
" we assume dm == dv\n",
" \"\"\"\n",
" super().__init__()\n",
" self.T = T # context length\n",
" self.dm = dm # model dimension\n",
" self.dk = dk # key dimension\n",
" self.scale = 1.0 / math.sqrt(dk)\n",
"\n",
" # positional Encoding\n",
" self.position = PositionalEncoding(T, dm)\n",
"\n",
" # self-attention layers\n",
" self.Wq = torch.nn.Linear(dm, dk, bias=False) # query layer\n",
" self.Wk = torch.nn.Linear(dm, dk, bias=False) # key layer\n",
" self.Wv = torch.nn.Linear(dm, dm, bias=False) # value layer\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" \"\"\"\n",
" Args:\n",
" x (torch.Tensor): input tensor of shape (T, d)\n",
" \"\"\"\n",
" # Positional Encoding\n",
" x = x + self.position()\n",
"\n",
" # (Scaled Dot-Product Attention)\n",
" Q = self.Wq(x) # Query\n",
" K = self.Wk(x) # Key\n",
" V = self.Wv(x) # Value\n",
" QK = Q @ K.T # Query Key product\n",
" S = QK * self.scale # Scores (scaled against staturation)\n",
" S_softmax = torch.softmax(S, dim=-1) # softmax attention scores (row dimensions)\n",
" A = S_softmax @ V # scaled dot-product attention\n",
" return A\n",
"\n",
"\n",
"def test_sdpa(your_implementation):\n",
" # self-attention layers\n",
" context_len = 7\n",
" model_dim = 5\n",
" key_dim = 3\n",
" toy_x = torch.rand(context_len, model_dim)\n",
" your_model = your_implementation(context_len, model_dim, key_dim)\n",
" our_model = SDPA_Solution(context_len, model_dim, key_dim)\n",
" your_model.Wq = our_model.Wq\n",
" your_model.Wk = our_model.Wk\n",
" your_model.Wv = our_model.Wv\n",
" your_results = your_model(toy_x.clone())\n",
" our_results = our_model(toy_x.clone())\n",
" if (your_results == our_results).all():\n",
" print(\"Very well Done!\")\n",
" else:\n",
" print(\"The two implementations should produce the same result!\")\n",
"\n",
"\n",
"def bin_acc(y_hat, y):\n",
" \"\"\"\n",
" Compute the binary accuracy\n",
" \"\"\"\n",
" y_ = y_hat.round()\n",
" TP_TN = (y_ == y).float().sum().item()\n",
" FP_FN = (y_ != y).float().sum().item()\n",
" assert TP_TN + FP_FN == y.numel(), f\"{TP_TN + FP_FN} != {y.numel()}\"\n",
" return TP_TN / y.numel()\n",
"\n",
"\n",
"def get_n_parameters(model: torch.nn.Module):\n",
" \"\"\"\n",
" Get the number of learnable parameters in a model\n",
" \"\"\"\n",
" i = 0\n",
" for par in model.parameters():\n",
" i += par.numel()\n",
" return i\n",
"\n",
"\n",
"def save_model(model):\n",
" torch.save(model.state_dict(), 'model_states.pt')\n",
"\n",
"\n",
"def load_model(model):\n",
" model_states = torch.load('model_states.pt')\n",
" model.load_state_dict(model_states)\n",
"\n",
"\n",
"class Trainer:\n",
" def __init__(self, model, optimizer, criterion):\n",
" self.model = model\n",
" self.optimizer = optimizer\n",
" self.criterion = criterion\n",
"\n",
" def do_eval(self, X_v, y_v, device=\"cpu\"):\n",
" self.model.to(device)\n",
" self.model.eval()\n",
" X_v, y_v = X_v.to(device), y_v.to(device)\n",
" with torch.no_grad():\n",
" y_hat = self.model(X_v)\n",
" loss = self.criterion(y_hat.squeeze(), y_v.squeeze())\n",
" acc = bin_acc(y_hat, y_v)\n",
" self.model.to(\"cpu\")\n",
" return loss.item(), acc\n",
"\n",
" def do_train(self, n_epochs, X_t, y_t, device=\"cpu\"):\n",
" train_loss, train_acc = [], []\n",
" self.model.to(device)\n",
" self.model.train()\n",
" X_t, y_t = X_t.to(device), y_t.to(device)\n",
" for i in range(n_epochs):\n",
" self.optimizer.zero_grad(set_to_none=True)\n",
" y_hat = self.model(X_t)\n",
" loss_t = self.criterion(y_hat.squeeze(), y_t.squeeze())\n",
" loss_t.backward()\n",
" self.optimizer.step()\n",
" self.optimizer.zero_grad(set_to_none=True)\n",
" if (i + 1) % 10 == 0 or i == 0:\n",
" train_loss.append(loss_t.item())\n",
" train_acc.append(bin_acc(y_hat, y_t))\n",
" self.model.eval()\n",
" self.model.to(\"cpu\")\n",
" return train_loss, train_acc\n",
"\n",
"\n",
"def make_train(model, data_gen, B_t, B_v, n_epochs, device=\"cpu\", kind=\"MLP\", verbose=False, etta=1e-2):\n",
" X_train, y_train = data_gen.generate(B_t, verbose=False)\n",
" X_valid, y_valid = data_gen.generate(B_v, verbose=False)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=etta)\n",
" criterion = torch.nn.BCELoss()\n",
" train_eval = Trainer(model, optimizer, criterion)\n",
" model_np = get_n_parameters(model) # number of learnable parameters\n",
" verbose and print(f\"Number of model's learnable parameters: {model_np}\")\n",
" if kind == \"MLP\":\n",
" t_loss, t_acc = train_eval.do_train(n_epochs, X_train.float(), y_train, device=device)\n",
" v_loss, v_acc = train_eval.do_eval(X_valid.float(), y_valid, device=device)\n",
" else:\n",
" t_loss, t_acc = train_eval.do_train(n_epochs, X_train, y_train, device=device)\n",
" v_loss, v_acc = train_eval.do_eval(X_valid, y_valid, device=device)\n",
" verbose and print(f\"Training loss: {t_loss[-1]:.3f}, accuracy: {t_acc[-1]:.3f}\")\n",
" verbose and print(f\"Validation loss: {v_loss:.3f}, accuracy: {v_acc:.3f}\")\n",
" return t_loss[-1], t_acc[-1], v_loss, v_acc, model_np\n",
"\n",
"\n",
"def scaled_dot_product_attention_solution(Q, K, V):\n",
" \"\"\" Scaled dot product attention\n",
" Args:\n",
" Q: queries (B, H, d, n)\n",
" K: keys (B, H, d, n)\n",
" V: values (B, H, d, n)\n",
" Returns:\n",
" Attention tensor (B, H, d, n), Scores (B, H, d, d)\n",
" Notes:\n",
" (B, H, d, n): batch size, H: number of heads, d: key-query dim, n: embedding dim\n",
" \"\"\"\n",
"\n",
" assert K.shape == Q.shape and K.shape == V.shape, \"Queries, Keys and Values must have the same shape\"\n",
" B, H, d, n = K.shape # batch_size, num_heads, key-query dim, embedding dim\n",
" scale = math.sqrt(d)\n",
" Q_mm_K = torch.einsum(\"bhdn,bhen->bhde\", Q, K) # dot-product reducing the n dimension\n",
" S = Q_mm_K / scale # score or scaled dot product\n",
" S_sm = torch.softmax(S, dim=-1) # softmax\n",
" A = torch.einsum(\"bhde,bhen->bhdn\", S_sm, V) # Attention\n",
" return A, S\n",
"\n",
"\n",
"def weighted_attention(self, x):\n",
" \"\"\"This function computes the weighted attention as a method for the BinarySAT class\n",
"\n",
" Args:\n",
" x (Tensor): An array of shape (B:batch_size, T:context length) containing the input data\n",
"\n",
" Returns:\n",
" Tensor: weighted attention of shape (B, E, E) where E = T + 1\n",
" \"\"\"\n",
" assert self.n_heads == 1, \"This function is only implemented for a single head!\"\n",
" # Embedding\n",
" B = x.size(0) # batch size\n",
" x = self.toke(x) # token embedding\n",
" x = torch.cat([x, self.cls.expand(B, -1, -1)], dim=1) # concatenate cls token\n",
" x = x + self.pose # positional embedding\n",
" norm_x = self.norm1(x) # normalization\n",
"\n",
" # Scaled Dot-Product Attention (partially implemented)\n",
" q, k, v = self.qkv(norm_x).view(B, self.E, 3, self.d).unbind(dim=2)\n",
" # q, k, v all have shape (B, E, h, d) where h is the number of heads (1 in this case)\n",
" W_qk = q @ k.transpose(-2, -1)\n",
" W_qk = W_qk * self.scale\n",
" W_qk = torch.softmax(W_qk, dim=-1)\n",
" return W_qk\n",
"\n",
"\n",
"def results_dict(results_sat_d, results_sat_s, results_mlp_d, results_mlp_s):\n",
" from collections import OrderedDict\n",
" t_loss_sat_d, t_acc_sat_d, v_loss_sat_d, v_acc_sat_d, model_np_sat_d = results_sat_d\n",
" t_loss_sat_s, t_acc_sat_s, v_loss_sat_s, v_acc_sat_s, model_np_sat_s = results_sat_s\n",
" t_loss_mlp_d, t_acc_mlp_d, v_loss_mlp_d, v_acc_mlp_d, model_np_mlp_d = results_mlp_d\n",
" t_loss_mlp_s, t_acc_mlp_s, v_loss_mlp_s, v_acc_mlp_s, model_np_mlp_s = results_mlp_s\n",
" output = OrderedDict()\n",
" # output[\"Training Loss SAT dense\"] = t_loss_sat_d\n",
" # output[\"Training Accuracy SAT dense\"] = t_acc_sat_d\n",
" output[\"Validation Loss SAT dense\"] = round(v_loss_sat_d, 3)\n",
" output[\"Validation Accuracy SAT dense\"] = v_acc_sat_d\n",
" output[\"Number of Parameters SAT dense\"] = model_np_sat_d\n",
" # output[\"Training Loss SAT sparse\"] = t_loss_sat_s\n",
" # output[\"Training Accuracy SAT sparse\"] = t_acc_sat_s\n",
" output[\"Validation Loss SAT sparse\"] = round(v_loss_sat_s, 3)\n",
" output[\"Validation Accuracy SAT sparse\"] = v_acc_sat_s\n",
" output[\"Number of Parameters SAT sparse\"] = model_np_sat_s\n",
" # output[\"Training Loss MLP dense\"] = t_loss_mlp_d\n",
" # output[\"Training Accuracy MLP dense\"] = t_acc_mlp_d\n",
" output[\"Validation Loss MLP dense\"] = round(v_loss_mlp_d, 3)\n",
" output[\"Validation Accuracy MLP dense\"] = v_acc_mlp_d\n",
" output[\"Number of Parameters MLP dense\"] = model_np_mlp_d\n",
" # output[\"Training Loss MLP sparse\"] = t_loss_mlp_s\n",
" # output[\"Training Accuracy MLP sparse\"] = t_acc_mlp_s\n",
" output[\"Validation Loss MLP sparse\"] = round(v_loss_mlp_s, 3)\n",
" output[\"Validation Accuracy MLP sparse\"] = v_acc_mlp_s\n",
" output[\"Number of Parameters MLP sparse\"] = model_np_mlp_s\n",
" return output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set random seed\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1718046717502,
"user": {
"displayName": "Xaq Pitkow",
"userId": "09050806329892245378"
},
"user_tz": -180
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Set random seed\n",
"\n",
"def set_seed(seed=None, seed_torch=True):\n",
" \"\"\"\n",
" Handles variability by controlling sources of randomness\n",
" through set seed values\n",
"\n",
" Args:\n",
" seed: Integer\n",
" Set the seed value to given integer.\n",
" If no seed, set seed value to random integer in the range 2^32\n",
" seed_torch: Bool\n",
" Seeds the random number generator for all devices to\n",
" offer some guarantees on reproducibility\n",
"\n",
" Returns:\n",
" Nothing\n",
" \"\"\"\n",
" if seed is None:\n",
" seed = np.random.choice(2 ** 32)\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" if seed_torch:\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" torch.cuda.manual_seed(seed)\n",
" torch.backends.cudnn.benchmark = False\n",
" torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set device (GPU or CPU)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Set device (GPU or CPU)\n",
"\n",
"def set_device():\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" if device != \"cuda\":\n",
" print(\"GPU is not enabled in this notebook. \\n\"\n",
" \"If you want to enable it, in the menu under `Runtime` -> \\n\"\n",
" \"`Hardware accelerator.` and select `GPU` from the dropdown menu\")\n",
" else:\n",
" print(\"GPU is enabled in this notebook. \\n\"\n",
" \"If you want to disable it, in the menu under `Runtime` -> \\n\"\n",
" \"`Hardware accelerator.` and select `None` from the dropdown menu\")\n",
"\n",
" return device\n",
"\n",
"DEVICE = set_device()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 1: Introduction to attention\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 1: Introduction to attention\n",
"\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"video_ids = [('Youtube', 'Ryr04zXhdIw'), ('Bilibili', 'BV1N4421D7NM')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_introduction_to_attention\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Section 1: Intro to multiplication\n",
"\n",
"In this section, we will show how signals can be used to gate inputs selectively. This is the essence of attention.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Video 2: Cross-attention\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 2: Cross-attention\n",
"\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"video_ids = [('Youtube', 'Jkt4w98sP1U'), ('Bilibili', 'BV1Rw4m1e7RZ')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_cross_attention\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"## Exercise 1: Dot product attention\n",
"\n",
"We'll implement simple dot product attention for input $\\mathbf{x}$ and scalar output $y$ given by a weighted combination of the inputs,\n",
"\n",
"$$y = \\mathbf{w}\\cdot\\mathbf{x}$$\n",
"\n",
"for weights $\\mathbf{w}$ (also called gains). Unlike the fixed weights in an MLP, attention can adjust the weights depending on some other signal. Here we will use a two dimensional attentional gain, $z_1$ and $z_2$, each with a corresponding vector $\\mathbf{q}$ that determines the spatial pattern to be attended:\n",
"\n",
"$$\\mathbf{w}(z_1,z_2) = \\text{softmax}(z_1 \\mathbf{q}_1 + z_2 \\mathbf{q}_2)$$\n",
"\n",
"where $\\text{softmax}(\\mathbf{x})={e^\\mathbf{x}}/{\\sum_i e^{x_i}}$ is a combination of a sparsifying element-wise nonlinearity ($\\exp$) and a normalization.\n",
"\n",
"You should code up this function to return the attentional weights $\\mathbf{w}$ and the weighted output $y$.\n",
"\n",
"