Kernels, Triton

Profiling, Writing GPU Kernels, and PyTorch compilation #

1. Benchmarking and Profiling #

  1. Benchmarking
    • Measure wall-clock time (end-to-end). Useful for comparing implementations and studying scaling.
    • Example code
def benchmark(description: str, run: Callable, num_warmups: int = 1, num_trials: int = 3):
    """Benchmark `func` by running it `num_trials`, and return all the times."""
    # Warmup: first times might be slower due to compilation, things not cached.
    # Since we will run the kernel multiple times, the timing that matters is steady state.
    for _ in range(num_warmups):
        run()
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)
    # Time it for real now
    times: list[float] = [] # @inspect times, @inspect description
    for trial in range(num_trials):  # Do it multiple times to capture variance
        start_time = time.time()
        run()  # Actually perform computation
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)
        end_time = time.time()
        times.append((end_time - start_time) * 1000) # @inspect times
    mean_time = mean(times) # @inspect mean_time
    return mean_time
  • Example: Matrix Multiplication
dims = (1024, 2048, 4096, 8192)
results = []
for dim in dims:
    result = benchmark(f"matmul(dim={dim})",
                       run_operation2(dim=dim, operation=lambda a, b: a @ b))
    results.append((dim, result))
  1. Profiling
    • where time is spent.
    • PyTorch Profiler: link
    • Example
def profile(description: str, run: Callable, num_warmups: int = 1, with_stack: bool = False):
    # Warmup
    for _ in range(num_warmups):
        run()
    if torch.cuda.is_available():
        torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)
    # Run the code with the profiler
    with torch.profiler.profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            # Output stack trace for visualization
            with_stack=with_stack,
            # Needed to export stack trace for visualization
            experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
        run()
        if torch.cuda.is_available():
            torch.cuda.synchronize()  # Wait for CUDA threads to finish (important!)
    # Print out table
    table = prof.key_averages().table(sort_by="cuda_time_total",
                                      max_name_column_width=80,
                                      row_limit=10)
    #text(f"## {description}")
    #text(table, verbatim=True)
    # Write stack trace visualization
    if with_stack:
        text_path = f"var/stacks_{description}.txt"
        svg_path = f"var/stacks_{description}.svg"
        prof.export_stacks(text_path, "self_cuda_time_total")
    return table
add_fn = lambda a, b: a + b
profile("add", run_operation2(2048, add_fn))

Profiler shows:

  • CUDA kernels launched (aten::add, aten::matmul, etc.)
  • CPU vs CUDA time
  • Kernel names (e.g., cutlass_80_simt_sgemm_256x128)

Observation:

  • Different input shapes → different CUDA kernels.
  • Kernel names encode tiling strategies.
  • Check Nividia Nsight Profiling.

2. Kernel Fusion Motivation #

  • Warehouse (DRAM) ↔ slow memory.
  • Factory (SRAM) ↔ fast shared memory.

Fusing avoids repeated read/write to DRAM.

  1. CUDA Execution Hierarchy
  • Grid: A collection of blocks.

    • Example (2D grid):

      gridDim = (2, 4)   // 2 rows, 4 columns of blocks → 8 blocks total
      
    • Each block is identified by its block index: blockIdx.

  • Block: A collection of threads (like a container).

    • Example (block size):

      blockDim = (1, 8)   // 1 row, 8 columns of threads per block
      
    • Each block knows its position in the grid: blockIdx.x, blockIdx.y.

  • Thread: The smallest unit of execution.

    • Each thread knows its position inside the block: threadIdx.x, threadIdx.y.

    • Example:

      blockIdx = (0, 1)   // 2nd block in the first row
      threadIdx = (0, 3)  // 4th thread inside this block
      
  • Global thread index (convert 2-level indices into one linear index):

    int i = blockIdx.x * blockDim.x + threadIdx.x;
    
  1. Example: GeLU in CUDA
    1. To catch errors immediately, run synchronously:

      import os
      os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
      
    2. File Layout

      There are two ways to organize:

      (A) Standard extension workflow (with .cpp + .cu files)

      my_extension/
      ├── gelu.cu   # CUDA kernel + C++ function definition
      ├── gelu.cpp  # C++ declaration + Python binding
      └── setup.py  # build script
      

      (B) Quick prototyping with load_inline

      • No files needed; provide CUDA + C++ source strings directly from Python.

      • The logic is the same — you still need:

      • CUDA source with the definition.

      • C++ source with the declaration.

    3. CUDA Implementation (gelu.cu)

      #include <math.h>
      #include <torch/extension.h>
      #include <c10/cuda/CUDAException.h>
      
      // CUDA kernel
      __global__ void gelu_kernel(float* in, float* out, int num_elements) {
          int i = blockIdx.x * blockDim.x + threadIdx.x;
      
          if (i < num_elements) {
              out[i] = 0.5 * in[i] *
                      (1.0 + tanh(0.79788456 * (in[i] + 0.044715 * in[i] * in[i] * in[i])));
          }
      }
      
      // Utility: ceil division
      inline unsigned int cdiv(unsigned int a, unsigned int b) {
          return (a + b - 1) / b;
      }
      
      // C++ function definition/ wrapper
      torch::Tensor gelu(torch::Tensor x) {
          TORCH_CHECK(x.device().is_cuda());
          TORCH_CHECK(x.is_contiguous());
      
          torch::Tensor y = torch::empty_like(x);
      
          int num_elements = x.numel();
          int block_size = 1024;
          int num_blocks = cdiv(num_elements, block_size);
      
          gelu_kernel<<<num_blocks, block_size>>>(x.data_ptr<float>(),
                                                  y.data_ptr<float>(),
                                                  num_elements);
      
          C10_CUDA_KERNEL_LAUNCH_CHECK();
          return y;
      }
      
    4. C++ Declaration + Binding (gelu.cpp)

      #include <torch/extension.h>
      
      // Forward declaration
      torch::Tensor gelu(torch::Tensor x);
      
      // Bindings
      PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
          m.def("gelu", &gelu, "GELU activation (CUDA)");
      }
      
    5. Build Options

      • (A) Standard build (setup.py)

        from setuptools import setup
        from torch.utils.cpp_extension import BuildExtension, CUDAExtension
        
        setup(
            name="gelu_extension",
            ext_modules=[
                CUDAExtension(
                    name="gelu_extension",
                    sources=["gelu.cpp", "gelu.cu"],
                )
            ],
            cmdclass={"build_ext": BuildExtension}
        )
        

        Build and import:

        python setup.py install
        
        import gelu_extension
        y = gelu_extension.gelu(x)
        
      • (B) Inline build (quick prototyping in Python)

        from torch.utils.cpp_extension import load_inline
        import torch, os
        
        def build_cuda_gelu():
            cuda_gelu_src = open("gelu.cu").read()
            cpp_gelu_src = "torch::Tensor gelu(torch::Tensor x);"
        
            os.makedirs("var/cuda_gelu", exist_ok=True)
        
            module = load_inline(
                cuda_sources=[cuda_gelu_src],
                cpp_sources=[cpp_gelu_src],
                functions=["gelu"],
                name="inline_gelu",
                build_directory="var/cuda_gelu",
                extra_cflags=["-O2"],
                verbose=True,
            )
        
            return getattr(module, "gelu")
        
        # Usage
        x = torch.randn(10, device="cuda")
        gelu_fn = build_cuda_gelu()
        y = gelu_fn(x)
        

Performance:

  • Faster than manual, slower than PyTorch fused.

TBD