3. Memory and Compute #
Memory #
- float32: default (4 bytes)
- | sign (1) | exponent (8) | fraction (23) |
- $(-1)^{\text{sign}} \times(1$.fraction bits $) \times 2^{\text{exponent}-127}$
- float16: 2 bytes, but small values underflow.
- bfloat16: same memory as float16, but dynamic range of float32. more exponent,less fraction.
- fp8: supported by H100 (E4M3, E5M2 variants).
- mixed_precision_training (float32 for attention/optimizer state, bf16 for feed forward)
Tensor Operations #
Tensors = pointers into allocated memory with metadata describing how to get to any element of the tensor.
Example:
x.stride(0) # steps to move across rows x.stride(1) # steps to move across cols
Slicing & Views #
- Views share storage, mutations reflect across them.
- Non-contiguous views must be made contiguous before reshaping.
Matrix Multiplication #
- FLOPs for matmul: $\text{FLOPs} = 2 \times m \times n \times p$
Einops #
Provides dimension-aware tensor operations.
Examples:
z = einsum(x, y, "b s h, b t h -> b s t") y = reduce(x, "... h -> ...", "sum") # or mean, max, min x = rearrange(x, "... (heads hidden1) -> ... heads hidden1", heads=2) # @inspect x x = rearrange(x, "... heads hidden2 -> ... (heads hidden2)") # @inspect x
jaxtyping:
#Old way: x = torch.ones(2, 2, 1, 3) # batch seq heads hidden @inspect x #New (jaxtyping) way: x: Float[torch.Tensor, "batch seq heads hidden"] = torch.ones(2, 2, 1, 3) # @inspect x
Compute Accounting #
FLOPs #
Definitions:
- FLOPs = number of floating-point ops.
- FLOP/s = FLOPs per second (to measure hardware throughput).
- depend on hardware (H100»A100) and data types (bfloat16 » float32)
- A100 has a peak performance of 312 teraFLOP/s (312e12)
Elementwise: O(mn); Addition of two matrices: mn; Matmul: 2mnp
Training Example #
Shapes: $X\in\mathbb{R}^{B\times D},; W_1\in\mathbb{R}^{D\times D},; W_2\in\mathbb{R}^{D\times K}$
$H_1 = XW_1\in\mathbb{R}^{B\times D},; H_2 = H_1W_2\in\mathbb{R}^{B\times K}$
$L=\text{mean}(H_2^2)=\frac{1}{BK}\sum_{i,k}H_{2,ik}^2$Forward FLOPs
- $H_1 = XW_1$: $(B\times D)\cdot(D\times D)\Rightarrow 2BD^2$
- $H_2 = H_1W_2$: $(B\times D)\cdot(D\times K)\Rightarrow 2BDK$
- $L=\mathrm{mean}(H_2^2)$: elementwise square and reduce
- squares: $BK$ mults; sum: $(BK-1)$ adds (≈ $BK$)
- scale by $1/(BK)$: ~1 mult (negligible) So ≈ $2BK$ FLOPs (tiny vs matmuls; often omitted).
Forward total (dominant terms):
$$ \boxed{2BD^2 + 2BDK} $$Backward—derivatives and FLOPs
For $L=\text{mean}(H_2^2)$, $G_2:=\nabla_{H_2}L = \frac{2}{BK}H_2$ (elementwise). Cost: $BK$ mults (tiny).
Layer 2 ($H_2 = H_1 W_2$)
- Gradient w.r.t. weights: $\nabla_{W_2}L = H_1^\top G_2$: $(D\times B)\cdot(B\times K)\Rightarrow \mathbf{2DBK}$
- Backprop to activation: $G_1 = G_2 W_2^\top$: $(B\times K)\cdot(K\times D)\Rightarrow \mathbf{2BDK}$
- Layer-2 backward subtotal: $2DBK + 2BDK = 4BDK$
Layer 1 ($H_1 = X W_1$)
Gradient w.r.t. weights: $\nabla_{W_1}L = X^\top G_1:$ $(D\times B)\cdot(B\times D)\Rightarrow \mathbf{2DBD = 2BD^2}$
(Optional) backprop to input: $G_X = G_1 W_1^\top$ $(B\times D)\cdot(D\times D)\Rightarrow \mathbf{2BD^2}$
Layer-1 backward subtotal (including $G_X$): $2BD^2 + 2BD^2 = 4BD^2$
Backward total (dominant terms):
$$ \boxed{4BD^2 + 4BDK}\quad(\text{+~}BK\ \text{for }G_2\text{ elementwise}) $$Combined forward + backward
$$ \boxed{\text{Total} \;\approx\; 6BD^2 + 6BDK} $$
Forward pass (matrix multiplication):
$$2 \times (\text{\#tokens}) \times (\text{\#parameters})$$Backward pass:
$$ 4 \times (\#\text{tokens}) \times (\#\text{parameters}) $$Total:
$$ 6 \times (\#\text{tokens}) \times (\#\text{parameters}) $$
Model FLOPs Utilization (MFU) #
Definition:
$$ \text{MFU} = \frac{\text{actual FLOP/s}}{\text{promised FLOP/s}} $$Good MFU ≈ 0.5 or higher.
Parameter Initialization #
x = nn.Parameter(torch.randn(input_dim))
output = x @ w # @inspect output
assert output.size() == torch.Size([output_dim])
# Each element of output scales as sqrt(input_dim): ~18.9
# Large values can cause gradients to blow up and make training unstable.
We want an initialization that is invariant to input_dim
.
To do that, we simply rescale by 1/sqrt(input_dim)
To be extra safe, we can truncate the normal distribution to [-3, 3]
to avoid any chance of outliers:
w = nn.Parameter(
nn.init.trunc_normal_(
torch.empty(input_dim, output_dim), std=1 / np.sqrt(input_dim), a=-3,b=3
)
)