<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://pchen7e2.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://pchen7e2.github.io/" rel="alternate" type="text/html" /><updated>2026-06-09T16:28:56+00:00</updated><id>https://pchen7e2.github.io/feed.xml</id><title type="html">Peng’s Tech Blog</title><subtitle>Notes and logs for GPU, HPC, AI compilers etc.</subtitle><author><name>Peng Chen</name></author><entry><title type="html">The Shared Memory layout of Blackwell MMAv5 operands</title><link href="https://pchen7e2.github.io/2026/06/07/mmav5-smem-layout.html" rel="alternate" type="text/html" title="The Shared Memory layout of Blackwell MMAv5 operands" /><published>2026-06-07T00:00:00+00:00</published><updated>2026-06-07T00:00:00+00:00</updated><id>https://pchen7e2.github.io/2026/06/07/mmav5-smem-layout</id><content type="html" xml:base="https://pchen7e2.github.io/2026/06/07/mmav5-smem-layout.html"><![CDATA[<p>The 5th gen Tensor Core on Blackwell GPU requires MMA’s operand B live in SMEM, and operand A live 
in SMEM or Tensor Memory (TMEM). MMAv3 (Hopper) supports “SS_GEMM” as well where A and B are both in SMEM. The layout is almost the same as v5.</p>

<p>In this note we analyse the SMEM layouts and examine how CUTLASS and Triton represent them. The non-swizzling cases 
need special handling and omitted for simplicity in this note. We also only talk about 16B atomicity swizzling for simplicity.</p>

<h2 id="core-matrix-vs-swizzle-atom">Core Matrix vs Swizzle Atom</h2>

<h3 id="core-matrix">Core Matrix</h3>
<p>“Core matrix” is a deprecated term and no longer available in official documents. It used to be a term used to help 
define <code>Leading Dimension Byte Offset (LBO)</code> and <code>Strided Dimension Byte Offset (SBO)</code>, which are two very important 
parameters required to supply SMEM layout representations to hardware such that Tensor Core knows where to find operands.</p>

<p>For example, in this <a href="https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/">Colfax tutorial</a>,</p>

<blockquote>
  <p>Each core matrix has a strided direction and a contiguous direction, such that its length is 8 in the strided direction and 16 bytes in the contiguous direction.
LBO (leading dimension byte offset): the distance, in bytes, between two adjacent core matrices in the K dimension.
SBO (stride dimension byte offset): the distance, in bytes, between two adjacent core matrices in the M or N dimension.</p>
</blockquote>

<p>which might be correct to that specific instance of MMA in the blog, but doesn’t cover MN-major cases.</p>

<p>As can be seen later in this note, the concept of core matrix is indeed no longer needed. And we only need Swizzle Atom 
to define SBO and LBO.</p>

<h3 id="swizzle-atom">Swizzle Atom</h3>
<p>As of Jun 2026, the official PTX documentation defines the SMEM layout with the concept of <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-shared-memory-layout-swizzling">“Swizzle Atom”</a>. 
A Swizzle Atom with <code>s</code> Bytes swizzling mode is a matrix of <code>8 * s Bytes</code> where <code>8</code> is on the strided dimension 
(e.g. M/N dim for K-major) and <code>s Bytes</code> is on the leading dim.</p>

<p>All the elements in a Swizzle Atom are compactly stored in a segment of contiguous physical SMEM. The SMEM swizzling then wouldn’t “exchange” elements across two Swizzle Atom. It’s only inside one Swizzle Atom. Also 
note the basic unit of swizzling is 128 bits or 16 bytes (in 16B atomicity mode). No “exchange” happens inside a single unit.</p>

<p><strong>SBO</strong> is then defined as the byte offset between two adjacent Swizzle Atoms on strided dim.</p>

<p><strong>LBO</strong> is defined as the byte offset between two adjacent Swizzle Atoms on leading dim. Note for K-major LBO is ignored 
because it’s “not used, assumed to be 1” in PTX doc.</p>

<h3 id="triton">Triton</h3>
<p>In Triton compiler there’s a function called <a href="https://github.com/triton-lang/triton/blob/2104a207c0595da7d099dd320967afd0fc41f70d/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp#L153">getCoreMatrixLinearLayout</a>.
It’s in fact getting a Linear Layout tile of a Swizzle Atom despite the naming. Since Linear Layout incorporates 
swizzling, the output of this function already encodes the full swizzled layouts of such an Atom. e.g.</p>

