GPU & ML

Making ML Workloads Fast on a GPU #

How Do We Make GPUs Go Fast? #

1. Control Divergence (not a memory bottleneck)
  • GPUs operate in a SIMT model(single instruction multi threds) – every thread in a warp executes the same instruction.
  • Divergent branches (if/else) serialize warp execution → overhead.
  • Example:
if (threadIdx.x < 4) {
    A;
} else {
    X;
}
Z;
  • CUDA threads are grouped into warps (typically 32 threads).

  • All threads in a warp execute the same instruction at a time.

  • If some threads take the if branch while others take the else branch, the warp must serialize:

    • First execute the if branch for threads that satisfy threadIdx.x < 4, while masking out the others.
    • Then execute the else branch for the remaining threads, while masking out the first group.

This is called warp divergence (control divergence)->lead to ineffeciency. We should keep warps aligned to maximize efficiency.

Reduce memory accesses #

2. Operator Fusion
  • Think of GPU like a factory: inputs come from memory (warehouse), processed inside (compute), then sent back.
  • Compute scales well, memory doesn’t → minimize round trips.
  • Example:
    • Naïve: sin²x + cos²x requires ~5 kernel launches.
    • Fused: all ops combined into one kernel → fewer memory accesses.
  • Compilers (e.g., torch.compile) can automatically fuse pointwise ops into a single kernel:
import torch

torch._dynamo.config.suppress_errors = True  # optional
x = torch.randn(1024, 1024, device="cuda")

@torch.compile
def fused_fn(x):
    return torch.exp(torch.sin(x)) + 1
out = fused_fn(x)
3. Memory Coalescing
  1. DRAM Burst Mode
Index:   0   1   2   3   | 4   5   6   7   | 8   9  10  11   | 12  13  14  15
Bytes:   [ Burst 0      ] [ Burst 1      ] [ Burst 2      ] [ Burst 3      ]
  • Global memory (DRAM) is not read byte by byte; it is read in bursts.
  • The address space is divided into burst sections:
    • If you read one location, all the other bytes in the same section are delivered to the processor “for free.” (DRAM has slow per-row access. Once a row is activated, streaming nearby bytes out in burst mode is fast.)
    • Example: In a 16-byte address space with 4-byte burst sections, accessing address 5 will automatically fetch bytes 4–7.
  • In practice, burst sizes are 128 bytes or more (for GPUs).
  1. Memory Coalescing
  • A warp = 32 consecutive threads executing together.
  • Coalesced access = when all threads in a warp access addresses that fall in the same burst section.
    • Result: 1 DRAM request serves all 32 threads.
  • Uncoalesced access = threads spread across multiple burst sections.
    • Result: multiple DRAM requests → higher latency and wasted bandwidth.
  • Rule of Thumb: Effective bandwidth is maximized when consecutive threads access consecutive memory locations.
  1. Coalescing in Matrix Multiplication
  • For row-major matrices:
    • If each thread moves along rows, accesses are not coalesced (each thread jumps across bursts).
    • If each thread moves down columns, accesses are coalesced (threads read contiguous memory).

Illustration:

  • Case (A) Not coalesced:

    • Thread 1 reads $M_{0,0}, M_{0,1}, M_{0,2}, \dots$
    • Thread 2 reads $M_{1,0}, M_{1,1}, M_{1,2}, \dots$
    • Threads are striding across rows → addresses spread out → multiple bursts.
  • Case (B) Coalesced:

    • Thread 1 reads $M_{0,0}, M_{1,0}, M_{2,0}, \dots$
    • Thread 2 reads $M_{0,1}, M_{1,1}, M_{2,1}, \dots$
    • Threads are aligned down columns → contiguous memory → single burst fetch.

Trade memory for compute/accuracy #

4. Low Precision Computation
  • Using fewer bits reduces memory traffic, improving arithmetic intensity (FLOPs per byte).
  • Modern GPUs provide specialized tensor cores that accelerate low-precision (FP16, BF16, INT8) matmul.

Example: elementwise ReLU

$$ x \mapsto \max(0, x), \quad \text{for a vector of size } n $$
PrecisionMemory Access (per element)OperationsArithmetic Intensity
Float321 read (4B) + 1 write (4B) = 8 bytes1 comparison = 1 FLOP8 bytes / FLOP
Float161 read (2B) + 1 write (2B) = 4 bytes1 comparison = 1 FLOP4 bytes / FLOP
  • Lower precision doubles arithmetic intensity.
  • Practical guidelines (from FP16/BF16 usage in ML)
    1. Operations that can use 16-bit storage (FP16/BF16):

      • Matrix multiplications (core of deep learning workloads).
      • Most pointwise operations (ReLU, tanh, add, sub, mul).
    2. Operations that need more precision (FP32/FP16):

      • Adding small values to large sums (susceptible to rounding errors).
      • Reduction ops (sum, softmax, normalization).
    3. Operations that need more range (FP32/BF16):**

      • Pointwise ops with large magnitude changes, e.g. $|f(x)| \gg |x|$ (exp, log, pow).
      • Loss functions (where gradients may explode/vanish).
5. Activation Recomputation (a.k.a. Checkpointing)

When we train neural networks, the forward pass computes intermediate activations (e.g., after each layer). During the backward pass, gradients require these activations. So, by default, frameworks store all activations in memory during the forward pass. Which become expensive memory traffic. Example: 3 stacked sigmoids

Forward:

$$ s_1 = \sigma(x), \quad s_2 = \sigma(s_1), \quad \text{out} = \sigma(s_2) $$

Backward (chain rule) below. To compute this, we need $s_1, s_2, \text{out}$.

