Creator: Francois Fleuret (original)
Flow diagram of single-head attention illustrating the equation with border colors to indicate tensor dimensions.
% Written by Francois Fleuret <francois@fleuret.org>
% https://twitter.com/francoisfleuret/status/1529744066086424577.
% Original TikZ source: https://fleuret.org/git-extract/tex/single-attention.tex
% Any copyright is dedicated to the Public Domain.
% https://creativecommons.org/publicdomain/zero/1.0
\documentclass[tikz]{standalone}
\usepackage{mathtools}
\def\transpose{^{\top}}
\DeclareMathOperator\softmax{softmax}
\DeclareMathOperator\Attention{Attention}
\usetikzlibrary{positioning, arrows.meta}
\begin{document}
\begin{tikzpicture}[
value/.style = {
font=\scriptsize, rectangle, draw=black!50, fill=white, thick,
inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt
},
parameter/.style = {
font=\scriptsize, rectangle, draw=black!50, fill=lightblue!15, thick,
inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt
},
operation/.style = {
font=\scriptsize, rectangle, draw=black!50, fill=teal!30, thick,
inner sep=3pt, minimum size=10pt, minimum height=20pt
},
flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick},
f2f/.style={draw=black!50, thick},
v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick},
f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick}
]
\node[font=\bfseries] at (3.5, 2.3) {Single-head attention};
\node at (3.5, -2.5) {$\displaystyle \Attention(Q, K, V) = \softmax_\text{row} \left( \frac{Q K\transpose}{\sqrt{d}} \right) V$};
\node[value, minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
\node[value, minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$};
\node[value, minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$};
\node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
\node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
\node[value, minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
\node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
\node[value, minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
\draw[v2f,rounded corners=1mm] (K) -- (att);
\draw[v2f,rounded corners=1mm] (Q) -| (att);
\draw[f2f,rounded corners=1mm] (att) -- (sm);
\draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
\draw[v2f,rounded corners=1mm] (A) -- (prod);
\draw[v2f,rounded corners=1mm] (V) -| (prod);
\draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
\draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
\draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
\draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
\draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
\draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
\draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
\draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
\draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
\draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
\draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
\end{tikzpicture}
\end{document}