<pre><code>Full tensor shape: 128, 256
Layout encoding: #ttg.nvmma_shared&lt;{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}&gt;
getCoreMatrixLinearLayout output: 
 - offset=1 -&gt; (0, 1)
   offset=2 -&gt; (0, 2)
   offset=4 -&gt; (0, 4)
   offset=8 -&gt; (0, 8)
   offset=16 -&gt; (0, 16)
   offset=32 -&gt; (0, 32)
   offset=64 -&gt; (1, 8)
   offset=128 -&gt; (2, 16)
   offset=256 -&gt; (4, 32)
where out dims are: [dim0 (size 8), dim1 (size 64)]
</code></pre>

<h2 id="smem-layouts-for-an-mma-instruction">SMEM Layouts for an MMA instruction</h2>

<p>The PTX <a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-canonical-layouts">documentation</a> 
records canonical CUTE layouts for the SMEM tensor of one single MMA instruction:</p>

<table>
  <thead>
    <tr>
      <th>Major- ness</th>
      <th>Swizzling mode</th>
      <th>Canonical Layout without swizzling</th>
      <th>Swizzling on the previous column</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>MN- major</td>
      <td>No-swizzling or Interleaved</td>
      <td>((T,1,m),(8,k)):((1,T,SBO),(1T,LBO))</td>
      <td>Swizzle&lt;0, 4, 3&gt;</td>
    </tr>
    <tr>
      <td> </td>
      <td>32B Swizzling</td>
      <td>((T,2,m),(8,k)):((1,T,LBO),(2T,SBO))</td>
      <td>Swizzle&lt;1, 4, 3&gt;</td>
    </tr>
    <tr>
      <td> </td>
      <td>64B Swizzling</td>
      <td>((T,4,m),(8,k)):((1,T,LBO),(4T,SBO))</td>
      <td>Swizzle&lt;2, 4, 3&gt;</td>
    </tr>
    <tr>
      <td> </td>
      <td>128B Swizzling</td>
      <td>((T,8,m),(8,k)):((1,T,LBO),(8T,SBO))</td>
      <td>Swizzle&lt;3, 4, 3&gt;</td>
    </tr>
    <tr>
      <td>K- major*</td>
      <td>No-swizzling or Interleaved</td>
      <td>((8,m),(T,2k)):((1T,SBO),(1,LBO))</td>
      <td>Swizzle&lt;0, 4, 3&gt;</td>
    </tr>
    <tr>
      <td> </td>
      <td>32B Swizzling</td>
      <td>((8,m),(T,2k)):((2T,SBO),(1,T))</td>
      <td>Swizzle&lt;1, 4, 3&gt;</td>
    </tr>
    <tr>
      <td> </td>
      <td>64B Swizzling</td>
      <td>((8,m),(T,2k)):((4T,SBO),(1,T))</td>
      <td>Swizzle&lt;2, 4, 3&gt;</td>
    </tr>
    <tr>
      <td> </td>
      <td>128B Swizzling</td>
      <td>((8,m),(T,2k)):((8T,SBO),(1,T))</td>
      <td>Swizzle&lt;3, 4, 3&gt;</td>
    </tr>
  </tbody>
</table>

<ul>
  <li>T = 128 / sizeof-elements-in-bits T represents scale factor which normalizes matrix element types to 128-bits.</li>
  <li>m represents the number of repeating patterns across rows.</li>
  <li>k represents the number of repeating patterns across columns.</li>
</ul>

<p>* As shown later in this note, the factor <code>k</code> in K-major layout is in fact not needed and should be dropped.</p>

<h3 id="mn-major">MN major</h3>
<p><img src="/assets/img/MMAv5-SMEM-MNMajor.png" alt="MN-major layout" /></p>

<p>(The figure is drawn as M/N x K following CUTLASS convention)</p>

<p>The Canonical CuTe layout in PTX documentation is consistent with CUTLASS, and is shown in the figure.</p>

<p>Inside a Swizzle Atom, there’re 8 columns that’re adjacent to each other on physical memory. Each column is a contiguous 
segment on physical memory and has size equal to swizzling byte width.</p>