$$ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial \text{out}} \cdot \sigma'(s_2) \cdot \sigma'(s_1) \cdot \sigma'(x) $$
  1. Memory Cost: Naïve (Store Everything):
    • forward: 1 read (load $x$), 3 writes (store $s_1, s_2, \text{out}$).
    • backwards: 3 reads (load $s_1, s_2, \text{out}$), 1 write (store gradient $\mathrm{d}x$)
    • total: 8 memory ops. Arithmetic intensity is very low → lots of memory traffic just to hold onto values.
  2. Solution: Throw Away Activations, Recompute Them
    • forward: Only store the final output ($\text{out}$), Discard $s_1, s_2$.
      • 1 read (load $x$), 1 write (store $\text{out}$)
    • Backward: When computing gradients, recompute $s_1, s_2$ from $x$. Use them in the backward chain rule.
      • 2 reads (load $x, \text{out}$), 1 write (store $\mathrm{d}x$). Plus recompute ops ($\sigma(x), \sigma(s_1), \sigma(s_2)$)
    • Total: 4 memory ops (down from 8). We’ve cut memory by 5/8 while only adding some extra compute.
  • GPUs are often memory-bandwidth bound, not compute-bound.
  • Trading a few extra FLOPs (cheap on modern GPUs) for much lower memory traffic is a win.

Move memory to shared memory #

6. Tiling
  1. Tiling = grouping and reordering threads to minimize global memory accesses.

  2. Steps

    • Cut the computation into tiles (blocks of data).
    • Load tiles into shared memory.
    • Reuse them multiple times before fetching new data from global memory.
    • This reduces redundant global memory reads and improves coalescing.
  3. Calculation

    We compute $C = A \times B, \quad C_{i,j} = \sum_{k=0}^{N-1} A_{i,k},B_{k,j}$.

    Without Tiling

    • Each $C_{i,j}$ needs $N$ elements from row of $A$ and column of $B$.
    • $N^2$ outputs →$\text{Reads} = N^2(N+N)=2N^3$

    With Tiling (tile size $T$)

    • Now divide matrices into tiles of size $T \times T$. Each thread block computes a tile of $C$.
    • Per phase: Load one tile of $A$ and one tile of $B$ into shared memory. → $2T^2$ reads. Each element inside the tile is reused $T$ times.
    • Tiles in $C$: $(N/T)^2$. And Each tile of $C$ requires $\frac{N}{T}$ phases
    • $\text{Reads} = \Big(\frac{N}{T}\Big)^2 \cdot \frac{N}{T} \cdot (2T^2) = \frac{2N^3}{T}$

    Reduction: $\frac{N^3}{N^3/T} = T \quad \text{fewer global reads.}$

  4. Advantages

    • Repeated reads are served from shared memory (fast), not global memory.
    • Memory accesses can be arranged to be coalesced.
    • Global memory traffic is reduced significantly.
  5. Complexities with Tiling - Tile size divisibility

    • Tile sizes may not divide the matrix size exactly → leads to low utilization.

    • Example:

    • If matrix is $256 \times 256$ and tile = $128 \times 128$, tiling is perfect.

    • If matrix is $257 \times 256$, extra tiles are needed, leaving partially filled thread blocks that waste work.

    • Factors affecting tile size:

      • Coalesced memory access.
      • Shared memory size (hardware constraint).
      • Divisibility of matrix dimensions.
  6. Complexities with Tiling - Memory alignment

    • DRAM is read in bursts.
    • Aligned layout:*tile fits neatly within bursts → efficient load.
    • Unaligned layout: tile straddles burst boundaries → requires extra transactions.
    • Fix: often need padding to enforce alignment.

Scaling #

model

1. The Roofline Model
  • X-axis (operational intensity): $\text{OI} = \frac{\text{FLOPs}}{\text{Bytes read from memory}}$
    • Low OI → memory-bound (little reuse).
    • High OI → compute-bound (high reuse).
  • Y-axis (Throughput): FLOPs per second (GFLOPS/TFLOPS).
  • Limits:
    • Sloped lines = memory bandwidth ceilings (DRAM, HBM, shared memory).
    • Flat lines = compute ceilings (ALUs, tensor cores).
2. Why Matmul Performance is Tricky (Second picture)

Square matrix multiplies don’t yield smooth performance curves — experiments show:

  • Rising performance with higher OI (expected).

  • Sudden jumps from tiling.

    • If $N$ not divisible by $T$, some thread blocks under-utilize.
    • Tiles must align with DRAM burst boundaries. Misaligned tiles → wasted transactions.
  • Periodic drops due to wave quantization.

    • GPUs execute tiles in waves across SMs.
    • If tile count ≤ #SMs → all tiles run in parallel.
    • If tile count > #SMs → multiple waves needed, lowering throughput.
  • Example of Wave quantization (A100, 108 SMs):

    • $1792 \times 1792$ with tile $256 \times 128$: $ \frac{1792}{256} \times \frac{1792}{128} = 7 \times 14 = 98 \ \text{tiles}$. Fits in one wave (fast).

    • $1793 \times 1793$: $\frac{1793}{256} \times \frac{1793}{128} \approx 8 \times 15 = 120 \ \text{tiles}$

      Requires >108 SMs → spills into 2 waves → performance dip.

Understand FlashAttention #

flashattention SRAM (on-chip shared memory) is tiny but ultra-fast, HBM (High Bandwidth Memory) is large but slower, DRAM is even slower. So we use tile to keep data in SRAM as much as possible.

  1. Tiled approach:
    • Break Q, K, V into tiles.
    • Load tiles into SRAM, perform matmul locally.
    • Write results back to HBM.
    • Outer loop: Copy blocks between HBM and SRAM; Inner loop: Multiply sub-blocks in SRAM.
  2. Fusion of the exponential operator
  3. Online (Incremental) softmax
    • Maintain running max and running normalization constant.