Memory and Compute

3. Memory and Compute #

Memory #

  • float32: default (4 bytes)
    • | sign (1) | exponent (8) | fraction (23) |
    • $(-1)^{\text{sign}} \times(1$.fraction bits $) \times 2^{\text{exponent}-127}$
  • float16: 2 bytes, but small values underflow.
  • bfloat16: same memory as float16, but dynamic range of float32. more exponent,less fraction.
  • fp8: supported by H100 (E4M3, E5M2 variants).
  • mixed_precision_training (float32 for attention/optimizer state, bf16 for feed forward)

Tensor Operations #

  • Tensors = pointers into allocated memory with metadata describing how to get to any element of the tensor.

  • Example:

    x.stride(0)  # steps to move across rows
    x.stride(1)  # steps to move across cols
    

Slicing & Views #

  • Views share storage, mutations reflect across them.
  • Non-contiguous views must be made contiguous before reshaping.

Matrix Multiplication #

  • FLOPs for matmul: $\text{FLOPs} = 2 \times m \times n \times p$

Einops #

  • Provides dimension-aware tensor operations.

  • Examples:

    z = einsum(x, y, "b s h, b t h -> b s t")
    y = reduce(x, "... h -> ...", "sum") # or mean, max, min
    x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2)  # @inspect x
    x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)")  # @inspect x
    
  • jaxtyping:

      #Old way:
      x = torch.ones(2, 2, 1, 3)  # batch seq heads hidden  @inspect x
      #New (jaxtyping) way:
      x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3)  # @inspect x
    

Compute Accounting #

FLOPs #

  • Definitions:

    • FLOPs = number of floating-point ops.
    • FLOP/s = FLOPs per second (to measure hardware throughput).
      • depend on hardware (H100»A100) and data types (bfloat16 » float32)
    • A100 has a peak performance of 312 teraFLOP/s (312e12)
  • Elementwise: O(mn); Addition of two matrices: mn; Matmul: 2mnp

Training Example #

  • Shapes: $X\in\mathbb{R}^{B\times D},; W_1\in\mathbb{R}^{D\times D},; W_2\in\mathbb{R}^{D\times K}$

    $H_1 = XW_1\in\mathbb{R}^{B\times D},; H_2 = H_1W_2\in\mathbb{R}^{B\times K}$
    $L=\text{mean}(H_2^2)=\frac{1}{BK}\sum_{i,k}H_{2,ik}^2$

  • Forward FLOPs

    1. $H_1 = XW_1$: $(B\times D)\cdot(D\times D)\Rightarrow 2BD^2$
    2. $H_2 = H_1W_2$: $(B\times D)\cdot(D\times K)\Rightarrow 2BDK$
    3. $L=\mathrm{mean}(H_2^2)$: elementwise square and reduce
    • squares: $BK$ mults; sum: $(BK-1)$ adds (≈ $BK$)
    • scale by $1/(BK)$: ~1 mult (negligible) So ≈ $2BK$ FLOPs (tiny vs matmuls; often omitted).

    Forward total (dominant terms):

    $$ \boxed{2BD^2 + 2BDK} $$
  • Backward—derivatives and FLOPs

    1. For $L=\text{mean}(H_2^2)$, $G_2:=\nabla_{H_2}L = \frac{2}{BK}H_2$ (elementwise). Cost: $BK$ mults (tiny).

    2. Layer 2 ($H_2 = H_1 W_2$)

      • Gradient w.r.t. weights: $\nabla_{W_2}L = H_1^\top G_2$: $(D\times B)\cdot(B\times K)\Rightarrow \mathbf{2DBK}$
      • Backprop to activation: $G_1 = G_2 W_2^\top$: $(B\times K)\cdot(K\times D)\Rightarrow \mathbf{2BDK}$
      • Layer-2 backward subtotal: $2DBK + 2BDK = 4BDK$
    3. Layer 1 ($H_1 = X W_1$)

      • Gradient w.r.t. weights: $\nabla_{W_1}L = X^\top G_1:$ $(D\times B)\cdot(B\times D)\Rightarrow \mathbf{2DBD = 2BD^2}$

      • (Optional) backprop to input: $G_X = G_1 W_1^\top$ $(B\times D)\cdot(D\times D)\Rightarrow \mathbf{2BD^2}$

      • Layer-1 backward subtotal (including $G_X$): $2BD^2 + 2BD^2 = 4BD^2$

    Backward total (dominant terms):

    $$ \boxed{4BD^2 + 4BDK}\quad(\text{+~}BK\ \text{for }G_2\text{ elementwise}) $$
  • Combined forward + backward

    $$ \boxed{\text{Total} \;\approx\; 6BD^2 + 6BDK} $$
  • Forward pass (matrix multiplication):

    $$2 \times (\text{\#tokens}) \times (\text{\#parameters})$$
  • Backward pass:

    $$ 4 \times (\#\text{tokens}) \times (\#\text{parameters}) $$
  • Total:

    $$ 6 \times (\#\text{tokens}) \times (\#\text{parameters}) $$

Model FLOPs Utilization (MFU) #

  • Definition:

    $$ \text{MFU} = \frac{\text{actual FLOP/s}}{\text{promised FLOP/s}} $$
  • Good MFU ≈ 0.5 or higher.

Parameter Initialization #

x = nn.Parameter(torch.randn(input_dim))
output = x @ w  # @inspect output
assert output.size() == torch.Size([output_dim])
# Each element of output scales as sqrt(input_dim): ~18.9
# Large values can cause gradients to blow up and make training unstable.

We want an initialization that is invariant to input_dim. To do that, we simply rescale by 1/sqrt(input_dim) To be extra safe, we can truncate the normal distribution to [-3, 3] to avoid any chance of outliers:

w = nn.Parameter(
    nn.init.trunc_normal_(
        torch.empty(input_dim, output_dim), std=1 / np.sqrt(input_dim), a=-3,b=3
    )
)