<p>It’s up to the user how to distribute the Swizzle Atoms(along MN or K dim first). As long as LBO and 
SBO are provided, the hardware knows where in the physical memory to look for desired Atoms. Note both LBO and SBO are 
needed because hardware needs to know where on the physical memory to load Swizzle Atom y, z and others.</p>

<p>From the <code>getCoreMatrixLinearLayout</code> function above, Triton always distributes the Atoms along strided dim first up to a
TMA block.</p>

<h3 id="k-major">K major</h3>
<p><img src="/assets/img/MMAv5-SMEM-KMajor.png" alt="K-major layout" /></p>

<p>The Canonical CuTe layout in PTX documentation is different from CUTLASS in that CUTLASS dropped factor <code>k</code>. We adopt 
CUTLASS’s layouts as the source of truth with confirmation from Nvidia.</p>

<p>Inside a Swizzle Atom, there’re 8 rows adjacent to each other on physical memory. Each row is a contiguous
segment on physical memory and has size equal to swizzling byte width.</p>

<p>It’s up to the user how to distribute the Swizzle Atoms(along MN or K dim first). For 64B/128B swizzling, an MMA “unit” 
SMEM is smaller than a Swizzle Atom and each row in a Swizzle Atom contains elements from 2/4 different MMA instructions. 
This is OK because swizzling just deterministically tells hardware the exact location of each element. e.g. For 128B 
swizzling this table shows where in the physical SMEM to find the 8*2 T units of the first MMA instruction:</p>

<table>
  <thead>
    <tr>
      <th>Physical Offset</th>
      <th>16B</th>
      <th>16B</th>
      <th>16B</th>
      <th>16B</th>
      <th>16B</th>
      <th>16B</th>
      <th>16B</th>
      <th>16B</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>0 ~ 127B</td>
      <td>0</td>
      <td>1</td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td>128 ~ 255B</td>
      <td>3</td>
      <td>2</td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td>256 ~ 383B</td>
      <td> </td>
      <td> </td>
      <td>4</td>
      <td>5</td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td>…</td>
      <td> </td>
      <td> </td>
      <td>7</td>
      <td>6</td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td>8</td>
      <td>9</td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td>11</td>
      <td>10</td>
      <td> </td>
      <td> </td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td>12</td>
      <td>13</td>
    </tr>
    <tr>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td> </td>
      <td>15</td>
      <td>14</td>
    </tr>
  </tbody>
</table>

<p>Each MMA instruction’s SMEM operand spans across multiple Swizzle Atoms along MN dim, but only one (or partial) Atom 
along K dim. So given location of Atom x, the hardware needs SBO to know where to load Atom z and others along MN dim.
However, LBO is not needed (except non-swizzling cases, not shown in the figure) because for example Atom y is not 
needed by the MMA instruction in the figure.</p>

<p>From the <code>getCoreMatrixLinearLayout</code> function above, Triton always distributes the Atoms along strided dim first up to a 
TMA block.</p>

<h1 id="references">References</h1>
<ul>
  <li><a href="https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-shared-memory-layout-swizzling">PTX ISA Documentation v9.3</a></li>
  <li><a href="https://research.colfax-intl.com/cutlass-tutorial-wgmma-hopper/">Colfax CUTLASS Tutorial: Fast Matrix-Multiplication with WGMMA on NVIDIA® Hopper™ GPUs</a></li>
  <li>CUTLASS source code for <a href="https://github.com/NVIDIA/cutlass/blob/cb37157db50d0528c4aea99feb37946ec278e3d9/include/cute/atom/mma_traits_sm100.hpp#L171">SM100 UMMA descriptors</a></li>
  <li>Triton source code for <a href="https://github.com/triton-lang/triton/blob/2104a207c0595da7d099dd320967afd0fc41f70d/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp#L153">NVMMAShared encoding to Linear Layout conversion</a></li>
</ul>

