Profiling, Writing GPU Kernels, and PyTorch compilation #
1. Benchmarking and Profiling #
- Benchmarking
- Measure wall-clock time (end-to-end). Useful for comparing implementations and studying scaling.
- Example code
def benchmark(description: str, run: Callable, num_warmups: int = 1, num_trials: int = 3):
"""Benchmark `func` by running it `num_trials`, and return all the times."""
# Warmup: first times might be slower due to compilation, things not cached.
# Since we will run the kernel multiple times, the timing that matters is steady state.
for _ in range(num_warmups):
run()
if torch.cuda.is_available():
torch.cuda.synchronize() # Wait for CUDA threads to finish (important!)
# Time it for real now
times: list[float] = [] # @inspect times, @inspect description
for trial in range(num_trials): # Do it multiple times to capture variance
start_time = time.time()
run() # Actually perform computation
if torch.cuda.is_available():
torch.cuda.synchronize() # Wait for CUDA threads to finish (important!)
end_time = time.time()
times.append((end_time - start_time) * 1000) # @inspect times
mean_time = mean(times) # @inspect mean_time
return mean_time
- Example: Matrix Multiplication
dims = (1024, 2048, 4096, 8192)
results = []
for dim in dims:
result = benchmark(f"matmul(dim={dim})",
run_operation2(dim=dim, operation=lambda a, b: a @ b))
results.append((dim, result))
- Profiling
- where time is spent.
- PyTorch Profiler: link
- Example
def profile(description: str, run: Callable, num_warmups: int = 1, with_stack: bool = False):
# Warmup
for _ in range(num_warmups):
run()
if torch.cuda.is_available():
torch.cuda.synchronize() # Wait for CUDA threads to finish (important!)
# Run the code with the profiler
with torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# Output stack trace for visualization
with_stack=with_stack,
# Needed to export stack trace for visualization
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
run()
if torch.cuda.is_available():
torch.cuda.synchronize() # Wait for CUDA threads to finish (important!)
# Print out table
table = prof.key_averages().table(sort_by="cuda_time_total",
max_name_column_width=80,
row_limit=10)
#text(f"## {description}")
#text(table, verbatim=True)
# Write stack trace visualization
if with_stack:
text_path = f"var/stacks_{description}.txt"
svg_path = f"var/stacks_{description}.svg"
prof.export_stacks(text_path, "self_cuda_time_total")
return table
add_fn = lambda a, b: a + b
profile("add", run_operation2(2048, add_fn))
Profiler shows:
- CUDA kernels launched (
aten::add,aten::matmul, etc.) - CPU vs CUDA time
- Kernel names (e.g.,
cutlass_80_simt_sgemm_256x128)
Observation:
- Different input shapes → different CUDA kernels.
- Kernel names encode tiling strategies.
- Check Nividia Nsight Profiling.
2. Kernel Fusion —— CUDA version #
- Warehouse (DRAM) ↔ slow memory.
- Factory (SRAM) ↔ fast shared memory.
Fusing avoids repeated read/write to DRAM.
- CUDA Execution Hierarchy
Grid: A collection of blocks.
Example (2D grid):
gridDim = (2, 4) // 2 rows, 4 columns of blocks -> 8 blocks totalEach block is identified by its block index:
blockIdx.
Block: A collection of threads (like a container).
Example (block size):
blockDim = (1, 8) // 1 row, 8 columns of threads per blockEach block knows its position in the grid:
blockIdx.x,blockIdx.y.
Thread: The smallest unit of execution.
Each thread knows its position inside the block:
threadIdx.x,threadIdx.y.Example:
blockIdx = (0, 1) // 2nd block in the first row threadIdx = (0, 3) // 4th thread inside this block
Global thread index (convert 2-level indices into one linear index):
int i = blockIdx.x * blockDim.x + threadIdx.x;
- Example: GeLU in CUDA
To catch errors immediately, run synchronously:
import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1"File Layout
There are two ways to organize:
(A) Standard extension workflow (with
.cpp+.cufiles)my_extension/ ├── gelu.cu # CUDA kernel + C++ function definition ├── gelu.cpp # C++ declaration + Python binding └── setup.py # build script(B) Quick prototyping with
load_inlineNo files needed; provide CUDA + C++ source strings directly from Python.
The logic is the same — you still need:
CUDA source with the definition.
C++ source with the declaration.
CUDA Implementation (
gelu.cu)#include <math.h> #include <torch/extension.h> #include <c10/cuda/CUDAException.h> // CUDA kernel __global__ void gelu_kernel(float* in, float* out, int num_elements) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < num_elements) { out[i] = 0.5 * in[i] * (1.0 + tanh(0.79788456 * (in[i] + 0.044715 * in[i] * in[i] * in[i]))); } } // Utility: ceil division inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b; } // C++ function definition/ wrapper torch::Tensor gelu(torch::Tensor x) { TORCH_CHECK(x.device().is_cuda()); TORCH_CHECK(x.is_contiguous()); torch::Tensor y = torch::empty_like(x); int num_elements = x.numel(); int block_size = 1024; int num_blocks = cdiv(num_elements, block_size); gelu_kernel<<<num_blocks, block_size>>>(x.data_ptr<float>(), y.data_ptr<float>(), num_elements); C10_CUDA_KERNEL_LAUNCH_CHECK(); return y; }C++ Declaration + Binding (
gelu.cpp)#include <torch/extension.h> // Forward declaration torch::Tensor gelu(torch::Tensor x); // Bindings PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gelu", &gelu, "GELU activation (CUDA)"); }Build Options
(A) Standard build (
setup.py)from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name="gelu_extension", ext_modules=[ CUDAExtension( name="gelu_extension", sources=["gelu.cpp", "gelu.cu"], ) ], cmdclass={"build_ext": BuildExtension} )Build and import:
python setup.py installimport gelu_extension y = gelu_extension.gelu(x)(B) Inline build (quick prototyping in Python)
from torch.utils.cpp_extension import load_inline import torch, os def build_cuda_gelu(): cuda_gelu_src = open("gelu.cu").read() cpp_gelu_src = "torch::Tensor gelu(torch::Tensor x);" os.makedirs("var/cuda_gelu", exist_ok=True) module = load_inline( cuda_sources=[cuda_gelu_src], cpp_sources=[cpp_gelu_src], functions=["gelu"], name="inline_gelu", build_directory="var/cuda_gelu", extra_cflags=["-O2"], verbose=True, ) return getattr(module, "gelu") # Usage x = torch.randn(10, device="cuda") gelu_fn = build_cuda_gelu() y = gelu_fn(x)
Performance:
- Faster than manual, slower than PyTorch fused.
3. Triton Kernels (in Python) #
- High-level GPU programming (in Python).
- Handles memory coalescing, shared memory, scheduling automatically.
gelu example:
- kernel function
import triton.language as tl @triton.jit def triton_gelu_kernel(x_ptr, y_ptr, num_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) # program (block) id along axis 0 block_start = pid * BLOCK_SIZE # first element handled by this block offsets = block_start + tl.arange(0, BLOCK_SIZE) # vector of element indices mask = offsets < num_elements # mask to avoid out-of-bound x = tl.load(x_ptr + offsets, mask=mask) a = 0.79788456 * (x + 0.044715 * x * x * x) e = tl.exp(2 * a) tanh_approx = (e - 1) / (e + 1) y = 0.5 * x * (1 + tanh_approx) tl.store(y_ptr + offsets, y, mask=mask) - python wrapper
def triton_gelu(x: torch.Tensor, block_size: int = 1024) -> torch.Tensor: assert x.is_cuda assert x.is_contiguous() y = torch.empty_like(x) n = x.numel() grid = (triton.cdiv(n, block_size),) triton_gelu_kernel[grid](x, y, n, BLOCK_SIZE=block_size) return y - PTX inspection
PTX stands for Parallel Thread Execution.
It’s NVIDIA’s intermediate assembly language for CUDA GPUs.
def print_ptx_main(block_size: int = 1024): compiled = triton.compile( triton_gelu_kernel, signature="*fp32,*fp32,i32", device="cuda", constants={"BLOCK_SIZE": block_size}, num_warps=4, ) print(compiled.asm["ptx"]) - Correctness check
def check_equal(f_triton, f_ref, atol=1e-6, rtol=1e-4): x = torch.randn(8192, device="cuda", dtype=torch.float32) y_triton = f_triton(x) y_ref = f_ref(x) print(torch.allclose(y_triton, y_ref, atol=atol, rtol=rtol))
4. PyTorch Compilation (torch.compile)
#
- You can just write Python and let the compiler generate optimized kernels.
- Example:
compiled_gelu = torch.compile(manual_gelu)
Profiler shows it fuses into a Triton kernel.
5. Summary #
- Explored 5 ways to write a function:
- Manual Python
- PyTorch built-in
- PyTorch
torch.compile - Custom CUDA
- Triton
| Method | Time (ms) |
|---|---|
| manual_time | 8.099 |
| pytorch_time | 1.109 |
| cuda_time | 1.816 |
| triton_time | 1.727 |
| compiled_time | 1.470 |