<p>Special thanks to Bingyi Zhang from Nvidia! A great amount of details and nuances in this note came from
extensive discussions and collaborations with Bingyi.</p>]]></content><author><name>Peng Chen</name></author><category term="Other" /><summary type="html"><![CDATA[The 5th gen Tensor Core on Blackwell GPU requires MMA’s operand B live in SMEM, and operand A live in SMEM or Tensor Memory (TMEM). MMAv3 (Hopper) supports “SS_GEMM” as well where A and B are both in SMEM. The layout is almost the same as v5.]]></summary></entry><entry><title type="html">Deriving formula for Backward Gradient of Matrix Multiplication</title><link href="https://pchen7e2.github.io/2026/04/12/matmul-backward.html" rel="alternate" type="text/html" title="Deriving formula for Backward Gradient of Matrix Multiplication" /><published>2026-04-12T00:00:00+00:00</published><updated>2026-04-12T00:00:00+00:00</updated><id>https://pchen7e2.github.io/2026/04/12/matmul-backward</id><content type="html" xml:base="https://pchen7e2.github.io/2026/04/12/matmul-backward.html"><![CDATA[<p>Reference: https://cs231n.stanford.edu/handouts/derivatives.pdf</p>

<p>In this short note, we derive the backward gradient of a matrix multiplication with minimal math background.</p>

<h1 id="forward-pass">Forward Pass</h1>

\[Y = WX\]

<p>Matrix shapes:</p>

\[Y \in \mathbb{R}^{M \times K}\]

\[W \in \mathbb{R}^{M \times N}\]

\[X \in \mathbb{R}^{N \times K}\]

<hr />

<h1 id="meaning-of-dy">Meaning of dY</h1>

<p>In backward prop, we are given the input that has known values and the same shape as $Y$:</p>

\[dY = \frac{\partial L}{\partial Y}\]

<p>where $L$ is a scalar loss.</p>

<p>Thus:</p>

\[dY[i][j] = \frac{\partial L}{\partial Y[i][j]}\]

<p>Intuitively, $dY[i][j]$ measures how much the loss $L$ changes if $Y[i][j]$ increases slightly.</p>

\[dY[i][j] = \frac{L(Y[i][j] + h) - L(Y[i][j])}{h}\]

<p>for a very small $h$.</p>

<hr />

<h1 id="goal">Goal</h1>

<p>We want to compute</p>

\[dX = \frac{\partial L}{\partial X}\]

<p>Each element</p>

\[dX[i][j] = \frac{\partial L}{\partial X[i][j]}\]

<p>represents how much the loss changes if $X[i][j]$ increases.</p>

<hr />

<h1 id="chain-rule">Chain Rule</h1>

<p>To compute a single element $dX[i][j]$ in $dX$, we apply the chain rule through $Y$:</p>

\[dX[i][j]= \frac{\partial L}{\partial Y}
\cdot
\frac{\partial Y}{\partial X[i][j]} =
\sum_{a=0}^{M-1} \sum_{b=0}^{K-1}
\frac{\partial L}{\partial Y[a][b]}
\cdot
\frac{\partial Y[a][b]}{\partial X[i][j]}\]

<p>Note both $\frac{\partial L}{\partial Y}$ and $\frac{\partial Y}{\partial X[i][j]}$ are of the same shape as Y, and
$dX[i][j]$ is the sum of pointwise product of them.</p>

<p>Also note element at index $[a][b]$ in $\frac{\partial Y}{\partial X[i][j]}$ is just $\frac{\partial Y[a][b]}{\partial X[i][j]}$.</p>

<p>Using the shorthand $dY[a][b] = \frac{\partial L}{\partial Y[a][b]}$:</p>

\[dX[i][j]=
\sum_{a,b}
dY[a][b]
\cdot
\frac{\partial Y[a][b]}{\partial X[i][j]}\]

<p>Intuitively:</p>

<ul>
  <li>$\frac{\partial Y[a][b]}{\partial X[i][j]}$ measures how $X[i][j]$ affects $Y[a][b]$</li>
  <li>$dY[a][b]$ measures how $Y[a][b]$ affects the loss $L$</li>
</ul>

<p>Multiplying them gives the influence path</p>

\[X[i][j] \rightarrow Y[a][b] \rightarrow L\]

<p>Summing over all $a,b$ aggregates all such paths, showing the influence $X[i][j]$ on $L$.</p>

<hr />

<h1 id="expand-the-forward-definition">Expand the Forward Definition</h1>

<p>From matrix multiplication:</p>

\[Y[a][b] =
\sum_{c=0}^{N-1}
W[a][c] \cdot X[c][b]\]

<p>That is:</p>

\[Y[a][b]=
W[a][0]X[0][b] + W[a][1]X[1][b] + \dots + W[a][N-1]X[N-1][b]\]

<hr />

<h1 id="dependency-observation">Dependency Observation</h1>

<p>If $b \ne j$, then $Y[a][b]$ <strong>does not depend on</strong> $X[i][j]$. It only depends on row $a$ of $W$ and column $b$ of $X$.</p>

<p>Therefore:</p>

\[\frac{\partial Y[a][b]}{\partial X[i][j]} = 0\]

<p>Thus only terms where $b = j$ contribute.</p>

<p>The formula simplifies from</p>

\[dX[i][j]=
\sum_{a,b}
dY[a][b]
\cdot
\frac{\partial Y[a][b]}{\partial X[i][j]}\]

<p>to</p>

\[dX[i][j]=
\sum_{a=0}^{M-1}
dY[a][j]
\cdot
\frac{\partial Y[a][j]}{\partial X[i][j]}\]

<hr />

<h1 id="compute-the-partial-derivative">Compute the Partial Derivative</h1>

<p>Consider</p>

\[Y[a][j]=
W[a][0]X[0][j] + W[a][1]X[1][j] + \dots + W[a][i]X[i][j] + \dots\]

<p>The <strong>only term involving $X[i][j]$</strong> is</p>

\[W[a][i]X[i][j]\]

<p>Therefore</p>

\[\frac{\partial Y[a][j]}{\partial X[i][j]} = W[a][i]\]

<p>Then</p>

\[dX[i][j]=
\sum_{a=0}^{M-1}
dY[a][j]
\cdot
\frac{\partial Y[a][j]}{\partial X[i][j]}\]

<p>becomes</p>

\[dX[i][j]=
\sum_{a=0}^{M-1}
dY[a][j] \cdot W[a][i]\]

<p>then exchange position of two elements</p>

\[dX[i][j]=
\sum_{a=0}^{M-1}
W[a][i] \cdot dY[a][j]\]

<hr />

<h1 id="recognizing-the-matrix-form">Recognizing the Matrix Form</h1>

<p>Notice</p>

\[W^T[i][a] = W[a][i]\]

<p>Thus</p>

\[dX[i][j]=
\sum_{a=0}^{M-1}
W^T[i][a] \cdot dY[a][j]\]

<p>This is exactly the definition of matrix multiplication</p>

\[dX = W^T dY\]

<hr />

<h1 id="final-result">Final Result</h1>

<p>For the forward operation</p>

\[Y = WX\]

<p>the backward gradients are</p>

\[\frac{\partial L}{\partial X} = W^T dY\]

<p>If we go through a very similar process, we could also prove</p>

\[\frac{\partial L}{\partial W} = dY X^T\]

<p>How to remember this: for Y=WX or Y=XW, when we want to compute dX given dY, we always just swap positions of X(dX) and Y(dY), and then just transpose W without changing its location in the equation.</p>]]></content><author><name>Peng Chen</name></author><category term="Other" /><summary type="html"><![CDATA[Reference: https://cs231n.stanford.edu/handouts/derivatives.pdf]]></summary></entry><entry><title type="html">Deriving formula for Backward Gradient of Softmax</title><link href="https://pchen7e2.github.io/2026/04/12/softmax-backward.html" rel="alternate" type="text/html" title="Deriving formula for Backward Gradient of Softmax" /><published>2026-04-12T00:00:00+00:00</published><updated>2026-04-12T00:00:00+00:00</updated><id>https://pchen7e2.github.io/2026/04/12/softmax-backward</id><content type="html" xml:base="https://pchen7e2.github.io/2026/04/12/softmax-backward.html"><![CDATA[<p>In this short note, we derive the backward gradient of a softmax calculation (in Flash Attention) with minimal math background.</p>

<h1 id="forward-pass">Forward Pass</h1>

<p>Given a vector of attention scores $S \in \mathbb{R}^{N}$, softmax produces:</p>

\[P[i] = \frac{e^{S[i]}}{\sum_{k=0}^{N-1} e^{S[k]}}\]

<p>Note $P[i] &gt; 0$ and $\sum_i P[i] = 1$.</p>

<p>In flash attention, softmax is applied independently to each row of the attention score matrix $S = QK^T$. Everything below applies per row.</p>

<hr />

<h1 id="meaning-of-dp">Meaning of dP</h1>

<p>In backward prop, we are given the input that has known values and the same shape as $P$:</p>

\[dP = \frac{\partial L}{\partial P}\]

<p>where $L$ is a scalar loss. Each element:</p>

\[dP[i] = \frac{\partial L}{\partial P[i]}\]

<p>measures how much the loss $L$ changes if $P[i]$ increases slightly.</p>

<hr />

<h1 id="goal">Goal</h1>

<p>We want to compute</p>

\[dS = \frac{\partial L}{\partial S}\]

<p>Each element</p>

\[dS[i] = \frac{\partial L}{\partial S[i]}\]

<p>represents how much the loss changes if $S[i]$ increases.</p>

<hr />

<h1 id="chain-rule">Chain Rule</h1>

<p>To compute a single element $dS[i]$, we apply the chain rule through $P$:</p>

\[dS[i] = \frac{\partial L}{\partial P}
\cdot
\frac{\partial P}{\partial S[i]} = \sum_{j=0}^{N-1}
\frac{\partial L}{\partial P[j]}
\cdot
\frac{\partial P[j]}{\partial S[i]}
= \sum_{j=0}^{N-1}
dP[j]
\cdot
\frac{\partial P[j]}{\partial S[i]}\]

<p>Note both $\frac{\partial L}{\partial P}$ and $\frac{\partial P}{\partial S[i]}$ have the same shape as $P$, and $dS[i]$ is the sum of pointwise product of them.</p>

<p>Also note element at index $[j]$ in $\frac{\partial P}{\partial S[i]}$ is just $\frac{\partial P[j]}{\partial S[i]}$.</p>

<p>Note: unlike matmul, <strong>every</strong> $P[j]$ depends on $S[i]$ because $S[i]$ appears in the denominator $\sum_k e^{S[k]}$. So no terms drop out.</p>

<hr />

<h1 id="compute-the-partial-derivatives">Compute the Partial Derivatives</h1>

<p>We need $\frac{\partial P[j]}{\partial S[i]}$ for two cases.</p>

<p><strong>Case 1: $j = i$</strong></p>

\[P[i] = \frac{e^{S[i]}}{\sum_k e^{S[k]}}\]

<p>Using the quotient rule where the numerator is $e^{S[i]}$ and the denominator is $\sum_k e^{S[k]}$:</p>

\[\frac{\partial P[i]}{\partial S[i]}
= \frac{e^{S[i]} \cdot \sum_k e^{S[k]} - e^{S[i]} \cdot e^{S[i]}}{\left(\sum_k e^{S[k]}\right)^2}\]

\[= \frac{e^{S[i]}}{\sum_k e^{S[k]}} - \frac{e^{S[i]}}{\sum_k e^{S[k]}} \cdot \frac{e^{S[i]}}{\sum_k e^{S[k]}}\]

\[= P[i] - P[i]^2 = P[i](1 - P[i])\]

<p><strong>Case 2: $j \ne i$</strong></p>

\[P[j] = \frac{e^{S[j]}}{\sum_k e^{S[k]}}\]

<p>Here the numerator $e^{S[j]}$ does not depend on $S[i]$, so using the quotient rule:</p>

\[\frac{\partial P[j]}{\partial S[i]}
= \frac{0 - e^{S[j]} \cdot e^{S[i]}}{\left(\sum_k e^{S[k]}\right)^2}
= -P[j] \cdot P[i]\]

<hr />

<h1 id="substitute-back">Substitute Back</h1>

<p>Split the chain rule sum into the $j = i$ term and the $j \ne i$ terms:</p>

\[dS[i] =dP[i] \cdot P[i](1 - P[i])+\sum_{j \ne i}dP[j] \cdot (-P[j] \cdot P[i])\]

<p>Factor out $P[i]$:</p>

\[dS[i] = P[i] \left(
dP[i] \cdot (1 - P[i]) - \sum_{j \ne i} dP[j] \cdot P[j]
\right)\]

<p>Expand the first term:</p>

\[dS[i] = P[i] \left(
dP[i] - dP[i] \cdot P[i]- \sum_{j \ne i} dP[j] \cdot P[j]
\right)\]

<p>Notice that $dP[i] \cdot P[i] + \sum_{j \ne i} dP[j] \cdot P[j] = \sum_{j} dP[j] \cdot P[j]$, so:</p>

\[dS[i] = P[i] \left(
dP[i] - \sum_{j} dP[j] \cdot P[j]
\right)\]

<p>Define the dot product as a single scalar:</p>

\[D = \sum_{j} dP[j] \cdot P[j] = dP \cdot P\]

<p>So:</p>

\[dS[i] = P[i] \cdot (dP[i] - D)\]

<p>Note $D$ is a fixed scalar no matter what value $i$ is.</p>

<hr />

<h1 id="final-result">Final Result</h1>

<p>For the forward operation</p>

\[P = \text{softmax}(S)\]

<p>the backward gradient is</p>

\[dS[i] = P[i] \cdot (dP[i] - D)\]

<p>where</p>

\[D = \sum_{j} dP[j] \cdot P[j]\]

<p>Or in vector form:</p>

\[dS = P \odot (dP - D)\]

<p>where $\odot$ is elementwise multiplication and $D$ is a scalar (per row).</p>

<hr />

<h1 id="why-this-matters-for-flash-attention">Why This Matters for Flash Attention</h1>

<p>In flash attention, softmax is applied row-wise to the attention score matrix $S = QK^T$, and the backward pass must be computed <strong>without materializing the full attention matrix</strong> in HBM.</p>

<p>The formula $dS = P \odot (dP - D)$ is perfectly suited for this because:</p>

<ol>
  <li><strong>$D$ is just a scalar per row.</strong> By definition:</li>
</ol>

\[D[i] = \sum_j dP[i][j] \cdot P[i][j]\]

<p>This looks like it needs both $dP$ and $P$, which are full-sized attention matrices we want to avoid materializing. But recall the forward output of attention is $O = PV$, and its gradient is $dO$. We have $dP = dO \cdot V^T$, so:</p>

\[D[i] = \sum_j dP[i][j] \cdot P[i][j] = \sum_j \left(\sum_l dO[i][l] \cdot V[j][l]\right) \cdot P[i][j]\]

<p>Swapping the order of summation:</p>

\[D[i] = \sum_l dO[i][l] \sum_j P[i][j] \cdot V[j][l] = \sum_l dO[i][l] \cdot O[i][l]\]

<p>since $\sum_j P[i][j] \cdot V[j][l] = O[i][l]$ by the forward definition $O = PV$. So:</p>

\[D[i] = \sum_l dO[i][l] \cdot O[i][l]\]

<p>This is just a row-wise dot product of $dO$ and $O$ — both of which are already available in HBM from the forward pass and the incoming gradient. No need to recompute $P$ or $dP$ for this step.</p>

<ol>
  <li><strong>Recomputing $P$ per tile using saved row sum.</strong> During the forward pass, we never have a full row of true $P$ values at once — each tile only sees a partial denominator. But the forward pass saves the sum of exponentials per row:</li>
</ol>

\[L[i] = \sum_{k} e^{S[i][k]}\]

<p>In the backward pass, for a tile covering column block $j$, we recompute the local attention scores $S[i][j] = Q[i] \cdot K[j]^T$ and recover the true softmax values for that tile:</p>

\[P[i][j] = \frac{e^{S[i][j]}}{L[i]}\]

<p>This is exactly the softmax definition. The key insight is that $L[i]$ encodes the full-row denominator, so any tile can produce its correct $P$ values independently.</p>

<ol>
  <li><strong>Each tile is independent.</strong> With $D$ precomputed (step 1) and $P$ recoverable per tile (step 2), we form $dP$ from $dO$ and $V^T$ for that tile, and apply $P \odot (dP - D)$. The subtraction of $D$ is the only term that couples different columns within a row, and since $D$ is already a known scalar, each tile can be processed independently.</li>
</ol>]]></content><author><name>Peng Chen</name></author><category term="Other" /><summary type="html"><![CDATA[In this short note, we derive the backward gradient of a softmax calculation (in Flash Attention) with minimal math background.]]></summary></entry></feed>